def blvnet_tam_backbone(depth, alpha, beta, num_frames, blending_frames=3, input_channels=3,
imagenet_blnet_pretrained=True):
layers = {
50: [3, 4, 6, 3],
101: [4, 8, 18, 3],
152: [5, 12, 30, 3]
}[depth]
model = bLVNet_TAM_BACKBONE(Bottleneck, layers, alpha, beta, num_frames,
blending_frames=blending_frames, input_channels=input_channels)
if imagenet_blnet_pretrained:
checkpoint = torch.load(model_urls['blresnet{}'.format(depth)], map_location='cpu')
print("loading weights from ImageNet-pretrained blnet, blresnet{}".format(depth),
flush=True)
state_d = OrderedDict()
if input_channels != 3:
print("Convert RGB model to Flow")
for key, value in checkpoint['state_dict'].items():
new_key = key.replace('module.', '')
if "conv1.weight" in key:
o_c, in_c, k_h, k_w = value.shape
else:
o_c, in_c, k_h, k_w = 0, 0, 0, 0
if k_h == 7 and k_w == 7:
new_shape = (o_c, input_channels, k_h, k_w)
new_value = value.mean(dim=1, keepdim=True).expand(new_shape).contiguous()