arsedler9 / lfads-torch

A PyTorch implementation of Latent Factor Analysis via Dynamical Systems (LFADS) and AutoLFADS.
https://arxiv.org/abs/2309.01230
Apache License 2.0
81 stars 18 forks source link

Issue with Gaussian and Gamma reconstruction cost #23

Open xiniyiwang opened 1 week ago

xiniyiwang commented 1 week ago

Hi Andrew, I can successfully run the multisession script with Poisson cost. However, I have a problem using Gamma reconstruction cost with calcium imaging dataset. This issue also occurs when I use Gaussian reconstruction. When I set:

reconstruction: target: lfads_torch.modules.recons.MultisessionReconstruction datafile_pattern: ${datamodule.datafile_pattern} recon: target: lfads_torch.modules.recons.Gamma

there is an error in line 367, lfads-torch/model: 'batch[s].recon_data' has a shape of [21, 25, 92], but 'output[s].output_params' has a shape of [21, 25, 46, 2]. where 21 is the number of valid trials, 25 is time steps, and 92 is the neuron numbers. Since Gamma and Gaussian reconstruction need two parameters, the dataset will be divided into two parts. However, in lfads-torch code, the divided dataset is not reshaped to be the original size when calculating loss. Is this the reason why this issue occurs?

Is there anything else I need to do after setting 'target' as 'lfads_torch.modules.recons.Gamma'? I'd appreciate it if you could give me some advice. Thanks!

arsedler9 commented 1 week ago

Hey @xiniyiwang, in addition to setting the reconstruction._target_ to one of the objects in recons, you'll need to set readout to the following:

readout:
  _target_: lfads_torch.modules.readin_readout.MultisessionReadout
  datafile_pattern: ${datamodule.datafile_pattern}
  in_features: ${model.fac_dim}
  pcr_init: False
  recon_params: 2

You might also find it helpful to replace CoordinatedDropout with TemporalShift in the train_aug_stack for Gamma observations (more detail in the paper). Also tagging first author @lwimala in case he wants to follow along, though I'm not sure if he has worked much with this implementation. Hope that helps!

xiniyiwang commented 1 week ago

Thanks for your advice. The error as mentioned earlier didn't happen now. However, when I set readout as:

readout:
  _target_: lfads_torch.modules.readin_readout.MultisessionReadout
  datafile_pattern: ${datamodule.datafile_pattern}
  in_features: ${model.fac_dim}
  pcr_init: False
  recon_params: 2

and train_aug_stack:

train_aug_stack:
  _target_: lfads_torch.modules.augmentations.AugmentationStack
  transforms:
    - _target_: lfads_torch.modules.augmentations.TemporalShift #CoordinatedDropout 
  batch_order: [0]
  loss_order: [0]
infer_aug_stack:
  _target_: lfads_torch.modules.augmentations.AugmentationStack

and:

reconstruction:
  _target_: lfads_torch.modules.recons.MultisessionReconstruction
  datafile_pattern: ${datamodule.datafile_pattern}
  recon:
    _target_: lfads_torch.modules.recons.Gamma
variational: True
co_prior:
  _target_: lfads_torch.modules.priors.AutoregressiveMultivariateNormal
  tau: 10.0
  nvar: 0.1
  shape: ${model.co_dim}
ic_prior:
  _target_: lfads_torch.modules.priors.MultivariateNormal
  mean: 0
  variance: 0.1
  shape: ${model.ic_dim}
ic_post_var_min: 1.0e-4
  1. The above settings will lead to an error:

    ray.exceptions.RayTaskError(InstantiationException): ray::ImplicitFunc.train() (pid=2766465, ip=192.168.1.102, repr=run_model)
    File "/home/datadisk1/wxy/lfads-torch/lfads_torch/modules/augmentations.py", line 25, in __init__
    assert all([hasattr(t, "process_losses") for t in self.loss_transforms])
    AssertionError
  2. When I use CoordinatedDropout instead:

    train_aug_stack:
    _target_: lfads_torch.modules.augmentations.AugmentationStack
    transforms:
    - _target_: lfads_torch.modules.augmentations.CoordinatedDropout #TemporalShift
      cd_rate: 0.3 # sampled
      cd_pass_rate: 0.0
      ic_enc_seq_len: ${model.ic_enc_seq_len}
    batch_order: [0]
    loss_order: [0]
    infer_aug_stack:
    _target_: lfads_torch.modules.augmentations.AugmentationStack

    another error occurs:

    File "/home/datadisk1/wxy/lfads-torch/lfads_torch/model.py", line 492, in training_step
    return self._shared_step(batch, batch_idx, "train")
    File "/home/datadisk1/wxy/lfads-torch/lfads_torch/model.py", line 362, in _shared_step
    output = self.forward(
    File "/home/datadisk1/wxy/lfads-torch/lfads_torch/model.py", line 236, in forward
    ic_post = self.ic_prior.make_posterior(ic_mean, ic_std)
    File "/home/datadisk1/wxy/lfads-torch/lfads_torch/modules/priors.py", line 30, in make_posterior
    return Independent(Normal(post_mean, post_std), 1)
    File "/home/datadisk/wxy/Anaconda/envs/lfads-torch/lib/python3.9/site-packages/torch/distributions/normal.py", line 56, in __init__
    super(Normal, self).__init__(batch_shape, validate_args=validate_args)
    File "/home/datadisk/wxy/Anaconda/envs/lfads-torch/lib/python3.9/site-packages/torch/distributions/distribution.py", line 56, in __init__
    raise ValueError(
    ValueError: Expected parameter loc (Tensor of shape (40, 100)) of distribution Normal(loc: torch.Size([40, 100]), scale: torch.Size([40, 100])) to satisfy the constraint Real(), but found invalid values:
    tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       grad_fn=<SplitBackward0>)

    Is there any step or parameter I didn't set up correctly?

I can run the multisession script successfully with Gaussian reconstruction:

reconstruction:
  _target_: lfads_torch.modules.recons.MultisessionReconstruction
  datafile_pattern: ${datamodule.datafile_pattern}
  recon:
    _target_: lfads_torch.modules.recons.Gaussian #Gamma
variational: True
co_prior:
  _target_: lfads_torch.modules.priors.AutoregressiveMultivariateNormal
  tau: 10.0
  nvar: 0.1
  shape: ${model.co_dim}
ic_prior:
  _target_: lfads_torch.modules.priors.MultivariateNormal
  mean: 0
  variance: 0.1
  shape: ${model.ic_dim}
ic_post_var_min: 1.0e-4
arsedler9 commented 4 days ago

Thanks for reporting this @xiniyiwang. I looked into it and it looks we didn't have process_losses implemented for TemporalShift, which was causing the first error. I just merged PR #25 which should fix this issue. The NaNs you're seeing with CoordinatedDropout are likely an optimization failure, so I would recommend reducing the learning rate.

xiniyiwang commented 20 hours ago

The first error is fixed using the new code. However, the NaNs also exist immediately, even if I set the learning rate to 1e-7. Based on my understanding, maybe the (batch_size, seq_len or dim) of (ic, ci or co), should match each other? Here are my settings:

 --------- architecture --------- #
encod_data_dim: 30
encod_seq_len: 37
recon_seq_len: ${model.encod_seq_len}
ext_input_dim: 0
ic_enc_seq_len: 0
ic_enc_dim: 32
ci_enc_dim: 32
ci_lag: 1
con_dim: 32
co_dim: 4
ic_dim: 32
gen_dim: 100
fac_dim: 30

and: batch_size=100