plai-group / flexible-video-diffusion-modeling

MIT License
113 stars 14 forks source link

Memory Cost and Sampling FPS. #3

Closed JunyaoHu closed 1 year ago

JunyaoHu commented 1 year ago

Question

Dear author, I am very interested in your work about long video prediction. Now long video prediction needs high memory cost and has a slow prediction rate. For example, I use my 3090 to train and sample your FDM model: the memory cost is 15.9 GB, on batch size 1, and sampling FPS is about 0.09 frames/s. But I wonder how to finish the memory and FPS analyzation for VDM, CWVAE and TATS, which are mentioned in your paper. Can you help me calculate the memory cost and FPS? Or can you provide VDM / CWVAE / TATS code which can be trained on the CARLA Town 01 dataset? Hope your early reply, thanks!

image

Definetion

Memory: The memory cost when training the model and the batch size. FPS: Only generate a sample with T predicted frames, batch size is also 1. Calculate T / spend_time .

Related Code

FPS

import time

for i, (batch_x, batch_y) in enumerate(vali):
    .......
    t = time.time()
    batch_x, batch_y = batch_x.to(self.device), batch_y.to(self.device)
    ......
    res = ......
    ......
    # batch_y.shape -> [b t=pred c h w]
    print(f"fps on {batch_y.shape[1]} frames (frames/s): {batch_y.shape[1]/(time.time()-t):.4f}")

Memory

import torch

mem_free, mem_total = torch.cuda.mem_get_info()
current_memory = (mem_total - mem_free) / 2 ** 30
print(f"device {th.cuda.current_device()}, memory           {current_memory}")
print(f"device {th.cuda.current_device()}, memory allocated {torch.cuda.memory_allocated()/ 2 ** 30}")
print(f"device {th.cuda.current_device()}, memory reserve   {torch.cuda.memory_reserved()/ 2 ** 30}")

Reference

https://github.com/lucidrains/video-diffusion-pytorch

https://github.com/vaibhavsaxena11/cwvae

https://github.com/SongweiGe/TATS

wsgharvey commented 1 year ago

Thanks for the interest! Hopefully the below is helpful for you!

VDM We compare against a version of VDM with the same architecture as FDM, so I should caveat this by saying that the architecture is not exactly the same as in the VDM paper. As in the VDM paper, training it involves training two models - a "frameskip-1" model on sequences of 9 consecutive frames, and a "frameskip-4" model on sequences of 16 evenly spaced frames. Our implementation of this is in this repo. It is badly documented, but the following commands should work to train and test our VDM baseline. You'll need to install it by running pip install . and also installing the same requirements as listed in this repo's README.

Training frameskip-1: python scripts/video_train.py --num_workers 10 --batch_size=2 --max_frames 20 --dataset=carla_no_traffic --num_res_blocks=1 --use_rpe_net True --save_latest_only False --save_interval 50000 --T 9 --max_frames 9 --mask_distribution differently-spaced-groups-no-marg --fake_seed 1

Training frameskip-4: python scripts/video_train.py --batch_size=1 --num_workers 10 --max_frames 20 --dataset=carla_no_traffic --num_res_blocks=1 --use_rpe_net True --save_latest_only False --save_interval 50000 --T 61 --max_frames 16 --mask_distribution linspace-0-60-16 --fake_seed 1

Sampling: ​​python scripts/video_sample_google.py --fs1_path <PATH TO FRAMESKIP 1 CHECKPOINTS>/ema_0.9999_000000.pt --fs4_path <PATH TO FRAMESKIP 4 CHECKPOINTS>/ema_0.9999_000000.pt --batch_size 1 --sample_idx 0

The ema_0.9999_000000.pt checkpoints are saved after the first training iteration, so if you just want to check the computational requirements of sampling, you can use these checkpoints without having to train for longer.

CWVAE We trained (and tested) CWVAE on a version of CARLA Town 01 downsampled to 64 x 64 resolution, so the FPS and memory cost will be the same as in any of the other datasets implemented in https://github.com/vaibhavsaxena11/cwvae.

TATS We didn't make major changes from the implementation at https://github.com/SongweiGe/TATS. We sampled using sample_vqgan_transformer_long_videos.py and trained using commands like those in the README.

JunyaoHu commented 1 year ago

Hello, in TATS, I see the origin project only supports unconditional model, and label/stft/text conditions. How to update the code to support the video condition?

Like this? In their tats_transformer.py forward function

    def init_cond_stage_from_ckpt(self, args):
        from .download import load_vqgan
        if self.cond_stage_key=='label' and not self.be_unconditional:
            model = Labelator(n_classes=args.class_cond_dim)
            model = model.eval()
            model.train = disabled_train
            self.cond_stage_model = model
            self.cond_stage_vocab_size = self.class_cond_dim
        ...
        elif self.cond_stage_key=='video':
            self.cond_stage_model = self.first_stage_model
            self.cond_stage_vocab_size = elf.first_stage_vocab_size
        ...
        else:
            ValueError('conditional model %s is not implementated'%self.cond_stage_key)

And how to change this condition flag class_cond?

python sample_vqgan_transformer_long_videos.py \
    --gpt_ckpt {GPT-CKPT} --vqgan_ckpt {VQGAN-CKPT} \
    --dataset ucf101 --class_cond --sample_length 16 --temporal_sample_pos 1 --batch_size 5 --n_sample 5 --save_videos
wsgharvey commented 1 year ago

We've just made our fork of the TATS repo public: https://github.com/whilo/TATS

We implemented conditioning a little bit differently to that - since TATS is autoregressive we just modify the sampling script to start with the first few frames of the video and then roll out from there. So our command to call this becomes something like

python sample_vqgan_transformer_long_videos.py \
    --gpt_ckpt {GPT-CKPT} --vqgan_ckpt {VQGAN-CKPT} \
    --dataset ucf101 --class_cond --temporal_pix 16 --obs_frames 1 --batch_size 5 --n_sample 5 --save_videos

where obs_frames is the number of frames we want to condition on (e.g. 36) and temporal_pix is the desired total number of frames (e.g. 1000). We also obviously use different datasets to the ucf101 this command, and so I think must not pass in the --class_cond option - sorry I can't find exactly the command we used anywhere but hopefully this is helpful!

JunyaoHu commented 1 year ago

Last question... Do you remember this. What is _load_model_state ?...

https://github.com/SongweiGe/TATS/issues/24

def load_transformer(gpt_ckpt, vqgan_ckpt, stft_vqgan_ckpt='', device=torch.device('cpu')):
    from pytorch_lightning.utilities.cloud_io import load as pl_load
    checkpoint = pl_load(gpt_ckpt)
    checkpoint['hyper_parameters']['args'].vqvae = vqgan_ckpt
    if stft_vqgan_ckpt:
        checkpoint['hyper_parameters']['args'].stft_vqvae = stft_vqgan_ckpt
    gpt = Net2NetTransformer._load_model_state(checkpoint)
    gpt.eval()
    return gpt
wsgharvey commented 1 year ago

Not sure - iirc we used the load_transformer function but didn't have to look into its internals

JunyaoHu commented 1 year ago

But it seems that this function is missing in the project, so I can't run the code properly now. I have also sent an email to the original author, thank you.

wsgharvey commented 1 year ago

Oh, it looks like this is inherited by Net2NetTransformer from pl.LightningModule?

With the Python environment I used to run TATS I get:

>>> import pytorch_lightning as pl
>>> pl.__version__
'1.6.5'
>>> pl.LightningModule._load_model_state
<bound method ModelIO._load_model_state of <class 'pytorch_lightning.core.lightning.LightningModule'>>
JunyaoHu commented 1 year ago

Very helpful. I did not find anything about _load_model_state even on lightning official docs. I used version 1.8.6 so I cannot use this function. I will degrade the version. Thanks again, you are very patient. I wish you good luck with your research.

>>>  import pytorch_lightning as pl
  File "<stdin>", line 1
    import pytorch_lightning as pl
IndentationError: unexpected indent
>>> import pytorch_lightning as pl
>>> pl.__version__
'1.8.6'
>>> pl.LightningModule._load_model_state
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
AttributeError: type object 'LightningModule' has no attribute '_load_model_state'