HumanCompatibleAI / imitation

Clean PyTorch implementations of imitation and reward learning algorithms
https://imitation.readthedocs.io/
MIT License
1.26k stars 239 forks source link

Not able to load model from zip file GAIL #362

Closed farbod1277 closed 2 years ago

farbod1277 commented 2 years ago

Hi, thanks for the awesome project. I've got the following issue. I've trained my agent using GAIL and saved the gen_algo using: gail_trainer.gen_algo.save(os.path.join(dirname, "gail_model.zip")) where gail_trainer is an adversarial.GAIL instance. I then load the model in my evaluation script (separate python script) using: model = BaseModel.load(os.path.join(dirname, "gail_model.zip")) But I get the following error:

Traceback (most recent call last):
  File "/home/farbod/git/gym-pybullet-drones/experiments/imitation/evalute_model.py", line 82, in <module>
    model = BaseModel.load(os.path.relpath(os.path.join(dirname, "gail_model.zip")))
  File "/home/farbod/.local/lib/python3.8/site-packages/stable_baselines3/common/policies.py", line 173, in load
    saved_variables = th.load(path, map_location=device)
  File "/home/farbod/.local/lib/python3.8/site-packages/torch/serialization.py", line 585, in load
    with _open_zipfile_reader(opened_file) as opened_zipfile:
  File "/home/farbod/.local/lib/python3.8/site-packages/torch/serialization.py", line 242, in __init__
    super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
RuntimeError: [enforce fail at inline_container.cc:110] . file in archive is not in a subdirectory: data

The zip file contains a file named data along with some .pth files and the stable baselines 3 version. Any help would be appreciated. Thanks

AdamGleave commented 2 years ago

Thanks for the bug report. Could you include a short Python script that I could run that would replicate this problem?

farbod1277 commented 2 years ago

Thank you for your response, below is the script:

import os

import pathlib
import pickle
import tempfile

import seals  # noqa: F401
import stable_baselines3 as sb3

from imitation.algorithms import adversarial, bc
from imitation.data import rollout
from imitation.util import logger, util
from stable_baselines3.common.policies import BaseModel

dirname = os.path.dirname(__file__)

# Load pickled test demonstrations.
with open(os.path.join(dirname, "final.pkl"), "rb") as f:
    # This is a list of `imitation.data.types.Trajectory`, where
    # every instance contains observations and actions for a single expert
    # demonstration.
    trajectories = pickle.load(f)

# Convert List[types.Trajectory] to an instance of `imitation.data.types.Transitions`.
# This is a more general dataclass containing unordered
# (observation, actions, next_observation) transitions.
transitions = rollout.flatten_trajectories(trajectories)

venv = util.make_vec_env("seals/CartPole-v0", n_envs=2)

tempdir = tempfile.TemporaryDirectory(prefix="quickstart")
tempdir_path = pathlib.Path(tempdir.name)
print(f"All Tensorboards and logging are being written inside {tempdir_path}/.")

# Train GAIL on expert data.
# GAIL, and AIRL also accept as `expert_data` any Pytorch-style DataLoader that
# iterates over dictionaries containing observations, actions, and next_observations.
gail_logger = logger.configure(tempdir_path / "GAIL/")
gail_trainer = adversarial.GAIL(
    venv,
    expert_data=transitions,
    expert_batch_size=32,
    gen_algo=sb3.PPO("MlpPolicy", venv, verbose=1, n_steps=1024),
    custom_logger=gail_logger,
)
gail_trainer.train(total_timesteps=2048)
gail_trainer.gen_algo.save(os.path.join(dirname,"gail_model.zip"))

model = BaseModel.load(os.path.join(dirname, "gail_model.zip"))

Just run this in any directory where you have the final.pkl file for the expert trajectories from a CartPole-v0 env. I've run into this issue when working on a custom environment (my code can be found here, I'm doing the saving and loading in two different scripts there) but the same error occurs using the environment used in the script above.

Thanks

AdamGleave commented 2 years ago

I think the problem is that gen_algo is a BaseAlgorithm (by default, PPO) not a BaseModel. We're saving the whole RL algorithm gen_algo, not just the policy.

Can you try changing BaseModel.load to PPO.load? That fixed it for me, although I had to modify your script a bit to get it to run on master. (By the way, you may want to try using master -- the PyPi release is quite out of date at this point, although we intend to make a release soon.)

For completeness, here's the script that works for me:

import os

import pathlib
import pickle
import tempfile

import seals  # noqa: F401
import stable_baselines3 as sb3

from imitation.algorithms.adversarial import gail
from imitation.data import rollout
from imitation.util import logger, util

dirname = os.path.dirname(__file__)

# Load pickled test demonstrations.
with open(os.path.join(dirname, "final.pkl"), "rb") as f:
    # This is a list of `imitation.data.types.Trajectory`, where
    # every instance contains observations and actions for a single expert
    # demonstration.
    trajectories = pickle.load(f)

# Convert List[types.Trajectory] to an instance of `imitation.data.types.Transitions`.
# This is a more general dataclass containing unordered
# (observation, actions, next_observation) transitions.
transitions = rollout.flatten_trajectories(trajectories)

venv = util.make_vec_env("seals/CartPole-v0", n_envs=2)

tempdir = tempfile.TemporaryDirectory(prefix="quickstart")
tempdir_path = pathlib.Path(tempdir.name)
print(f"All Tensorboards and logging are being written inside {tempdir_path}/.")

# Train GAIL on expert data.
# GAIL, and AIRL also accept as `expert_data` any Pytorch-style DataLoader that
# iterates over dictionaries containing observations, actions, and next_observations.
gail_logger = logger.configure(tempdir_path / "GAIL/")
gail_trainer = gail.GAIL(
    demonstrations=transitions,
    demo_batch_size=32,
    venv=venv,
    gen_algo=sb3.PPO("MlpPolicy", venv, verbose=1, n_steps=1024),
    custom_logger=gail_logger,
)
gail_trainer.train(total_timesteps=2048)
gail_trainer.gen_algo.save(os.path.join(dirname,"gail_model.zip"))

model = sb3.PPO.load(os.path.join(dirname, "gail_model.zip"))

Closing but feel free to re-open if this does not address your issue.

farbod1277 commented 2 years ago

That fixed it! Thank you very much for your assistance.