Replicable-MARL / MARLlib

One repository is all that is necessary for Multi-agent Reinforcement Learning (MARL)
https://marllib.readthedocs.io
MIT License
887 stars 142 forks source link

How to export trained model as a .pt (pytorch ) or ONNX model. #227

Open manaspalaparthi opened 7 months ago

manaspalaparthi commented 7 months ago

How to export trained model as a .pt (pytorch ) or ONNX model.

I have fully trained my model and want to deploy the model into the Unity ML agents Env. I have to export the trained model either in Pytorch or ONNX.

I could only see one option "algo.render()" in the documentation.

Aequatio-Space commented 7 months ago

Although I do not know how Ray can do that directly, I tried to unwrap a Ray checkpoint and figured out its structure. First, load the raw checkpoint with pickle.load, you will get a dictionary instance, whose value for key 'worker' is a bytes instance that contains the model weights. Use pickle.loads to get the worker status dictionary. select key 'state' and then 'weight', which will be the raw parameters for the network. You may manually pack them into a .pt object.

Morphlng commented 7 months ago

RLlib's Policy class has the function export_model, which is used for exporting raw learning framework model with options to save as ONNX model.

So the problem falls back to how to load the checkpoint MARLlib saved. I've personally wrote a script to load the checkpoint + params.json. You can reuse the load_model function to retreive the policy, and then export it:

from eval import load_model

ckpt = load_model(
    {
        "model_path": "best_model/checkpoint",
        "params_path": "best_model/params.json",
    }
)

env = marl.make_env(environment_name=ckpt.env_name, map_name=ckpt.map_name)
env_instance, env_info = env

# Change the policy name accordingly
policy = ckpt.trainer.get_policy("shared_policy")
policy.export_model("/directoty/to/save")

PS: In case anybody want to know how to use the raw model:

model = policy.model
state = policy.get_initial_state()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_obs(env):
    obs = env.observation_space.sample()
    # Suppose observation is a dict. E.g.
    # obs = {
    #   "action_mask": [0, 0, 1, 0],
    #   "obs": [1, 1, 4, 5, 1, 4],
    # }
    for key in obs:
        obs[key] = torch.from_numpy(np.array([obs[key]])).to(DEVICE)
    return obs

dummy_input = {
    "input_dict": {"obs": get_obs(env_instance)},
    "state": [torch.from_numpy(np.array(state)).to(DEVICE)],
    "seq_lens": np.array([1])
}

output = model(**dummy_input)