ml-struct-bio / cryodrgn

Neural networks for cryo-EM reconstruction
http://cryodrgn.cs.princeton.edu
GNU General Public License v3.0
316 stars 76 forks source link

AssertionError when using train_vae #389

Open oleuns opened 3 months ago

oleuns commented 3 months ago

Hey, I am encountering an AssertionError on one of our workstations when I use multiple GPUs. The error only occurs when I am utilizing more than one GPU (the workstation has 4 RTX 3080TI).

The command that I am using is cryodrgn train_vae particles.128.mrcs --poses poses.pkl --ctf ctf.pkl --zdim 8 -n 50 --multigpu -o 128_it0. Besides this, I also tried specifying the GPU with CUDA_VISIBLE_DEVICES=0,1 cryodrgn train_vae particles.128.mrcs --poses poses.pkl --ctf ctf.pkl --zdim 8 -n 50 --multigpu -o 128_it1 which didn't resolve the issue. Further, I used another workstation for training on the same dataset (2x RTX 4090) which worked well using the --multigpu flag.

The error is the following: (INFO) (train_vae.py) (23-Jul-24 14:02:55) Increasing batch size to 32 Traceback (most recent call last): File "/home/supervisor/miniconda3/envs/cryodrgn-env/bin/cryodrgn", line 8, in sys.exit(main_commands()) File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/cryodrgn/command_line.py", line 68, in main_commands _get_commands( File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/cryodrgn/command_line.py", line 63, in _get_commands args.func(args) File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/cryodrgn/commands/train_vae.py", line 941, in main loss, gen_loss, kld = train_batch( File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/cryodrgn/commands/train_vae.py", line 380, in train_batch z_mu, z_logvar, z, y_recon, mask = run_batch( File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/cryodrgn/commands/train_vae.py", line 452, in run_batch y_recon = model(lattice.coords[mask] / lattice.extent / 2 @ rot, z).view(B, -1) File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, *kwargs) File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 185, in forward outputs = self.parallel_apply(replicas, inputs, module_kwargs) File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/torch/nn/parallel/data_parallel.py", line 200, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 108, in parallel_apply output.reraise() File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/torch/_utils.py", line 705, in reraise raise exception AssertionError: Caught AssertionError in replica 1 on device 1. Original Traceback (most recent call last): File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker output = module(input, kwargs) File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, *kwargs) File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/cryodrgn/models.py", line 176, in forward return self.decode(args, kwargs) File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/cryodrgn/models.py", line 171, in decode retval = decoder(self.cat_z(coords, z) if z is not None else coords) File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, **kwargs) File "/home/supervisor/miniconda3/envs/cryodrgn-env/lib/python3.9/site-packages/cryodrgn/models.py", line 537, in forward assert abs(lattice[..., 0:3].mean()) < 1e-4, "{} != 0.0".format( AssertionError: -0.2916666567325592 != 0.0

I would be happy if someone could help me out with this!

michal-g commented 3 months ago

I have not seen this error yet, but will double-check to see if it pops up when I use --multigpu across some testing runs on our cluster — in the meantime, can you check whether the version of CUDA tools you have installed in the environment from which you are running cryoDRGN (e.g. using conda list | grep 'cuda') matches that which is installed on the workstation with 3080TI machines (checked using e.g. nvidia-smi)?

oleuns commented 3 months ago

This issue is also very specific to this workstation, unfortunately, I don't have access to other workstations with more than 2 GPUs to validate if it's a general issue. Two more things I can add are that it is data set independent and setting up a new conda environment for cryodrgn (conda create --name cryodrgn-env python=3.9 conda activate cryodrgn-env pip install cryodrgn) doesn't fix the issue.

There seems to be a CUDA mismatch as conda in the cryodrgn environment reports CUDA 12.1.105 (conda list | grep 'cuda' nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi) and nvidia-smi CUDA version 12.3 (NVIDIA-SMI 545.23.08, Driver Version 545.23.08). I will try to fix that and let you know if it helped.

michal-g commented 3 months ago

Ok, 12.1 vs. 12.3 would not necessarily break things but it's worth trying! Can you also try adding --lazy to your command just in case this is caused by an out-of-memory issue?

oleuns commented 3 months ago

I'll try later. Another thing I tested today was using cryodrgn version 2.3, however, the error persists. I have the gut feeling that it might have something to do with this very specific workstation setup. I tried adding --lazy but it did not change things, the particle stack is only 34 GB. The workstation is fairly new with 512 GB RAM, Intel® Xeon(R) w5-3435X × 32, and runs on Ubuntu 22.04.3 LTS.

ts387 commented 2 months ago

Hey @michal-g & @oleuns.

I get the same kind of error, again specifically when using --multigpu.

michal-g commented 1 month ago

Another thing to try that has resolved memory issues for me when using --multigpu (which may be upstream of the errors mentioned here) is reducing the batch size to the minimal size (-b 1)!