google-research / batch_rl

Offline Reinforcement Learning (aka Batch Reinforcement Learning) on Atari 2600 games
https://offline-rl.github.io/
Apache License 2.0
536 stars 75 forks source link

Save the trained model to hdf5 file format #6

Closed 2bben closed 4 years ago

2bben commented 4 years ago

Hi,

I'm trying to make the trained model made by the offline agents to work with my online environment, which is written in Golang and loads models from hdf5 files. But when I'm looking at the source code from this repo, I can't seem to find a way to do this easily.

Is there an "easy" way to save the trained model as a hdf5 file, not just checkpoints?

agarwl commented 4 years ago

I don't understand the question fully but Keras models used in the repo allow for easy saving to hdf5 format.

Clarification: What is meant by the model which is different from the network checkpoints?

2bben commented 4 years ago

What I meant by the "model" is the trained weights, but mostly just to differentiate it from checkpoint, since it's already being saved.

If I try to rephrase my problem: I can't seem to find the variable in your code which let me save the model/weights to a hdf5 file, I.e. which variable where I can call model.save('my_model.h5') on, so I ca save the model/weights as they do in the web page you linked?

The repo you linked to, didn't quite help. As I can't see where it's being called upon when doing the training from batch_rl/fixed_replay/train.py with the dqn agent.

agarwl commented 4 years ago

I see, the code heavily builds on top of dopamine and the online and target Q-networks (Keras models) are given by the variables self.online_convnet and self.target_convnet. Each agent derives from a base DQNAgent and you can see the models being created here.

You can access these variables, from all the offline agents and simply call save on them -- currently, the saving is done inside the base agent via this saver but something like this should work:

offline_agent =  OfflineAgent() #create any offline agent (DQN, QR-DQN) from the code
offline_agent.online_convnet.save('my_model.h5')
agarwl commented 4 years ago

Hi @2bben, were you able to save the trained models?

2bben commented 4 years ago

Hi,

Yes I was, thanks for the help!

It was still a little bit problematic as the model was not sequential (or at least what the error message told me) and the names on the layers where the same. This is more of a dopamine issue than a issue for this repo, since the network is inherited from that library. I only wanted the weights, so it was only necessary to fix the names on the layers:

The networks is being called in the agent files, so for example for the dqn_agent.py, I changed line 39: nature_dqn_network = atari_lib.NatureDQNNetwork To a equal my own file, which had a copy of the network from the dopamine repo where only the names where changed, e.g.,

self.conv1 = tf.keras.layers.Conv2D(32, [8, 8], strides=4, padding='same', activation=activation_fn, name='Conv')

self.conv2 = tf.keras.layers.Conv2D(64, [4, 4], strides=2, padding='same', activation=activation_fn, name='Conv')

To:

self.conv1 = tf.keras.layers.Conv2D(32, [8, 8], strides=4, padding='same', activation=activation_fn, name='Conv1')

self.conv2 = tf.keras.layers.Conv2D(64, [4, 4], strides=2, padding='same', activation=activation_fn, name='Conv2')

This has to be done on all layers, I also think that it will work to just remove the name= statement from every line.

Then I could use: offline_agent.online_convnet.save_weights('my_model.h5')

Note: Didn't try to solve the 'sequential' error

agarwl commented 4 years ago

If it's not too much trouble, can you please post the snippet you used to save and load the keras model? This might be useful to other people too.

2bben commented 4 years ago

Updated my previous comment