Open JaehaKim97 opened 6 months ago
A quick look at git log --patch model/spaced_sampler.py
shows like the model/spaced_sampler.py
file was updated, but model/cldm.py
was not updated to reflect the changes in sampler API.
Just bumped into this myself, you have two solutions:
model/spaced_sampler.py
, but you might end up with something inconsistent) to previous version when spaced sampler had the old api. From what i see, d3e29f7
is the last commit when spaced_sampler had the old apiLet me know if you already did some progress on this, since you asked a while ago, i'd be interested in something other than git reset --hard d3e29f7
Thanks for sharing the awesome information!
In my solution, I changed the original sampler.sample function into another sampler.sample function, which is implemented in the inference.py
. (To be precise, I additionally modify here to add "c_lq" into condition, and remove the decoding and normalizing steps.)
I'm not sure it is the correct solution, but it seems now to be working as I expected.
Thank you for the solution. Could you please provide a more specific code snippet? I now modify the calling method of the function to this way.
samples = sampler.sample( steps=steps, shape=shape, cond_img=cond["c_concat"][0], positive_prompt="", negative_prompt="", cfg_scale=1.0 )
But another error occurred
Traceback (most recent call last):█████████████████████████████████████████████████████████████████████████| 50/50 [00:37<00:00, 1.32it/s]
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1045, in _run_train
self.fit_loop.run()
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, kwargs)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py", line 200, in advance
epoch_output = self.epoch_loop.run(train_dataloader)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/base.py", line 111, in run
self.advance(*args, *kwargs)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 149, in advance
self.trainer.call_hook(
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1217, in call_hook
trainer_hook(args, kwargs)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/trainer/callback_hook.py", line 189, in on_train_batch_end
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/pytorch_lightning/utilities/distributed.py", line 48, in wrapped_fn
return fn(args, kwargs)
File "/data/jt/projects/DiffBIR/model/callbacks.py", line 55, in on_train_batch_end
images: Dict[str, torch.Tensor] = pl_module.log_images(batch, self.log_images_kwargs)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(args, kwargs)
File "/data/jt/projects/DiffBIR/model/cldm.py", line 385, in log_images
x_samples = self.decode_first_stage(samples)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, *kwargs)
File "/data/jt/projects/DiffBIR/ldm/models/diffusion/ddpm.py", line 832, in decode_first_stage
return self.first_stage_model.decode(z)
File "/data/jt/projects/DiffBIR/ldm/models/autoencoder.py", line 90, in decode
z = self.post_quant_conv(z)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(input, kwargs)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/jt/anaconda3/envs/diffbir/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [4, 4, 1, 1], expected input[8, 3, 512, 512] to have 4 channels, but got 3 channels instead
Does "remove the decoding and normalizing steps" means delete https://github.com/XPixelGroup/DiffBIR/blob/7bd5675823c157b9afdd479b59a2bf0a8954ce11/model/cldm.py#L384 where error happened? Can you give me a more detailed solution, thank you!
Below is my code snippet. But again, note that it is NOT the official solution.
from .spaced_sampler import SpacedSampler
...
class ControlLDM(LatentDiffusion):
...
@torch.no_grad()
def log_images(self, batch, sample_steps=50):
log = dict()
z, c = self.get_input(batch, self.first_stage_key)
c_lq = c["lq"][0]
c_latent = c["c_latent"][0]
c_cat, c = c["c_concat"][0], c["c_crossattn"][0]
log["hq"] = (self.decode_first_stage(z) + 1) / 2
log["control"] = c_cat
log["decoded_control"] = (self.decode_first_stage(c_latent) + 1) / 2
log["lq"] = c_lq
log["text"] = (log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16) + 1) / 2
samples = self.sample_log(
# TODO: remove c_concat from cond
# cond={"c_concat": [c_cat], "c_crossattn": [c], "c_latent": [c_latent]},
cond={"c_lq": c_lq, "c_concat": [c_cat], "c_crossattn": [c], "c_latent": [c_latent]},
steps=sample_steps
)
# x_samples = self.decode_first_stage(samples)
# log["samples"] = (x_samples + 1) / 2
log["samples"] = samples
return log
@torch.no_grad()
def sample_log(self, cond, steps, cond_fn=None, color_fix_type="wavelet"):
sampler = SpacedSampler(self)
b, c, h, w = cond["c_concat"][0].shape
shape = (b, self.channels, h // 8, w // 8)
x_T = torch.randn(shape, device=self.model.device, dtype=torch.float32)
# samples = sampler.sample(
# steps, shape, cond, unconditional_guidance_scale=1.0,
# unconditional_conditioning=None
# )
samples = sampler.sample(
steps=steps, shape=shape, cond_img=cond["c_lq"],
positive_prompt="", negative_prompt="", x_T=x_T,
cfg_scale=1.0, cond_fn=cond_fn,
color_fix_type=color_fix_type
)
return samples
Hi, thanks for sharing the great work!
I tried to follow the training process, but faced problems during training in the Stage 2 model.
After filling the
train_cldm.yaml
file, I run thepython train.py --config configs/train_cldm.yaml
, but got the below error:I suspect the error occurs from here: https://github.com/XPixelGroup/DiffBIR/blob/7bd5675823c157b9afdd479b59a2bf0a8954ce11/model/cldm.py#L394
where the function sample in SpacedSampler does not require
unconditional_guidance_scale
as input components.Could you please let me know the solution for this symptom?