rmin2000 / WaDiff

A Watermark-Conditioned Diffusion Model for IP Protection (ECCV 2024)
12 stars 1 forks source link

Some question about this error #2

Open moonfalling opened 6 days ago

moonfalling commented 6 days ago

I use 5,000 256x256 images to train on first step with this command

python train.py --data_dir ../dataset256/dataset256 \ --bit_length 48 --image_resolution 256 --num_epochs 100 --data_size 5000 --cuda 0 and got output like step_7500_decoder.pth

In step 2

I use this command to fine-tune [256x256_diffusion_uncond.pt i download from this Link MODEL_FLAGS="--wm_length 48 --attention_resolutions 32,16,8 --class_cond False --image_size 256 --num_channels 256 --learn_sigma True --num_head_channels 64 --num_res_blocks 2 --resblock_updown True" DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear" TRAIN_FLAGS="--lr 1e-4 --batch_size 2" NUM_GPUS=1 python scripts/image_train.py --alpha 0.4 --threshold 400 --wm_decoder_path ../checkpoints256/checkpoints/step_7500_decoder.pth --data_dir ../dataset256/dataset256 --resume_checkpoint models/256x256_diffusion_uncond.pt $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

I finetune until log file generate model like ema_0.9999_060000.pt , model060000.pt , opt060000.pt

And in Step 3 I want to generate image from my model i must use model060000.pt in log folder right ? I will command on generate.sh like this

MODEL_FLAGS="--wm_length 48 --attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True" SAMPLE_FLAGS="--batch_size 4 --num_samples 8 --timestep_respacing 100 --use_ddim True" python scripts/image_sample.py $MODEL_FLAGS --output_path saved_images/ --model_path ./finetunelog/openai-2024-09-12-20-05-41-150569/model060000.pt $SAMPLE_FLAGS

i got some error like this [rank0]: File "/home/jovyan/work/wadiff/Wadiff_bibi/guided-diffusion/guided_diffusion/gaussian_diffusion.py", line 301, in p_mean_variance [rank0]: model_output = model(x, self._scale_timesteps(t), **model_kwargs) [rank0]: File "/home/jovyan/work/wadiff/Wadiff_bibi/guided-diffusion/guided_diffusion/respace.py", line 128, in __call__ [rank0]: return self.model(x, new_ts, **kwargs) [rank0]: File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: File "/home/jovyan/work/wadiff/Wadiff_bibi/guided-diffusion/guided_diffusion/unet.py", line 664, in forward [rank0]: wm_emb = self.secret_dense(fingerprint).view((-1, self.in_channels, self.image_size, self.image_size)).type(self.dtype) [rank0]: File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl [rank0]: return self._call_impl(*args, **kwargs) [rank0]: File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl [rank0]: return forward_call(*args, **kwargs) [rank0]: File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 117, in forward [rank0]: return F.linear(input, self.weight, self.bias) [rank0]: RuntimeError: mat1 and mat2 shapes cannot be multiplied (768x256 and 48x196608)

if i change command to --wm_length 0 it error like this raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( [rank0]: RuntimeError: Error(s) in loading state_dict for UNetModel: [rank0]: Unexpected key(s) in state_dict: "secret_dense.weight", "secret_dense.bias". [rank0]: size mismatch for input_blocks.0.0.weight: copying a param with shape torch.Size([256, 6, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 3, 3, 3]).

what wrong of my file setup ?

rmin2000 commented 5 days ago

Hi, thanks for your interest :). This error is caused by the incomplete code in the generation pipeline, and I've updated my code to fix it. Feel free to let me know if you have any further concerns.