songweige / TATS

Official PyTorch implementation of TATS: A Long Video Generation Framework with Time-Agnostic VQGAN and Time-Sensitive Transformer (ECCV 2022)
MIT License
263 stars 17 forks source link

Pre-trained model for UCF-101 without class conditioning & Question about CCVS model for Fig.5(a) #11

Closed Ugness closed 2 years ago

Ugness commented 2 years ago

May I get the pretrained transformer model for the UCF-101 dataset without class conditioning? I found that the checkpoint you attached for the UCF-101 requires class labels.

Ugness commented 2 years ago

Also, if possible, can you share the CCVS checkpoint that you used as the baseline. Based on the paper, I tried to find the official model checkpoint for CCVS on UCF-101 with 128x128 resolution, but the official model checkpoint was trained on 256x256 resolution. Therefore, I'd like to know how you evaluated the long-term performance of CCVS.

songweige commented 2 years ago

Hi Yoo, thank you for your interests, please find the unconditional transformer model on the UCF dataset here: https://drive.google.com/file/d/1pQsMdO2b84m7asp_lNg44c_UE5DUC6pa/view?usp=sharing. We also reflect it in the ReadMe file.

In terms of CCVS, we used their official model on the UCF dataset with 256x256 resolution. We modified the code to do sliding window during the inference time for long video generation.

Ugness commented 2 years ago

Thanks a lot! Your kindness helped me a lot.

Ugness commented 2 years ago

Hi, may I know the decoding parameters (e.g., temperature, top_k, top_p) for the UCF unconditional model to reproduce the result? I generated the videos with top_p=0.8 and top_k=2048 like as the option for the conditional version but couldn't reproduce the score.

songweige commented 2 years ago

Hi, can you try top_p = 0.92 and top_k = 8192, and let me know what results you are getting?

Ugness commented 2 years ago

I generated the samples with the options you recommended, but I got the FVD as 394. with top_p=0.92 and top_k=2048, I got 392.

songweige commented 1 year ago

Interesting, could you please help me to see where the divergence starts? This is the exact script I used:

python sample_vqgan_transformer_short_videos.py --gpt_ckpt /fs/vulcan-projects/contrastive_learning_songweig/TATS/ckpts/ucf101/uncond_gpt_ucf_128_488_29999/epoch=21-step=1349999-train/uncond_gpt_ucf_128_488_29999.ckpt --vqgan_ckpt /fs/vulcan-projects/contrastive_learning_songweig/TATS/ckpts/ucf101/vqgan_ucf_128_488/lightning_logs/version_50507903/checkpoints/epoch=1-step=29999-10000-train/recon_loss=0.27.ckpt --save /fs/vulcan-projects/contrastive_learning_songweig/TATS/results/numpy_files/ucf101/issue_test_uncond/ --data_path /fs/vulcan-projects/contrastive_learning_songweig/TATS/data/ucf101/ --top_k 8192 --top_p 0.92 --dataset 
ucf101 --compute_fvd --batch_size 16 --run 8

image

Could you share your script here? Also to debug this, I uploaded one of the numpy files to google drive. Could you download and compute FVD on your end to see if it can match the number, to see if the inconsistency comes from the model side or data side? https://drive.google.com/file/d/1BAvdzq5Xt-i5CVtntfPZzn2NrSrLg139/view?usp=sharing

Ugness commented 1 year ago
python sample_vqgan_transformer_short_videos.py --gpt_ckpt $1 --vqgan_ckpt $2 --save results/tats_baseline --data_path datasets/vqgan_data/ucf --batch_size 16 --top_k 8192 --top_p 0.92 --dataset ucf101 --compute_fvd
image

I obtained the FVD with the above script. The one thing I modified is changing the train_dataloader() into val_dataloader() to measure FVD against the validation set of UCF.

I'll measure the FVD with the npy file you attached ASAP. Thanks a lot. Also, may I know the floating point type that you used for training & inference? Currently, I'm using TF32 type.

songweige commented 1 year ago

Oh I think that is probably the reason - usually people report the train FVD on UCF dataset. For validation FVD I think the value looks reasonable? Can you try to compute FVD with the train set?

In terms of the data type, I used fp32. I'm also curious if other floating point type would affect the result or not. : )

Ugness commented 1 year ago

Okay Thanks! I'll check the items and share to you ASAP! Also, I'm afraid I am using the splits of benchmarks in the wrong way. Can you check the below setups for me?

Ugness commented 1 year ago

May I get the version of pytorch that you used? I measured FVD with the numpy file that you gave, and used FP32 + training set of UCF101, but still shows high FVD (=383). I think there would be a problem with the dataset or gap b/w environments (including S/W and H/W). I'm running my codes on A100 machine with torch==1.10.0.

I'll take a look at the environments and UCF101 dataset that I downloaded. Thanks a lot.

songweige commented 1 year ago

Thank you for helping debugging this! I think the setups sound good to me except that when people train on train+val sets of UCF101, they also measure FVD on train+val.

Ugness commented 1 year ago

Hi, may I get the checkpoints for TATS-hierarchical on UCF101 if they are unconditional models?

songweige commented 1 year ago

Hi, unfortunately we don't have those checkpoints. Sorry about that!