bpowell122 / tomodrgn

Neural networks to analyze structural heterogeneity in cryo-electron sub-tomograms
GNU General Public License v3.0
19 stars 1 forks source link

train_vae AttributeError: 'DataParallel' object has no attribute 'zdim' #6

Open TJN25 opened 1 year ago

TJN25 commented 1 year ago

Describe the bug When running the quicktest.py example and adding the flag --multigpu the error AttributeError: 'DataParallel' object has no attribute 'zdim' is produced.

To Reproduce Running the first command from within quicktest.py modified with --multigpu and possibly(?) --num-workers 2 (to resolve a previous issue).

tomodrgn train_vae data/10076_both_32_sim.star -o output/01_vae_both_sim --zdim 8 --uninvert-data --seed 42 --log-interval 100 --enc-dim-A 64 --enc-layers-A 2 --out-dim-A 64 --enc-dim-B 32 --enc-layers-B 4 --dec-dim 16 --dec-layers 3 -n 5 --num-workers 2 --multigpu

Expected behavior The script to run the whole way through and not give this error.

Additional context I have dug into it and the train_vae.py file (line 257 and 258) have a call for model.zdim. With the --multigpu flag, the data type is DataParallel and not TiltSeriesHetOnlyVAE. This error can be cleared by changing this to model.module.zdim but then it fails without --mulitgpu as the data type is TiltSeriesHetOnlyVAE.

Adding the lines:

if isinstance(model, nn.DataParallel):
                model = model.module

resolves the error messages, but I do not know if this retains the function of the code.

TJN25 commented 1 year ago

I have found that adding the test if isinstance(model, nn.DataParallel): and assigning an object zdim_value to use later has worked.

            if isinstance(model, nn.DataParallel):
                zdim_value = model.module.zdim
            else:
                zdim_value = model.zdim
            z_mu_all = torch.zeros((data.nptcls, zdim_value), device=device, dtype=torch.half if use_amp else torch.float) #changed model.zdim for zdim_value
            z_logvar_all = torch.zeros((data.nptcls, zdim_value), device=device, dtype=torch.half if use_amp else torch.float) #changed model.zdim for zdim_value