Open holgerroth opened 1 year ago
Hi @vfdev-5 ,
Do you have any similar feature requirements in the ignite workflows?
Engine
users may want to get model weights when received some event.
Thanks in advance.
@Nic-Ma in case of ignite, model registration into state can be done manually if model instance is not present in the same scope:
trainer = Engine(train_step)
trainer.state.model = model
# Somewhere else we define a handler that is attached to the engine:
def my_handler(trainer):
model = trainer.state.model
# ...
On the other hand, handlers (as classes that are using explicitly model etc) can define themselves with model, e.g. Checkpoint({"model1": model1, "model2": model2}, "temp")
Currently, MonaiAlgo is using weights = get_state_dict(self.trainer.network)
. This will not work in case of GAN trainer where there are two models.
Hi @vfdev-5 @holgerroth ,
Thanks for sharing the information.
Let me think about it deeper, we may need to add some API in the trainer , get_network()
.
Thanks.
We could add a state_dict
/load_state_dict
pair of methods to Workflow
which by default would deal with a dictionary containing {"model": self.state.model.state_dict()}
. Subclasses of Workflow
would override these to include more than one model under known keys, but this would only be necessary for those that aren't straight-forward single model workflows.
I like @ericspod 's idea about state_dict
method, however such method IMO should output state dicts of everything: models, optimizer, trainer, lr_schedulers, amp, etc. If we want to limit the scope to models only we maybe could consider: Trainer.get_models() -> Dict[str, torch.nn.Module]
method (and also similary Trainer.get_optims()
etc).
I like @ericspod 's idea about
state_dict
method, however such method IMO should output state dicts of everything: models, optimizer, trainer, lr_schedulers, amp, etc. If we want to limit the scope to models only we maybe could consider:Trainer.get_models() -> Dict[str, torch.nn.Module]
method (and also similaryTrainer.get_optims()
etc).
@vfdev-5 Yes, other state is important. We could have state_dict
use parameters to control what to output with sensible defaults so that the behaviour best matches that in Pytorch.
Any update on this? We are looking at implementing an AdversarialTrainer and would prefer to have if future-proofed to a degree.
There hasn't been so we can discuss a solution for this later today.
@Nic-Ma, @ericspod, any updates on this topic?
Hi @holgerroth ,
I am designing a solution which is included in this draft PR: https://github.com/Project-MONAI/MONAI/pull/5822.
Thanks.
Is your feature request related to a problem? Please describe. Certain trainers, like GanTrainer use several models, while the standard SupervisedTrainer only uses one model.
Describe the solution you'd like There should be a standardized API (e.g., get_weights/model_states) that returns all model parameters. This would be useful to allow integration with other components, like MonaiAlgo for federated learning.
Having a standardized API would allow MonaiAlgo to easily support bundles that use Trainers other than SupervisedTrainer, e.g. mednist_gan.
Describe alternatives you've considered n/a
Additional context n/a