songweige / TATS

Official PyTorch implementation of TATS: A Long Video Generation Framework with Time-Agnostic VQGAN and Time-Sensitive Transformer (ECCV 2022)
MIT License
267 stars 17 forks source link

Training on single GPU #7

Closed ndahlqvist closed 2 years ago

ndahlqvist commented 2 years ago

Hi!

Thanks for the great work. I have been trying to train on a single GPU but it keeps throwing this error:

"Default process group has not been initialized, " RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

Is it possible to configure the model to train on a single GPU?

full error message:

Traceback (most recent call last): File "/content/TATS/scripts/train_vqgan.py", line 70, in main() File "/content/TATS/scripts/train_vqgan.py", line 66, in main trainer.fit(model, data) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 738, in fit self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt return trainer_fn(*args, kwargs) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 772, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1195, in _run self._dispatch() File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _dispatch self.training_type_plugin.start_training(self) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 202, in start_training self._results = trainer.run_stage() File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1284, in run_stage return self._run_train() File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1314, in _run_train self.fit_loop.run() File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(*args, *kwargs) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/fit_loop.py", line 234, in advance self.epoch_loop.run(data_fetcher) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(args, kwargs) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 193, in advance batch_output = self.batch_loop.run(batch, batch_idx) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(*args, kwargs) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 145, in run self.advance(*args, *kwargs) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 219, in advance self.optimizer_idx, File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 259, in _run_optimization closure() File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 160, in call self._result = self.closure(args, kwargs) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 142, in closure step_output = self._step_fn() File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 435, in _training_step training_step_output = self.trainer.accelerator.training_step(step_kwargs) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 216, in training_step return self.training_type_plugin.training_step(step_kwargs.values()) File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 213, in training_step return self.model.training_step(args, kwargs) File "/content/TATS/scripts/tats/tats_vqgan.py", line 182, in training_step reconloss, , vq_output, aeloss, perceptual_loss, gan_feat_loss = self.forward(x, optimizer_idx) File "/content/TATS/scripts/tats/tats_vqgan.py", line 118, in forward logits_image_fake, pred_image_fake = self.image_discriminator(frames_recon) File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/content/TATS/scripts/tats/tats_vqgan.py", line 463, in forward res.append(model(res[-1])) File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(input, kwargs) File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/container.py", line 139, in forward input = module(input) File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/usr/local/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 731, in forward world_size = torch.distributed.get_world_size(process_group) File "/usr/local/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 867, in get_world_size return _get_group_size(group) File "/usr/local/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 325, in _get_group_size default_pg = _get_default_group() File "/usr/local/lib/python3.7/site-packages/torch/distributed/distributed_c10d.py", line 430, in _get_default_group "Default process group has not been initialized, " RuntimeError: Default process group has not been initialized, please make sure to call init_process_group.

songweige commented 2 years ago

Did you set args.gpus=1 in your training script?

ndahlqvist commented 2 years ago

Yes! this is the arguments i use:

!python train_vqgan.py --embedding_dim 256 --n_codes 16384 --n_hiddens 32 --no_random_restart \ --gpus 1 --sync_batchnorm --batch_size 2 --num_workers 32 --accumulate_grad_batches 6 \ --progress_bar_refresh_rate 500 --max_steps 2000000 --gradient_clip_val 1.0 --lr 3e-5 \ --data_path /content/drive/MyDrive/ut --default_root_dir /content/drive/MyDrive/ut \ --resolution 128 --sequence_length 16 --discriminator_iter_start 10000 --norm_type batch \ --perceptual_weight 4 --image_gan_weight 1 --video_gan_weight 1 --gan_feat_weight 4

ndahlqvist commented 2 years ago

Solveed it with:

import torch.distributed as dist dist.init_process_group('gloo', init_method='file:///tmp/somefile', rank=0, world_size=1)

songweige commented 2 years ago

Thanks for the updated solution!