graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
90 stars 92 forks source link

Fix save_state_dict #645

Closed AMHermansen closed 9 months ago

AMHermansen commented 10 months ago

Current implementation moves the entire model to cpu, whenever save_state_dict is called. This seems like an undesireable side effect of the method. This PR changes save_state_dict to only save the statedict, but keeps the model on the current device.

RasmusOrsoe commented 9 months ago

@AMHermansen thanks for this suggestion. I'm not sure we want this change though - If the state dict is saved to disk from gpu, it will require the model to be on gpu when the state dict is loaded in again, or an error will be thrown. see https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html .

Has this been a issue for you?

AMHermansen commented 9 months ago

@AMHermansen thanks for this suggestion. I'm not sure we want this change though - If the state dict is saved to disk from gpu, it will require the model to be on gpu when the state dict is loaded in again, or an error will be thrown. see https://pytorch.org/tutorials/recipes/recipes/save_load_across_devices.html .

Has this been a issue for you?

The suggested change in this PR only removes the side effect from the current save_state_dict implementation to not move the model to cpu when this is called. This is done by copying the state_dict to cpu and then saving the copy. The reason for this implementation is to make saving models more streamlined, my current understanding from the example scripts is that save_model_config and save_state_dict is the intended way to save graphnet models. If you however want to save a model like this during training, you will run into problems, since the model will be moved away from the accelerator.