tensorflow / agents

TF-Agents: A reliable, scalable and easy to use TensorFlow library for Contextual Bandits and Reinforcement Learning.
Apache License 2.0
2.77k stars 714 forks source link

Does TF-Agents support delayed rewards? #529

Open greglira opened 3 years ago

greglira commented 3 years ago

I couldn't find any references in the documentation regarding the support for learning under delayed feedback (https://sites.ualberta.ca/~szepesva/papers/DelayedOnlineLearning.pdf). For example, in a simple batch-oriented usecase with multi-armed bandits, is there a mechanism to store the action log and then use it during the training when the rewards are collected?

touqir14 commented 3 years ago

This feature would really be useful (much of the real world uses include delayed rewards). I was working on an advertisement recommender system for a web platform using Multi-Armed Bandits and was looking for algorithms under the delayed reward setting. I can definitely look into this if there is interest from the project maintainers.

bartokg commented 3 years ago

While there are no agents specifically designed for delayed feedback in the codebase, you can use any of the agents without major hiccups.

Training and serving can be completely decoupled, and you can construct any training schedule. Instead of using the dynamic_step_driver, you can collect training data by logging the actions (and context if you have any in your problem), and joining them with rewards when you receive them. Then, you can call the train() function on your agent with the trajectory you constructed from the logged data.

greglira commented 3 years ago

@bartokg just to make sure we're doing it right, because the current solution looks a little bit hacky to me:

  1. Get the actions to save for later use: actions = agent.collect_policy.action(ts).action
  2. During the training use the actions (replay_actions) persisted before to construct a trajectory:
    
    action_step = action_step.replace(action=replay_actions)
    next_time_step = self.env.step(action_step.action)
    time_step = time_step._replace(
                  step_type=tf.fill(batch_size, ts.StepType.FIRST))

traj = trajectory.from_transition(time_step, action_step, next_time_step)


We created a custom `_loop_body_fn` to do this and redifined the `DynamicStepDriver`. We provide observations and rewards separately, from our bandit environment, based on the action log.
Is it the right way to decouple the action selection and training or something better is supported directly by the API? If so, do you have any examples of that?
Thanks!
bartokg commented 3 years ago

In your example I don't see guaranteed that the actions (and the observations) are linked to the right rewards when they arrive. Maybe a good starting point is to implement a replay_buffer that looks at some kind of "time stamp" of the reward and does the joining in the add_batch function.

One side note: when you are creating a trajectory for a bandit problem, I suggest using trajectory.single_step().

diskun00 commented 3 years ago

@bartokg just to make sure we're doing it right, because the current solution looks a little bit hacky to me:

  1. Get the actions to save for later use: actions = agent.collect_policy.action(ts).action
  2. During the training use the actions (replay_actions) persisted before to construct a trajectory:
action_step = action_step.replace(action=replay_actions)
next_time_step = self.env.step(action_step.action)
time_step = time_step._replace(
                  step_type=tf.fill(batch_size, ts.StepType.FIRST))

traj = trajectory.from_transition(time_step, action_step, next_time_step)

We created a custom _loop_body_fn to do this and redifined the DynamicStepDriver. We provide observations and rewards separately, from our bandit environment, based on the action log. Is it the right way to decouple the action selection and training or something better is supported directly by the API? If so, do you have any examples of that? Thanks!

Hi @greglira , have you found a solution to deal with delayed rewards?