Project-MONAI / MONAI

AI Toolkit for Healthcare Imaging
https://monai.io/
Apache License 2.0
5.76k stars 1.06k forks source link

Standardize the way of getting model weights from Trainer #5371

Open holgerroth opened 1 year ago

holgerroth commented 1 year ago

Is your feature request related to a problem? Please describe. Certain trainers, like GanTrainer use several models, while the standard SupervisedTrainer only uses one model.

Describe the solution you'd like There should be a standardized API (e.g., get_weights/model_states) that returns all model parameters. This would be useful to allow integration with other components, like MonaiAlgo for federated learning.

Having a standardized API would allow MonaiAlgo to easily support bundles that use Trainers other than SupervisedTrainer, e.g. mednist_gan.

Describe alternatives you've considered n/a

Additional context n/a

Nic-Ma commented 1 year ago

Hi @vfdev-5 ,

Do you have any similar feature requirements in the ignite workflows? Engine users may want to get model weights when received some event.

Thanks in advance.

vfdev-5 commented 1 year ago

@Nic-Ma in case of ignite, model registration into state can be done manually if model instance is not present in the same scope:

trainer = Engine(train_step)

trainer.state.model = model

# Somewhere else we define a handler that is attached to the engine:

def my_handler(trainer):
    model = trainer.state.model
    # ...

On the other hand, handlers (as classes that are using explicitly model etc) can define themselves with model, e.g. Checkpoint({"model1": model1, "model2": model2}, "temp")

holgerroth commented 1 year ago

Currently, MonaiAlgo is using weights = get_state_dict(self.trainer.network). This will not work in case of GAN trainer where there are two models.

Nic-Ma commented 1 year ago

Hi @vfdev-5 @holgerroth ,

Thanks for sharing the information. Let me think about it deeper, we may need to add some API in the trainer , get_network().

Thanks.

ericspod commented 1 year ago

We could add a state_dict/load_state_dict pair of methods to Workflow which by default would deal with a dictionary containing {"model": self.state.model.state_dict()}. Subclasses of Workflow would override these to include more than one model under known keys, but this would only be necessary for those that aren't straight-forward single model workflows.

vfdev-5 commented 1 year ago

I like @ericspod 's idea about state_dict method, however such method IMO should output state dicts of everything: models, optimizer, trainer, lr_schedulers, amp, etc. If we want to limit the scope to models only we maybe could consider: Trainer.get_models() -> Dict[str, torch.nn.Module] method (and also similary Trainer.get_optims() etc).

ericspod commented 1 year ago

I like @ericspod 's idea about state_dict method, however such method IMO should output state dicts of everything: models, optimizer, trainer, lr_schedulers, amp, etc. If we want to limit the scope to models only we maybe could consider: Trainer.get_models() -> Dict[str, torch.nn.Module] method (and also similary Trainer.get_optims() etc).

@vfdev-5 Yes, other state is important. We could have state_dict use parameters to control what to output with sensible defaults so that the behaviour best matches that in Pytorch.

danieltudosiu commented 1 year ago

Any update on this? We are looking at implementing an AdversarialTrainer and would prefer to have if future-proofed to a degree.

ericspod commented 1 year ago

There hasn't been so we can discuss a solution for this later today.

holgerroth commented 1 year ago

@Nic-Ma, @ericspod, any updates on this topic?

Nic-Ma commented 1 year ago

Hi @holgerroth ,

I am designing a solution which is included in this draft PR: https://github.com/Project-MONAI/MONAI/pull/5822.

Thanks.