# reshape positional embedding to accomodate for image resolution change
pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder)
state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped
if not args.evaluate:
if config['distill']:
m_pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],model.visual_encoder_m)
state_dict['visual_encoder_m.pos_embed'] = m_pos_embed_reshaped
for key in list(state_dict.keys()):
if 'bert' in key:
encoder_key = key.replace('bert.','')
state_dict[encoder_key] = state_dict[key]
# intialize text decoder as multimodal encoder (last 6 layers of model.text_encoder)
if 'text_encoder' in key:
if 'layer' in key:
encoder_keys = key.split('.')
**layer_num = int(encoder_keys[4])**
if layer_num<6:
del state_dict[key]
continue
else:
decoder_layer_num = (layer_num-6)
encoder_keys[4] = str(decoder_layer_num)
encoder_key = '.'.join(encoder_keys)
else:
encoder_key = key
decoder_key = encoder_key.replace('text_encoder','text_decoder')
state_dict[decoder_key] = state_dict[key]
del state_dict[key] `
`if args.checkpoint:
checkpoint = torch.load(args.checkpoint, map_location='cpu') state_dict = checkpoint['model']