google-research / batch_rl

Offline Reinforcement Learning (aka Batch Reinforcement Learning) on Atari 2600 games
https://offline-rl.github.io/
Apache License 2.0
528 stars 74 forks source link

JAX code #30

Open lucasliunju opened 2 years ago

lucasliunju commented 2 years ago

Hi,

I would like to ask whether there is a jax-based code.

And whether there are some recommendations about jax-based offline rl algorithms.

Thanks!

agarwl commented 2 years ago

Releasing the JAX code might take some time but it should be easy to modify existing dopamine agents. In the meanwhile, here are some tips to get started with jax agents:


@gin.configurable
class OfflineJaxDQNAgent(dqn_agent.JaxDQNAgent):
  """A JAX implementation of the Offline DQN agent."""

  def __init__(self,
               num_actions,
               replay_data_dir,
               summary_writer=None):
    """Initializes the agent and constructs the necessary components.

    Args:
      num_actions: int, number of actions the agent can take at any state.
      replay_data_dir: str, log Directory from which to load the replay buffer.
      summary_writer: SummaryWriter object for outputting training statistics
    """
    logging.info('Creating %s agent with the following parameters:',
                 self.__class__.__name__)
    logging.info('\t replay directory: %s', replay_data_dir)
    self.replay_data_dir = replay_data_dir
    super().__init__(
        num_actions, update_period=1, summary_writer=summary_writer)

  def _build_replay_buffer(self):
    """Creates the fixed replay buffer used by the agent."""
    return fixed_replay_buffer.FixedReplayBuffer(
        data_dir=self.replay_data_dir,
        observation_shape=self.observation_shape,
        stack_size=self.stack_size,
        update_horizon=self.update_horizon,
        gamma=self.gamma,
        observation_dtype=self.observation_dtype)

  def reload_data(self):
    # This needs to be called every iteration to subsample a portion of the dataset.
    self._replay.reload_data()

  def step(self, reward, observation):
    """Returns the agent's next action and update agent's state.

    Args:
      reward: float, the reward received from the agent's most recent action.
      observation: numpy array, the most recent observation.

    Returns:
      int, the selected action.
    """
    self._record_observation(observation)
    self._rng, self.action = dqn_agent.select_action(
        self.network_def, self.online_params, self.state, self._rng,
        self.num_actions, self.eval_mode, self.epsilon_eval, self.epsilon_train,
        self.epsilon_decay_period, self.training_steps, self.min_replay_history,
        self.epsilon_fn)
    self.action = onp.asarray(self.action)
    return self.action

  def train_step(self):
    """Exposes the train step for offline learning."""
    super()._train_step()
lucasliunju commented 2 years ago

Thank you very much! May I ask if I can also run the code in TPU-VM with JAX?

Best, Lucas

agarwl commented 2 years ago

I think so -- you probably want to use the tfds datasets or use much larger batch sizes with the dopamine codebase.

lucasliunju commented 2 years ago

Thank you very much! I'll have a try.

lucasliunju commented 1 year ago

Dear agarwl,

I try to follow your provided code and reproduce the results of offline dqn based on jax. I find the training speed of jax is quite slow compared with TensorFlow. May I ask the possible reason about that. I try to change these parts in the vanilla dopamine code: (1) I try to rewrite the Runner in dopamine/dopamine/discrete_domains/run_experiment.py based on the code in batch_rl/batch_rl/fixed_replay/run_experiment.py:

@gin.configurable
class FixedReplayRunner(run_experiment.Runner):
  """Object that handles running Dopamine experiments with fixed replay buffer."""

  def _initialize_checkpointer_and_maybe_resume(self, checkpoint_file_prefix):
    super(FixedReplayRunner, self)._initialize_checkpointer_and_maybe_resume(
        checkpoint_file_prefix)

    # Code for the loading a checkpoint at initialization
    init_checkpoint_dir = self._agent._init_checkpoint_dir  # pylint: disable=protected-access
    if (self._start_iteration == 0) and (init_checkpoint_dir is not None):
      if checkpointer.get_latest_checkpoint_number(self._checkpoint_dir) < 0:
        # No checkpoint loaded yet, read init_checkpoint_dir
        init_checkpointer = checkpointer.Checkpointer(
            init_checkpoint_dir, checkpoint_file_prefix)
        latest_init_checkpoint = checkpointer.get_latest_checkpoint_number(
            init_checkpoint_dir)
        if latest_init_checkpoint >= 0:
          experiment_data = init_checkpointer.load_checkpoint(
              latest_init_checkpoint)
          if self._agent.unbundle(
              init_checkpoint_dir, latest_init_checkpoint, experiment_data):
            if experiment_data is not None:
              assert 'logs' in experiment_data
              assert 'current_iteration' in experiment_data
              self._logger.data = experiment_data['logs']
              self._start_iteration = experiment_data['current_iteration'] + 1
            tf.logging.info(
                'Reloaded checkpoint from %s and will start from iteration %d',
                init_checkpoint_dir, self._start_iteration)

  def _run_train_phase(self):
    """Run training phase."""
    self._agent.eval_mode = False
    start_time = time.time()
    for _ in range(self._training_steps):
      self._agent._train_step()  # pylint: disable=protected-access
    time_delta = time.time() - start_time
    tf.logging.info('Average training steps per second: %.2f',
                    self._training_steps / time_delta)

  def _run_one_iteration(self, iteration):
    """Runs one iteration of agent/environment interaction."""
    statistics = iteration_statistics.IterationStatistics()
    tf.logging.info('Starting iteration %d', iteration)
    # pylint: disable=protected-access
    if not self._agent._replay_suffix:
      # Reload the replay buffer
      self._agent._replay.memory.reload_buffer(num_buffers=5)
    # pylint: enable=protected-access
    self._run_train_phase()

    num_episodes_eval, average_reward_eval = self._run_eval_phase(statistics)

    self._save_tensorboard_summaries(
        iteration, num_episodes_eval, average_reward_eval)
    return statistics.data_lists

  def _save_tensorboard_summaries(self, iteration,
                                  num_episodes_eval,
                                  average_reward_eval):
    """Save statistics as tensorboard summaries.
    Args:
      iteration: int, The current iteration number.
      num_episodes_eval: int, number of evaluation episodes run.
      average_reward_eval: float, The average evaluation reward.
    """
    summary = tf.Summary(value=[
        tf.Summary.Value(tag='Eval/NumEpisodes',
                         simple_value=num_episodes_eval),
        tf.Summary.Value(tag='Eval/AverageReturns',
                         simple_value=average_reward_eval)
    ])
    self._summary_writer.add_summary(summary, iteration)

(2) creat offline buffer: fixed_replay_buffer.py

(3) create OfflineJaxDQNAgent:

@gin.configurable
class OfflineJaxDQNAgent(dqn_agent.JaxDQNAgent):
  """A JAX implementation of the Offline DQN agent."""

  def __init__(self,
               num_actions,
               replay_data_dir,
               summary_writer=None):
    """Initializes the agent and constructs the necessary components.

    Args:
      num_actions: int, number of actions the agent can take at any state.
      replay_data_dir: str, log Directory from which to load the replay buffer.
      summary_writer: SummaryWriter object for outputting training statistics
    """
    logging.info('Creating %s agent with the following parameters:',
                 self.__class__.__name__)
    logging.info('\t replay directory: %s', replay_data_dir)
    self.replay_data_dir = replay_data_dir
    super().__init__(
        num_actions, update_period=1, summary_writer=summary_writer)

  def _build_replay_buffer(self):
    """Creates the fixed replay buffer used by the agent."""
    return fixed_replay_buffer.FixedReplayBuffer(
        data_dir=self.replay_data_dir,
        observation_shape=self.observation_shape,
        stack_size=self.stack_size,
        update_horizon=self.update_horizon,
        gamma=self.gamma,
        observation_dtype=self.observation_dtype)

  def reload_data(self):
    # This needs to be called every iteration to subsample a portion of the dataset.
    self._replay.reload_data()

  def step(self, reward, observation):
    """Returns the agent's next action and update agent's state.

    Args:
      reward: float, the reward received from the agent's most recent action.
      observation: numpy array, the most recent observation.

    Returns:
      int, the selected action.
    """
    self._record_observation(observation)
    self._rng, self.action = dqn_agent.select_action(
        self.network_def, self.online_params, self.state, self._rng,
        self.num_actions, self.eval_mode, self.epsilon_eval, self.epsilon_train,
        self.epsilon_decay_period, self.training_steps, self.min_replay_history,
        self.epsilon_fn)
    self.action = onp.asarray(self.action)
    return self.action

  def train_step(self):
    """Exposes the train step for offline learning."""
    super()._train_step()

(4) I try to compare the difference between the jax code and vanilla tf code, I find they use different repaly buffer (FixedReplayBuffer in JAX and WrappedFixedReplayBuffer in TF). I'm not sure whether this is the main reason.

Best

lucasliunju commented 1 year ago

Hi I find the update_period is 1 and the tf code is 4. Maybe that is the main reason.

agarwl commented 1 year ago

Yeah, update_period 1 corresponds to 1 gradient step every environment step (default is 4 which corresponds to 1 grad step every env step). In each iteration, we do 62.5K grad steps, so we can also set num_training_steps to 62.5K with update period 1.

lucasliunju commented 1 year ago

Hi @agarwl Thanks for your reply. I will try it. By the way, I would like to ask can I run the TF code on TPU-VM? Since I find TF is still a little bit faster.

agarwl commented 1 year ago

Sure -- you may not see much benefit of using TPUs (due to small batch size and dopamine replay) but the code be run on TPU.

agarwl commented 1 year ago

Here's some JAX code for reference: https://github.com/google/dopamine/tree/master/dopamine/labs/offline_rl

lucasliunju commented 1 year ago

@agarwl Thank you very much! I will have a try. Thanks!