dllllb / pytorch-lifestream

A library built upon PyTorch for building embeddings on discrete event sequences using self-supervision
Apache License 2.0
215 stars 46 forks source link

Error computing torch metrics #151

Open dzrlva opened 5 months ago

dzrlva commented 5 months ago

I’m using MultiModalSupervisedDataset for data, SequenceToTarget for model initialization. Getting this warning: UserWarning: The compute method of metric AUROC was called before the update method which may lead to errors, as metric states have not yet been updated. warnings.warn(*args, **kwargs) Can someone help ?
 I'll be glad if you provide me demo for MultiModalSupervisedDataset, SequenceToTarget classes.

Full error log:

ValueError                                Traceback (most recent call last) Input In [52], in <cell line: 2>()       1 print(f'logger.version = {pl_trainer.logger.version}') ----> 2 pl_trainer.fit(model, train_loader)       3 print(pl_trainer.logged_metrics)   File ~/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:770, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)     751 r"""     752 Runs the full optimization routine.     753    (...)     767     datamodule: An instance of :class:~pytorch_lightning.core.datamodule.LightningDataModule.     768 """    769 self.strategy.model = model --> 770 self._call_and_handle_interrupt(     771     self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path     772 )   File ~/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:723, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, kwargs)     721         return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, *kwargs)     722     else: --> 723         return trainer_fn(args, kwargs)     724 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7     725 except KeyboardInterrupt as exception:   File ~/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:811, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)     807 ckpt_path = ckpt_path or self.resume_from_checkpoint     808 self._ckpt_path = self.__set_ckpt_path(     809     ckpt_path, model_provided=True, model_connected=self.lightning_module is not None     810 ) --> 811 results = self._run(model, ckpt_path=self.ckpt_path)     813 assert self.state.stopped     814 self.training = False   File ~/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1236, in Trainer._run(self, model, ckpt_path)    1233 self._checkpoint_connector.restore_training_state()    1235 self._checkpoint_connector.resume_end() -> 1237 results = self._run_stage()    1238 log.detail(f"{self.class.name}: trainer tearing down")   File ~/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1323, in Trainer._run_stage(self)    1321 if self.predicting:    1322     return self._run_predict() -> 1323 return self._run_train()   File ~/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1345, in Trainer._run_train(self)    1342 self._pre_training_routine()    1344 with isolate_rng(): -> 1345     self._run_sanity_check()    1347 # enable train mode    1348 self.model.train()   File ~/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1413, in Trainer._run_sanity_check(self)    1411 # run eval step    1412 with torch.no_grad(): -> 1413     val_loop.run()    1415 self._call_callback_hooks("on_sanity_check_end")    1417 # reset logger connector   File ~/.local/lib/python3.9/site-packages/pytorch_lightning/loops/base.py:211, in Loop.run(self, *args, kwargs)     208         break     209 self._restarting = False --> 211 output = self.on_run_end()     212 return output   File ~/.local/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:183, in EvaluationLoop.on_run_end(self)     180 self.trainer._logger_connector.epoch_end_reached()     182 # hook --> 183 self._evaluation_epoch_end(self._outputs)     184 self._outputs = []  # free memory     186 # hook   File ~/.local/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:311, in EvaluationLoop._evaluation_epoch_end(self, outputs)     309     self.trainer._call_lightning_module_hook("test_epoch_end", output_or_outputs)     310 else: --> 311     self.trainer._call_lightning_module_hook("validation_epoch_end", output_or_outputs)   File ~/.local/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1595, in Trainer._call_lightning_module_hook(self, hook_name, pl_module, *args, *kwargs)    1592 pl_module._current_fx_name = hook_name    1594 with self.profiler.profile(f"[LightningModule]{pl_module.class.name}.{hook_name}"): -> 1595     output = fn(args, kwargs)    1597 # restore current_fx when nested context    1598 pl_module._current_fx_name = prev_fx_name   File ~/.local/lib/python3.9/site-packages/ptls/frames/supervised/seq_to_target.py:148, in SequenceToTarget..validation_epoch_end(self, outputs)     147 for name, mf in self.valid_metrics.items(): --> 148     self.log(f'valid/{name}', mf.compute(), prog_bar=True)     149 for name, mf in self.valid_metrics.items():     150     mf.reset()   File ~/.local/lib/python3.9/site-packages/torchmetrics/metric.py:531, in Metric._wrap_compute..wrapped_func(*args, *kwargs)     523 # compute relies on the sync context manager to gather the states across processes and apply reduction     524 # if synchronization happened, the current rank accumulated states will be restored to keep     525 # accumulation going if should_unsync=True,     526 with self.sync_context(     527     dist_sync_fn=self.dist_sync_fn,  # type: ignore     528     should_sync=self._to_sync,     529     should_unsync=self._should_unsync,     530 ): --> 531     value = compute(args, **kwargs)     532     self._computed = _squeeze_if_scalar(value)     534 return self._computed   File ~/.local/lib/python3.9/site-packages/torchmetrics/classification/auroc.py:172, in AUROC.compute(self)     171     raise RuntimeError("You have to have determined mode.") --> 172 preds = dim_zero_cat(self.preds)     173 target = dim_zero_cat(self.target)     174 return _auroc_compute(     175     self.preds,     176     self.target,    (...)     181     self.max_fpr,     182 )   File ~/.local/lib/python3.9/site-packages/torchmetrics/utilities/data.py:41, in dim_zero_cat(x)      39 x = [y.unsqueeze(0) if y.numel() == 1 and y.ndim == 0 else y for y in x]      40 if not x:  # empty list ---> 41     raise ValueError("No samples to concatenate")      42 return torch.cat(x, dim=0)   ValueError: No samples to concatenate