ehoogeboom / e3_diffusion_for_molecules

MIT License
408 stars 110 forks source link

AssertionError: Variables not masked properly. #24

Closed sundevil0405 closed 7 months ago

sundevil0405 commented 1 year ago

I get the following error message when running main_qm9.py with 8 GPU cards. But the error is gone If I use 4 cards or even 1 card. Could you please let me know how to fix it?

Entropy of n_nodes: H[N] -2.475700616836548 alphas2 [9.99990000e-01 9.99988000e-01 9.99982000e-01 ... 2.59676966e-05 1.39959211e-05 1.00039959e-05] gamma [-11.51291546 -11.33059532 -10.92513058 ... 10.55863126 11.17673063 11.51251595] Training using 8 GPUs Traceback (most recent call last): File "/disk/nvme1n1/mg/e3_diffusion_for_molecules/main_qm9.py", line 289, in main() File "/disk/nvme1n1/mg/e3_diffusion_for_molecules/main_qm9.py", line 241, in main train_epoch(args=args, loader=dataloaders['train'], epoch=epoch, model=model, model_dp=model_dp, File "/disk/nvme1n1/mg/e3_diffusion_for_molecules/train_test.py", line 53, in train_epoch nll, reg_term, mean_abs_z = losses.compute_loss_and_nll(args, model_dp, nodes_dist, File "/disk/nvme1n1/mg/e3_diffusion_for_molecules/qm9/losses.py", line 23, in compute_loss_and_nll nll = generative_model(x, h, node_mask, edge_mask, context) File "/root/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, kwargs) File "/root/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/root/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 181, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/root/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 89, in parallel_apply output.reraise() File "/root/anaconda3/lib/python3.9/site-packages/torch/_utils.py", line 543, in reraise raise exception AssertionError: Caught AssertionError in replica 4 on device 4. Original Traceback (most recent call last): File "/root/anaconda3/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 64, in _worker output = module(*input, *kwargs) File "/root/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(input, kwargs) File "/disk/nvme1n1/mg/e3_diffusion_for_molecules/equivariant_diffusion/en_diffusion.py", line 701, in forward loss, loss_dict = self.compute_loss(x, h, node_mask, edge_mask, context, t0_always=False) File "/disk/nvme1n1/mg/e3_diffusion_for_molecules/equivariant_diffusion/en_diffusion.py", line 606, in compute_loss diffusion_utils.assert_mean_zero_with_mask(z_t[:, :, :self.n_dims], node_mask) File "/disk/nvme1n1/mg/e3_diffusion_for_molecules/equivariant_diffusion/utils.py", line 47, in assert_mean_zero_with_mask assert_correctly_masked(x, node_mask) File "/disk/nvme1n1/mg/e3_diffusion_for_molecules/equivariant_diffusion/utils.py", line 56, in assert_correctly_masked assert (variable * (1 - node_mask)).abs().max().item() < 1e-4, \ AssertionError: Variables not masked properly.

ehoogeboom commented 7 months ago

honestly not sure, sorry. 8 GPUs is a lot for qm9 though ;)