Closed JunyaoHu closed 1 year ago
Thanks for the interest! Hopefully the below is helpful for you!
VDM
We compare against a version of VDM with the same architecture as FDM, so I should caveat this by saying that the architecture is not exactly the same as in the VDM paper. As in the VDM paper, training it involves training two models - a "frameskip-1" model on sequences of 9 consecutive frames, and a "frameskip-4" model on sequences of 16 evenly spaced frames. Our implementation of this is in this repo. It is badly documented, but the following commands should work to train and test our VDM baseline. You'll need to install it by running pip install .
and also installing the same requirements as listed in this repo's README.
Training frameskip-1:
python scripts/video_train.py --num_workers 10 --batch_size=2 --max_frames 20 --dataset=carla_no_traffic --num_res_blocks=1 --use_rpe_net True --save_latest_only False --save_interval 50000 --T 9 --max_frames 9 --mask_distribution differently-spaced-groups-no-marg --fake_seed 1
Training frameskip-4:
python scripts/video_train.py --batch_size=1 --num_workers 10 --max_frames 20 --dataset=carla_no_traffic --num_res_blocks=1 --use_rpe_net True --save_latest_only False --save_interval 50000 --T 61 --max_frames 16 --mask_distribution linspace-0-60-16 --fake_seed 1
Sampling:
python scripts/video_sample_google.py --fs1_path <PATH TO FRAMESKIP 1 CHECKPOINTS>/ema_0.9999_000000.pt --fs4_path <PATH TO FRAMESKIP 4 CHECKPOINTS>/ema_0.9999_000000.pt --batch_size 1 --sample_idx 0
The ema_0.9999_000000.pt
checkpoints are saved after the first training iteration, so if you just want to check the computational requirements of sampling, you can use these checkpoints without having to train for longer.
CWVAE We trained (and tested) CWVAE on a version of CARLA Town 01 downsampled to 64 x 64 resolution, so the FPS and memory cost will be the same as in any of the other datasets implemented in https://github.com/vaibhavsaxena11/cwvae.
TATS
We didn't make major changes from the implementation at https://github.com/SongweiGe/TATS. We sampled using sample_vqgan_transformer_long_videos.py
and trained using commands like those in the README.
Hello, in TATS, I see the origin project only supports unconditional model, and label/stft/text conditions. How to update the code to support the video condition?
Like this? In their tats_transformer.py
forward function
def init_cond_stage_from_ckpt(self, args):
from .download import load_vqgan
if self.cond_stage_key=='label' and not self.be_unconditional:
model = Labelator(n_classes=args.class_cond_dim)
model = model.eval()
model.train = disabled_train
self.cond_stage_model = model
self.cond_stage_vocab_size = self.class_cond_dim
...
elif self.cond_stage_key=='video':
self.cond_stage_model = self.first_stage_model
self.cond_stage_vocab_size = elf.first_stage_vocab_size
...
else:
ValueError('conditional model %s is not implementated'%self.cond_stage_key)
And how to change this condition flag class_cond
?
python sample_vqgan_transformer_long_videos.py \
--gpt_ckpt {GPT-CKPT} --vqgan_ckpt {VQGAN-CKPT} \
--dataset ucf101 --class_cond --sample_length 16 --temporal_sample_pos 1 --batch_size 5 --n_sample 5 --save_videos
We've just made our fork of the TATS repo public: https://github.com/whilo/TATS
We implemented conditioning a little bit differently to that - since TATS is autoregressive we just modify the sampling script to start with the first few frames of the video and then roll out from there. So our command to call this becomes something like
python sample_vqgan_transformer_long_videos.py \
--gpt_ckpt {GPT-CKPT} --vqgan_ckpt {VQGAN-CKPT} \
--dataset ucf101 --class_cond --temporal_pix 16 --obs_frames 1 --batch_size 5 --n_sample 5 --save_videos
where obs_frames
is the number of frames we want to condition on (e.g. 36) and temporal_pix
is the desired total number of frames (e.g. 1000). We also obviously use different datasets to the ucf101
this command, and so I think must not pass in the --class_cond
option - sorry I can't find exactly the command we used anywhere but hopefully this is helpful!
Last question... Do you remember this. What is _load_model_state
?...
https://github.com/SongweiGe/TATS/issues/24
def load_transformer(gpt_ckpt, vqgan_ckpt, stft_vqgan_ckpt='', device=torch.device('cpu')):
from pytorch_lightning.utilities.cloud_io import load as pl_load
checkpoint = pl_load(gpt_ckpt)
checkpoint['hyper_parameters']['args'].vqvae = vqgan_ckpt
if stft_vqgan_ckpt:
checkpoint['hyper_parameters']['args'].stft_vqvae = stft_vqgan_ckpt
gpt = Net2NetTransformer._load_model_state(checkpoint)
gpt.eval()
return gpt
Not sure - iirc we used the load_transformer
function but didn't have to look into its internals
But it seems that this function is missing in the project, so I can't run the code properly now. I have also sent an email to the original author, thank you.
Oh, it looks like this is inherited by Net2NetTransformer
from pl.LightningModule
?
With the Python environment I used to run TATS I get:
>>> import pytorch_lightning as pl
>>> pl.__version__
'1.6.5'
>>> pl.LightningModule._load_model_state
<bound method ModelIO._load_model_state of <class 'pytorch_lightning.core.lightning.LightningModule'>>
Very helpful. I did not find anything about _load_model_state
even on lightning official docs. I used version 1.8.6 so I cannot use this function. I will degrade the version. Thanks again, you are very patient. I wish you good luck with your research.
>>> import pytorch_lightning as pl
File "<stdin>", line 1
import pytorch_lightning as pl
IndentationError: unexpected indent
>>> import pytorch_lightning as pl
>>> pl.__version__
'1.8.6'
>>> pl.LightningModule._load_model_state
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: type object 'LightningModule' has no attribute '_load_model_state'
Question
Dear author, I am very interested in your work about long video prediction. Now long video prediction needs high memory cost and has a slow prediction rate. For example, I use my 3090 to train and sample your FDM model: the memory cost is 15.9 GB, on batch size 1, and sampling FPS is about 0.09 frames/s. But I wonder how to finish the memory and FPS analyzation for VDM, CWVAE and TATS, which are mentioned in your paper. Can you help me calculate the memory cost and FPS? Or can you provide VDM / CWVAE / TATS code which can be trained on the CARLA Town 01 dataset? Hope your early reply, thanks!
Definetion
Memory: The memory cost when training the model and the batch size. FPS: Only generate a sample with T predicted frames, batch size is also 1. Calculate
T / spend_time
.Related Code
FPS
Memory
Reference
https://github.com/lucidrains/video-diffusion-pytorch
https://github.com/vaibhavsaxena11/cwvae
https://github.com/SongweiGe/TATS