jiaor17 / DiffCSP-PP

[ICLR 2024] The implementation for the paper "Space Group Constrained Crystal Generation"
25 stars 6 forks source link

Size mismatch in model architecture #2

Open andrelowky opened 4 months ago

andrelowky commented 4 months ago

Dear team,

I am trying to replicate the results by using the same training script for dataset=perov_5 following instructions here. When loading the model in a Jupyter notebook, I get the following error message:

RuntimeError: Error(s) in loading state_dict for CSPDiffusion: size mismatch for decoder.csp_layer_0.edge_mlp.0.weight: copying a param with shape torch.Size([512, 1798]) from checkpoint, the shape in current model is torch.Size([512, 1801]). size mismatch for decoder.csp_layer_1.edge_mlp.0.weight: copying a param with shape torch.Size([512, 1798]) from checkpoint, the shape in current model is torch.Size([512, 1801]). size mismatch for decoder.csp_layer_2.edge_mlp.0.weight: copying a param with shape torch.Size([512, 1798]) from checkpoint, the shape in current model is torch.Size([512, 1801]). size mismatch for decoder.csp_layer_3.edge_mlp.0.weight: copying a param with shape torch.Size([512, 1798]) from checkpoint, the shape in current model is torch.Size([512, 1801]). size mismatch for decoder.csp_layer_4.edge_mlp.0.weight: copying a param with shape torch.Size([512, 1798]) from checkpoint, the shape in current model is torch.Size([512, 1801]). size mismatch for decoder.csp_layer_5.edge_mlp.0.weight: copying a param with shape torch.Size([512, 1798]) from checkpoint, the shape in current model is torch.Size([512, 1801]). size mismatch for decoder.lattice_out.weight: copying a param with shape torch.Size([6, 512]) from checkpoint, the shape in current model is torch.Size([9, 512]).

jiaor17 commented 4 months ago

Hi, This mismatch results from a slight difference on the model architecture between DiffCSP++ and the original DiffCSP. In DiffCSP++, we use a 6-dimension invariant vector to represent the lattice matrix, while DiffCSP utilizes the inner product, which is flatten as 9-dimension. To use DiffCSP++, you may re-train a model via the codes in this repo.

andrelowky commented 4 months ago

Thanks for the explanation!

I have successfully trained the model following the script. The error message appears when using the load_model function from eval_utils in the original DiffCSP repository,

In particular, I have isolated it to this line of code: model = model.load_from_checkpoint(ckpt, hparams_file=hparams, strict=False)