Denys88 / rl_games

RL implementations
MIT License
819 stars 138 forks source link

Externally Load Checkpoint #243

Closed davesarmoury closed 1 year ago

davesarmoury commented 1 year ago

I've trained a network in Isaac Sim/Gym and it seems like it should work. I have a separate control script written in python for controlling the hardware. How can I take the pth checkpoint file and restore it outside of rl_games?

Denys88 commented 1 year ago

I recommend you to use onnx: https://colab.research.google.com/github/Denys88/rl_games/blob/master/notebooks/train_and_export_onnx_example_continuous.ipynb

davesarmoury commented 1 year ago

I'm trying to follow that, but I keep getting errors about layer torch sizes:

size mismatch for running_mean_std.running_mean: copying a param with shape torch.Size([44]) from checkpoint, the shape in current model is torch.Size([3]).

size mismatch for running_mean_std.running_var: copying a param with shape torch.Size([44]) from checkpoint, the shape in current model is torch.Size([3]).

size mismatch for a2c_network.sigma: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([1]).

size mismatch for a2c_network.actor_mlp.0.weight: copying a param with shape torch.Size([256, 44]) from checkpoint, the shape in current model is torch.Size([256, 3]).

size mismatch for a2c_network.mu.weight: copying a param with shape torch.Size([12, 64]) from checkpoint, the shape in current model is torch.Size([1, 64]).

size mismatch for a2c_network.mu.bias: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([1]).

My network has unit: [256, 128, 64] and should have 44 inputs and 12 outputs

davesarmoury commented 1 year ago

It looks like it's related to the environment selected through envpool. I've tried to set num_obs and num_actions through different means, but I can't find the correct way to set these values through the env_config variable in the config dictionary

Denys88 commented 1 year ago

I've got an Idea: could you try to create it with num_actors equal to 1?

davesarmoury commented 1 year ago

Setting num_actors to 1 has a similar output:

size mismatch for running_mean_std.running_mean: copying a param with shape torch.Size([44]) from checkpoint, the shape in current model is torch.Size([3]). size mismatch for running_mean_std.running_var: copying a param with shape torch.Size([44]) from checkpoint, the shape in current model is torch.Size([3]). size mismatch for a2c_network.sigma: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for a2c_network.actor_mlp.0.weight: copying a param with shape torch.Size([256, 44]) from checkpoint, the shape in current model is torch.Size([256, 3]). size mismatch for a2c_network.mu.weight: copying a param with shape torch.Size([12, 64]) from checkpoint, the shape in current model is torch.Size([1, 64]). size mismatch for a2c_network.mu.bias: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([1]).

davesarmoury commented 1 year ago

HERE is a link to my network in case you need a reference. It was trained from the Go1_Horizontal task in THIS REPO

Denys88 commented 1 year ago

Did you you update a config from example? Looks like you are trying to load network from wrong config. And looks like it uses observations and actions from the wrong env.

davesarmoury commented 1 year ago

I've kept the default notebook, except i've changed "train': True," to false, added "run: True" and changed the checkpoint path to mine

Denys88 commented 1 year ago

This google colab is more like example with random environment. In order to get right obs and action shapes you need to create env and use the same config as you did. I think probably the fastest way is to change omniverse code a little bit to export.

davesarmoury commented 1 year ago

I modified rl_games to export an onnx model each time it saves a checkpoint. Not ideal, but should work for what I'm doing. Thanks for the help