jolibrain / joliGEN

Generative AI Image Toolset with GANs and Diffusion for Real-World Applications
https://www.joligen.com
Other
236 stars 31 forks source link

WIP: video_unet_generator_attn #669

Closed wr0124 closed 1 week ago

wr0124 commented 1 month ago
wr0124 commented 1 month ago

UNet=((ResBlock+Attention )2)4 for input_blocks python3 -W ignore::UserWarning train.py \ --dataroot /data1/juliew/mini_dataset/online_mario2sonic_lite \ --checkpoints_dir /data1/juliew/checkpoints \ --name mario \ --config_json examples/example_ddpm_mario.json \ --gpu_ids 1 \ --output_display_env test_mario_unet \ --output_display_freq 1 \ --output_print_freq 1 \ --G_diff_n_timestep_test 5 \ --G_diff_n_timestep_train 2000 \ --G_unet_mha_channel_mults 1 2 4 8 \ --G_unet_mha_res_blocks 2 2 2 2 \ --train_batch_size 1 \ --G_unet_mha_attn_res 1 2 4 8 \ --data_num_threads 1 \ ~

wr0124 commented 1 month ago

Since joliGEN DDPM temporal use_temporal, it creates tensor in the shape(b,f,c,h,w), which differs from the priginal paper's formate of (b,c,f,h,w). So, in this version, all tensor flow is the formate of (b,f,c,h,w). Due to compatibility with other models in joliGEN, it may be advantageous to treat the tensor in 4D format during trainning ?

wr0124 commented 1 month ago

python3 -W ignore::UserWarning train.py \ --dataroot /data1/juliew/dataset/online_mario2sonic_full_mario \ --checkpoints_dir /data1/juliew/checkpoints \ --name mario_temporal \ --config_json examples/example_ddpm_mario.json \ --gpu_ids 2 \ --output_display_env test_mario_temporal \ --output_print_freq 1 \ --output_display_freq 1 \ --data_dataset_mode self_supervised_temporal_labeled_mask_online \ --train_batch_size 1 \ --train_iter_size 4 \ --data_temporal_number_frames 4 \ --data_temporal_frame_step 1 \ --data_num_threads 1 \ --train_temporal_criterion \ --G_diff_n_timestep_test 1000 \ --G_diff_n_timestep_train 2000 \ --train_temporal_criterion_lambda 1.0 \ --G_netG unet_vid \ --data_online_creation_crop_size_A 64 \ --data_online_creation_crop_size_B 64 \ --data_crop_size 64 \ --data_load_size 64 \ --G_unet_mha_attn_res 1 2 4 8 \ --output_verbose \

wr0124 commented 1 month ago

works with command line python3 -W ignore::UserWarning train.py \ --dataroot /data1/juliew/dataset/online_mario2sonic_full_mario \ --checkpoints_dir /data1/juliew/checkpoints \ --name mario_antoine \ --gpu_ids 2 \ --output_display_env test_mario_antoine \ --model_type palette \ --output_print_freq 1 \ --output_display_freq 1 \ --data_dataset_mode self_supervised_temporal_labeled_mask_online \ --train_batch_size 1 \ --train_iter_size 1 \ --model_input_nc 3 \ --model_output_nc 3 \ --data_relative_paths \ --train_G_ema \ --train_optim adamw \ --train_temporal_criterion_lambda 1.0 \ --G_netG unet_vid \ --data_online_creation_crop_size_A 64 \ --data_online_creation_crop_size_B 64 \ --data_crop_size 64 \ --data_load_size 64 \ --G_unet_mha_attn_res 16 \ --data_online_creation_rand_mask_A \ --train_G_lr 0.0001 \ --dataaug_no_rotate \ --G_diff_n_timestep_train 5 \ --G_diff_n_timestep_test 6 \ --data_temporal_number_frames 4 \ --data_temporal_frame_step 1 \ --data_num_threads 4 \ --UNetVid \

wr0124 commented 1 month ago

python3 -W ignore::UserWarning train.py \ --dataroot /data1/juliew/dataset/online_mario2sonic_full_mario \ --checkpoints_dir /data1/juliew/checkpoints \ --name mario_vid_bs1 \ --gpu_ids 2 \ --model_type palette \ --output_print_freq 1 \ --output_display_freq 1 \ --data_dataset_mode self_supervised_temporal_labeled_mask_online \ --train_batch_size 1 \ --train_iter_size 4 \ --model_input_nc 3 \ --model_output_nc 3 \ --data_relative_paths \ --train_G_ema \ --train_optim adamw \ --train_temporal_criterion_lambda 1.0 \ --G_netG unet_vid \ --data_online_creation_crop_size_A 64 \ --data_online_creation_crop_size_B 64 \ --data_crop_size 64 \ --data_load_size 64 \ --G_unet_mha_attn_res 1 2 4 8 \ --data_online_creation_rand_mask_A \ --train_G_lr 0.0001 \ --dataaug_no_rotate \ --G_diff_n_timestep_train 8 \ --G_diff_n_timestep_test 6 \ --data_temporal_number_frames 10 \ --data_temporal_frame_step 1 \ --data_num_threads 8 \ --UNetVid \ --output_verbose \

wr0124 commented 1 month ago

lanch inference

cd scripts/ python3 gen_vid_diffusion.py \ --model_in_file /data1/juliew/checkpoints/mario_vid_bs1/latest_net_G_A.pth \ --img_in /data1/juliew/mini_dataset/online_mario2sonic_video/trainA/paths_part.txt \ --paths_file /data1/juliew/mini_dataset/online_mario2sonic_video/trainA/paths_part.txt \ --mask_in /data1/juliew/mini_dataset/online_mario2sonic_video/trainA/paths_part.txt \ --data_root /data1/juliew/mini_dataset/online_mario2sonic_video/ \ --dir_out ../inference_mario \ --img_width 128 \ --img_height 128 \

wr0124 commented 1 month ago

create videos by this command_line:

cd scripts/ python3 gen_vid_diffusion.py \ --model_in_file /data1/juliew/checkpoints/test_vid/latest_net_G_A.pth \ --img_in /data1/paths_part.txt \ --paths_file /data1/juliew/ori_dataset/online_mario2sonic_full/trainA/paths_part4.txt \ --mask_in /paths_part.txt \ --data_root /data1/juliew/ori_dataset/online_mario2sonic_full/ \ --dir_out ../inference_mario_vid \ --img_width 128 \ --img_height 128 \ --nb_samples 2 \

wr0124 commented 3 weeks ago

create one unite test file "test_run_video_diffusion_online.py " for unite test

wr0124 commented 2 weeks ago

during inference, additional frames beyong the specified opt.data_temporal_number_frames can be added for video generation, but according to the literature, this often results in degraded outcomes. the additional_frame in gen_vid_diffusion file needs to be tested when its value is negative.