Closed dbjhbyun closed 4 years ago
Hi,
DataParallel class requires the input to be in regular tensors, where certain dimension represents batch index. In JT-VAE, each batch (say batch_size=10) is viewed as a big graph with 10 disconnected components. So there is essentially no batch dimension. This maybe cause DataParallel to fail.
I think from pytorch 1.4, Pytorch suggests to use DistributedParallel instead of DataParallel. I guess DistributedParallel would work.
Hi,
Is there any way to use torch.nn.parallel.DataParallel to run the model on multiple gpus? I tried naively wrapping the model = JTNNVAE(vocab, args.hidden_size, args.latent_size, args.depthT, args.depthG) with model = DataParallel(model, device_ids=[0,1,2,3,4,5,6,7]) but it gave me the following errors:
Also, thank you for sharing your awesome code! Really found your work really fascinating!
Thanks in advance!