Closed lyx-JuneSnow closed 3 months ago
qu_o 和 ku_cat/vu_cat 的形状不匹配是正常的,qu_o 就只是一个 batch 的数据,而 ku_cat 是多张图片拼接的,这两个变量能进行矩阵乘法。
问题一开始发生在 'sim_ref = sim_ref + ref_mask.masked_fill(ref_mask == 0, torch.finfo(sim.dtype).min)' 这里,从打印信息来看,应该是 ref_mask 出问题了,你的 mask 文件是不是以灰度文件保存的?mask保存的时候用单通道的,不要用 rgb 格式保存
有两个建议可能可以帮助你排查问题。第一,你为了解决问题,对代码作出修改的部分都撤回,使用原始代码;第二,跑一遍我提供的示例,在最开始报错的位置打上断点,打印一下各个变量的 shape,然后和你的示例对比一下,看看是不是 ref_mask 的 shape 不一致,也可以在读取 mask 的地方(https://github.com/aim-uofa/FreeCustom/blob/main/freecustom_stable_diffusion.py#L51) 打上断点看看你的 mask 文件是不是和我的格式不一致
有两个建议可能可以帮助你排查问题。第一,你为了解决问题,对代码作出修改的部分都撤回,使用原始代码;第二,跑一遍我提供的示例,在最开始报错的位置打上断点,打印一下各个变量的 shape,然后和你的示例对比一下,看看是不是 ref_mask 的 shape 不一致,也可以在读取 mask 的地方(https://github.com/aim-uofa/FreeCustom/blob/main/freecustom_stable_diffusion.py#L51) 打上断点看看你的 mask 文件是不是和我的格式不一致
非常感谢您提供的建议,我在读取ref_mask的地方打了断点输出了ref_mask的格式,我发现的确是因为我的mask通道数是3而导致了bug。 在freecustom_stable_diffusion.py的第51行后我添加了代码ref_mask = ref_mask[:, :1, :, :]解决了这个问题。
以上是我的运行报错 然后我在mrsa.py中添加了如下的打印信息
结果为
qu_o 和 ku_cat/vu_cat 的形状不匹配。特别是 qu_o 的形状是 [8, 1024, 80],而 ku_cat 和 vu_cat 的形状是 [8, 4096, 80]。 做了如下调整
仍然是一样的报错 于是修改mrsa.py中
打印出
sim_ref 和 ref_mask 的形状不匹配。sim_ref 形状为 [8, 1024, 1024],而 ref_mask 的形状为 [1, 3072] 继续修改mrsa.py为
以及
结果是
继续改为
结果为
继续改为
报错变为