Open conradkun opened 5 months ago
Hi Conrad,
Thank you for your feedback!
Firstly, would you mind sharing a code snippet of what you have so far so we can have a clearer idea of the problem?
This discussion on the pytorch-lightning repo seems to suggest that calling trainer.predict
in the Callback
hook on_validation_epoch_end
shouldn't cause any issues. However, it is over a year old and might be out of date.
We are currently in the process of removing the image stitching logic from the prediction loop and implementing it as a separate function. This should be merged with the main this week. It might make logging predictions throughout training easier. So watch this space!
Some solutions we will consider implementing in the future might include leveraging existing logging frameworks such as Weights & Biases. See their lightning integration docs, where they implement an example LogPredictionSamplesCallback
class.
Melisande
Hi Melisande,
Thank you for reply! Shortly after submitting my issue I came across issue #141, and indeed I am almost certain that it would solve the problem I am having.
I was basing my solution on the PyTorch Lightning discussion you shared, but I cannot do it exactly as them since, as it stands, I need to use the prediction loop in order to stitch the images together (and all solutions I could see involve calling predict_step
). This should not be an issue anymore after the refactoring. In any case, I will wait until it is all merged before I share the (fairly huge) code snippet needed to reproduce it.
The W&B Callback you shared does pretty much the same I am doing with mine, but I must use TensorBoard with my project, unfortunately. Do you have any plans to allow custom Callbacks being passed to the CAREamist
class? (as of now I am just redefining _define_callbacks
which is a bit ugly)
Hi @conradkun,
Passing custom callbacks
to the CAREamist
is a great idea! We will actually need this very soon for the napari
plugin, so here ya go: https://github.com/CAREamics/careamics/pull/150
We have not worked much on the loggers so far, but we will try to have some support of both TensorBoard and WandB, with examples of useful cases. Feedback and suggestions always welcome!
EDIT: Fixing link
EDIT2: Regarding the TensorBoard callback, we actually have a way to define TensorBoard as the logger (this is set in the Configuration.training_config.logger
), we have not really tested it so if you notice something necessary to make it work that is missing. Let us know!!
Hi @conradkun,
We no longer use a custom prediction loop, and instead, the outputs of lightning's Trainer.predict
are converted at the end of CAREamist.predict
with the new function careamics.prediction_utils.convert_output
. convert_output
needs to know whether the prediction outputs are tiled and this is determined in the initialisation of CAREamicsPredictData
; and it's creation logic has been moved to the function careamics.prediction_utils.create_pred_datamodule
. You will see this implemented in the CAREamics.predict
method.
Do not hesitate to get in contact with any questions!
Perfect, thanks for the heads up! I will reimplement what I had using the new setup soon and let you know if I come across any issues. Just saw you also merged the custom callbacks branch, so my life is even simpler! Thank you for the quick work.
Unfortunately, I was not able to make it work, so here I go.
What do I want to achieve?
Given a CAREamicsPredictData
datamodule, I want to be able to predict on it at specific times: at the end of each training epoch, at the end of X number of training steps, etc... The best tool for this is PyTorch Lightning's Callback
.
I want to use a datamodule that is detached from both the training and the validation datasets found in the CAREamicsTrainData
for a number of reasons:
CAREamicsPredictData
would be ideal.I think such functionality is really useful in a model like n2v where metrics are not entirely representative of the perceived performance. Picture a UI where a user is training on an image and selects a patch to see how it develops as the training progresses; this real time feedback would be great to give the user an idea of what's going on and whether training longer is a good idea, etc...
What have I tried?
With the new Callback
functionality, I can pass directly the following to the CAREamist
class (assuming some config
):
class CustomPredictAfterValidationCallback(Callback):
def __init__(self, pred_datamodule):
self.pred_datamodule = pred_datamodule
def setup(self, trainer, pl_module, stage):
if stage in ("fit", "validate"):
# setup the predict data for fit/validate, as we will call it during `on_validation_epoch_end`
# not sure if needed, but doesn't hurt until I get it to work
self.pred_datamodule.prepare_data()
self.pred_datamodule.setup("predict")
def on_validation_epoch_end(self, trainer, pl_module):
if trainer.sanity_checking: # optional skip
return
predictions = trainer.predict(model=pl_module, datamodule=self.pred_datamodule)
return convert_outputs(predictions, self.pred_datamodule.tiled)
pred_datamodule = create_pred_datamodule(
source="image.tiff",
config=config
)
predict_after_val_callback = CustomPredictAfterValidationCallback(pred_datamodule=pred_datamodule)
engine = CAREamist(config, callbacks=[predict_after_val_callback])
For some reason, PyTorch lightning really does not like this setup. In particular, I get the following error after trying to call fit
:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
Cell In[32], [line 1](vscode-notebook-cell:?execution_count=32&line=1)
----> [1](vscode-notebook-cell:?execution_count=32&line=1) engine.train(datamodule=data_module)
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:322, in CAREamist.train(self, datamodule, train_source, val_source, train_target, val_target, use_in_memory, val_percentage, val_minimum_split)
[320](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:320) # train
[321](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:321) if datamodule is not None:
--> [322](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:322) self._train_on_datamodule(datamodule=datamodule)
[324](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:324) else:
[325](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:325) # raise error if target is provided to N2V
[326](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:326) if self.cfg.algorithm_config.algorithm == SupportedAlgorithm.N2V.value:
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:394, in CAREamist._train_on_datamodule(self, datamodule)
[391](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:391) # record datamodule
[392](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:392) self.train_datamodule = datamodule
--> [394](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/careamics/careamist.py:394) self.trainer.fit(self.model, datamodule=datamodule)
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
[542](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:542) self.state.status = TrainerStatus.RUNNING
[543](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:543) self.training = True
--> [544](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:544) call._call_and_handle_interrupt(
[545](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:545) self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
[546](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:546) )
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
[42](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:42) if trainer.strategy.launcher is not None:
[43](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:43) return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> [44](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:44) return trainer_fn(*args, **kwargs)
[46](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:46) except _TunerExitException:
[47](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py:47) _call_teardown_hook(trainer)
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
[573](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:573) assert self.state.fn is not None
[574](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:574) ckpt_path = self._checkpoint_connector._select_ckpt_path(
[575](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:575) self.state.fn,
[576](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:576) ckpt_path,
[577](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:577) model_provided=True,
[578](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:578) model_connected=self.lightning_module is not None,
[579](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:579) )
--> [580](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:580) self._run(model, ckpt_path=ckpt_path)
[582](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:582) assert self.state.stopped
[583](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:583) self.training = False
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:987, in Trainer._run(self, model, ckpt_path)
[982](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:982) self._signal_connector.register_signal_handlers()
[984](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:984) # ----------------------------
[985](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:985) # RUN THE TRAINER
[986](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:986) # ----------------------------
--> [987](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:987) results = self._run_stage()
[989](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:989) # ----------------------------
[990](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:990) # POST-Training CLEAN UP
[991](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:991) # ----------------------------
[992](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:992) log.debug(f"{self.__class__.__name__}: trainer tearing down")
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1033, in Trainer._run_stage(self)
[1031](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1031) self._run_sanity_check()
[1032](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1032) with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> [1033](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1033) self.fit_loop.run()
[1034](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1034) return None
[1035](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py:1035) raise RuntimeError(f"Unexpected state {self.state}")
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:205, in _FitLoop.run(self)
[203](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:203) try:
[204](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:204) self.on_advance_start()
--> [205](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:205) self.advance()
[206](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:206) self.on_advance_end()
[207](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:207) self._restarting = False
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:363, in _FitLoop.advance(self)
[361](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:361) with self.trainer.profiler.profile("run_training_epoch"):
[362](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:362) assert self._data_fetcher is not None
--> [363](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:363) self.epoch_loop.run(self._data_fetcher)
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:141, in _TrainingEpochLoop.run(self, data_fetcher)
[139](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:139) try:
[140](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:140) self.advance(data_fetcher)
--> [141](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:141) self.on_advance_end(data_fetcher)
[142](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:142) self._restarting = False
[143](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:143) except StopIteration:
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:295, in _TrainingEpochLoop.on_advance_end(self, data_fetcher)
[291](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:291) if not self._should_accumulate():
[292](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:292) # clear gradients to not leave any unused memory during validation
[293](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:293) call._call_lightning_module_hook(self.trainer, "on_validation_model_zero_grad")
--> [295](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:295) self.val_loop.run()
[296](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:296) self.trainer.training = True
[297](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py:297) self.trainer._logger_connector._first_loop_iter = first_loop_iter
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:182, in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
[180](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:180) context_manager = torch.no_grad
[181](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:181) with context_manager():
--> [182](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py:182) return loop_run(self, *args, **kwargs)
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:142, in _EvaluationLoop.run(self)
[140](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:140) self._restarting = False
[141](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:141) self._store_dataloader_outputs()
--> [142](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:142) return self.on_run_end()
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:254, in _EvaluationLoop.on_run_end(self)
[251](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:251) self.trainer._logger_connector._evaluation_epoch_end()
[253](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:253) # hook
--> [254](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:254) self._on_evaluation_epoch_end()
[256](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:256) logged_outputs, self._logged_outputs = self._logged_outputs, [] # free memory
[257](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:257) # include any logged outputs on epoch_end
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:336, in _EvaluationLoop._on_evaluation_epoch_end(self)
[333](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:333) call._call_callback_hooks(trainer, hook_name)
[334](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:334) call._call_lightning_module_hook(trainer, hook_name)
--> [336](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py:336) trainer._logger_connector.on_epoch_end()
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:195, in _LoggerConnector.on_epoch_end(self)
[193](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:193) def on_epoch_end(self) -> None:
[194](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:194) assert self._first_loop_iter is None
--> [195](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:195) metrics = self.metrics
[196](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:196) self._progress_bar_metrics.update(metrics["pbar"])
[197](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:197) self._callback_metrics.update(metrics["callback"])
File ~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:233, in _LoggerConnector.metrics(self)
[231](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:231) """This function returns either batch or epoch metrics."""
[232](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:232) on_step = self._first_loop_iter is not None
--> [233](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:233) assert self.trainer._results is not None
[234](https://file+.vscode-resource.vscode-cdn.net/Users/c.cardona/Documents/Projects/DL/noise2void/CAREamics/notebooks/~/miniconda3/envs/dl/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:234) return self.trainer._results.metrics(on_step)
AssertionError:
I found a post with a similar error from April, which suggested to use predict_step
instead of predict
directly on the Lightning Module. Incidentally, that is also what the aforementioned discussion from the Lightning forums seems to converge towards. So I tried version 2:
class CustomPredictAfterValidationCallback(Callback):
def __init__(self, pred_datamodule):
self.pred_datamodule = pred_datamodule
def setup(self, trainer, pl_module, stage):
if stage in ("fit", "validate"):
# setup the predict data for fit/validate, as we will call it during `on_validation_epoch_end`
# not sure if needed, but doesn't hurt until I get it to work
self.pred_datamodule.prepare_data()
self.pred_datamodule.setup("predict")
def on_validation_epoch_end(self, trainer, pl_module):
if trainer.sanity_checking: # optional skip
return
# not entirely sure about how preds are returned (and how they must be concatenated), take as pseudocode
predictions = []
for batch, idx in enumerate(self.pred_datamodule.predict_dataloader()):
preds = pl_module.predict_step(batch, idx) # breaks here
predictions += preds
return convert_outputs(predictions, self.pred_datamodule.tiled)
The problem with this approach is how the predict_step
function in CAREamicsModule
has changed:
Referencing _trainer
directly means that it will be looking at the CAREamicsTrainData
used for training, and not at the CAREamicsPredictData
the current batch comes from. And training modules do not have a .tiled
attribute.
What to do?
Somehow, predict_step
should be getting the tiled information from the datamodule that is yielding the current batches. Any ideas?
Thanks for sharing your code!
We've investigated a little and identified what is basically preventing this approach. We will have a deeper look in the next weeks to see if some refactoring would make the current code base compatible with a prediction callback!
That's great to hear, thank you very much!
Hi both,
I know there has been some work done towards this feature, is it fully working yet? Let me know if I could help implementing it, otherwise :)
Hi Conrad! We are hosting I2K so we are a bit overwhelmed at the moment. Let us come back to you beginning of November!
Hi @conradkun,
Sorry for the delay, life's always more busy than expected.
I had a go at it and came up with a hacky way to make it fit together (https://github.com/CAREamics/careamics/actions/runs/11783874603?pr=266), in essence it looks like this:
import numpy as np
from pytorch_lightning import Callback, Trainer
from careamics import CAREamist, Configuration
from careamics.lightning import PredictDataModule, create_predict_datamodule
from careamics.prediction_utils import convert_outputs
config = Configuration(**minimum_configuration)
class CustomPredictAfterValidationCallback(Callback):
def __init__(self, pred_datamodule: PredictDataModule):
self.pred_datamodule = pred_datamodule
# prepare data and setup
self.pred_datamodule.prepare_data()
self.pred_datamodule.setup()
self.pred_dataloader = pred_datamodule.predict_dataloader()
self.data = None
def on_validation_epoch_end(self, trainer: Trainer, pl_module):
if trainer.sanity_checking: # optional skip
return
# update statistics in the prediction dataset for coherence
# (they can computed on-line by the training dataset)
self.pred_datamodule.predict_dataset.image_means = (
trainer.datamodule.train_dataset.image_stats.means
)
self.pred_datamodule.predict_dataset.image_stds = (
trainer.datamodule.train_dataset.image_stats.stds
)
# predict on the dataset
outputs = []
for idx, batch in enumerate(self.pred_dataloader):
batch = pl_module.transfer_batch_to_device(batch, pl_module.device, 0)
outputs.append(pl_module.predict_step(batch, batch_idx=idx))
self.data = convert_outputs(outputs, self.pred_datamodule.tiled)
# save data here
array = np.arange(32 * 32).reshape((32, 32))
pred_datamodule = create_predict_datamodule(
pred_data=array,
data_type=config.data_config.data_type,
axes=config.data_config.axes,
image_means=[11.8], # random placeholder
image_stds=[3.14],
# can choose tiling here
)
predict_after_val_callback = CustomPredictAfterValidationCallback(
pred_datamodule=pred_datamodule
)
engine = CAREamist(config, callbacks=[predict_after_val_callback])
engine.train(train_source=array)
This raised a few issues pertaining to the logic of CAREamics that I detail in the PR. We are undergoing a dataset refactoring step, and we will consider the aforementioned points during the refactoring.
Regarding the PR, I need to test it with a real example before merging it (providing that it passes review), as I only made it run through a simple test! You are welcome to try it and comment in the PR, hoping that I am not making you try a faulty example.
We would also be super interested in knowing how and why you are using CAREamics! And if you ever want to use it in production, we would be happy to potentially integrate specific integration tests to make sure the API is stable for your use cases.
edit: Add comments to the code
Hi! I recently came across this repo and it looks very promising for my use case!
As I was playing around I wanted to see how the predictions improved throughout the epochs. Ideally, I would like to have a small separate dataset on which the model could be run after each training epoch, or just a random draw of some images from the validation dataset.
I tried to implement this myself with Pytorch lightning
Callbacks
, but I don't see a clear way to get around having to calltrainer.predict
inside the callback, and I fear that messes up thetrainer.fit
loop by deleting the validation losses it keeps track of.Given that you may have a lot more intuition of how CAREamics works, do you have an idea of something that would work? I am happy to implement it myself and create a PR, but currently I do not see how I can avoid calling
training.predict
since the prediction loop is modified to stitch the image together.Thank you,
Conrad