tianweiy / DMD2

Other
417 stars 24 forks source link

LoRA training #35

Open SYZhang0805 opened 1 month ago

SYZhang0805 commented 1 month ago

When I try to finetune a 1step SDXL model with LoRA, I got an error: Traceback (most recent call last): File "main/train_sd.py", line 701, in trainer.train() File "main/train_sd.py", line 598, in train self.train_one_step() File "main/train_sd.py", line 385, in train_one_step guidance_loss_dict, guidance_log_dict = self.model( File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "/hy-tmp/DMD2/main/sd_unified_model.py", line 353, in forward loss_dict, log_dict = self.guidance_model( File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward output = self._run_ddp_forward(inputs, kwargs) File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward return module_to_run(*inputs[0], kwargs[0]) # type: ignore[index] File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, *kwargs) File "/hy-tmp/DMD2/main/sd_guidance.py", line 446, in forward loss_dict, log_dict = self.guidance_forward( File "/hy-tmp/DMD2/main/sd_guidance.py", line 417, in guidance_forward clean_cls_loss_dict, clean_cls_log_dict = self.compute_guidance_clean_cls_loss( File "/hy-tmp/DMD2/main/sd_guidance.py", line 376, in compute_guidance_clean_cls_loss pred_realism_on_real = self.compute_cls_logits( File "/hy-tmp/DMD2/main/sd_guidance.py", line 165, in compute_cls_logits logits = self.cls_pred_branch(rep).squeeze(dim=[2, 3]) File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(args, kwargs) File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/container.py", line 217, in forward input = module(input) File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 463, in forward return self._conv_forward(input, self.weight, self.bias) File "/usr/local/miniconda3/envs/dmd2/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward return F.conv2d(input, weight, bias, self.stride, RuntimeError: Calculated padded input size per channel: (2 x 2). Kernel size: (4 x 4). Kernel size can't be greater than actual input size

If I removed the "--generator_lora" from the command, this error would not exist, but I would get "out of memory". So I must use LoRA because of the limited GPU resources. Why did this error occur?

tianweiy commented 1 month ago

the current lora implementation doesn't seem to save the memory so it won't help with oom.

the issues happens in the classification branch. if you train with resolution different than 1024x1024, you need to replace https://github.com/tianweiy/DMD2/blob/0f8a481716539af7b2795740c9763a7d0d05b83b/main/sd_guidance.py#L118 with a pooling layer so that it can handle varying size

SYZhang0805 commented 1 month ago

When I was training, I used a resolution of 1024, but this error still persisted after adding generator_lora. Additionally, I have another question: what is the difference between the sd_vae_latents_laion_500k_lmdb downloaded using download_sdxl.sh and download_sdv15.sh?

SYZhang0805 commented 1 month ago

Additionally, when I use generator_lora, there are many parameter mismatches when loading the pre-trained generator. However, when not using lora, the parameters match completely. Could this be the reason for the aforementioned issue?

tianweiy commented 1 month ago

please check the input resolution at the line (adding some print statements) i never met the issues

When I was training, I used a resolution of 1024, but this error still persisted after adding generator_lora. Additionally, I have another question: what is the difference between the sd_vae_latents_laion_500k_lmdb downloaded using download_sdxl.sh and download_sdv15.sh?

they use different vae encoder and the image resolution is different