lengmo1996 / ECDM

This is the official repository of Data Generation Scheme for Thermal Modality with Edge-Guided Adversarial Conditional Diffusion Model (ACM MM'24)
1 stars 0 forks source link

训练遇到问题:RuntimeError: Given groups=1, weight of size [128, 6, 3, 3], expected input[1, 2, 512, 640] to have 6 channels, but got 2 channels instead #1

Closed DuanHongxuan closed 2 months ago

DuanHongxuan commented 2 months ago

Traceback (most recent call last): File "/media/ubuntu/新加卷/dhx_workspace/image_matching/ECDM-master/main.py", line 144, in cli_main() File "/media/ubuntu/新加卷/dhx_workspace/image_matching/ECDM-master/main.py", line 137, in cli_main cli = MyLightningCLI( File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/cli.py", line 359, in init self._run_subcommand(self.subcommand) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/cli.py", line 650, in _run_subcommand fn(fn_kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 532, in fit call._call_and_handle_interrupt( File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt return trainer_fn(*args, *kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 571, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 980, in _run results = self._run_stage() File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1023, in _run_stage self.fit_loop.run() File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 202, in run self.advance() File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 355, in advance self.epoch_loop.run(self._data_fetcher) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 133, in run self.advance(data_fetcher) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 219, in advance batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 188, in run self._optimizer_step(kwargs.get("batch_idx", 0), closure) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 266, in _optimizer_step call._call_lightning_module_hook( File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 146, in _call_lightning_module_hook output = fn(args, kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1276, in optimizer_step optimizer.step(closure=optimizer_closure) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py", line 161, in step step_output = self._strategy.optimizer_step(self._optimizer, closure, kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 231, in optimizer_step return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py", line 116, in optimizer_step return optimizer.step(closure=closure, kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/optim/optimizer.py", line 391, in wrapper out = func(*args, *kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/optim/optimizer.py", line 76, in _use_grad ret = func(self, args, kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/optim/adamw.py", line 165, in step loss = closure() File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py", line 103, in _wrap_closure closure_result = closure() File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 142, in call self._result = self.closure(*args, kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 128, in closure step_output = self._step_fn() File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 315, in _training_step training_step_output = call._call_strategy_hook(trainer, "training_step", kwargs.values()) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 294, in _call_strategy_hook output = fn(args, kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 380, in training_step return self.model.training_step(*args, kwargs) File "/media/ubuntu/新加卷/dhx_workspace/image_matching/ECDM-master/ecdm/models/diffusion/ecdm_first_stage.py", line 552, in training_step loss, loss_dict = self.shared_step(batch) File "/media/ubuntu/新加卷/dhx_workspace/image_matching/ECDM-master/ecdm/models/diffusion/ecdm_first_stage.py", line 548, in shared_step loss, loss_dict = self(x, cond=cond) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, *kwargs) File "/media/ubuntu/新加卷/dhx_workspace/image_matching/ECDM-master/ecdm/models/diffusion/ecdm_first_stage.py", line 534, in forward return self.p_losses(x, t, args, kwargs) File "/media/ubuntu/新加卷/dhx_workspace/image_matching/ECDM-master/ecdm/models/diffusion/ecdm_first_stage.py", line 500, in p_losses model_out = self.model(x_noisy, t, cond) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, kwargs) File "/media/ubuntu/新加卷/dhx_workspace/image_matching/ECDM-master/ecdm/models/diffusion/ecdm_first_stage.py", line 711, in forward out = self.diffusion_model(xc, t) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, *kwargs) File "/media/ubuntu/新加卷/dhx_workspace/image_matching/ECDM-master/ecdm/modules/diffusionmodules/simple_unet.py", line 296, in forward hs = [self.conv_in(x)] File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, **kwargs) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward return self._conv_forward(input, self.weight, self.bias) File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward return F.conv2d(input, weight, bias, self.stride, RuntimeError: Given groups=1, weight of size [128, 6, 3, 3], expected input[1, 2, 512, 640] to have 6 channels, but got 2 channels instead Epoch 0: 0%| | 0/12025 [00:01<?, ?it/s]

上述是我训练的时候遇到的问题,貌似是输入的维度不匹配,我在ecdm_first_stage.yaml文件中发现in_channels: 6 ,而在dataset.py中 vis_img = vis_img.convert("L") vis_HF_img = vis_HF_img.convert("L") tir_img = tir_img.convert("L") tir_HF_img = tir_HF_img.convert("L") 转换为灰度图,可能是这个原因导致的输入为2channels,且infrared是单通道的图像,我想请教一下这里应该是将6改为2,还是需要将输入的图像改为三通道的图像。也有可能是我理解错误,问题的原因不在这里,还麻烦帮我解答一下问题的原因和解决方法,谢谢。

lengmo1996 commented 2 months ago

您好,感谢您对我们工作的关注! 这里是因为我们在开源的时候对代码做了一定的修改导致的不一致。您既可以将输入的channels从6改为2,也可以将输入的图像改为三通道的图像。前者直接修改ecdm_first_stage.yaml配置文件就好,后者可以通过vis_img = vis_img.convert("RGB")实现。 我们早期是将输入图像设置为RGB模式的,这主要是为了兼容一些评价指标的算法,但这种方式由于没有对各个通道之间的一致性进行约束,有时候会导致生成的图像偏色(这也是论文中的实践方案)。在开源代码时我们将输入图像改为了L模式。