tianweiy / DMD2

(NeurIPS 2024 Oral 🔥) Improved Distribution Matching Distillation for Fast Image Synthesis
Other
524 stars 28 forks source link

When I was training with cls_on_clean_image, I got an error: Given groups=1, weight of size [320, 4, 3, 3], expected input[1, 3, 512, 512] to have 4 channels, but got 3 channels instead #38

Closed koking0 closed 4 months ago

koking0 commented 4 months ago

When I was training cls_on_clean_image, I encountered the following error: Given groups=1, weight of size [320, 4, 3, 3], expected input[1, 3, 512, 512] to have 4 channels, but got 3 channels instead.

The first stage of training script:

export CHECKPOINT_PATH=/root/workspace/env_run/dmd2/lina-4.3-dmd2-ckpt
export WANDB_ENTITY=
export WANDB_PROJECT=lina-4.3-dmd2
export WANDB_MODE=offline

torchrun --nnodes 1 --nproc_per_node=16 main/train_sd.py \
--generator_lr 1e-5 \
--guidance_lr 1e-5 \
--train_iters 100000 \ 
--output_path $CHECKPOINT_PATH \ 
--batch_size 44 \ 
--grid_size 2 \ 
--initialie_generator --log_iters 1000 \ 
--resolution 512 \ 
--latent_resolution 64 \ 
--seed 10 \ 
--real_guidance_scale 1.75 \ 
--fake_guidance_scale 1.0 \ 
--max_grad_norm 10.0 \ 
--model_id "/root/workspace/env_run/lina4-3base" \ 
--train_prompt_path /root/workspace/env_run/dmd2/prompts/shuffled.txt \ 
--afs_data_path="/root/workspace/env/2kw_merge_result/" \ 
--afs_part_list="/root/workspace/env/2kw_part_count/part-00000" \
--log_path /root/workspace/env_run/dmd2/tensorboard_log_lina-4.3 \
--wandb_iters 100 \
--use_fp16 \
--log_loss \
--dfake_gen_update_ratio 10 \
--gradient_checkpointing

Training script for the second stage:

export CHECKPOINT_PATH=/root/workspace/env_run/dmd2/lina-4.3-dmd2-ckpt-cls
export WANDB_ENTITY=
export WANDB_PROJECT=lina-4.3-dmd2
export WANDB_MODE=offline 

accelerate launch main/train_sd.py \ 
--generator_lr 5e-7 \ 
--guidance_lr 5e-7 \ 
--train_iters 50000 \ 
--output_path $CHECKPOINT_PATH \ 
--batch_size 1 \ 
--grid_size 2 \ 
--initialie_generator --log_iters 1000 \ 
--resolution 512 \ 
--latent_resolution 64 \ 
--seed 10 \
--real_guidance_scale 1.75 \ 
--fake_guidance_scale 1.0 \ 
--max_grad_norm 10.0 \ 
--model_id "/root/workspace/env_run/lina4-3base" \ 
--train_prompt_path /root/workspace/env_run/dmd2/prompts/shuffled.txt \ 
--afs_data_path="/root/workspace/env/2kw_merge_result/" \ 
--afs_part_list="/root/workspace/env/2kw_part_count/part-00000" \ 
--log_ path /root/workspace/env_run/dmd2/tensorboard_log_lina-4.3-cls \ 
--wandb_iters 100 \ 
--use_fp16 \ 
--log_loss \ 
--dfake_gen_update_ratio 10 \ 
--gradient_checkpointing \ 
--cls_on_clean_image \ 
--gen_cls_loss \ 
--gen_cls_loss_weight 1e-3 \ 
--guidance_cls_loss_weight 1e-2 \
--diffusion_gan \
--diffusion_gan_max_timestep 1000 \
--ckpt_only_path /root/workspace/env_run/dmd2/lina-4.3-dmd2-ckpt/time_1721280534_seed10/checkpoint_model_029000

I encountered an error in the second stage of training:

Traceback (most recent call last):
File "/root/workspace/baidu/personal-code/DMD2/main/train_sd.py", line 753, in <module>
trainer.train()
File "/root/workspace/baidu/personal-code/DMD2/main/train_sd.py", line 645, in train
self.train_one_step()
File "/root/workspace/baidu/personal-code/DMD2/main/train_sd.py", line 425, in train_one_step
self.accelerator.backward(guidance_loss)
File "/usr/local/python3.9/lib/python3.9/site-packages/accelerate/accelerator.py", line 1987, in backward
self.scaler.scale(loss).backward(**kwargs)
File "/usr/local/python3.9/lib/python3.9/site-packages/torch/_tensor.py", line 492, in backward
torch.autograd.backward(
File "/usr/local/python3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 251, in backward
Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Given groups=1, weight of size [320, 4, 3, 3], expected input[1, 3, 512, 512] to have 4 channels, but got 3 channels instead

In the following [1, 3, 512, 512], 1 refers to batch size, 3 refers to the number of image channels, and 512 refers to the image resolution.

So this error means that a weight of [320, 4, 3, 3] expects the number of input channels to be 4, but the actual number of input channels is 3.

Printing the dimensions before and after GAN classification, two output sizes appear:

real_image shape: torch.Size([1, 3, 512, 512])
fake_image shape: torch.Size([1, 4, 64, 64])
noisy_image shape: torch.Size([1, 4, 64, 64])
generated_noise shape: torch.Size([1, 4, 64, 64])
generated_image shape: torch.Size([1, 4, 64, 64])

I think there is a mismatch here.

My question is:

  1. I replaced our own image Dataset and did not use LMDB. Is the real_image shape loaded by retrieve_row_from_lmdb of SDImageDatasetLMDB the same as the fake_image shape?
  2. If it is not the problem described in 1, what might be the cause of this problem?
koking0 commented 4 months ago

yes, is 1.