ml-struct-bio / cryodrgn

Neural networks for cryo-EM reconstruction
http://cryodrgn.cs.princeton.edu
GNU General Public License v3.0
307 stars 76 forks source link

Gaussian positional encoding not implemented for --domain hartley #202

Open zhonge opened 1 year ago

zhonge commented 1 year ago

The default domain for cryodrgn abinit_homo and cryodrgn abinit_het should be hartley (currently set to fourier).

When I made this change, we ran into the following error since we never implemented random fourier features (--pe-type gaussian) for the Hartley decoder:

Traceback (most recent call last): File "/home/zhonge/dev/cryodrgn/testing/../cryodrgn/commands/abinit_het.py", line 1137, in main(args) File "/home/zhonge/dev/cryodrgn/testing/../cryodrgn/commands/abinit_het.py", line 930, in main loss = pretrain(model, lattice, optim, batch, tilt=ps.tilt, zdim=args.zdim) File "/home/zhonge/dev/cryodrgn/testing/../cryodrgn/commands/abinit_het.py", line 407, in pretrain gen_loss = F.mse_loss(gen_slice(rot), y) File "/home/zhonge/dev/cryodrgn/testing/../cryodrgn/commands/abinit_het.py", line 398, in gen_slice return _model.decode(lattice.coords[mask] @ R, z).view(B, -1) File "/home/zhonge/dev/cryodrgn/cryodrgn/models.py", line 160, in decode retval = decoder(self.cat_z(coords, z) if z is not None else coords) File "/home/zhonge/.conda/envs/cryodrgn/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/zhonge/dev/cryodrgn/cryodrgn/models.py", line 315, in forward return self.decoder(self.positional_encoding_geom(coords)) File "/home/zhonge/dev/cryodrgn/cryodrgn/models.py", line 284, in positional_encoding_geom raise RuntimeError("Encoding type {} not recognized".format(self.enc_type)) RuntimeError: Encoding type gaussian not recognized

zhonge commented 1 year ago

Fixed in commit ba7ae6d, and updated the default decoder domain from Fourier to Hartley in 69251ff.

Todo: the PositionalDecoder and FTPositionalDecoder models could be refactored (lots of repeat code).