shenghanlin / SeismicFoundationModel

Seismic Foundation Model
59 stars 12 forks source link

Get Embedding from pretrained model #3

Closed ramdhan1989 closed 5 months ago

ramdhan1989 commented 7 months ago

Hi, Thank you for your repo. do you have example code on how to get embedding using pretrained model? I have error incompatibilities when loading the weights. `import models_mae device = torch.device('cpu')

model = models_mae.dict'mae_vit_large_patch16' model.to(device)

import pathlib from pathlib import Path pathlib.PosixPath = pathlib.WindowsPath

checkpoint = torch.load('/output_dir/SFM-Large.pth', map_location='cpu')

checkpoint_model = checkpoint['model'] state_dict = model.state_dict() for k in ['head.weight', 'head.bias']: if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape: print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k]

from util.pos_embed import interpolate_pos_embed interpolate_pos_embed(model, checkpoint_model) msg = model.load_state_dict(checkpoint_model, strict=False) print(msg)`

this is the error `--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[37], line 23 21 from util.pos_embed import interpolate_pos_embed 22 interpolate_pos_embed(model, checkpoint_model) ---> 23 msg = model.load_state_dict(checkpoint_model, strict=False) 24 print(msg)

File ~\anaconda3\envs\py39\lib\site-packages\torch\nn\modules\module.py:1223, in Module.load_state_dict(self, state_dict, strict) 1218 error_msgs.insert( 1219 0, 'Missing key(s) in state_dict: {}. '.format( 1220 ', '.join('"{}"'.format(k) for k in missing_keys))) 1222 if len(error_msgs) > 0: -> 1223 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 1224 self.class.name, "\n\t".join(error_msgs))) 1225 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for MaskedAutoencoderViT: size mismatch for mask_token: copying a param with shape torch.Size([1, 1, 256]) from checkpoint, the shape in current model is torch.Size([1, 1, 512]). size mismatch for decoder_pos_embed: copying a param with shape torch.Size([1, 197, 256]) from checkpoint, the shape in current model is torch.Size([1, 197, 512]). size mismatch for decoder_embed.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([512, 1024]). size mismatch for decoder_embed.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.0.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.0.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.0.attn.qkv.weight: copying a param with shape torch.Size([768, 256]) from checkpoint, the shape in current model is torch.Size([1536, 512]). size mismatch for decoder_blocks.0.attn.qkv.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1536]). size mismatch for decoder_blocks.0.attn.proj.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for decoder_blocks.0.attn.proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.0.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.0.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.0.mlp.fc1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([2048, 512]). size mismatch for decoder_blocks.0.mlp.fc1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([2048]). size mismatch for decoder_blocks.0.mlp.fc2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([512, 2048]). size mismatch for decoder_blocks.0.mlp.fc2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.1.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.1.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.1.attn.qkv.weight: copying a param with shape torch.Size([768, 256]) from checkpoint, the shape in current model is torch.Size([1536, 512]). size mismatch for decoder_blocks.1.attn.qkv.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1536]). size mismatch for decoder_blocks.1.attn.proj.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for decoder_blocks.1.attn.proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.1.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.1.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.1.mlp.fc1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([2048, 512]). size mismatch for decoder_blocks.1.mlp.fc1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([2048]). size mismatch for decoder_blocks.1.mlp.fc2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([512, 2048]). size mismatch for decoder_blocks.1.mlp.fc2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.2.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.2.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.2.attn.qkv.weight: copying a param with shape torch.Size([768, 256]) from checkpoint, the shape in current model is torch.Size([1536, 512]). size mismatch for decoder_blocks.2.attn.qkv.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1536]). size mismatch for decoder_blocks.2.attn.proj.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for decoder_blocks.2.attn.proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.2.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.2.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.2.mlp.fc1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([2048, 512]). size mismatch for decoder_blocks.2.mlp.fc1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([2048]). size mismatch for decoder_blocks.2.mlp.fc2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([512, 2048]). size mismatch for decoder_blocks.2.mlp.fc2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.3.norm1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.3.norm1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.3.attn.qkv.weight: copying a param with shape torch.Size([768, 256]) from checkpoint, the shape in current model is torch.Size([1536, 512]). size mismatch for decoder_blocks.3.attn.qkv.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([1536]). size mismatch for decoder_blocks.3.attn.proj.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([512, 512]). size mismatch for decoder_blocks.3.attn.proj.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.3.norm2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.3.norm2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_blocks.3.mlp.fc1.weight: copying a param with shape torch.Size([1024, 256]) from checkpoint, the shape in current model is torch.Size([2048, 512]). size mismatch for decoder_blocks.3.mlp.fc1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([2048]). size mismatch for decoder_blocks.3.mlp.fc2.weight: copying a param with shape torch.Size([256, 1024]) from checkpoint, the shape in current model is torch.Size([512, 2048]). size mismatch for decoder_blocks.3.mlp.fc2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_norm.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_norm.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([512]). size mismatch for decoder_pred.weight: copying a param with shape torch.Size([256, 256]) from checkpoint, the shape in current model is torch.Size([256, 512]).` Thanks Regards

shenghanlin commented 5 months ago

It seems like the model you're attempting to load is the "Large-512" version.We will check if there has been a naming error. You can try loading another version of the large model. Regarding the acquisition of embeddings, we can soon write a Jupyter notebook or a Python program to facilitate obtaining them.