Googolxx / STF

Pytorch implementation of the paper "The Devil Is in the Details: Window-based Attention for Image Compression".
Apache License 2.0
161 stars 20 forks source link

关于两种模型通道数的设置问题 #18

Closed cezdo closed 1 year ago

cezdo commented 1 year ago

作者您好!

由于我在使用CAVE数据集(31通道)训练高光谱图像的压缩,因此需要修改图像通道数量来进行训练。因此我想请问一下cnn模型和stf模型中对于图像通道数量的设置有什么区别。

首先就是CNN模型中似乎只需要在31行附近的这个部分将原来的3修改为31即可,并且可以正常训练。

self.g_a = nn.Sequential(
    conv(31, N, kernel_size=5, stride=2), #从3修改为31
    ...
)

紧接着就是在STF模型中,我尝试修改了351行与388行两处 in_chans 的初始值为31,尝试运行后仍旧是报通道数不匹配的错误。

    def __init__(self, patch_size=4, in_chans=31, embed_dim=96, norm_layer=None):
        super().__init__()
        ...
class SymmetricalTransFormer(CompressionModel):
    def __init__(self,
                pretrain_img_size=256,
                patch_size=2,
                in_chans=31, #input channel num = 31
                ...

以下是报错的提示:

Namespace(aux_learning_rate=0.001, batch_size=16, checkpoint=None, clip_max_norm=1.0, cuda=True, dataset='CAVE', epochs=1000, learning_rate=0.0001, lmbda=0.0035, model='stf', num_workers=30, patch_size=(256, 256), save=True, save_path='stf/ckpt/', seed=None, test_batch_size=64)
Learning rate: 0.0001
/data2/zhaoshuyi/anaconda3/envs/compressai/lib/python3.7/site-packages/torch/nn/modules/loss.py:528: UserWarning: Using a target size (torch.Size([16, 31, 128, 128])) that is different to the input size (torch.Size([16, 3, 128, 128])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
Traceback (most recent call last):
  File "stf/train.py", line 367, in <module>
    main(sys.argv[1:])
  File "stf/train.py", line 342, in main
    args.clip_max_norm,
  File "stf/train.py", line 132, in train_one_epoch
    out_criterion = criterion(out_net, d)
  File "/data2/zhaoshuyi/anaconda3/envs/compressai/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "stf/train.py", line 52, in forward
    out["mse_loss"] = self.mse(output["x_hat"], target)
  File "/data2/zhaoshuyi/anaconda3/envs/compressai/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data2/zhaoshuyi/anaconda3/envs/compressai/lib/python3.7/site-packages/torch/nn/modules/loss.py", line 528, in forward
    return F.mse_loss(input, target, reduction=self.reduction)
  File "/data2/zhaoshuyi/anaconda3/envs/compressai/lib/python3.7/site-packages/torch/nn/functional.py", line 2928, in mse_loss
    expanded_input, expanded_target = torch.broadcast_tensors(input, target)
  File "/data2/zhaoshuyi/anaconda3/envs/compressai/lib/python3.7/site-packages/torch/functional.py", line 74, in broadcast_tensors
    return _VF.broadcast_tensors(tensors)  # type: ignore
RuntimeError: The size of tensor a (3) must match the size of tensor b (31) at non-singleton dimension 1

所以,请问还有什么别的需要修改的地方来实现对高光谱图像的压缩的支持呢?十分感谢!

Googolxx commented 1 year ago

你还需要修改这里,改为 self.end_conv = nn.Sequential(nn.Conv2d(embed_dim, embed_dim * patch_size ** 2, kernel_size=5, stride=1, padding=2), nn.PixelShuffle(patch_size), nn.Conv2d(embed_dim, 31, kernel_size=3, stride=1, padding=1), )

cezdo commented 1 year ago

非常感谢!