Open KyonP opened 1 year ago
Is there a way to save output (generated and ground truth) images from the best-performing checkpoint?
I looked into your code; it is hard for me to utilize your acc_tensors_to_images
in utils.py.
Do I need to train the pretrain weights first to save 'pt' files?
can you tell me how to use it?
also, your pre-trained weights on ./1.3B/
seem incompatible for inference.
root@1074e9836478:/home/my/storydalle/story-dalle# bash infer_story.sh pororo test
Evaluating on Pororo
Global seed set to 42
Initializing the Conditional Dalle model
Setting up Cross-attention Layers
Total parameters is 1396275075
./ckpt
path : ./ckpt/bpe-16k-vocab.json
./ckpt successfully restored..
Loaded tokenizer from finetuned checkpoint
[2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41]
Loading model from pretrained checkpoint ./ckpt/25_v1.pth
Traceback (most recent call last):
File "/home/my/storydalle/story-dalle/dalle/models/__init__.py", line 961, in from_pretrained
model.load_state_dict(torch.load(args.model_name_or_path)['state_dict'])
KeyError: 'state_dict'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "./infer_t2i.py", line 396, in <module>
main(args)
File "./infer_t2i.py", line 126, in main
model, config = StoryDalle.from_pretrained(args)
File "/home/my/storydalle/story-dalle/dalle/models/__init__.py", line 963, in from_pretrained
model.load_state_dict(torch.load(args.model_name_or_path)['model_state_dict'])
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for StoryDalle:
size mismatch for stage2.tok_emb_txt.weight: copying a param with shape torch.Size([16393, 1536]) from checkpoint, the shape in current model is torch.Size([16384, 1536]).
size mismatch for stage2.head_txt.weight: copying a param with shape torch.Size([16393, 1536]) from checkpoint, the shape in current model is torch.Size([16384, 1536]).
root@1074e9836478:/home/my/storydalle/story-dalle#
for weight shape mismatch case, I found out the reason 😓
I didn't set the checkpoint path to contain the string 'pororo'.
I changed the if-condition to check dataset_name instead.
#story-dalle.dalle.models.__init__.py
if args.model_name_or_path:
if 'pororo' in args.dataset_name: #args.model_name_or_path:
config_update.stage2.vocab_size_txt = config_update.stage2.vocab_size_txt + 9
elif 'flintstones' in args.dataset_name: #args.model_name_or_path:
config_update.stage2.vocab_size_txt = config_update.stage2.vocab_size_txt + 7
我正在尝试从性能最佳的检查点收集生成的图像,但我遇到了这个错误。
(ldm) root@2157b047841c:/home/my/storydalle/story-dalle# bash infer_story.sh pororo Evaluating on Pororo ╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮ │ /home/my/storydalle/story-dalle/./infer_t2i.py:24 in <module> │ │ │ │ 21 │ │ 22 import logging │ │ 23 import os, torch │ │ ❱ 24 from dalle.models import PrefixTuningDalle, StoryDalle, PromptDalle │ │ 25 import torchvision │ │ 26 import torchvision.transforms as transforms │ │ 27 import pytorch_lightning as pl │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ ImportError: cannot import name 'PrefixTuningDalle' from 'dalle.models' (/home/my/storydalle/story-dalle/dalle/models/__init__.py)
似乎是dalle.models。当前版本上的 init.py 没有该模块。
我可能错过了什么。如果是这样,我希望得到你的回复
how to fix this problem?
I am trying to gather generated images from your best-performing checkpoint, and I faced this error.
Seems like dalle.models.init.py on the current version doesn't have the module.
I may have missed something. If so, I hope for your reply