yamatokataoka / reinforcement-learning-replications

Reinforcement Learning Replications is a set of Pytorch implementations of reinforcement learning algorithms.
MIT License
25 stars 1 forks source link

Rethink training design #59

Closed yamatokataoka closed 2 years ago

yamatokataoka commented 2 years ago

I'm thinking using rl-replicas on this project. https://github.com/yamatokataoka/learning-from-human-preferences

In that case, learn function should be customisable so swapping a reward function provided from an OpenAI gym environment with a learned reward function.

I also don't think Inheritance is easy to understand and extensible for the future RL implementations.

Todos

The point is how much I can bring these two functions on on_policy_algorithm.py and off_policy_algorithm.py become closer together.

yamatokataoka commented 2 years ago

I'm not sure what is the final outcome yet, but let's do the separation of concerns. Designing is a wicked problem.

For now, there are two similar functions with the same name, collect_one_epoch_experience in src/rl_replicas/common/base_algorithms/on_policy_algorithm.py and src/rl_replicas/common/base_algorithms/off_policy_algorithm.py.

The responsibility of this functions is to collect experience to train the models. I can start from refactoring these functions.

yamatokataoka commented 2 years ago

So after this redesigning finished, we can write RL training like this?

Need a trainer? it seems for now only sampler would be necessary.

For on-policy algorithms

...omitted policy and value_function instantiation
model: VPG = VPG(policy, value_function, action_dim, state_dim, seed=seed)
sampler = Sampler()
# trainer = Trainer(output_dir=output_dir, tensorboard=True, model_saving=True)

for epoch in range(num_epochs):
    episodes = sampler.sample(model, env, steps)
    model.train(episodes)

For off-policy algorithms

...omitted policy and q_function instantiation
model: DDPG = DDPG(policy, q_function, action_dim, state_dim)
sampler = Sampler()
random_sampler = RandomSampler()
# trainer = Trainer(output_dir=output_dir, tensorboard=True, model_saving=True)
replay_buffer = ReplayBuffer()

for epoch in range(num_epochs):
    if start_random_sample:
        episodes = random_sampler.sample(env, steps)
    else:
        episodes = continuous_sampler.sample(model, env, steps)
    replay_buffer.add(episodes)
    model.train(replay_buffer)
yamatokataoka commented 2 years ago

On on_policy_algorithm.py, the current implementation of collect_one_epoch_experiencefunction is to primarily collect experience interacting with the env but also calculate advantages and discounted returns. I think this function should be responsible for collecting experience only. Other things can be done in a train function as preprocessing raw experience data.

yamatokataoka commented 2 years ago

Probably, it needs a sampler for both on-policy and off-policy algorithms

yamatokataoka commented 2 years ago

Differences between collect_one_epoch_experience function on on-policy and off-policy algorithms

yamatokataoka commented 2 years ago

I found that trainer model cannot handle this part. I think it may become complicated if I tried to implement it. https://github.com/yamatokataoka/reinforcement-learning-replications/blob/6936158383e6895494fdf8c7a5bd5aeca3dfbc9f/src/rl_replicas/base_algorithms/off_policy_algorithm.py#L225-L231

yamatokataoka commented 2 years ago

I found that it's possible to handle this using policy argument on sample function.

https://github.com/rlworkgroup/garage/blob/3492f446633a7e748f2f79077f6301c5b3ec9281/src/garage/trainer.py#L176

yamatokataoka commented 2 years ago

During the implementation, I noticed that the part of DDPG which collect experience using random policy and further experience with another policy is like a part of the algorithm. So I'm a bit hesitated to pull this implementation out of DDPG.

https://github.com/vwxyzjn/cleanrl/blob/7ce655d8fb7f632957a98d6500fc77a8285d22af/cleanrl/ddpg_continuous_action.py#L170-L177

yamatokataoka commented 2 years ago

All todos are done.