Open jiaxilv opened 2 weeks ago
Have you tried the docker we provided? This graphics card seems very new and should be able to support bf16.
Could you post the exception stack?
Thank you for your enthusiasm! I found out that the problem was caused by the default cuda 12.2 version of the machine, and I was able to use bf16 normally after using cuda version 1.18 in my conda environment.
I had the same problem, disabled accelerate
or use fp16
can fix this problem.
However it doesn't seem like a good solution...
accelerate launch --mixed_precision="fp16" scripts/train_t2iv.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$DATASET_NAME \
--train_data_meta=$DATASET_META_NAME \
--config_path "config/easyanimate_video_magvit_motion_module_v2.yaml" \
--image_sample_size=512 \
--video_sample_size=512 \
--video_sample_stride=1 \
--video_sample_n_frames=24 \
--train_batch_size=1 \
--video_repeat=1 \
--gradient_accumulation_steps=1 \
--dataloader_num_workers=8 \
--num_train_epochs=100 \
--checkpointing_steps=500 \
--learning_rate=2e-05 \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=100 \
--seed=42 \
--output_dir="output_dir" \
--enable_xformers_memory_efficient_attention \
--gradient_checkpointing \
--adam_weight_decay=3e-2 \
--adam_epsilon=1e-10 \
--max_grad_norm=1 \
--vae_mini_batch=1 \
--random_frame_crop \
--enable_bucket \
--trainable_modules "transformer_blocks" "proj_out" "pos_embed" "long_connect_fc"
I had the same problem, disabled
accelerate
or usefp16
can fix this problem. However it doesn't seem like a good solution...accelerate launch --mixed_precision="fp16" scripts/train_t2iv.py \ --pretrained_model_name_or_path=$MODEL_NAME \ --train_data_dir=$DATASET_NAME \ --train_data_meta=$DATASET_META_NAME \ --config_path "config/easyanimate_video_magvit_motion_module_v2.yaml" \ --image_sample_size=512 \ --video_sample_size=512 \ --video_sample_stride=1 \ --video_sample_n_frames=24 \ --train_batch_size=1 \ --video_repeat=1 \ --gradient_accumulation_steps=1 \ --dataloader_num_workers=8 \ --num_train_epochs=100 \ --checkpointing_steps=500 \ --learning_rate=2e-05 \ --lr_scheduler="constant_with_warmup" \ --lr_warmup_steps=100 \ --seed=42 \ --output_dir="output_dir" \ --enable_xformers_memory_efficient_attention \ --gradient_checkpointing \ --adam_weight_decay=3e-2 \ --adam_epsilon=1e-10 \ --max_grad_norm=1 \ --vae_mini_batch=1 \ --random_frame_crop \ --enable_bucket \ --trainable_modules "transformer_blocks" "proj_out" "pos_embed" "long_connect_fc"
Check the CUDA version ? Docker is recommended !
Thank you very much for your work, I am running the predict script on an H20. Whether generating an image or a video, torch.bfloat16 causes a
Floating point exception (core dumped)
error. Predicting the image (PixArt weights) on torch.float16 generates the image normally. But inference of videos with fp16 only generates all-black gifs, and using fp16 during training results in a LOSS of NAN. Although we can train properly using fp32, it is too slow.Do you guys have any suggestions?