cellarium-ai / cellarium-ml

Distributed single-cell data analysis.
BSD 3-Clause "New" or "Revised" License
22 stars 3 forks source link

CPU transforms on dataloader workers #223

Closed sjfleming closed 3 months ago

sjfleming commented 3 months ago

Small modifications to allow some optional set of transforms to be performed on CPU by dataloader workers. Main use case would be the Filter transform, but it's possible that other transforms like NormalizeTotal and Log1p are so fast that they can be performed on CPU by dataloader workers as well, freeing up the GPU to just crank away on the model.

This adds the input cpu_transforms to CellariumModule. If a CellariumModule is being used in the context of pytorch lightning training (i.e. if it hasattr(self, "trainer")) and if it is using a datamodule that is a CellariumAnnDataDataModule, then it will exclude the cpu_transforms from its list of applied transforms when calling training_step(), forward(), and validation_step().

sjfleming commented 3 months ago

@ordabayevy This is ready to take a look at, I think. There's something real ugly going on with a try / except at the end of CellariumModule.cofigure_model(), but other than that, I think I like it. Not sure what's the right way to check for a Trainer.

ordabayevy commented 3 months ago

Let's consider the scenario when you load the module from the checkpoint in the Jupyter Notebook:

load_from_checkpoint does three things:

  1. Initialize CellariumModule using hparams in the checkpoint. At this step self._lightning_training_using_datamodule will be set to False.
  2. Call configure_model. At this step the module will not detect the trainer and will keep self._lightning_training_using_datamodule=False and will not dispatch cpu_transforms to the datamodule.
  3. Call load_state_dict. Nothing interesting happens here.

So in this scenario cpu_transforms will be run by the CellariumModule.

ordabayevy commented 3 months ago

If you use .fit(..., ckpt_path=ckpt_path) or .predict(..., ckpt_path=ckpt_path) then the steps are:

  1. CellariumModule is initialized by the user or from the config file. At this step self._lightning_training_using_datamodule=False and configure_model hasn't been called yet.
  2. The module gets attached to the trainer.
  3. Call configure_model. At this step cpu_transforms is dispatched to datamodule and self._lightning_training_using_datamodule=True
  4. Call load_state_dict. Nothing interesting.

This case is similar to calling .fit(...) without ckpt_path argument.

sjfleming commented 3 months ago

I should probably add a test to ensure that what you say is happening. I don't think I am explicitly checking yet for all of those cases.

sjfleming commented 3 months ago

Hm, I think there is a problem I've caught actually... if there are cpu_transforms, then CellariumModule.load_from_checkpoint(ckpt_path) will throw a RuntimeError in tests using BoringModel:

E           RuntimeError: Error(s) in loading state_dict for CellariumModule:
E               Missing key(s) in state_dict: "module_pipeline.3._dummy_param". 
E               Unexpected key(s) in state_dict: "module_pipeline.2._dummy_param".

To me, it looks like this is due to different ModuleList numbering of the module associated with the model.

ordabayevy commented 3 months ago

A solution can be to have module_pipeline and datamodule_pipeline as a property which will slice and return the pipeline on the fly. Then it shouldn't show up in CellariumModules children(). state_dict should only save pipeline in the checkpoint.

sjfleming commented 3 months ago

Expanded test_cpu_transforms to cover checkpoint loading, both manually and via Trainer.fit(..., ckpt_path). Getting the additional tests to pass did require making module_pipeline a @property of CellariumModule, rather than an object attribute.

ordabayevy commented 3 months ago

@sjfleming there is one more thing that needs to be fixed. cpu_transforms needs to be passed to cellarium.ml.cli.compute_var_names_g:

def compute_var_names_g(cpu_transforms: list[torch.nn.Module] | None, transforms: list[torch.nn.Module] | None, data: CellariumAnnDataDataModule) -> np.ndarray:
    adata = data.dadc[0]
    batch = {key: field(adata) for key, field in data.batch_keys.items()}
    pipeline = CellariumPipeline(cpu_transforms) + CellariumPipeline(transforms)

and add cpu_transforms to LinkArguments:

LinkArguments(("model.cpu_transforms", "model.transforms", "data"), "model.model.init_args.var_names_g", compute_var_names_g)

If we move Filter to cpu_transforms below it should fail without the changes above:

https://github.com/cellarium-ai/cellarium-ml/blob/main/tests/test_cli.py#L295-L312

ordabayevy commented 3 months ago

@sjfleming is it okay if I push the changes related to compute_var_names_g?

sjfleming commented 3 months ago

Thanks for catching this @ordabayevy ! I'll add the changes now

sjfleming commented 3 months ago

@ordabayevy I am seeing some unexpected behavior that's leading me to implement it using this inside compute_var_names_g():

    pipeline = CellariumPipeline((cpu_transforms if cpu_transforms else []) + (transforms if transforms else []))

rather than this

    pipeline = CellariumPipeline(cpu_transforms) + CellariumPipeline(transforms)

The issue is that the second implementation complains that "forward" is not implemented. Here's the minimal demo:

>>> from cellarium.ml import CellariumPipeline
>>> m = CellariumPipeline(None) + CellariumPipeline(None)
>>> m.forward()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/sfleming/miniconda3/envs/cellarium/lib/python3.10/site-packages/torch/nn/modules/module.py", line 374, in _forward_unimplemented
    raise NotImplementedError(f"Module [{type(self).__name__}] is missing the required \"forward\" function")
NotImplementedError: Module [ModuleList] is missing the required "forward" function

but this is fine as expected:

>>> from cellarium.ml import CellariumPipeline
>>> m = CellariumPipeline(None)
>>> m.forward(None)

(runs fine)

I think this points at a bug, where we really need to override __add__ in CellariumPipeline or something? Right now, it seems that the behavior of adding CellariumPipelines is not properly handled. Maybe a separate issue?

ordabayevy commented 3 months ago

Sounds good.

I think this points at a bug, where we really need to override add in CellariumPipeline or something? Right now, it seems that the behavior of adding CellariumPipelines is not properly handled. Maybe a separate issue?

Agree