THUDM / CogVideo

text and image to video generation: CogVideoX (2024) and CogVideo (ICLR 2023)
Apache License 2.0
9.29k stars 873 forks source link

CogvideoX1.5 I2V training patch_size_t = 2, makes num_frame=49 cannot be divided by 2 #532

Open xilanhua12138 opened 2 days ago

xilanhua12138 commented 2 days ago

System Info / 系統信息

When training num_frames=49 and it will be compressed into 13 by vae, but in this code p = self.patch_size p_t = self.patch_size_t

        image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
        image_embeds = image_embeds.reshape(
            batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
        )
        p_t is 2, and 13 // 2 = 6 so this reshape reports error

Information / 问题信息

Reproduction / 复现过程

{ // 使用 IntelliSense 了解相关属性。 // 悬停以查看现有属性的描述。 // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ { "name": "Train CogVideoX-5b-I2V", "type": "debugpy", "request": "launch", "program": "training/cogvideox_image_to_video_lora.py", "console": "integratedTerminal", "justMyCode": false, "args": [ "--gradient_checkpointing", "--pretrained_model_name_or_path", "/mnt2/shuiyunhao/task/CogVideo/models", "--data_root", "data/disney", "--caption_column", "prompt.txt", "--video_column", "videos.txt", "--validation_prompt", "BW_STYLE, A cartoon version of Mickey Mouse is seen in a cozy kitchen setting, surrounded by hanging pots and utensils. The character initially bends to interact with a wooden crate on the floor, then pulls out a pot and places it on the stove. Mickey manipulates a ladle above the pot, stirring vigorously, before raising the ladle in a dynamic pose, indicating completion of their task. The background remains consistent, emphasizing the domestic environment, as Mickey's actions progress from preparation to execution. ::: BW_STYLE, A black and white animated kitchen scene unfolds with two characters: one hit on the back of the head with a frying pan, causing musical notes to scatter. The hit character stands upright, while the other lies on the ground, indicating a comedic interaction. The hit character then appears distressed, but soon adopts a mischievous grin, lifting a heavy basket of bread in preparation for a prank, before realizing the potential consequences and freezing in surprise.", "--validation_images", "data/disney/validation/val_disney.png:::data/disney/validation/val_disney_2.png", "--validation_prompt_separator", ":::", "--num_validation_videos", "1", "--validation_epochs", "10", "--seed", "42", "--rank", "128", "--lora_alpha", "128", "--mixed_precision", "bf16", "--output_dir", "output_models", "--id_token", "BW_STYLE", "--height_buckets", "480", "--width_buckets", "720", "--frame_buckets", "49", "--max_num_frames", "49", "--train_batch_size", "1", "--num_train_epochs", "30", "--checkpointing_steps", "1000", "--gradient_accumulation_steps", "2", "--learning_rate", "1e-3", "--lr_scheduler", "cosine_with_restarts", "--lr_warmup_steps", "400", "--lr_num_cycles", "1", "--enable_slicing", "--enable_tiling", "--optimizer", "AdamW", "--beta1", "0.9", "--beta2", "0.95", "--max_grad_norm", "1.0", "--allow_tf32", "--report_to", "tensorboard" ], } ] }

Expected behavior / 期待表现

Should be divided by 2

zRzRzRzRzRzRzR commented 2 days ago

You need to set it to 81 and clone the source code of the diffusers library, as this is where the implementation of version 1.5 is located.

xilanhua12138 commented 2 days ago

You need to set it to 81 and clone the source code of the diffusers library, as this is where the implementation of version 1.5 is located.

81 still not work, because 81 is transformed to 21

zRzRzRzRzRzRzR commented 1 day ago

No, it is not 21 in the calculation process, but 22, the first frame will be copied once.

xilanhua12138 commented 1 day ago

No, it is not 21 in the calculation process, but 22, the first frame will be copied once.

I install diffusers by pip install git+https://github.com/zRzRzRzRzRzRzR/diffusers.git

but in my code, the calculation result is still 21, you can reproduce this bug by disney dataset, and the following scripts

#!/bin/bash
# max batch-size  is 2.
DEFAULT_GPUS="0,1,2,3,4,5,6,7"
CUDA_DEVICES=${2:-$DEFAULT_GPUS}
export CUDA_VISIBLE_DEVICES=$CUDA_DEVICES

NUM_PROCESSES=${1:-8}

accelerate launch \
    --config_file finetune/accelerate_config_machine_single.yaml \
    --multi_gpu \
    --num_processes $NUM_PROCESSES \
    --machine_rank 0 \
    --main_process_port 29501 \
    finetune/train_cogvideox_image_to_video_lora.py \
    --gradient_checkpointing \
    --pretrained_model_name_or_path 'models' \
    --enable_tiling \
    --enable_slicing \
    --instance_data_root "data/disney" \
    --caption_column prompt.txt \
    --video_column videos.txt \
    --validation_prompt "A kitchen scene unfolds with Mickey Mouse-like character on the left, startled, while a female mouse character, wearing a hat, holds a gun to her head. Above, a 'KITCHEN' sign is visible. The scene shifts to a black goat character standing alone against a plain background, initially facing away, then turning with a content smile and raised hoof near a guitar. The scene returns to the kitchen, with both mice interacting near the goat, looking surprised or curious about the situation.:::A cartoon version of Mickey Mouse is seen in a cozy kitchen setting, surrounded by hanging pots and utensils. The character initially bends to interact with a wooden crate on the floor, then pulls out a pot and places it on the stove. Mickey manipulates a ladle above the pot, stirring vigorously, before raising the ladle in a dynamic pose, indicating completion of their task. The background remains consistent, emphasizing the domestic environment, as Mickey's actions progress from preparation to execution. ::: A black and white animated kitchen scene unfolds with two characters: one hit on the back of the head with a frying pan, causing musical notes to scatter. The hit character stands upright, while the other lies on the ground, indicating a comedic interaction. The hit character then appears distressed, but soon adopts a mischievous grin, lifting a heavy basket of bread in preparation for a prank, before realizing the potential consequences and freezing in surprise." \
    --validation_prompt_separator ::: \
    --validation_images "data/disney/validation/train_disney.png:::data/disney/validation/val_disney.png:::data/disney/validation/val_disney_2.png" \
    --num_validation_videos 1 \
    --validation_epochs 10 \
    --seed 42 \
    --rank 128 \
    --lora_alpha 128 \
    --mixed_precision bf16 \
    --output_dir "output_models" \
    --height 480 \
    --width 720 \
    --fps 8 \
    --max_num_frames 81 \
    --skip_frames_start 0 \
    --skip_frames_end 0 \
    --train_batch_size 4 \
    --max_train_steps 10000 \
    --num_train_epochs 300 \
    --checkpointing_steps 500 \
    --gradient_accumulation_steps 1 \
    --learning_rate 1e-3 \
    --lr_scheduler cosine_with_restarts \
    --lr_warmup_steps 200 \
    --lr_num_cycles 1 \
    --optimizer AdamW \
    --adam_beta1 0.9 \
    --adam_beta2 0.95 \
    --max_grad_norm 1.0 \
    --allow_tf32 \
    --report_to wandb
xilanhua12138 commented 1 day ago

I found that in the inference pipeline, the frame does duplicated to 22,

        # 5. Prepare latents
        latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1

        # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
        patch_size_t = self.transformer.config.patch_size_t
        additional_frames = 0
        if patch_size_t is not None and latent_frames % patch_size_t != 0:
            additional_frames = patch_size_t - latent_frames % patch_size_t
            num_frames += additional_frames * self.vae_scale_factor_temporal

but in training process, relevant action is not taken

zRzRzRzRzRzRzR commented 1 day ago

@zhangch 能看一下吗

xilanhua12138 commented 1 day ago

Is there any quick method to fix this bug, any suggestions?

TheDenk commented 1 day ago

Hey :) As I can see in prepare_latents:

if self.transformer.config.patch_size_t is not None:
    shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:]

there 13 frames transforms to 14.

I think that for training it is need to increase the num frames. (49 -> 53 | 81 -> 85). And then for rotary_positional_embeddings decrease num_frames as num_frames = num_frames // 2. I tested it and train works, but I am not sure that it is right approach :)

spacegoing commented 1 day ago

@xilanhua12138 Setting "--max_num_frames", "85" would work.

@zRzRzRzRzRzRzR but then would hit this error:

[rank5]: File "/workspace/public/users/lichang93/mydocker/cogvx/host_folder/diffusers/src/diffusers/models/transformers/cogvideox_transformer_3d.py", line 132, in forward [rank5]: attn_hidden_states, attn_encoder_hidden_states = self.attn1( [rank5]: File "/root/mambaforge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl [rank5]: return self._call_impl(*args, *kwargs) [rank5]: File "/root/mambaforge/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl [rank5]: return forward_call(args, *kwargs) [rank5]: File "/workspace/public/users/lichang93/mydocker/cogvx/host_folder/diffusers/src/diffusers/models/attention_processor.py", line 530, in forward [rank5]: return self.processor( [rank5]: File "/workspace/public/users/lichang93/mydocker/cogvx/host_folder/diffusers/src/diffusers/models/attention_processor.py", line 2293, in call [rank5]: query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb) [rank5]: File "/workspace/public/users/lichang93/mydocker/cogvx/host_folder/diffusers/src/diffusers/models/embeddings.py", line 816, in apply_rotary_emb [rank5]: out = (x.float() cos + x_rotated.float() * sin).to(x.dtype) [rank5]: RuntimeError: The size of tensor a (14850) must match the size of tensor b (29700) at non-singleton dimension 2

TheDenk commented 1 day ago

And then for rotary_positional_embeddings decrease num_frames as num_frames = num_frames // 2 @spacegoing

spacegoing commented 1 day ago
image

@TheDenk @xilanhua12138 @zRzRzRzRzRzRzR

FYI, it should based on patch_size_t

TheDenk commented 1 day ago

@spacegoing yep, but patch_size_t = 2 :)

spacegoing commented 1 day ago

@spacegoing yep, but patch_size_t = 2 :)

yes that's a perfect match:D