RasmusBrostroem / ConnectFourRL

0 stars 0 forks source link

Save and load agents #68

Closed jbirkesteen closed 1 year ago

jbirkesteen commented 1 year ago

closes #52 , addresses some of #54

Saves what is necessary for doing inference and for resuming a training. In short, the trained part of the model is saved minimally by use of the state_dict functionality in torch. This means that any agent with the same network structure can load the parameters. A similar approach is taken with the optimizer to make resuming of training possible, and to this end, the loss is also stored (this follows the general approach mentioned here: https://github.com/RasmusBrostroem/ConnectFourRL/issues/52#issuecomment-1325446243).

Implements

This PR implements DirectPolicyAgent.save_agent(), DirectPolicyAgent.load_network_weights() and also modifies Env.check_user_events() to make sure that model is saved in case of quitting the environment by mistake(see #54). Moreover, these methods are added to the parent Player class to ensure we can always call them on a player.
See the documentation of the methods for further details.

learned_weights are added to the .gitignore, since this is the default folder for save_agents(). I liked this name better than AgentParameters.

Finally, I've written a short note and a couple of proposed sections for the README.md.

Comments

The saving function currently needs the optimizer as input (just like update_agent() does) since the optimizer isn't directly accessible from within the agent object, but instead is a part of the training script running the agent. This feels kind of awkward, and we could consider changing it in the future.
This also gave some problems related to saving when the user quits the environment. For now, the state_dict of the optimizer is not saved in this case. This means that quitting will only save the state_dict of the model and the model's current loss along with the file_name_metadata.json, in which 'optim_name' will always be "NoneType". I chose this approach in order to not make it overly complicated.

Not stored

Some parts mentioned in https://github.com/RasmusBrostroem/ConnectFourRL/issues/52#issuecomment-1325446243 is not being stored.
Aside from the current loss, nothing from the log is being stored to disk. I decided to assume that users will always use neptune for logging (if they want to log), therefore it is redundant and unnecessary to also store this information on disk.
I've also chosen to not store the number of epochs/generations, as it is a) not directly accessible by the agent and b) not strictly necessary for resuming training.
If an user does not use neptune, they'll have to figure out how to log by themselves.

Not implemented

Loading the metadata, the optimizer state dict and the loss is not implemented. This needs to be implemented directly where relevant in training scripts/evaluation scripts, and should not be part of the agents themselves.

One thing mentioned in #54 which is not addressed here is the cleaning of attributes, which I have not implemented.
For the time being, the user needs to be cautious when loading on-quit objects. So the question is if we are satisfied with this implementation to the extent that we wish to close #54. We could let #54 remain open, if you think - one way to address it down the road would be to add a reset/clean-method to DirectPolicyAgent classes. What do you think?

Next up

I suggest that we merge refactor-project into main after this PR has been approved, since we are done with the refactoring (besides documentation).

jbirkesteen commented 1 year ago

Cool. Jeg har svaret på din kommentar - venter lige med at merge til jeg skal bruge det, det bliver nok ikke lige i dag 😅 Så kan du også nå at svare på min kommentar, hvis det er ❤️

jbirkesteen commented 1 year ago

Now, one can pass optimizer=None and still save state_dict of the model as well as loss_sum and metadata without saving the optimizer. ´check_user_events()` in Env.py still saves players with "on_quit" as part of the filename to reflect data might be corrupted.