Closed DuanHongxuan closed 2 months ago
您好,感谢您对我们工作的关注! 这里是因为我们在开源的时候对代码做了一定的修改导致的不一致。您既可以将输入的channels从6改为2,也可以将输入的图像改为三通道的图像。前者直接修改ecdm_first_stage.yaml配置文件就好,后者可以通过vis_img = vis_img.convert("RGB")实现。 我们早期是将输入图像设置为RGB模式的,这主要是为了兼容一些评价指标的算法,但这种方式由于没有对各个通道之间的一致性进行约束,有时候会导致生成的图像偏色(这也是论文中的实践方案)。在开源代码时我们将输入图像改为了L模式。
Traceback (most recent call last): File "/media/ubuntu/新加卷/dhx_workspace/image_matching/ECDM-master/main.py", line 144, in
cli_main()
File "/media/ubuntu/新加卷/dhx_workspace/image_matching/ECDM-master/main.py", line 137, in cli_main
cli = MyLightningCLI(
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/cli.py", line 359, in init
self._run_subcommand(self.subcommand)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/cli.py", line 650, in _run_subcommand
fn(fn_kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 532, in fit
call._call_and_handle_interrupt(
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
return trainer_fn(*args, *kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 571, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 980, in _run
results = self._run_stage()
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1023, in _run_stage
self.fit_loop.run()
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 202, in run
self.advance()
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 355, in advance
self.epoch_loop.run(self._data_fetcher)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 133, in run
self.advance(data_fetcher)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 219, in advance
batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 188, in run
self._optimizer_step(kwargs.get("batch_idx", 0), closure)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 266, in _optimizer_step
call._call_lightning_module_hook(
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 146, in _call_lightning_module_hook
output = fn(args, kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1276, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py", line 161, in step
step_output = self._strategy.optimizer_step(self._optimizer, closure, kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 231, in optimizer_step
return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py", line 116, in optimizer_step
return optimizer.step(closure=closure, kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/optim/optimizer.py", line 391, in wrapper
out = func(*args, *kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/optim/optimizer.py", line 76, in _use_grad
ret = func(self, args, kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/optim/adamw.py", line 165, in step
loss = closure()
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision_plugin.py", line 103, in _wrap_closure
closure_result = closure()
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 142, in call
self._result = self.closure(*args, kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 128, in closure
step_output = self._step_fn()
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 315, in _training_step
training_step_output = call._call_strategy_hook(trainer, "training_step", kwargs.values())
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 294, in _call_strategy_hook
output = fn(args, kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 380, in training_step
return self.model.training_step(*args, kwargs)
File "/media/ubuntu/新加卷/dhx_workspace/image_matching/ECDM-master/ecdm/models/diffusion/ecdm_first_stage.py", line 552, in training_step
loss, loss_dict = self.shared_step(batch)
File "/media/ubuntu/新加卷/dhx_workspace/image_matching/ECDM-master/ecdm/models/diffusion/ecdm_first_stage.py", line 548, in shared_step
loss, loss_dict = self(x, cond=cond)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, *kwargs)
File "/media/ubuntu/新加卷/dhx_workspace/image_matching/ECDM-master/ecdm/models/diffusion/ecdm_first_stage.py", line 534, in forward
return self.p_losses(x, t, args, kwargs)
File "/media/ubuntu/新加卷/dhx_workspace/image_matching/ECDM-master/ecdm/models/diffusion/ecdm_first_stage.py", line 500, in p_losses
model_out = self.model(x_noisy, t, cond)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(args, kwargs)
File "/media/ubuntu/新加卷/dhx_workspace/image_matching/ECDM-master/ecdm/models/diffusion/ecdm_first_stage.py", line 711, in forward
out = self.diffusion_model(xc, t)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, *kwargs)
File "/media/ubuntu/新加卷/dhx_workspace/image_matching/ECDM-master/ecdm/modules/diffusionmodules/simple_unet.py", line 296, in forward
hs = [self.conv_in(x)]
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(args, kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/ubuntu/anaconda3/envs/dhx/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [128, 6, 3, 3], expected input[1, 2, 512, 640] to have 6 channels, but got 2 channels instead
Epoch 0: 0%| | 0/12025 [00:01<?, ?it/s]
上述是我训练的时候遇到的问题,貌似是输入的维度不匹配,我在ecdm_first_stage.yaml文件中发现in_channels: 6 ,而在dataset.py中 vis_img = vis_img.convert("L") vis_HF_img = vis_HF_img.convert("L") tir_img = tir_img.convert("L") tir_HF_img = tir_HF_img.convert("L") 转换为灰度图,可能是这个原因导致的输入为2channels,且infrared是单通道的图像,我想请教一下这里应该是将6改为2,还是需要将输入的图像改为三通道的图像。也有可能是我理解错误,问题的原因不在这里,还麻烦帮我解答一下问题的原因和解决方法,谢谢。