HumanCompatibleAI / imitation

Clean PyTorch implementations of imitation and reward learning algorithms
https://imitation.readthedocs.io/
MIT License
1.3k stars 247 forks source link

[Question] AIRL / GAIL : prediction on state ? #264

Closed romain-mondelice closed 3 years ago

romain-mondelice commented 3 years ago

Hello everyone, I have a question that may seem silly.

I come directly from stable-baselines. I'm used to use the trained GAIL model to make predictions, like this way:


model = SAC('MlpPolicy', 'Pendulum-v0', verbose=1)
generate_expert_traj(model, 'expert_pendulum', n_timesteps=100, n_episodes=10)

dataset = ExpertDataset(expert_path='expert_pendulum.npz', traj_limitation=10, verbose=1)

model = GAIL('MlpPolicy', 'Pendulum-v0', dataset, verbose=1)
model.learn(total_timesteps=1000)
model.save("gail_pendulum")

del model 

model = GAIL.load("gail_pendulum")

env = gym.make('Pendulum-v0')
obs = env.reset()
while True:
  action, _states = model.predict(obs)
  obs, rewards, dones, info = env.step(action)
  env.render()

But here I get lost because once the GAIL / AIRL is trained I don't know how to use it and make predictions on my test and validation dataset.

logger.configure("D:/TensorLogs/AIRL")
airl_trainer = AIRL(
    train_env,
    expert_data=transitions,
    expert_batch_size=32,
    gen_algo=PPO("MlpPolicy", train_env, verbose=1, n_steps=50000),
)

airl_trainer.train(total_timesteps=50000)

I'm stuck here. I don't know how to use my trained model. Is there an equivalent to the .predict(state) that we can use in stable-baselines?

I need your expertise, thank you in advance! Best and kinds regards, Romain

NathanGavenski commented 2 years ago

for anyone finding this issue, the solution is to use:

airl_trainer.gen_algo.prediction(obs)