XPixelGroup / DiffBIR

Official codes of DiffBIR: Towards Blind Image Restoration with Generative Diffusion Prior
Apache License 2.0
3.17k stars 268 forks source link

Question on training Stage2 with Real-ESRGAN degradation #92

Open JaehaKim97 opened 6 months ago

JaehaKim97 commented 6 months ago

Hi, thanks for sharing the great work!

I tried to follow the training process, but faced problems during training in the Stage 2 model.

After filling the train_cldm.yaml file, I run the python train.py --config configs/train_cldm.yaml, but got the below error:

Traceback (most recent call last):
  File "/home/jaeha/Research/DiffBIR/train.py", line 32, in <module>
    main()
  File "/home/jaeha/Research/DiffBIR/train.py", line 28, in main
    trainer.fit(model, datamodule=data_module)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 553, in fit
    self._run(model)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 918, in _run
    self._dispatch()
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 986, in _dispatch
    self.accelerator.start_training(self)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 92, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in start_training
    self._results = trainer.run_stage()
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 996, in run_stage
    return self._run_train()
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1045, in _run_train
    self.fit_loop.run()
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 200, in advance
    epoch_output = self.epoch_loop.run(train_dataloader)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
    self.advance(*args, **kwargs)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 149, in advance
    self.trainer.call_hook(
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1217, in call_hook
    trainer_hook(*args, **kwargs)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/callback_hook.py", line 189, in on_train_batch_end
    callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/utilities/distributed.py", line 48, in wrapped_fn
    return fn(*args, **kwargs)
  File "/home/jaeha/Research/DiffBIR/model/callbacks.py", line 55, in on_train_batch_end
    images: Dict[str, torch.Tensor] = pl_module.log_images(batch, **self.log_images_kwargs)
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/jaeha/Research/DiffBIR/model/cldm.py", line 379, in log_images
    samples = self.sample_log(
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/jaeha/Research/DiffBIR/model/cldm.py", line 394, in sample_log
    samples = sampler.sample(
  File "/home/jaeha/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
TypeError: sample() got an unexpected keyword argument 'unconditional_guidance_scale'

I suspect the error occurs from here: https://github.com/XPixelGroup/DiffBIR/blob/7bd5675823c157b9afdd479b59a2bf0a8954ce11/model/cldm.py#L394

where the function sample in SpacedSampler does not require unconditional_guidance_scale as input components.

Could you please let me know the solution for this symptom?

umbertov commented 6 months ago

A quick look at git log --patch model/spaced_sampler.py shows like the model/spaced_sampler.py file was updated, but model/cldm.py was not updated to reflect the changes in sampler API. Just bumped into this myself, you have two solutions:

  1. Resetting the repo (or just model/spaced_sampler.py, but you might end up with something inconsistent) to previous version when spaced sampler had the old api. From what i see, d3e29f7 is the last commit when spaced_sampler had the old api
  2. Looking at code, scratching your head and rewriting the necessary code.

Let me know if you already did some progress on this, since you asked a while ago, i'd be interested in something other than git reset --hard d3e29f7

JaehaKim97 commented 6 months ago

Thanks for sharing the awesome information!

In my solution, I changed the original sampler.sample function into another sampler.sample function, which is implemented in the inference.py. (To be precise, I additionally modify here to add "c_lq" into condition, and remove the decoding and normalizing steps.)

I'm not sure it is the correct solution, but it seems now to be working as I expected.

Windrain7 commented 6 months ago

Thank you for the solution. Could you please provide a more specific code snippet? I now modify the calling method of the function to this way. samples = sampler.sample( steps=steps, shape=shape, cond_img=cond["c_concat"][0], positive_prompt="", negative_prompt="", cfg_scale=1.0 ) But another error occurred

Traceback (most recent call last):█████████████████████████████████████████████████████████████████████████| 50/50 [00:37<00:00, 1.32it/s]
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1045, in _run_train
self.fit_loop.run()
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, kwargs)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 200, in advance
epoch_output = self.epoch_loop.run(train_dataloader)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, *kwargs)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 149, in advance
self.trainer.call_hook(
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1217, in call_hook
trainer_hook(
args,
kwargs)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/callback_hook.py", line 189, in on_train_batch_end
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/utilities/distributed.py", line 48, in wrapped_fn
return fn(args, kwargs)
File "/data/jt/projects/DiffBIR/model/callbacks.py", line 55, in on_train_batch_end
images: Dict[str, torch.Tensor] = pl_module.log_images(batch,
self.log_images_kwargs)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(
args, kwargs)
File "/data/jt/projects/DiffBIR/model/cldm.py", line 385, in log_images
x_samples = self.decode_first_stage(samples)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, *kwargs)
File "/data/jt/projects/DiffBIR/ldm/models/diffusion/ddpm.py", line 832, in decode_first_stage
return self.first_stage_model.decode(z)
File "/data/jt/projects/DiffBIR/ldm/models/autoencoder.py", line 90, in decode
z = self.post_quant_conv(z)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(
input,
kwargs)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [4, 4, 1, 1], expected input[8, 3, 512, 512] to have 4 channels, but got 3 channels instead

Does "remove the decoding and normalizing steps" means delete https://github.com/XPixelGroup/DiffBIR/blob/7bd5675823c157b9afdd479b59a2bf0a8954ce11/model/cldm.py#L384 where error happened? Can you give me a more detailed solution, thank you!

JaehaKim97 commented 5 months ago

Below is my code snippet. But again, note that it is NOT the official solution.

from .spaced_sampler import SpacedSampler
...

class ControlLDM(LatentDiffusion):
     ...
    @torch.no_grad()
    def log_images(self, batch, sample_steps=50):
        log = dict()
        z, c = self.get_input(batch, self.first_stage_key)
        c_lq = c["lq"][0]
        c_latent = c["c_latent"][0]
        c_cat, c = c["c_concat"][0], c["c_crossattn"][0]

        log["hq"] = (self.decode_first_stage(z) + 1) / 2
        log["control"] = c_cat
        log["decoded_control"] = (self.decode_first_stage(c_latent) + 1) / 2
        log["lq"] = c_lq
        log["text"] = (log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16) + 1) / 2

        samples = self.sample_log(
            # TODO: remove c_concat from cond
            # cond={"c_concat": [c_cat], "c_crossattn": [c], "c_latent": [c_latent]},
            cond={"c_lq": c_lq, "c_concat": [c_cat], "c_crossattn": [c], "c_latent": [c_latent]},
            steps=sample_steps
        )
        # x_samples = self.decode_first_stage(samples)
        # log["samples"] = (x_samples + 1) / 2
        log["samples"] = samples

        return log

    @torch.no_grad()
    def sample_log(self, cond, steps, cond_fn=None, color_fix_type="wavelet"):
        sampler = SpacedSampler(self)
        b, c, h, w = cond["c_concat"][0].shape
        shape = (b, self.channels, h // 8, w // 8)
        x_T = torch.randn(shape, device=self.model.device, dtype=torch.float32)
        # samples = sampler.sample(
        #     steps, shape, cond, unconditional_guidance_scale=1.0,
        #     unconditional_conditioning=None
        # )
        samples = sampler.sample(
            steps=steps, shape=shape, cond_img=cond["c_lq"],
            positive_prompt="", negative_prompt="", x_T=x_T,
            cfg_scale=1.0, cond_fn=cond_fn,
            color_fix_type=color_fix_type
        )
        return samples