Open manaspalaparthi opened 7 months ago
Although I do not know how Ray can do that directly, I tried to unwrap a Ray checkpoint and figured out its structure.
First, load the raw checkpoint with pickle.load
, you will get a dictionary instance, whose value for key 'worker' is a bytes
instance that contains the model weights. Use pickle.loads
to get the worker status dictionary. select key 'state' and then 'weight', which will be the raw parameters for the network. You may manually pack them into a .pt
object.
RLlib's Policy
class has the function export_model
, which is used for exporting raw learning framework model with options to save as ONNX model.
So the problem falls back to how to load the checkpoint MARLlib saved. I've personally wrote a script to load the checkpoint + params.json
. You can reuse the load_model
function to retreive the policy, and then export it:
from eval import load_model
ckpt = load_model(
{
"model_path": "best_model/checkpoint",
"params_path": "best_model/params.json",
}
)
env = marl.make_env(environment_name=ckpt.env_name, map_name=ckpt.map_name)
env_instance, env_info = env
# Change the policy name accordingly
policy = ckpt.trainer.get_policy("shared_policy")
policy.export_model("/directoty/to/save")
PS: In case anybody want to know how to use the raw model:
model = policy.model
state = policy.get_initial_state()
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_obs(env):
obs = env.observation_space.sample()
# Suppose observation is a dict. E.g.
# obs = {
# "action_mask": [0, 0, 1, 0],
# "obs": [1, 1, 4, 5, 1, 4],
# }
for key in obs:
obs[key] = torch.from_numpy(np.array([obs[key]])).to(DEVICE)
return obs
dummy_input = {
"input_dict": {"obs": get_obs(env_instance)},
"state": [torch.from_numpy(np.array(state)).to(DEVICE)],
"seq_lens": np.array([1])
}
output = model(**dummy_input)
How to export trained model as a .pt (pytorch ) or ONNX model.
I have fully trained my model and want to deploy the model into the Unity ML agents Env. I have to export the trained model either in Pytorch or ONNX.
I could only see one option "algo.render()" in the documentation.