Open SYZhang0805 opened 1 month ago
the current lora implementation doesn't seem to save the memory so it won't help with oom.
the issues happens in the classification branch. if you train with resolution different than 1024x1024, you need to replace https://github.com/tianweiy/DMD2/blob/0f8a481716539af7b2795740c9763a7d0d05b83b/main/sd_guidance.py#L118 with a pooling layer so that it can handle varying size
When I was training, I used a resolution of 1024, but this error still persisted after adding generator_lora. Additionally, I have another question: what is the difference between the sd_vae_latents_laion_500k_lmdb downloaded using download_sdxl.sh and download_sdv15.sh?
Additionally, when I use generator_lora, there are many parameter mismatches when loading the pre-trained generator. However, when not using lora, the parameters match completely. Could this be the reason for the aforementioned issue?
please check the input resolution at the line (adding some print statements) i never met the issues
When I was training, I used a resolution of 1024, but this error still persisted after adding generator_lora. Additionally, I have another question: what is the difference between the sd_vae_latents_laion_500k_lmdb downloaded using download_sdxl.sh and download_sdv15.sh?
they use different vae encoder and the image resolution is different
When I try to finetune a 1step SDXL model with LoRA, I got an error: Traceback (most recent call last): File "main/train_sd.py", line 701, in
trainer.train()
File "main/train_sd.py", line 598, in train
self.train_one_step()
File "main/train_sd.py", line 385, in train_one_step
guidance_loss_dict, guidance_log_dict = self.model(
File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, kwargs)
File "/hy-tmp/DMD2/main/sd_unified_model.py", line 353, in forward
loss_dict, log_dict = self.guidance_model(
File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, *kwargs)
File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
output = self._run_ddp_forward(inputs, kwargs)
File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
return module_to_run(*inputs[0], kwargs[0]) # type: ignore[index]
File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, *kwargs)
File "/hy-tmp/DMD2/main/sd_guidance.py", line 446, in forward
loss_dict, log_dict = self.guidance_forward(
File "/hy-tmp/DMD2/main/sd_guidance.py", line 417, in guidance_forward
clean_cls_loss_dict, clean_cls_log_dict = self.compute_guidance_clean_cls_loss(
File "/hy-tmp/DMD2/main/sd_guidance.py", line 376, in compute_guidance_clean_cls_loss
pred_realism_on_real = self.compute_cls_logits(
File "/hy-tmp/DMD2/main/sd_guidance.py", line 165, in compute_cls_logits
logits = self.cls_pred_branch(rep).squeeze(dim=[2, 3])
File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(args, kwargs)
File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Calculated padded input size per channel: (2 x 2). Kernel size: (4 x 4). Kernel size can't be greater than actual input size
If I removed the "--generator_lora" from the command, this error would not exist, but I would get "out of memory". So I must use LoRA because of the limited GPU resources. Why did this error occur?