Closed piotr-komorowski closed 9 months ago
If you are using the default hyperparameter to train the VQGAN, then it looks like you trained on 256x256 resolution, which results in a latent size of (4, 32, 32) and block size of 4096. It means the sequence to be modeled by the transformer has a length of 4096 which cannot feed into the memory. You have several options:
--downsample 4 16 16
. This may make the reconstruction quality of VQGAN bad, though.Thank you, it's working now! I have another question regarding training. My data is structured in a way where there are 704 folders (videos), each containing some number of images (frames). To train vqgan, I'm using the --image_folder
flag. However, I've noticed that the training progress output is a bit strange. For example, when using --batch_size 2
, the output shows:
Validating: 100%|██████████| 88/88 [01:13<00:00, 1.19it/s]
Epoch 1: 100%|██████████| 176/176 [07:06<00:00, 2.42s/it, loss=1.11, v_num=78873, train/perceptual_loss_step=1.700, train/recon_loss_step=0.278, train/aeloss_step=0.000, train/commitment_loss_step=0.00433, train/perplexity_step=53.50, train/discloss_step=0.000, val/recon_loss=0.446, val/perceptual_loss=1.690, val/perplexity=53.40, val/commitment_loss=0.0475, train/perceptual_loss_epoch=2.640, train/recon_loss_epoch=0.665, train/aeloss_epoch=0.000, train/commitment_loss_epoch=0.012, train/perplexity_epoch=68.70, train/discloss_epoch=0.000]
However, I expected the training progress to show twice as many iterations (176 -> 352
), to cover the entire batch in one epoch. Similarly, when using --batch_size 1
, I'm getting
Validating: 100%|██████████| 176/176 [01:27<00:00, 2.01it/s]
Epoch 1: 100%|██████████| 352/352 [10:10<00:00, 1.73s/it, loss=1.26, v_num=79006, train/perceptual_loss_step=1.750, train/recon_loss_step=0.384, train/aeloss_step=0.000, train/commitment_loss_step=0.0054, train/perplexity_step=50.50, train/discloss_step=0.000, val/recon_loss=0.520, val/perceptual_loss=1.950, val/perplexity=40.20, val/commitment_loss=0.253]
I couldn't find the reason for this in the code. Do you have any insights on why this might be happening?
Could it be that you are training on multiple gpus, e.g. 2?
I considered it, but since I'm training on 4 GPUs, it didn't add up.
Hi! I am currently working on using the code for the FaceForensics dataset. I have been able to train VQGAN, but I encountered an issue while training the TATS-base Transformer. Here is the command I am using for training:
However, I got an error: Traceback (most recent call last): File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/train_transformer.py", line 77, in
main()
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/train_transformer.py", line 73, in main
trainer.fit(model, data)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
self._call_and_handle_interrupt(
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
self._dispatch()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch
self.training_type_plugin.start_training(self)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage
return self._run_train()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1306, in _run_train
self._run_sanity_check(self.lightning_module)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1370, in _run_sanity_check
self._evaluation_loop.run()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, *kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 109, in advance
dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(args, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 122, in advance
output = self._evaluation_step(batch, batch_idx, dataloader_idx)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 217, in _evaluation_step
output = self.trainer.accelerator.validation_step(step_kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 236, in validation_step
return self.training_type_plugin.validation_step(step_kwargs.values())
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 444, in validation_step
return self.model(args, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, *kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward
output = self._run_ddp_forward(inputs, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 969, in _run_ddp_forward
return module_to_run(*inputs[0], kwargs[0])
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, *kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 92, in forward
output = self.module.validation_step(inputs, kwargs)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tats_transformer.py", line 250, in validation_step
loss, acc1, acc5 = self.shared_step(batch, batch_idx)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tats_transformer.py", line 237, in shared_step
logits, target = self(x, c, cbox)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, *kwargs)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tatstransformer.py", line 126, in forward
logits, = self.transformer(cz_indices[:, :-1], cbox=cbox)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(input, **kwargs)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/modules/gpt.py", line 208, in forward
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
AssertionError: Cannot forward, model block size is exhausted.
where
main()
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/train_transformer.py", line 73, in main
trainer.fit(model, data)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in fit
self._call_and_handle_interrupt(
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run
self._dispatch()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch
self.training_type_plugin.start_training(self)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training
self._results = trainer.run_stage()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage
return self._run_train()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1314, in _run_train
self.fit_loop.run()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, *kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance
self.epoch_loop.run(data_fetcher)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(args, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 193, in advance
batch_output = self.batch_loop.run(batch, batch_idx)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(*args, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 145, in run
self.advance(args, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 215, in advance
result = self._run_optimization(
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 266, in _run_optimization
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 378, in _optimizer_step
lightning_module.optimizer_step(
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1651, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 164, in step
trainer.accelerator.optimizer_step(self._optimizer, self._optimizer_idx, closure, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 336, in optimizer_step
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 163, in optimizer_step
optimizer.step(closure=closure, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/optim/optimizer.py", line 113, in wrapper
return func(args, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/optim/adamw.py", line 119, in step
loss = closure()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 148, in _wrap_closure
closure_result = closure()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 160, in call
self._result = self.closure(*args, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 142, in closure
step_output = self._step_fn()
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 435, in _training_step
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 216, in training_step
return self.training_type_plugin.training_step(step_kwargs.values())
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 439, in training_step
return self.model(args, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward
output = self._run_ddp_forward(*inputs, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 969, in _run_ddp_forward
return module_to_run(*inputs[0], *kwargs[0])
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(input, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 81, in forward
output = self.module.training_step(*inputs, kwargs)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tats_transformer.py", line 243, in training_step
loss, acc1, acc5 = self.shared_step(batch, batch_idx)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tats_transformer.py", line 237, in shared_step
logits, target = self(x, c, cbox)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, *kwargs)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/tatstransformer.py", line 126, in forward
logits, = self.transformer(cz_indices[:, :-1], cbox=cbox)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(input, kwargs)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/modules/gpt.py", line 218, in forward
x = self.blocks(x)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, kwargs)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
input = module(input)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, *kwargs)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/modules/gpt.py", line 153, in forward
attn, present = self.attn(self.ln1(x), layer_past=layer_past)
File "/home/pkomor/anaconda3/envs/tats/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(input, kwargs)
File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/../tats/modules/gpt.py", line 121, in forward
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 1; 31.74 GiB total capacity; 29.71 GiB already allocated; 45.12 MiB free; 30.61 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
t=4096
andself.block_size=1024
. When I tried to increase theblock_size
argument to 4096, I received another error message (even with abatch_size
of 1): Traceback (most recent call last): File "/lu/tetyda/home/pkomor/Projects/VidGen/TATS/scripts/train_transformer.py", line 77, inCan you suggest any potential solutions that might work?