HumanCompatibleAI / imitation

Clean PyTorch implementations of imitation and reward learning algorithms
https://imitation.readthedocs.io/
MIT License
1.33k stars 249 forks source link

segmentation fault while loading policy #845

Closed chenyangkang closed 7 months ago

chenyangkang commented 7 months ago

Bug description

>>> load_policy(
...     "ppo-huggingface",
...     organization="HumanCompatibleAI",
...     env_name="seals/CartPole-v0",
...     venv=env,
... )

/Users/chenyangkang/miniconda3/lib/python3.11/site-packages/stable_baselines3/common/save_util.py:167: UserWarning: Could not deserialize object learning_rate. Consider using `custom_objects` argument to replace this object.
Exception: code() argument 13 must be str, not int
  warnings.warn(
/Users/chenyangkang/miniconda3/lib/python3.11/site-packages/stable_baselines3/common/save_util.py:167: UserWarning: Could not deserialize object clip_range. Consider using `custom_objects` argument to replace this object.
Exception: code() argument 13 must be str, not int
  warnings.warn(
/Users/chenyangkang/miniconda3/lib/python3.11/site-packages/stable_baselines3/common/save_util.py:167: UserWarning: Could not deserialize object lr_schedule. Consider using `custom_objects` argument to replace this object.
Exception: code() argument 13 must be str, not int
  warnings.warn(
zsh: segmentation fault  python

Steps to reproduce

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

import numpy as np
from imitation.policies.serialize import load_policy
from imitation.util.util import make_vec_env
from imitation.data.wrappers import RolloutInfoWrapper

SEED = 42

env = make_vec_env(
    "seals:seals/CartPole-v0",
    rng=np.random.default_rng(SEED),
    n_envs=8,
    post_wrappers=[
        lambda env, _: RolloutInfoWrapper(env)
    ],  # needed for computing rollouts later
)

expert = load_policy(
    "ppo-huggingface",
    organization="HumanCompatibleAI",
    env_name="seals/CartPole-v0",
    venv=env,
)

Environment

chenyangkang commented 7 months ago

Likely due to the python version. Python 3.8 works fine.

references: https://github.com/DLR-RM/stable-baselines3/issues/172