songweige / TATS

Official PyTorch implementation of TATS: A Long Video Generation Framework with Time-Agnostic VQGAN and Time-Sensitive Transformer (ECCV 2022)
MIT License
267 stars 17 forks source link

Training on FaceForensics #22

Closed piotr-komorowski closed 9 months ago

piotr-komorowski commented 1 year ago

Hi! I am currently working on using the code for the FaceForensics dataset. I have been able to train VQGAN, but I encountered an issue while training the TATS-base Transformer. Here is the command I am using for training:

python scripts/train_transformer.py --num_workers 32 --val_check_interval 0.5 --progress_bar_refresh_rate 500 \
                        --gpus 4 --sync_batchnorm --batch_size 2 --unconditional \
                        --vqvae exp_1/lightning_logs/version_78873/checkpoints/latest_checkpoint.ckpt \
                        --data_path data/ffs_processed --default_root_dir exp_1_tats --image_folder \
                        --vocab_size 16384 --block_size 1024 --n_layer 24 --n_head 16 --n_embd 1024  \
                        --resolution 256 --sequence_length 16 --max_steps 2000000

However, I got an error: Traceback (most recent call last): File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/train_transformer.py", line 77, in main() File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/train_transformer.py", line 73, in main trainer.fit(model, data) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit self._call_and_handle_interrupt( File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt return trainer_fn(*args, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run self._dispatch() File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch self.training_type_plugin.start_training(self) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training self._results = trainer.run_stage() File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage return self._run_train() File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1306, in _run_train self._run_sanity_check(self.lightning_module) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1370, in _run_sanity_check self._evaluation_loop.run() File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(*args, *kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 109, in advance dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(args, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance output = self._evaluation_step(batch, batch_idx, dataloader_idx) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step output = self.trainer.accelerator.validation_step(step_kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step return self.training_type_plugin.validation_step(step_kwargs.values()) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 444, in validation_step return self.model(args, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward output = self._run_ddp_forward(inputs, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 969, in _run_ddp_forward return module_to_run(*inputs[0], kwargs[0]) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 92, in forward output = self.module.validation_step(inputs, kwargs) File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tats_transformer.py", line 250, in validation_step loss, acc1, acc5 = self.shared_step(batch, batch_idx) File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tats_transformer.py", line 237, in shared_step logits, target = self(x, c, cbox) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tatstransformer.py", line 126, in forward logits, = self.transformer(cz_indices[:, :-1], cbox=cbox) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, **kwargs) File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/modules/gpt.py", line 208, in forward assert t <= self.block_size, "Cannot forward, model block size is exhausted." AssertionError: Cannot forward, model block size is exhausted.

where t=4096 and self.block_size=1024. When I tried to increase the block_size argument to 4096, I received another error message (even with a batch_size of 1): Traceback (most recent call last): File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/train_transformer.py", line 77, in main() File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/train_transformer.py", line 73, in main trainer.fit(model, data) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit self._call_and_handle_interrupt( File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt return trainer_fn(*args, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run self._dispatch() File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch self.training_type_plugin.start_training(self) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training self._results = trainer.run_stage() File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage return self._run_train() File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1314, in _run_train self.fit_loop.run() File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(*args, *kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance self.epoch_loop.run(data_fetcher) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(args, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 193, in advance batch_output = self.batch_loop.run(batch, batch_idx) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(*args, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(args, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 215, in advance result = self._run_optimization( File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 266, in _run_optimization self._optimizer_step(optimizer, opt_idx, batch_idx, closure) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 378, in _optimizer_step lightning_module.optimizer_step( File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1651, in optimizer_step optimizer.step(closure=optimizer_closure) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 164, in step trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 336, in optimizer_step self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 163, in optimizer_step optimizer.step(closure=closure, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/optim/optimizer.py", line 113, in wrapper return func(args, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/optim/adamw.py", line 119, in step loss = closure() File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 148, in _wrap_closure closure_result = closure() File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 160, in call self._result = self.closure(*args, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 142, in closure step_output = self._step_fn() File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 435, in _training_step training_step_output = self.trainer.accelerator.training_step(step_kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 216, in training_step return self.training_type_plugin.training_step(step_kwargs.values()) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 439, in training_step return self.model(args, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward output = self._run_ddp_forward(*inputs, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 969, in _run_ddp_forward return module_to_run(*inputs[0], *kwargs[0]) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 81, in forward output = self.module.training_step(*inputs, kwargs) File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tats_transformer.py", line 243, in training_step loss, acc1, acc5 = self.shared_step(batch, batch_idx) File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tats_transformer.py", line 237, in shared_step logits, target = self(x, c, cbox) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tatstransformer.py", line 126, in forward logits, = self.transformer(cz_indices[:, :-1], cbox=cbox) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, kwargs) File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/modules/gpt.py", line 218, in forward x = self.blocks(x) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, kwargs) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward input = module(input) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/modules/gpt.py", line 153, in forward attn, present = self.attn(self.ln1(x), layer_past=layer_past) File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, kwargs) File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/modules/gpt.py", line 121, in forward att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 1; 31.74 GiB total capacity; 29.71 GiB already allocated; 45.12 MiB free; 30.61 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Can you suggest any potential solutions that might work?

songweige commented 1 year ago

If you are using the default hyperparameter to train the VQGAN, then it looks like you trained on 256x256 resolution, which results in a latent size of (4, 32, 32) and block size of 4096. It means the sequence to be modeled by the transformer has a length of 4096 which cannot feed into the memory. You have several options:

piotr-komorowski commented 1 year ago

Thank you, it's working now! I have another question regarding training. My data is structured in a way where there are 704 folders (videos), each containing some number of images (frames). To train vqgan, I'm using the --image_folder flag. However, I've noticed that the training progress output is a bit strange. For example, when using --batch_size 2, the output shows:

Validating: 100%|██████████| 88/88 [01:13<00:00,  1.19it/s]
Epoch 1: 100%|██████████| 176/176 [07:06<00:00,  2.42s/it, loss=1.11, v_num=78873, train/perceptual_loss_step=1.700, train/recon_loss_step=0.278, train/aeloss_step=0.000, train/commitment_loss_step=0.00433, train/perplexity_step=53.50, train/discloss_step=0.000, val/recon_loss=0.446, val/perceptual_loss=1.690, val/perplexity=53.40, val/commitment_loss=0.0475, train/perceptual_loss_epoch=2.640, train/recon_loss_epoch=0.665, train/aeloss_epoch=0.000, train/commitment_loss_epoch=0.012, train/perplexity_epoch=68.70, train/discloss_epoch=0.000]

However, I expected the training progress to show twice as many iterations (176 -> 352), to cover the entire batch in one epoch. Similarly, when using --batch_size 1, I'm getting

Validating: 100%|██████████| 176/176 [01:27<00:00,  2.01it/s]
Epoch 1: 100%|██████████| 352/352 [10:10<00:00,  1.73s/it, loss=1.26, v_num=79006, train/perceptual_loss_step=1.750, train/recon_loss_step=0.384, train/aeloss_step=0.000, train/commitment_loss_step=0.0054, train/perplexity_step=50.50, train/discloss_step=0.000, val/recon_loss=0.520, val/perceptual_loss=1.950, val/perplexity=40.20, val/commitment_loss=0.253]

I couldn't find the reason for this in the code. Do you have any insights on why this might be happening?

songweige commented 1 year ago

Could it be that you are training on multiple gpus, e.g. 2?

piotr-komorowski commented 1 year ago

I considered it, but since I'm training on 4 GPUs, it didn't add up.