adymaharana / storydalle

MIT License
330 stars 27 forks source link

PrefixTuningDalle not exist #12

Open KyonP opened 1 year ago

KyonP commented 1 year ago

I am trying to gather generated images from your best-performing checkpoint, and I faced this error.

(ldm) root@2157b047841c:/home/my/storydalle/story-dalle# bash infer_story.sh pororo
Evaluating on Pororo
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/my/storydalle/story-dalle/./infer_t2i.py:24 in <module>                                │
│                                                                                                  │
│    21                                                                                            │
│    22 import logging                                                                             │
│    23 import os, torch                                                                           │
│ ❱  24 from dalle.models import PrefixTuningDalle, StoryDalle, PromptDalle                        │
│    25 import torchvision                                                                         │
│    26 import torchvision.transforms as transforms                                                │
│    27 import pytorch_lightning as pl                                                             │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ImportError: cannot import name 'PrefixTuningDalle' from 'dalle.models' (/home/my/storydalle/story-dalle/dalle/models/__init__.py)

Seems like dalle.models.init.py on the current version doesn't have the module.

I may have missed something. If so, I hope for your reply

KyonP commented 1 year ago

Is there a way to save output (generated and ground truth) images from the best-performing checkpoint?

I looked into your code; it is hard for me to utilize your acc_tensors_to_images in utils.py.

Do I need to train the pretrain weights first to save 'pt' files?

can you tell me how to use it?

KyonP commented 1 year ago

also, your pre-trained weights on ./1.3B/ seem incompatible for inference.

root@1074e9836478:/home/my/storydalle/story-dalle# bash infer_story.sh pororo test
Evaluating on Pororo
Global seed set to 42
Initializing the Conditional Dalle model
Setting up Cross-attention Layers
Total parameters is 1396275075
./ckpt
path :  ./ckpt/bpe-16k-vocab.json
./ckpt successfully restored..
Loaded tokenizer from finetuned checkpoint
[2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41]
Loading model from pretrained checkpoint ./ckpt/25_v1.pth
Traceback (most recent call last):
  File "/home/my/storydalle/story-dalle/dalle/models/__init__.py", line 961, in from_pretrained
    model.load_state_dict(torch.load(args.model_name_or_path)['state_dict'])
KeyError: 'state_dict'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "./infer_t2i.py", line 396, in <module>
    main(args)
  File "./infer_t2i.py", line 126, in main
    model, config = StoryDalle.from_pretrained(args)
  File "/home/my/storydalle/story-dalle/dalle/models/__init__.py", line 963, in from_pretrained
    model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict'])
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for StoryDalle:
        size mismatch for stage2.tok_emb_txt.weight: copying a param with shape torch.Size([16393, 1536]) from checkpoint, the shape in current model is torch.Size([16384, 1536]).
        size mismatch for stage2.head_txt.weight: copying a param with shape torch.Size([16393, 1536]) from checkpoint, the shape in current model is torch.Size([16384, 1536]).
root@1074e9836478:/home/my/storydalle/story-dalle# 
KyonP commented 1 year ago

for weight shape mismatch case, I found out the reason 😓

I didn't set the checkpoint path to contain the string 'pororo'.

I changed the if-condition to check dataset_name instead.

#story-dalle.dalle.models.__init__.py 

        if args.model_name_or_path:
            if 'pororo' in args.dataset_name: #args.model_name_or_path:
                config_update.stage2.vocab_size_txt = config_update.stage2.vocab_size_txt + 9
            elif 'flintstones' in args.dataset_name: #args.model_name_or_path:
                config_update.stage2.vocab_size_txt = config_update.stage2.vocab_size_txt + 7
ghost commented 1 year ago

我正在尝试从性能最佳的检查点收集生成的图像,但我遇到了这个错误。

(ldm) root@2157b047841c:/home/my/storydalle/story-dalle# bash infer_story.sh pororo
Evaluating on Pororo
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/my/storydalle/story-dalle/./infer_t2i.py:24 in <module>                                │
│                                                                                                  │
│    21                                                                                            │
│    22 import logging                                                                             │
│    23 import os, torch                                                                           │
│ ❱  24 from dalle.models import PrefixTuningDalle, StoryDalle, PromptDalle                        │
│    25 import torchvision                                                                         │
│    26 import torchvision.transforms as transforms                                                │
│    27 import pytorch_lightning as pl                                                             │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
ImportError: cannot import name 'PrefixTuningDalle' from 'dalle.models' (/home/my/storydalle/story-dalle/dalle/models/__init__.py)

似乎是dalle.models。当前版本上的 init.py 没有该模块。

我可能错过了什么。如果是这样,我希望得到你的回复

how to fix this problem?