wengong-jin / icml18-jtnn

Junction Tree Variational Autoencoder for Molecular Graph Generation (ICML 2018)
MIT License
509 stars 190 forks source link

DataParallel issues #59

Closed dbjhbyun closed 4 years ago

dbjhbyun commented 4 years ago

Hi,

Is there any way to use torch.nn.parallel.DataParallel to run the model on multiple gpus? I tried naively wrapping the model = JTNNVAE(vocab, args.hidden_size, args.latent_size, args.depthT, args.depthG) with model = DataParallel(model, device_ids=[0,1,2,3,4,5,6,7]) but it gave me the following errors:

Screenshot from 2020-07-20 14-31-57

Also, thank you for sharing your awesome code! Really found your work really fascinating!

Thanks in advance!

wengong-jin commented 4 years ago

Hi,

DataParallel class requires the input to be in regular tensors, where certain dimension represents batch index. In JT-VAE, each batch (say batch_size=10) is viewed as a big graph with 10 disconnected components. So there is essentially no batch dimension. This maybe cause DataParallel to fail.

I think from pytorch 1.4, Pytorch suggests to use DistributedParallel instead of DataParallel. I guess DistributedParallel would work.