kohjingyu / gill

🐟 Code and models for the NeurIPS 2023 paper "Generating Images with Multimodal Language Models".
https://jykoh.com/gill
Apache License 2.0
400 stars 33 forks source link

Error size mismatch when load decision model #38

Open haunt98 opened 3 months ago

haunt98 commented 3 months ago

After training both gill and decision model, load_model failed:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <cell line: 2>:2                                                                              │
│                                                                                                  │
│ /content/gill/gill/models.py:873 in load_gill                                                    │
│                                                                                                  │
│   870 │   decision_model_path = None                                                             │
│   871                                                                                            │
│   872   # Initialize model for inference.                                                        │
│ ❱ 873   model = GILL(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix,              │
│   874 │   │   │      load_sd=True, num_gen_images=1, decision_model_path=decision_model_path)    │
│   875   model = model.eval()                                                                     │
│   876   model = model.bfloat16()                                                                 │
│                                                                                                  │
│ /content/gill/gill/models.py:560 in __init__                                                     │
│                                                                                                  │
│   557 │   │     nn.Linear(768, 2),                                                               │
│   558 │     ])                                                                                   │
│   559 │     mlp_checkpoint = torch.load(decision_model_path)                                     │
│ ❱ 560 │     self.decision_model.load_state_dict(mlp_checkpoint['state_dict'], strict=False)      │
│   561 │     self.decision_model.eval()                                                           │
│   562                                                                                            │
│   563   def __call__(self, images: Tensor, tgt_tokens: Optional[Tensor] = None, caption_len: O   │
│                                                                                                  │
│ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1671 in load_state_dict       │
│                                                                                                  │
│   1668 │   │   │   │   │   │   ', '.join('"{}"'.format(k) for k in missing_keys)))               │
│   1669 │   │                                                                                     │
│   1670 │   │   if len(error_msgs) > 0:                                                           │
│ ❱ 1671 │   │   │   raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(     │
│   1672 │   │   │   │   │   │   │      self.__class__.__name__, "\n\t".join(error_msgs)))         │
│   1673 │   │   return _IncompatibleKeys(missing_keys, unexpected_keys)                           │
│   1674                                                                                           │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: Error(s) in loading state_dict for Sequential:
        size mismatch for 1.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in 
current model is torch.Size([2, 4096]).
haunt98 commented 3 months ago

Looks like you hardcode 4096 when init decision model

My hotfix is here but maybe use param or something?

kohjingyu commented 3 months ago

Glad you solved it. Yeah, I think the right way to do this would be to add this as a param to the model args.