wilson1yan / VideoGPT

MIT License
962 stars 115 forks source link

Tensor Size Mismatch in attention.py #28

Closed humishum closed 6 months ago

humishum commented 2 years ago

Getting the following runtime error when running train_videogpt.py on a custom dataset. As for the training args, I trained the vqvae with a batch size of 16, and doing so the same for videogpt. Attached is the traceback. New to pytorch, so any pointers as to where I could go to fix this? Currently only using 1 GPU as I was getting DDP errors from lightning using two, but will be going back later once I can make sure the training can run properly. Any idea what this tensor has in dim3? The dataset I'm using is simply a train/test set of mp4s.

Used this command with basic arguments. python3 scripts/train_videogpt.py --data_path <custom_data_path> --gpus 1 --batch_size 16 --vqvae <custom_data_path/epoch=0-step=11702.ckpt> --max_steps 200000

Traceback (most recent call last): File "scripts/train_videogpt.py", line 43, in main() File "scripts/train_videogpt.py", line 39, in main trainer.fit(model, data) File "/home/humdaan/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 458, in fit self._run(model) File "/home/humdaan/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 756, in _run self.dispatch() File "/home/humdaan/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 797, in dispatch self.accelerator.start_training(self) File "/home/humdaan/miniconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training self.training_type_plugin.start_training(trainer) File "/home/humdaan/miniconda3/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training self._results = trainer.run_stage() File "/home/humdaan/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 807, in run_stage return self.run_train() File "/home/humdaan/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 842, in run_train self.run_sanity_check(self.lightning_module) File "/home/humdaan/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1107, in run_sanity_check self.run_evaluation() File "/home/humdaan/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 962, in run_evaluation output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx) File "/home/humdaan/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 174, in evaluation_step output = self.trainer.accelerator.validation_step(args) File "/home/humdaan/miniconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 226, in validation_step return self.training_type_plugin.validation_step(args) File "/home/humdaan/miniconda3/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 161, in validation_step return self.lightning_module.validation_step(args, kwargs) File "/home/humdaan/Documents/school/AME494/videoGPT/VideoGPT/videogpt/gpt.py", line 158, in validation_step loss = self.training_step(batch, batch_idx) File "/home/humdaan/Documents/school/AME494/videoGPT/VideoGPT/videogpt/gpt.py", line 154, in trainingstep loss, = self(x, targets, cond) File "/home/humdaan/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, *kwargs) File "/home/humdaan/Documents/school/AME494/videoGPT/VideoGPT/videogpt/gpt.py", line 131, in forward h = self.attn_stack(h, cond, decode_step, decode_idx) File "/home/humdaan/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(input, kwargs) File "/home/humdaan/Documents/school/AME494/videoGPT/VideoGPT/videogpt/attention.py", line 55, in forward x = self.pos_embd(x, decode_step, decode_idx) File "/home/humdaan/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "/home/humdaan/Documents/school/AME494/videoGPT/VideoGPT/videogpt/attention.py", line 493, in forward return x + embs RuntimeError: The size of tensor a (32) must match the size of tensor b (16) at non-singleton dimension 3

wilson1yan commented 2 years ago

I believe the shape for x and embs should both be B x T x H x W x D, where B is batch size, THW is the shape of the encoded video, and D is the hidden dimension. So dim 3 should be W. You can try to print it out to confirm.

humishum commented 2 years ago

That clears it up a bit, thank you! Didn't change any settings or code asides from the batch size so not sure why the height and width don't match up, will have to dig deeper into it. torch.Size([16, 4, 32, 32, 576])
torch.Size([1, 4, 16, 16, 576])

wilson1yan commented 2 years ago

The shape of the position embeddings comes from vqvae.latent_shape, so that's somehow not matching the actual shape of the output of the vqvae. Maybe the video resolution is different when you trained the vqvae vs the gpt model?