callsys / GenPromp

[ICCV 2023] Generative Prompt Model for Weakly Supervised Object Localization
Apache License 2.0
55 stars 2 forks source link

您好,感谢您在该领域做出的贡献!我在复现您提供的权重去测试时出现以下错误,该如何解决呢 谢谢 #13

Open jjxyhb opened 3 days ago

jjxyhb commented 3 days ago

python main.py --function test --config configs/cub_stage2.yml --opt "{'test': {'load_token_path': 'ckpts/cub983/tokens/', 'load_unet_path': 'ckpts/cub983/unet/', 'save_log_path': 'ckpts/cub983/log.txt'}}"

Namespace(config='configs/cub_stage2.yml', function='test', opt={'test': {'load_token_path': 'ckpts/cub983/tokens/', 'load_unet_path': 'ckpts/cub983/unet/', 'save_log_path': 'ckpts/cub983/log.txt'}}, seed=1234) An error occurred while trying to fetch ckpts/pretrains/stable-diffusion-v1-4: Error no file named diffusion_pytorch_model.safetensors found in directory ckpts/pretrains/stable-diffusion-v1-4. Defaulting to unsafe serialization. Pass allow_pickle=False to raise an error instead. An error occurred while trying to fetch ckpts/cub983/unet/: Error no file named diffusion_pytorch_model.safetensors found in directory ckpts/cub983/unet/. Defaulting to unsafe serialization. Pass allow_pickle=False to raise an error instead. INFO: CUBDataset: load data. INFO: CUBDataset: init samples. INFO: CUBDataset: init text encoders. INFO: Test Save: [log: ckpts/cub983/log.txt] [vis: None] INFO: Test CheckPoint: [token: ckpts/cub983/tokens/] [unet: ckpts/cub983/unet/] INFO: Test Class [0-199]: [dataset: cub] [eval mode: top1] [cam thr: 0.23] [combine ratio: 0.6] 0%| | 0/2897 [00:04<?, ?it/s]

报错

Traceback (most recent call last): File "main.py", line 641, in eval(args.function)(config) File "main.py", line 299, in test cams = controller.diffusion_cam(idx=5) File "/home/dell/CV408/hb/GenPromp/models/attn.py", line 178, in diffusion_cam attention_maps_8_ca = self.cross_attention_map(8, ("up", "mid", "down"), bz, idx=idx) File "/home/dell/CV408/hb/GenPromp/models/attn.py", line 163, in cross_attention_map attention_maps = self.aggregate_attention(res, from_where, True, bz) File "/home/dell/CV408/hb/GenPromp/models/attn.py", line 138, in aggregate_attention for item in attentionmaps[f"{location}{'cross' if is_cross else 'self'}"]: KeyError: 'up_cross'

callsys commented 1 day ago

你好,有可能是diffusers库更新导致了无法从SD model中抽取到attention map。