Linfeng-Tang / SwinFusion

This is official Pytorch implementation of "SwinFusion: Cross-domain Long-range Learning for General Image Fusion via Swin Transformer"
174 stars 19 forks source link

我也出现了这个问题 #2

Open ykykyyk opened 1 year ago

ykykyyk commented 1 year ago

在更改通道数的时候出现错误,是因为预训练文件的原因吗

Linfeng-Tang commented 1 year ago

在我们的程序配置中只支持三通道 你可以尝试将灰度图像转换成三通道的 或者 通过简单修改数据加载的代码来适应不同的通道数哈~

mk12306 commented 1 year ago

在我们的程序配置中只支持三通道 你可以尝试将灰度图像转换成三通道的 或者 通过简单修改数据加载的代码来适应不同的通道数哈~

好像现在是只支持单通道,三通道会报错,是因为训练用的单通道吗?

LittlePika commented 1 year ago

Hello, thank you for providing the paper as well as the public code. I get this error message as well. When running the program with the configuration in_channel = 1, this works. When changing to in_channel=3 one gets the error message attached below.

In my opinion, this sounds like too few weights in the model available in git for VIF. How can I get the weights to use with in_channel=3?

Thanks in advance.

loading model from ./Model/Infrared_Visible_Fusion/Infrared_Visible_Fusion/models/ in_chans: 3 .local/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.) return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] Traceback (most recent call last): File "./SwinFusion/test_swinfusion.py", line 153, in main() File "./SwinFusion/test_swinfusion.py", line 50, in main model = define_model(args) File "./SwinFusion/test_swinfusion.py", line 94, in define_model model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model.keys() else pretrained_model, strict=True) File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for SwinFusion: size mismatch for conv_first1_A.weight: copying a param with shape torch.Size([30, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([30, 3, 3, 3]). size mismatch for conv_first1_B.weight: copying a param with shape torch.Size([30, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([30, 3, 3, 3]). size mismatch for conv_last3.weight: copying a param with shape torch.Size([1, 15, 3, 3]) from checkpoint, the shape in current model is torch.Size([3, 15, 3, 3]). size mismatch for conv_last3.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([3]).

Amano-Hina commented 1 year ago

Hello, thank you for providing the paper as well as the public code. I get this error message as well. When running the program with the configuration in_channel = 1, this works. When changing to in_channel=3 one gets the error message attached below.

In my opinion, this sounds like too few weights in the model available in git for VIF. How can I get the weights to use with in_channel=3?

Thanks in advance.

loading model from ./Model/Infrared_Visible_Fusion/Infrared_Visible_Fusion/models/ in_chans: 3 .local/lib/python3.10/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.) return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined] Traceback (most recent call last): File "./SwinFusion/test_swinfusion.py", line 153, in main() File "./SwinFusion/test_swinfusion.py", line 50, in main model = define_model(args) File "./SwinFusion/test_swinfusion.py", line 94, in define_model model.load_state_dict(pretrained_model[param_key_g] if param_key_g in pretrained_model.keys() else pretrained_model, strict=True) File ".local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for SwinFusion: size mismatch for conv_first1_A.weight: copying a param with shape torch.Size([30, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([30, 3, 3, 3]). size mismatch for conv_first1_B.weight: copying a param with shape torch.Size([30, 1, 3, 3]) from checkpoint, the shape in current model is torch.Size([30, 3, 3, 3]). size mismatch for conv_last3.weight: copying a param with shape torch.Size([1, 15, 3, 3]) from checkpoint, the shape in current model is torch.Size([3, 15, 3, 3]). size mismatch for conv_last3.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([3]).

I have met the same error, Do you find how to solve this?

LittlePika commented 1 year ago

For testing, you can take the workaround of splitting into the individual channels. Then the existing weights can be used. This produces comparable results.

Make the following modification in the test file:

def test(img_ir, img_vis, model): img_ir = [:,0:1,:,:] img_vis_r = [:,0:1,::] img_vis_g = [:,1:2,::] img_vis_b = [:,2:3,::]

output_r = model(img_ir, img_vis_r) output_g = model(img_ir, img_vis_g) output_b = model(img_ir, img_vis_b)

output = torch.cat((output_r,output_g, output_b),dim=1)

return output

Hope it helps. However, this workaround does not work for training.

JackNewComer111 commented 10 months ago

请问对于红外光和可见光融合,没办法进行三通道的训练吗?代码中的其他融合呢?