Closed farbod1277 closed 2 years ago
Thanks for the bug report. Could you include a short Python script that I could run that would replicate this problem?
Thank you for your response, below is the script:
import os
import pathlib
import pickle
import tempfile
import seals # noqa: F401
import stable_baselines3 as sb3
from imitation.algorithms import adversarial, bc
from imitation.data import rollout
from imitation.util import logger, util
from stable_baselines3.common.policies import BaseModel
dirname = os.path.dirname(__file__)
# Load pickled test demonstrations.
with open(os.path.join(dirname, "final.pkl"), "rb") as f:
# This is a list of `imitation.data.types.Trajectory`, where
# every instance contains observations and actions for a single expert
# demonstration.
trajectories = pickle.load(f)
# Convert List[types.Trajectory] to an instance of `imitation.data.types.Transitions`.
# This is a more general dataclass containing unordered
# (observation, actions, next_observation) transitions.
transitions = rollout.flatten_trajectories(trajectories)
venv = util.make_vec_env("seals/CartPole-v0", n_envs=2)
tempdir = tempfile.TemporaryDirectory(prefix="quickstart")
tempdir_path = pathlib.Path(tempdir.name)
print(f"All Tensorboards and logging are being written inside {tempdir_path}/.")
# Train GAIL on expert data.
# GAIL, and AIRL also accept as `expert_data` any Pytorch-style DataLoader that
# iterates over dictionaries containing observations, actions, and next_observations.
gail_logger = logger.configure(tempdir_path / "GAIL/")
gail_trainer = adversarial.GAIL(
venv,
expert_data=transitions,
expert_batch_size=32,
gen_algo=sb3.PPO("MlpPolicy", venv, verbose=1, n_steps=1024),
custom_logger=gail_logger,
)
gail_trainer.train(total_timesteps=2048)
gail_trainer.gen_algo.save(os.path.join(dirname,"gail_model.zip"))
model = BaseModel.load(os.path.join(dirname, "gail_model.zip"))
Just run this in any directory where you have the final.pkl
file for the expert trajectories from a CartPole-v0
env. I've run into this issue when working on a custom environment (my code can be found here, I'm doing the saving and loading in two different scripts there) but the same error occurs using the environment used in the script above.
Thanks
I think the problem is that gen_algo
is a BaseAlgorithm
(by default, PPO
) not a BaseModel
. We're saving the whole RL algorithm gen_algo
, not just the policy.
Can you try changing BaseModel.load
to PPO.load
? That fixed it for me, although I had to modify your script a bit to get it to run on master
. (By the way, you may want to try using master
-- the PyPi release is quite out of date at this point, although we intend to make a release soon.)
For completeness, here's the script that works for me:
import os
import pathlib
import pickle
import tempfile
import seals # noqa: F401
import stable_baselines3 as sb3
from imitation.algorithms.adversarial import gail
from imitation.data import rollout
from imitation.util import logger, util
dirname = os.path.dirname(__file__)
# Load pickled test demonstrations.
with open(os.path.join(dirname, "final.pkl"), "rb") as f:
# This is a list of `imitation.data.types.Trajectory`, where
# every instance contains observations and actions for a single expert
# demonstration.
trajectories = pickle.load(f)
# Convert List[types.Trajectory] to an instance of `imitation.data.types.Transitions`.
# This is a more general dataclass containing unordered
# (observation, actions, next_observation) transitions.
transitions = rollout.flatten_trajectories(trajectories)
venv = util.make_vec_env("seals/CartPole-v0", n_envs=2)
tempdir = tempfile.TemporaryDirectory(prefix="quickstart")
tempdir_path = pathlib.Path(tempdir.name)
print(f"All Tensorboards and logging are being written inside {tempdir_path}/.")
# Train GAIL on expert data.
# GAIL, and AIRL also accept as `expert_data` any Pytorch-style DataLoader that
# iterates over dictionaries containing observations, actions, and next_observations.
gail_logger = logger.configure(tempdir_path / "GAIL/")
gail_trainer = gail.GAIL(
demonstrations=transitions,
demo_batch_size=32,
venv=venv,
gen_algo=sb3.PPO("MlpPolicy", venv, verbose=1, n_steps=1024),
custom_logger=gail_logger,
)
gail_trainer.train(total_timesteps=2048)
gail_trainer.gen_algo.save(os.path.join(dirname,"gail_model.zip"))
model = sb3.PPO.load(os.path.join(dirname, "gail_model.zip"))
Closing but feel free to re-open if this does not address your issue.
That fixed it! Thank you very much for your assistance.
Hi, thanks for the awesome project. I've got the following issue. I've trained my agent using GAIL and saved the gen_algo using:
gail_trainer.gen_algo.save(os.path.join(dirname, "gail_model.zip"))
wheregail_trainer
is anadversarial.GAIL
instance. I then load the model in my evaluation script (separate python script) using:model = BaseModel.load(os.path.join(dirname, "gail_model.zip"))
But I get the following error:The zip file contains a file named
data
along with some.pth
files and the stable baselines 3 version. Any help would be appreciated. Thanks