snap-stanford / GEARS

GEARS is a geometric deep learning model that predicts outcomes of novel multi-gene perturbations
MIT License
189 stars 38 forks source link

Error when trying to load the pre-trained model #32

Closed moghra closed 10 months ago

moghra commented 10 months ago

Hi! Thank you for providing an opportunity to use your amazing tool!

I faced an issue while running the tutorial for using the trained model. Although I'm loading my own data instead of Norman dataset, I have no problems with creating dataloader and I do not modify any further code. However, I get the error when I load the pre-trained model with load_pretrained. Here is the relevant part of the code:

## Download model from dataverse
dataverse_download('https://dataverse.harvard.edu/api/access/datafile/6979956', 'model.zip')

## Extract and set up model directory
with ZipFile(('model.zip'), 'r') as zip:
    zip.extractall(path = './')
model_name = 'gears_misc_umi_no_test'

gears_model = GEARS(pert_data, device = 'cpu',
                        weight_bias_track = False,
                        proj_name = 'gears',
                        exp_name = model_name)
gears_model.load_pretrained('./model_ckpt')

And the error is

TypeError                                 Traceback (most recent call last)
[<ipython-input-32-2b23a346f02a>](https://localhost:8080/#) in <cell line: 7>()
      5                         proj_name = 'gears',
      6                         exp_name = model_name)
----> 7 gears_model.load_pretrained('./model_ckpt')

[/content/GEARS/gears/gears.py](https://localhost:8080/#) in load_pretrained(self, path)
    255 
    256         del config['device'], config['num_genes'], config['num_perts']
--> 257         self.model_initialize(**config)
    258         self.config = config
    259 

TypeError: GEARS.model_initialize() got an unexpected keyword argument 'cell_fitness_pred'

It seems to me that the problem is following: the dictionary config.pkl (which is a part of the trained model directory) contains an item 'cell_fitness_pred': False, while 'cell_fitness_pred' is not among the arguments of model_initialize(). However, simply removing this item from the dictionary gives different error (here 'model_modified_ckpt' is a copy of 'model_ckpt' directory; the only difference is that I removed the mentioned item from config.pkl):

gears_model.load_pretrained('./model_modified_ckpt')
RuntimeError                              Traceback (most recent call last)
[<ipython-input-36-62ddcd5f7a40>](https://localhost:8080/#) in <cell line: 7>()
      5                         proj_name = 'gears',
      6                         exp_name = model_name)
----> 7 gears_model.load_pretrained('./model_modified_ckpt')

1 frames
[/content/GEARS/gears/gears.py](https://localhost:8080/#) in load_pretrained(self, path)
    268             state_dict = new_state_dict
    269 
--> 270         self.model.load_state_dict(state_dict)
    271         self.model = self.model.to(self.device)
    272         self.best_model = self.model

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in load_state_dict(self, state_dict, strict)
   2039 
   2040         if len(error_msgs) > 0:
-> 2041             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   2042                                self.__class__.__name__, "\n\t".join(error_msgs)))
   2043         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for GEARS_Model:
    Unexpected key(s) in state_dict: "cell_fitness_mlp.network.0.weight", "cell_fitness_mlp.network.0.bias", "cell_fitness_mlp.network.1.weight", "cell_fitness_mlp.network.1.bias", "cell_fitness_mlp.network.1.running_mean", "cell_fitness_mlp.network.1.running_var", "cell_fitness_mlp.network.1.num_batches_tracked", "cell_fitness_mlp.network.3.weight", "cell_fitness_mlp.network.3.bias", "cell_fitness_mlp.network.4.weight", "cell_fitness_mlp.network.4.bias", "cell_fitness_mlp.network.4.running_mean", "cell_fitness_mlp.network.4.running_var", "cell_fitness_mlp.network.4.num_batches_tracked", "cell_fitness_mlp.network.6.weight", "cell_fitness_mlp.network.6.bias", "cell_fitness_mlp.network.7.weight", "cell_fitness_mlp.network.7.bias", "cell_fitness_mlp.network.7.running_mean", "cell_fitness_mlp.network.7.running_var", "cell_fitness_mlp.network.7.num_batches_tracked". 
    size mismatch for indv_w1: copying a param with shape torch.Size([5054, 64, 1]) from checkpoint, the shape in current model is torch.Size([5000, 64, 1]).
    size mismatch for indv_b1: copying a param with shape torch.Size([5054, 1]) from checkpoint, the shape in current model is torch.Size([5000, 1]).
    size mismatch for indv_w2: copying a param with shape torch.Size([1, 5054, 65]) from checkpoint, the shape in current model is torch.Size([1, 5000, 65]).
    size mismatch for indv_b2: copying a param with shape torch.Size([1, 5054]) from checkpoint, the shape in current model is torch.Size([1, 5000]).
    size mismatch for gene_emb.weight: copying a param with shape torch.Size([5054, 64]) from checkpoint, the shape in current model is torch.Size([5000, 64]).
    size mismatch for emb_pos.weight: copying a param with shape torch.Size([5054, 64]) from checkpoint, the shape in current model is torch.Size([5000, 64]).
    size mismatch for cross_gene_state.network.0.weight: copying a param with shape torch.Size([64, 5054]) from checkpoint, the shape in current model is torch.Size([64, 5000]).

Is there some simple way to resolve this issue?

yhr91 commented 10 months ago

Thanks for your question. The pretrained model can only be used with the dataloader used to train it. GEARS currently does not have functionality for transferring to new dataloaders after training. Having said that, this dataloader includes a broad range of potential perturbations so hopefully you can still use it to learn new insights.