a-r-j / ProteinWorkshop

Benchmarking framework for protein representation learning. Includes a large number of pre-training and downstream task datasets, models and training/task utilities. (ICLR 2024)
https://proteins.sh/
MIT License
194 stars 16 forks source link

Mismatch of dimensions for PTM data #61

Closed asaksager closed 9 months ago

asaksager commented 10 months ago

Hi Thanks for your cool workshop!

I have been trying to run the PTM data, as a test, before I run on my own data. After I have trained I would like to extract the model and evaluate with some of the metrics. I can see that the finetune has a test/evaluation setting - however I get an error when loading the checkpoint file that the dimensions doesnt match.

What could cause this? Is there a way to test/evaluate without finetuning the model?

workshop train dataset=ptm encoder=schnet task=multiclass_node_classification trainer=gpu env.paths.data=/dtu/blackhole/17/126583/Post env.paths.output_dir=/dtu/blackhole/17/126583/Post/output_test

(OBS i am using epoch_000.ckpt as the training didnt finish - I ran out of memory). But that should not be the problem right?

workshop finetune dataset=ptm encoder=schnet task=multiclass_node_classification trainer=gpu env.paths.data=/dtu/blackhole/17/126583/Post env.paths.output_dir=/dtu/blackhole/17/126583/Post/output_test ckpt_path=/dtu/blackhole/17/126583/Post/output_test/checkpoints/epoch_000.ckpt

The following is printed from the finetune after the config tree:

DEBUG    Requested GPUs: None.                                                                                                                                                                                                                 config.py:248
                    WARNING  You are not using early stopping.                                                                                                                                                                                                     config.py:164
Seed set to 52
                    INFO     Instantiating datamodule:...                                                                                                                                                                                                         finetune.py:34
                    INFO     Instantiating model:...                                                                                                                                                                                                              finetune.py:39
[2023-11-29 15:36:26,347][torch.distributed.nn.jit.instantiator][INFO] - Created a temporary directory at /tmp/tmpn530817v
[2023-11-29 15:36:26,347][torch.distributed.nn.jit.instantiator][INFO] - Writing /tmp/tmpn530817v/_remote_module_non_scriptable.py
                    INFO     Instantiating encoder...                                                                                                                                                                                                                base.py:407
                    INFO     SchNetModel(hidden_channels=512, num_filters=128, num_interactions=6, num_gaussians=50, cutoff=10.0)                                                                                                                                    base.py:409
                    INFO     Instantiating decoders...                                                                                                                                                                                                               base.py:411
                    INFO     Building node_label decoder. Output dim 13                                                                                                                                                                                              base.py:245
                    INFO     {'_target_': 'proteinworkshop.models.decoders.mlp_decoder.MLPDecoder', 'hidden_dim': [128, 128], 'dropout': 0.0, 'activations': ['relu', 'relu', 'none'], 'skip': 'concat', 'out_dim': '${dataset.num_classes}', 'input':               base.py:248
                             'node_embedding'}                                                                                                                                                                                                                                  
                    INFO     Using skip connection in decoder.                                                                                                                                                                                                mlp_decoder.py:123
                    INFO     ModuleDict(                                                                                                                                                                                                                             base.py:413
                               (node_label): MLPDecoder(                                                                                                                                                                                                                        
                                 (layers): LinearSkipBlock(                                                                                                                                                                                                                     
                                   (layers): ModuleList(                                                                                                                                                                                                                        
                                     (0-1): 2 x LazyLinear(in_features=0, out_features=128, bias=True)                                                                                                                                                                          
                                     (2): LazyLinear(in_features=0, out_features=13, bias=True)                                                                                                                                                                                 
                                   )                                                                                                                                                                                                                                            
                                   (activations): ModuleList(                                                                                                                                                                                                                   
                                     (0-1): 2 x ReLU()                                                                                                                                                                                                                          
                                     (2): Identity()                                                                                                                                                                                                                            
                                   )                                                                                                                                                                                                                                            
                                   (dropout_layers): ModuleList(                                                                                                                                                                                                                
                                     (0-1): 2 x Dropout(p=0.0, inplace=False)                                                                                                                                                                                                   
                                   )                                                                                                                                                                                                                                            
                                 )                                                                                                                                                                                                                                              
                               )                                                                                                                                                                                                                                                
                             )                                                                                                                                                                                                                                                  
                    INFO     Instantiating losses...                                                                                                                                                                                                                 base.py:415
                    INFO     Using losses: {'node_label': CrossEntropyLoss()}                                                                                                                                                                                        base.py:417
                    INFO     Not using aux loss scaling                                                                                                                                                                                                              base.py:424
                    INFO     Configuring metrics...                                                                                                                                                                                                                  base.py:426
                    INFO     ['accuracy', 'f1_score', 'f1_max']                                                                                                                                                                                                      base.py:428
                    INFO     Instantiating featuriser...                                                                                                                                                                                                             base.py:430
                    INFO     ProteinFeaturiser(representation=CA, scalar_node_features=['amino_acid_one_hot'], vector_node_features=[], edge_types=['knn_16'], scalar_edge_features=['edge_distance'], vector_edge_features=[])                                      base.py:432
                    INFO     Instantiating task transform...                                                                                                                                                                                                         base.py:434
                    INFO     None                                                                                                                                                                                                                                    base.py:438
[11/29/23 15:36:53] INFO     Instantiating callbacks...                                                                                                                                                                                                           finetune.py:42
                    INFO     Instantiating callback <lightning.pytorch.callbacks.ModelCheckpoint>                                                                                                                                                                callbacks.py:31
                    INFO     Instantiating callback <lightning.pytorch.callbacks.EarlyStopping>                                                                                                                                                                  callbacks.py:31
                    INFO     Instantiating callback <lightning.pytorch.callbacks.RichModelSummary>                                                                                                                                                               callbacks.py:31
                    INFO     Instantiating callback <lightning.pytorch.callbacks.RichProgressBar>                                                                                                                                                                callbacks.py:31
                    INFO     Instantiating callback <lightning.pytorch.callbacks.LearningRateMonitor>                                                                                                                                                            callbacks.py:31
                    INFO     Instantiating callback <lightning.pytorch.callbacks.EarlyStopping>                                                                                                                                                                  callbacks.py:31
                    INFO     Instantiating loggers:...                                                                                                                                                                                                            finetune.py:47
                    INFO     Instantiating logger <lightning.pytorch.loggers.csv_logs.CSVLogger>                                                                                                                                                                   loggers.py:31
                    INFO     Instantiating trainer <lightning.pytorch.trainer.Trainer>                                                                                                                                                                            finetune.py:50
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[11/29/23 15:36:54] INFO     Initializing lazy layers...                                                                                                                                                                                                          finetune.py:59
[11/29/23 15:36:55] INFO     Found 43903 examples in train                                                                                                                                                                                                            ptm.py:223
[11/29/23 15:36:57] INFO     Downloading 789 PDBs...                                                                                                                                                                                                                  ptm.py:172
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 789/789 [00:08<00:00, 93.15it/s]
[11/29/23 15:37:06] INFO     Unavailable structures: 793                                                                                                                                                                                                              ptm.py:148
                    INFO     Found 2511 examples in test                                                                                                                                                                                                              ptm.py:223
                    INFO     Downloading 64 PDBs...                                                                                                                                                                                                                   ptm.py:172
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 64/64 [00:03<00:00, 18.72it/s]
[11/29/23 15:37:10] INFO     Unavailable structures: 857                                                                                                                                                                                                              ptm.py:148
                    INFO     Found 2393 examples in val                                                                                                                                                                                                               ptm.py:223
                    INFO     Downloading 40 PDBs...                                                                                                                                                                                                                   ptm.py:172
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:03<00:00, 11.38it/s]
[11/29/23 15:37:14] INFO     Unavailable structures: 897                                                                                                                                                                                                              ptm.py:148
                    INFO     Caching unavailable structure list to /dtu/blackhole/17/126583/Post/PostTranslationalModification/ptm_13/unavailable_structures.txt                                                                                                      ptm.py:154
                    INFO     Found 2353 examples in val                                                                                                                                                                                                               ptm.py:223
2353it [00:00, 50776.10it/s]
2353
                    INFO     All structures already processed and overwrite=False. Skipping download.                                                                                                                                                                base.py:304
                    INFO     All structures already processed and overwrite=False. Skipping download.                                                                                                                                                                base.py:335
[11/29/23 15:37:16] INFO     Unfeaturized batch: DataProteinBatch(fill_value=1e-05, atom_list=[37], coords=[18125, 37, 3], residues=[32], id=[32], residue_id=[32], residue_type=[18125], chains=[18125], node_y=[18125, 13], x=[18125],                          finetune.py:63
                             amino_acid_one_hot=[18125, 23], batch=[18125], ptr=[33])                                                                                                                                                                                           
                    INFO     Featurized batch: DataProteinBatch(fill_value=1e-05, atom_list=[37], coords=[18125, 37, 3], residues=[32], id=[32], residue_id=[32], residue_type=[18125], chains=[18125], node_y=[18125, 13], x=[18125, 23],                        finetune.py:65
                             amino_acid_one_hot=[18125, 23], batch=[18125], ptr=[33], pos=[18125, 3], edge_index=[2, 290000], edge_type=[1, 290000], num_relation=1, edge_attr=[290000, 1])                                                                                     
[11/29/23 15:37:22] INFO     Model output: {'node_embedding': tensor[18125, 32] n=580000 (2.2Mb) x∈[-0.252, 0.204] μ=0.007 σ=0.071, 'graph_embedding': tensor[32, 32] n=1024 (4Kb) x∈[-193.717, 170.779] μ=4.195 σ=40.578, 'node_label': tensor[18125, 13]        finetune.py:67
                             n=235625 (0.9Mb) x∈[-0.137, 0.115] μ=0.002 σ=0.045}                                                                                                                                                                                                
                    INFO     Loading weights from checkpoint /dtu/blackhole/17/126583/Post/output_test/checkpoints/epoch_000.ckpt...                                                                                                                              finetune.py:72
                    INFO     Loading encoder weights: OrderedDict([('embedding.weight', tensor[512, 39] n=19968 (78Kb) x∈[-0.488, 0.417] μ=-0.000 σ=0.117 cuda:0), ('embedding.bias', tensor[512] 2Kb x∈[-0.186, 0.205] μ=0.005 σ=0.083 cuda:0),                  finetune.py:80
                             ('distance_expansion.offset', tensor[50] x∈[0., 10.000] μ=5.000 σ=2.975 cuda:0), ('interactions.0.mlp.0.weight', tensor[128, 50] n=6400 (25Kb) x∈[-0.270, 0.307] μ=0.001 σ=0.108 cuda:0), ('interactions.0.mlp.0.bias', tensor[128]                
                             x∈[-0.044, 0.043] μ=0.001 σ=0.018 cuda:0), ('interactions.0.mlp.2.weight', tensor[128, 128] n=16384 (64Kb) x∈[-0.289, 0.256] μ=-0.000 σ=0.089 cuda:0), ('interactions.0.mlp.2.bias', tensor[128] x∈[-0.027, 0.042] μ=0.002 σ=0.015                 
                             cuda:0), ('interactions.0.conv.lin1.weight', tensor[128, 512] n=65536 (0.2Mb) x∈[-0.231, 0.240] μ=0.000 σ=0.063 cuda:0), ('interactions.0.conv.lin2.weight', tensor[512, 128] n=65536 (0.2Mb) x∈[-0.185, 0.201] μ=0.000 σ=0.058                    
                             cuda:0), ('interactions.0.conv.lin2.bias', tensor[512] 2Kb x∈[-0.053, 0.046] μ=-0.002 σ=0.017 cuda:0), ('interactions.0.conv.nn.0.weight', tensor[128, 50] n=6400 (25Kb) x∈[-0.270, 0.307] μ=0.001 σ=0.108 cuda:0),                                
                             ('interactions.0.conv.nn.0.bias', tensor[128] x∈[-0.044, 0.043] μ=0.001 σ=0.018 cuda:0), ('interactions.0.conv.nn.2.weight', tensor[128, 128] n=16384 (64Kb) x∈[-0.289, 0.256] μ=-0.000 σ=0.089 cuda:0),                                           
                             ('interactions.0.conv.nn.2.bias', tensor[128] x∈[-0.027, 0.042] μ=0.002 σ=0.015 cuda:0), ('interactions.0.lin.weight', tensor[512, 512] n=262144 (1Mb) x∈[-0.183, 0.191] μ=2.693e-05 σ=0.052 cuda:0), ('interactions.0.lin.bias',                  
                             tensor[512] 2Kb x∈[-0.056, 0.064] μ=-0.001 σ=0.022 cuda:0), ('interactions.1.mlp.0.weight', tensor[128, 50] n=6400 (25Kb) x∈[-0.293, 0.286] μ=-0.001 σ=0.107 cuda:0), ('interactions.1.mlp.0.bias', tensor[128] x∈[-0.059, 0.052]                  
                             μ=-0.001 σ=0.021 cuda:0), ('interactions.1.mlp.2.weight', tensor[128, 128] n=16384 (64Kb) x∈[-0.263, 0.248] μ=-0.001 σ=0.090 cuda:0), ('interactions.1.mlp.2.bias', tensor[128] x∈[-0.040, 0.032] μ=0.001 σ=0.014 cuda:0),                         
                             ('interactions.1.conv.lin1.weight', tensor[128, 512] n=65536 (0.2Mb) x∈[-0.211, 0.222] μ=-0.000 σ=0.063 cuda:0), ('interactions.1.conv.lin2.weight', tensor[512, 128] n=65536 (0.2Mb) x∈[-0.205, 0.212] μ=-6.669e-05 σ=0.059                       
                             cuda:0), ('interactions.1.conv.lin2.bias', tensor[512] 2Kb x∈[-0.047, 0.040] μ=-0.002 σ=0.015 cuda:0), ('interactions.1.conv.nn.0.weight', tensor[128, 50] n=6400 (25Kb) x∈[-0.293, 0.286] μ=-0.001 σ=0.107 cuda:0),                               
                             ('interactions.1.conv.nn.0.bias', tensor[128] x∈[-0.059, 0.052] μ=-0.001 σ=0.021 cuda:0), ('interactions.1.conv.nn.2.weight', tensor[128, 128] n=16384 (64Kb) x∈[-0.263, 0.248] μ=-0.001 σ=0.090 cuda:0),                                          
                             ('interactions.1.conv.nn.2.bias', tensor[128] x∈[-0.040, 0.032] μ=0.001 σ=0.014 cuda:0), ('interactions.1.lin.weight', tensor[512, 512] n=262144 (1Mb) x∈[-0.175, 0.210] μ=4.499e-05 σ=0.051 cuda:0), ('interactions.1.lin.bias',                  
                             tensor[512] 2Kb x∈[-0.065, 0.061] μ=-0.001 σ=0.019 cuda:0), ('interactions.2.mlp.0.weight', tensor[128, 50] n=6400 (25Kb) x∈[-0.253, 0.307] μ=-0.000 σ=0.109 cuda:0), ('interactions.2.mlp.0.bias', tensor[128] x∈[-0.055, 0.060]                  
                             μ=0.004 σ=0.022 cuda:0), ('interactions.2.mlp.2.weight', tensor[128, 128] n=16384 (64Kb) x∈[-0.250, 0.266] μ=0.000 σ=0.090 cuda:0), ('interactions.2.mlp.2.bias', tensor[128] x∈[-0.033, 0.040] μ=0.001 σ=0.015 cuda:0),                           
                             ('interactions.2.conv.lin1.weight', tensor[128, 512] n=65536 (0.2Mb) x∈[-0.204, 0.224] μ=0.000 σ=0.062 cuda:0), ('interactions.2.conv.lin2.weight', tensor[512, 128] n=65536 (0.2Mb) x∈[-0.211, 0.190] μ=0.000 σ=0.058 cuda:0),                    
                             ('interactions.2.conv.lin2.bias', tensor[512] 2Kb x∈[-0.043, 0.037] μ=-0.003 σ=0.016 cuda:0), ('interactions.2.conv.nn.0.weight', tensor[128, 50] n=6400 (25Kb) x∈[-0.253, 0.307] μ=-0.000 σ=0.109 cuda:0),                                        
                             ('interactions.2.conv.nn.0.bias', tensor[128] x∈[-0.055, 0.060] μ=0.004 σ=0.022 cuda:0), ('interactions.2.conv.nn.2.weight', tensor[128, 128] n=16384 (64Kb) x∈[-0.250, 0.266] μ=0.000 σ=0.090 cuda:0),                                            
                             ('interactions.2.conv.nn.2.bias', tensor[128] x∈[-0.033, 0.040] μ=0.001 σ=0.015 cuda:0), ('interactions.2.lin.weight', tensor[512, 512] n=262144 (1Mb) x∈[-0.180, 0.195] μ=5.915e-05 σ=0.051 cuda:0), ('interactions.2.lin.bias',                  
                             tensor[512] 2Kb x∈[-0.051, 0.049] μ=-0.001 σ=0.018 cuda:0), ('interactions.3.mlp.0.weight', tensor[128, 50] n=6400 (25Kb) x∈[-0.288, 0.302] μ=0.000 σ=0.107 cuda:0), ('interactions.3.mlp.0.bias', tensor[128] x∈[-0.069, 0.059]                   
                             μ=-0.003 σ=0.021 cuda:0), ('interactions.3.mlp.2.weight', tensor[128, 128] n=16384 (64Kb) x∈[-0.224, 0.245] μ=0.001 σ=0.089 cuda:0), ('interactions.3.mlp.2.bias', tensor[128] x∈[-0.034, 0.031] μ=-3.920e-05 σ=0.014 cuda:0),                     
                             ('interactions.3.conv.lin1.weight', tensor[128, 512] n=65536 (0.2Mb) x∈[-0.225, 0.212] μ=0.000 σ=0.062 cuda:0), ('interactions.3.conv.lin2.weight', tensor[512, 128] n=65536 (0.2Mb) x∈[-0.197, 0.224] μ=-0.000 σ=0.058 cuda:0),                   
                             ('interactions.3.conv.lin2.bias', tensor[512] 2Kb x∈[-0.043, 0.034] μ=-0.003 σ=0.015 cuda:0), ('interactions.3.conv.nn.0.weight', tensor[128, 50] n=6400 (25Kb) x∈[-0.288, 0.302] μ=0.000 σ=0.107 cuda:0),                                         
                             ('interactions.3.conv.nn.0.bias', tensor[128] x∈[-0.069, 0.059] μ=-0.003 σ=0.021 cuda:0), ('interactions.3.conv.nn.2.weight', tensor[128, 128] n=16384 (64Kb) x∈[-0.224, 0.245] μ=0.001 σ=0.089 cuda:0),                                           
                             ('interactions.3.conv.nn.2.bias', tensor[128] x∈[-0.034, 0.031] μ=-3.920e-05 σ=0.014 cuda:0), ('interactions.3.lin.weight', tensor[512, 512] n=262144 (1Mb) x∈[-0.182, 0.181] μ=7.247e-05 σ=0.050 cuda:0),                                         
                             ('interactions.3.lin.bias', tensor[512] 2Kb x∈[-0.044, 0.050] μ=-0.001 σ=0.016 cuda:0), ('interactions.4.mlp.0.weight', tensor[128, 50] n=6400 (25Kb) x∈[-0.289, 0.293] μ=-0.000 σ=0.107 cuda:0), ('interactions.4.mlp.0.bias',                    
                             tensor[128] x∈[-0.045, 0.038] μ=0.001 σ=0.017 cuda:0), ('interactions.4.mlp.2.weight', tensor[128, 128] n=16384 (64Kb) x∈[-0.236, 0.226] μ=-0.000 σ=0.089 cuda:0), ('interactions.4.mlp.2.bias', tensor[128] x∈[-0.037, 0.041]                     
                             μ=-0.001 σ=0.015 cuda:0), ('interactions.4.conv.lin1.weight', tensor[128, 512] n=65536 (0.2Mb) x∈[-0.210, 0.207] μ=-3.303e-05 σ=0.062 cuda:0), ('interactions.4.conv.lin2.weight', tensor[512, 128] n=65536 (0.2Mb) x∈[-0.185,                     
                             0.185] μ=-0.000 σ=0.057 cuda:0), ('interactions.4.conv.lin2.bias', tensor[512] 2Kb x∈[-0.046, 0.044] μ=-0.002 σ=0.017 cuda:0), ('interactions.4.conv.nn.0.weight', tensor[128, 50] n=6400 (25Kb) x∈[-0.289, 0.293] μ=-0.000 σ=0.107                
                             cuda:0), ('interactions.4.conv.nn.0.bias', tensor[128] x∈[-0.045, 0.038] μ=0.001 σ=0.017 cuda:0), ('interactions.4.conv.nn.2.weight', tensor[128, 128] n=16384 (64Kb) x∈[-0.236, 0.226] μ=-0.000 σ=0.089 cuda:0),                                  
                             ('interactions.4.conv.nn.2.bias', tensor[128] x∈[-0.037, 0.041] μ=-0.001 σ=0.015 cuda:0), ('interactions.4.lin.weight', tensor[512, 512] n=262144 (1Mb) x∈[-0.179, 0.183] μ=-4.716e-05 σ=0.049 cuda:0), ('interactions.4.lin.bias',                
                             tensor[512] 2Kb x∈[-0.045, 0.053] μ=-0.000 σ=0.014 cuda:0), ('interactions.5.mlp.0.weight', tensor[128, 50] n=6400 (25Kb) x∈[-0.252, 0.285] μ=-0.001 σ=0.108 cuda:0), ('interactions.5.mlp.0.bias', tensor[128] x∈[-0.044, 0.045]                  
                             μ=7.291e-05 σ=0.019 cuda:0), ('interactions.5.mlp.2.weight', tensor[128, 128] n=16384 (64Kb) x∈[-0.263, 0.236] μ=-0.001 σ=0.090 cuda:0), ('interactions.5.mlp.2.bias', tensor[128] x∈[-0.038, 0.029] μ=-0.001 σ=0.015 cuda:0),                     
                             ('interactions.5.conv.lin1.weight', tensor[128, 512] n=65536 (0.2Mb) x∈[-0.205, 0.199] μ=0.001 σ=0.062 cuda:0), ('interactions.5.conv.lin2.weight', tensor[512, 128] n=65536 (0.2Mb) x∈[-0.203, 0.182] μ=3.817e-05 σ=0.058 cuda:0),                
                             ('interactions.5.conv.lin2.bias', tensor[512] 2Kb x∈[-0.064, 0.051] μ=-0.003 σ=0.018 cuda:0), ('interactions.5.conv.nn.0.weight', tensor[128, 50] n=6400 (25Kb) x∈[-0.252, 0.285] μ=-0.001 σ=0.108 cuda:0),                                        
                             ('interactions.5.conv.nn.0.bias', tensor[128] x∈[-0.044, 0.045] μ=7.291e-05 σ=0.019 cuda:0), ('interactions.5.conv.nn.2.weight', tensor[128, 128] n=16384 (64Kb) x∈[-0.263, 0.236] μ=-0.001 σ=0.090 cuda:0),                                       
                             ('interactions.5.conv.nn.2.bias', tensor[128] x∈[-0.038, 0.029] μ=-0.001 σ=0.015 cuda:0), ('interactions.5.lin.weight', tensor[512, 512] n=262144 (1Mb) x∈[-0.168, 0.190] μ=3.901e-05 σ=0.048 cuda:0), ('interactions.5.lin.bias',                 
                             tensor[512] 2Kb x∈[-0.036, 0.040] μ=-0.000 σ=0.016 cuda:0), ('lin1.weight', tensor[256, 512] n=131072 (0.5Mb) x∈[-0.254, 0.245] μ=-0.000 σ=0.060 cuda:0), ('lin1.bias', tensor[256] 1Kb x∈[-0.031, 0.038] μ=-0.001 σ=0.014 cuda:0),                
                             ('lin2.weight', tensor[32, 256] n=8192 (32Kb) x∈[-0.201, 0.156] μ=-0.001 σ=0.046 cuda:0), ('lin2.bias', tensor[32] x∈[-0.074, 0.082] μ=0.001 σ=0.043 cuda:0)])                                                                                     
Error executing job with overrides: ['encoder=schnet', 'task=multiclass_node_classification', 'trainer=gpu', 'env.paths.data=/dtu/blackhole/17/126583/Post', 'env.paths.output_dir=/dtu/blackhole/17/126583/Post/output_test', 'ckpt_path=/dtu/blackhole/17/126583/Post/output_test/checkpoints/epoch_000.ckpt']
Traceback (most recent call last):
  File "/zhome/ce/4/118546/deeplearning/env/lib/python3.10/site-packages/proteinworkshop/finetune.py", line 161, in _main
    finetune(cfg)
  File "/zhome/ce/4/118546/deeplearning/env/lib/python3.10/site-packages/proteinworkshop/finetune.py", line 81, in finetune
    err = model.encoder.load_state_dict(encoder_weights, strict=False)
  File "/zhome/ce/4/118546/deeplearning/env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 2041, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for SchNetModel:
    size mismatch for embedding.weight: copying a param with shape torch.Size([512, 39]) from checkpoint, the shape in current model is torch.Size([512, 23]).
a-r-j commented 10 months ago

Hi @asaksager the problem is on our end. We neglected to include the label transform in this release.

You can see it here: https://github.com/a-r-j/ProteinWorkshop/blob/758a292b2d3087ed6ccb8a3ff8bcf47c1a3edcc2/proteinworkshop/tasks/multihot_label_encoding.py#L9

It just need to be added as a transform to the task: https://github.com/a-r-j/ProteinWorkshop/blob/758a292b2d3087ed6ccb8a3ff8bcf47c1a3edcc2/proteinworkshop/config/task/multilabel_graph_classification.yaml#L14

https://github.com/a-r-j/ProteinWorkshop/blob/758a292b2d3087ed6ccb8a3ff8bcf47c1a3edcc2/proteinworkshop/config/transforms/multihot_label_encoding.yaml#L1

Apologies for this. We're likely to push a new release in the next week or so with these fixes.

a-r-j commented 10 months ago

Regarding the shape mismatch, it would seem the the 39 input features you're using (features=ca_seq) is different to the 23 features the model was trained with (features=ca_base). (23 for AA type, 16 for positional encoding).

Ultimately, I think we'd like to remove this complexity such that the feature config is pulled from the checkpoint such that mismatches like this can't happen.

It seems we have set different defaults for train and finetune:

https://github.com/a-r-j/ProteinWorkshop/blob/ceabdaec3bf7e61292f033507bd092bff0d7c61a/proteinworkshop/config/finetune.yaml#L12

https://github.com/a-r-j/ProteinWorkshop/blob/ceabdaec3bf7e61292f033507bd092bff0d7c61a/proteinworkshop/config/train.yaml#L12

asaksager commented 10 months ago

Thank you for the quick response. Yes aligning the defaults would be nice.

I will try it out.
I am trying to do node classification - should I also add this multihot_label_encoding in the multiclass_node_classification.yaml?

a-r-j commented 10 months ago

Yes, it would be required for node classification too.