Closed sjfleming closed 3 months ago
@ordabayevy This is ready to take a look at, I think. There's something real ugly going on with a try / except
at the end of CellariumModule.cofigure_model()
, but other than that, I think I like it. Not sure what's the right way to check for a Trainer
.
Let's consider the scenario when you load the module from the checkpoint in the Jupyter Notebook:
load_from_checkpoint
does three things:
CellariumModule
using hparams
in the checkpoint. At this step self._lightning_training_using_datamodule
will be set to False
.configure_model
. At this step the module will not detect the trainer and will keep self._lightning_training_using_datamodule=False
and will not dispatch cpu_transforms
to the datamodule.load_state_dict
. Nothing interesting happens here.So in this scenario cpu_transforms
will be run by the CellariumModule
.
If you use .fit(..., ckpt_path=ckpt_path)
or .predict(..., ckpt_path=ckpt_path)
then the steps are:
CellariumModule
is initialized by the user or from the config file. At this step self._lightning_training_using_datamodule=False
and configure_model
hasn't been called yet.configure_model
. At this step cpu_transforms
is dispatched to datamodule and self._lightning_training_using_datamodule=True
load_state_dict
. Nothing interesting.This case is similar to calling .fit(...)
without ckpt_path
argument.
I should probably add a test to ensure that what you say is happening. I don't think I am explicitly checking yet for all of those cases.
Hm, I think there is a problem I've caught actually... if there are cpu_transforms
, then CellariumModule.load_from_checkpoint(ckpt_path)
will throw a RuntimeError
in tests using BoringModel
:
E RuntimeError: Error(s) in loading state_dict for CellariumModule:
E Missing key(s) in state_dict: "module_pipeline.3._dummy_param".
E Unexpected key(s) in state_dict: "module_pipeline.2._dummy_param".
To me, it looks like this is due to different ModuleList numbering of the module associated with the model.
A solution can be to have module_pipeline
and datamodule_pipeline
as a property which will slice and return the pipeline on the fly. Then it shouldn't show up in CellariumModule
s children()
. state_dict
should only save pipeline
in the checkpoint.
Expanded test_cpu_transforms
to cover checkpoint loading, both manually and via Trainer.fit(..., ckpt_path)
. Getting the additional tests to pass did require making module_pipeline
a @property
of CellariumModule
, rather than an object attribute.
@sjfleming there is one more thing that needs to be fixed. cpu_transforms
needs to be passed to cellarium.ml.cli.compute_var_names_g
:
def compute_var_names_g(cpu_transforms: list[torch.nn.Module] | None, transforms: list[torch.nn.Module] | None, data: CellariumAnnDataDataModule) -> np.ndarray:
adata = data.dadc[0]
batch = {key: field(adata) for key, field in data.batch_keys.items()}
pipeline = CellariumPipeline(cpu_transforms) + CellariumPipeline(transforms)
and add cpu_transforms
to LinkArguments
:
LinkArguments(("model.cpu_transforms", "model.transforms", "data"), "model.model.init_args.var_names_g", compute_var_names_g)
If we move Filter
to cpu_transforms
below it should fail without the changes above:
https://github.com/cellarium-ai/cellarium-ml/blob/main/tests/test_cli.py#L295-L312
@sjfleming is it okay if I push the changes related to compute_var_names_g
?
Thanks for catching this @ordabayevy ! I'll add the changes now
@ordabayevy I am seeing some unexpected behavior that's leading me to implement it using this inside compute_var_names_g()
:
pipeline = CellariumPipeline((cpu_transforms if cpu_transforms else []) + (transforms if transforms else []))
rather than this
pipeline = CellariumPipeline(cpu_transforms) + CellariumPipeline(transforms)
The issue is that the second implementation complains that "forward" is not implemented. Here's the minimal demo:
>>> from cellarium.ml import CellariumPipeline
>>> m = CellariumPipeline(None) + CellariumPipeline(None)
>>> m.forward()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/sfleming/miniconda3/envs/cellarium/lib/python3.10/site-packages/torch/nn/modules/module.py", line 374, in _forward_unimplemented
raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
NotImplementedError: Module [ModuleList] is missing the required "forward" function
but this is fine as expected:
>>> from cellarium.ml import CellariumPipeline
>>> m = CellariumPipeline(None)
>>> m.forward(None)
(runs fine)
I think this points at a bug, where we really need to override __add__
in CellariumPipeline
or something? Right now, it seems that the behavior of adding CellariumPipeline
s is not properly handled. Maybe a separate issue?
Sounds good.
I think this points at a bug, where we really need to override add in CellariumPipeline or something? Right now, it seems that the behavior of adding CellariumPipelines is not properly handled. Maybe a separate issue?
Agree
Small modifications to allow some optional set of transforms to be performed on CPU by dataloader workers. Main use case would be the
Filter
transform, but it's possible that other transforms likeNormalizeTotal
andLog1p
are so fast that they can be performed on CPU by dataloader workers as well, freeing up the GPU to just crank away on the model.This adds the input
cpu_transforms
toCellariumModule
. If aCellariumModule
is being used in the context of pytorch lightning training (i.e. if ithasattr(self, "trainer")
) and if it is using a datamodule that is aCellariumAnnDataDataModule
, then it will exclude thecpu_transforms
from its list of applied transforms when callingtraining_step()
,forward()
, andvalidation_step()
.