arsedler9 / lfads-torch

A PyTorch implementation of Latent Factor Analysis via Dynamical Systems (LFADS) and AutoLFADS.
https://arxiv.org/abs/2309.01230
Other
82 stars 18 forks source link

What are the learning parameters to avoid nan values? #13

Open Riverside-ms opened 9 months ago

Riverside-ms commented 9 months ago

Hi Andrew,

I applied the multisession_PCR analysis based on the tutorial to my original data set. Then I got the following error. It seems to be an error due to the training parameter. Below are the parameters for the architecture of the multisession_PCR.yaml file used for the test. Can you please advise me which value to change and how to fix it?

[Error message] ERROR trial_runner.py:993 -- Trial run_model_2595b_00009: Error processing event. ray.exceptions.RayTaskError(ValueError): ray::ImplicitFunc.train() (pid=2684, ip=127.0.0.1, repr=run_model) File "python\ray_raylet.pyx", line 859, in ray._raylet.execute_task File "python\ray_raylet.pyx", line 863, in ray._raylet.execute_task File "python\ray_raylet.pyx", line 810, in ray._raylet.execute_task.function_executor File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\ray_private\function_manager.py", line 674, in actor_method_executor return method(ray_actor, *args, kwargs) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\ray\util\tracing\tracing_helper.py", line 466, in _resume_span return method(self, *_args, *_kwargs) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\ray\tune\trainable\trainable.py", line 355, in train raise skipped from exception_cause(skipped) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\ray\tune\trainable\function_trainable.py", line 325, in entrypoint return self._trainable_func( File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\ray\util\tracing\tracing_helper.py", line 466, in _resume_span return method(self, _args, _kwargs) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\ray\tune\trainable\function_trainable.py", line 651, in _trainable_func output = fn() File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\ray\tune\trainable\util.py", line 365, in inner trainable(config, fn_kwargs) File "c:\windows\system32\lfads-torch\lfads_torch\run_model.py", line 78, in run_model trainer.fit( File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 771, in fit self._call_and_handle_interrupt( File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 724, in _call_and_handle_interrupt return trainer_fn(*args, *kwargs) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 812, in _fit_impl results = self._run(model, ckpt_path=self.ckpt_path) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1237, in _run results = self._run_stage() File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1324, in _run_stage return self._run_train() File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1354, in _run_train self.fit_loop.run() File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\base.py", line 204, in run self.advance(args, kwargs) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\fit_loop.py", line 269, in advance self._outputs = self.epoch_loop.run(self._data_fetcher) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\base.py", line 204, in run self.advance(*args, kwargs) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\epoch\training_epoch_loop.py", line 208, in advance batch_output = self.batch_loop.run(batch, batch_idx) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\base.py", line 204, in run self.advance(*args, *kwargs) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\batch\training_batch_loop.py", line 88, in advance outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\base.py", line 204, in run self.advance(args, kwargs) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\optimization\optimizer_loop.py", line 203, in advance result = self._run_optimization( File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\optimization\optimizer_loop.py", line 256, in _run_optimization self._optimizer_step(optimizer, opt_idx, batch_idx, closure) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\optimization\optimizer_loop.py", line 369, in _optimizer_step self.trainer._call_lightning_module_hook( File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1596, in _call_lightning_module_hook output = fn(args, kwargs) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\core\lightning.py", line 1625, in optimizer_step optimizer.step(closure=optimizer_closure) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\core\optimizer.py", line 168, in step step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, kwargs) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\strategies\strategy.py", line 193, in optimizer_step return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, kwargs) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\plugins\precision\precision_plugin.py", line 155, in optimizer_step return optimizer.step(closure=closure, kwargs) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\torch\optim\optimizer.py", line 140, in wrapper out = func(args, *kwargs) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context return func(args, **kwargs) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\torch\optim\adamw.py", line 120, in step loss = closure() File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\plugins\precision\precision_plugin.py", line 140, in _wrap_closure closure_result = closure() File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\optimization\optimizer_loop.py", line 148, in call self._result = self.closure(*args, kwargs) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\optimization\optimizer_loop.py", line 134, in closure step_output = self._step_fn() File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\loops\optimization\optimizer_loop.py", line 427, in _training_step training_step_output = self.trainer._call_strategy_hook("training_step", step_kwargs.values()) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1766, in _call_strategy_hook output = fn(args, kwargs) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\pytorch_lightning\strategies\strategy.py", line 333, in training_step return self.model.training_step(*args, **kwargs) File "c:\windows\system32\lfads-torch\lfads_torch\model.py", line 487, in training_step return self._shared_step(batch, batch_idx, "train") File "c:\windows\system32\lfads-torch\lfads_torch\model.py", line 357, in _shared_step output = self.forward( File "c:\windows\system32\lfads-torch\lfads_torch\model.py", line 231, in forward ic_post = self.ic_prior.make_posterior(ic_mean, ic_std) File "c:\windows\system32\lfads-torch\lfads_torch\modules\priors.py", line 30, in make_posterior return Independent(Normal(post_mean, post_std), 1) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\torch\distributions\normal.py", line 56, in init super(Normal, self).init__(batch_shape, validate_args=validate_args) File "C:\ProgramData\anaconda3\envs\lfads-torch\lib\site-packages\torch\distributions\distribution.py", line 56, in init raise ValueError( ValueError: Expected parameter loc (Tensor of shape (980, 100)) of distribution Normal(loc: torch.Size([980, 100]), scale: torch.Size([980, 100])) to satisfy the constraint Real(), but found invalid values: tensor([[nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], ..., [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan], [nan, nan, nan, ..., nan, nan, nan]], grad_fn=)

[parameters for the architecture of the multisession_PCR.yaml] encod_data_dim: 150 encod_seq_len: 40 recon_seq_len: ${model.encod_seq_len} ext_input_dim: 0 ic_enc_seq_len: 0 ic_enc_dim: 100 ci_enc_dim: 100 ci_lag: 1 con_dim: 100 co_dim: 6 ic_dim: 100 gen_dim: 100 fac_dim: 150

arsedler9 commented 9 months ago

Hi! Does this happen immediately or after a few training steps? I think it's likely that the learning rate is too high. I would try running just a single model with scripts/run_single.py (using your new config) on the data a few times with lower learning rates to get a sense of what learning rates would be appropriate.

Riverside-ms commented 9 months ago

Thanks for the reply. This error occurs after several training steps. Which variable do you mean by learning rate?

arsedler9 commented 9 months ago

This lr_init parameter controls the initial learning rate, which will be reduced over the course of training. I’d recommend trying a few different values (1e-3, 3e-4, 1e-4, 3e-5, etc.) and visualizing loss curves with tensorboard / wandb to see which allows a quick but stable descent. https://github.com/arsedler9/lfads-torch/blob/e5b540eaff84359e5f5d1a73131a22e7ba9bbbec/configs/model/rouse_multisession_PCR.yaml#L66

Riverside-ms commented 9 months ago

Thank you for your kind suggestion. I have tried various learning rates (1e-4, 1e-5, 1e-6, 1e-7, 1e-8) and still get the same error. Are there any other possible causes?

arsedler9 commented 9 months ago

Hmmm… what are your batch size, sequence length, and number of neurons? It looks like sequence length may be close to 1k. 100-300 steps is more typical. You’d probably find it easier to fit that data with a 3-5x larger bin size.

Riverside-ms commented 9 months ago

Those parameters are set in accordance with the tutorial. The batch size is 1000, the sequence size is 40 (20 ms bin × 40), and the number of neurons varies from session to session, ranging from 30 to 200. The number of sessions varies from 20 to 80 per brain region, and the number of conditions is 20 or 30.

arsedler9 commented 9 months ago

Hmm, could you upload your files to Google Drive and send me a link to arsedler9@gmail.com?

arsedler9 commented 8 months ago

Hey @Riverside-ms just checking in-- were you able to resolve this?

Riverside-ms commented 8 months ago

Sorry for the late reply. I have emailed you, please check it.