THUDM / Inf-DiT

Official implementation of Inf-DiT: Upsampling Any-Resolution Image with Memory-Efficient Diffusion Transformer
Apache License 2.0
378 stars 19 forks source link

when set inference_type=ar2, miss lr_imgs #21

Open kidhan1234 opened 4 months ago

kidhan1234 commented 4 months ago

image 如题,当设置inference_type=ar2,cross_attention_forward()函数会报缺少lr_imgs,点进代码(model.py line 426)可以看到确实没有传递lr_imgs,奇怪的是选择full时没有这个bug,请问是什么原因呢 image

xiaom233 commented 3 months ago

我也遇到了相同的问题,图比较大的时候full出现oom,ar1跑的比较慢,ar2存在bug不能使用

yzy-thu commented 3 months ago

可以先把推理脚本里的cross_lr删掉跑下试试,我会尽快修一下这个bug

Carl-Lin-cloud commented 2 months ago

可以先把推理脚本里的cross_lr删掉跑下试试,我会尽快修一下这个bug

删掉了cross_lr跑的时候,只能单卡跑不能多卡并行。

(model.py line 718) output, output_per_layers = self.model_forward(args, hw=[vit_block_bsize, vit_block_bsize], mems=mems, inference=2, **kwargs)

手动添加了 lr_imgs=lr_imgs 依然跑不了,报错k must have shape (batch_size, seqlen_k, num_heads_k, head_size_og).