mlech26l / ncps

PyTorch and TensorFlow implementation of NCP, LTC, and CfC wired neural models
https://www.nature.com/articles/s42256-020-00237-3
Apache License 2.0
1.86k stars 297 forks source link

TypeError: SequenceLearner.optimizer_step() missing 1 required positional argument: 'closure' #50

Open AXYZdong opened 1 year ago

AXYZdong commented 1 year ago

When I want to run the pt_example, the Error happened.

Error displaying widget: model not found


TypeError                                 Traceback (most recent call last)
Cell In[15], line 2
      1 # Train the model for 400 epochs (= training steps)
----> 2 trainer.fit(model=learn, train_dataloaders=dataloader)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:520, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    518 model = _maybe_unwrap_optimized(model)
    519 self.strategy._lightning_module = model
--> 520 call._call_and_handle_interrupt(
    521     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    522 )

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
     43     else:
---> 44         return trainer_fn(*args, **kwargs)
     46 except _TunerExitException:
     47     _call_teardown_hook(trainer)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:559, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    549 self._data_connector.attach_data(
    550     model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
    551 )
    553 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    554     self.state.fn,
    555     ckpt_path,
    556     model_provided=True,
    557     model_connected=self.lightning_module is not None,
    558 )
--> 559 self._run(model, ckpt_path=ckpt_path)
    561 assert self.state.stopped
    562 self.training = False

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:935, in Trainer._run(self, model, ckpt_path)
    930 self._signal_connector.register_signal_handlers()
    932 # ----------------------------
    933 # RUN THE TRAINER
    934 # ----------------------------
--> 935 results = self._run_stage()
    937 # ----------------------------
    938 # POST-Training CLEAN UP
    939 # ----------------------------
    940 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py:978, in Trainer._run_stage(self)
    976         self._run_sanity_check()
    977     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
--> 978         self.fit_loop.run()
    979     return None
    980 raise RuntimeError(f"Unexpected state {self.state}")

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:201, in _FitLoop.run(self)
    199 try:
    200     self.on_advance_start()
--> 201     self.advance()
    202     self.on_advance_end()
    203     self._restarting = False

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:354, in _FitLoop.advance(self)
    352 self._data_fetcher.setup(combined_loader)
    353 with self.trainer.profiler.profile("run_training_epoch"):
--> 354     self.epoch_loop.run(self._data_fetcher)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py:133, in _TrainingEpochLoop.run(self, data_fetcher)
    131 while not self.done:
    132     try:
--> 133         self.advance(data_fetcher)
    134         self.on_advance_end()
    135         self._restarting = False

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/training_epoch_loop.py:218, in _TrainingEpochLoop.advance(self, data_fetcher)
    215 with trainer.profiler.profile("run_training_batch"):
    216     if trainer.lightning_module.automatic_optimization:
    217         # in automatic optimization, there can only be one optimizer
--> 218         batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
    219     else:
    220         batch_output = self.manual_optimization.run(kwargs)

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py:185, in _AutomaticOptimization.run(self, optimizer, kwargs)
    178         closure()
    180 # ------------------------------
    181 # BACKWARD PASS
    182 # ------------------------------
    183 # gradient update with accumulated gradients
    184 else:
--> 185     self._optimizer_step(kwargs.get("batch_idx", 0), closure)
    187 result = closure.consume_result()
    188 if result.loss is None:

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/automatic.py:261, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
    258     self.optim_progress.optimizer.step.increment_ready()
    260 # model hook
--> 261 call._call_lightning_module_hook(
    262     trainer,
    263     "optimizer_step",
    264     trainer.current_epoch,
    265     batch_idx,
    266     optimizer,
    267     train_step_and_backward_closure,
    268 )
    270 if not should_accumulate:
    271     self.optim_progress.optimizer.step.increment_completed()

File /opt/conda/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py:142, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
    139 pl_module._current_fx_name = hook_name
    141 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 142     output = fn(*args, **kwargs)
    144 # restore current_fx when nested context
    145 pl_module._current_fx_name = prev_fx_name

TypeError: SequenceLearner.optimizer_step() missing 1 required positional argument: 'closure'```
mlech26l commented 11 months ago

The pt_example.py was outdated. Please refer to https://ncps.readthedocs.io/en/latest/examples/torch_first_steps.html on how to use NCPs with Pytorch