nissymori / JAX-CORL

Clean single-file implementation of offline RL algorithms in JAX
MIT License
98 stars 2 forks source link
awac cql d4rl decision-transformer flax iql jax offline-reinforcement-learning offline-rl reinforcement-learning single-file td3bc

JAX-CORL

This repository aims JAX version of CORL, clean single-file implementations of offline RL algorithms with solid performance reports.

JAX-CORL is complementing the single-file RL ecosystem by offering the combination of offline x JAX.

Algorithms

Algorithm implementation training time (CORL) training time (ours) wandb
AWAC algos/awac.py 4.46h 11m(24x faster) link
IQL algos/iql.py 4.08h 9m(28x faster) link
TD3+BC algos/td3_bc.py 2.47h 9m(16x faster) link
CQL algos/cql.py 11.52h 56m(12x faster) link
DT algos/dt.py 42m 11m(4x faster) link

Training time is for 1000_000 update steps without evaluation for halfcheetah-medium-expert v2 (little difference between different D4RL mujoco environments). The training time of ours includes the compile time for jit. The computations were performed using four GeForce GTX 1080 Ti GPUs. PyTorch's time is measured with CORL implementations.

Reports for D4RL mujoco

Normalized Score

Here, we used D4RL mujoco control tasks as the benchmark. We reported the mean and standard deviation of the average normalized score of 5 episodes over 5 seeds. We plan to extend the verification to other D4RL benchmarks such as AntMaze. For those who would like to know about the source of hyperparameters and the validity of the performance, please refer to Wiki. env AWAC IQL TD3+BC CQL DT
halfcheetah-medium-v2 $41.56\pm0.79$ $43.28\pm0.51$ $48.12\pm0.42$ $48.65\pm 0.49$ $42.63 \pm 0.53$
halfcheetah-medium-expert-v2 $76.61\pm 9.60$ $92.87\pm0.61$ $92.99\pm 0.11$ $53.76 \pm 14.53$ $70.63\pm 14.70$
hopper-medium-v2 $51.45\pm 5.40$ $52.17\pm2.88$ $46.51\pm4.57$ $77.56\pm 7.12$ $60.85\pm6.78$
hopper-medium-expert-v2 $51.89\pm2.11$ $53.35\pm5.63$ $105.47\pm5.03$ $90.37 \pm 31.29$ $109.07\pm 4.56$
walker2d-medium-v2 $68.12\pm12.08$ $75.33\pm5.2$ $72.73\pm4.66$ $80.16\pm 4.19$ $71.04 \pm5.64$
walker2d-medium-expert-v2 $91.36\pm23.13$ $109.07\pm0.32$ $109.17\pm0.71$ $110.03 \pm 0.72$ $99.81\pm17.73$

How to use this codebase for your research

This codebase can be used independently as a baseline for D4RL projects. It is also designed to be flexible, allowing users to develop new algorithms or adapt them for datasets other than D4RL.

For researchers interested in using this code for their projects, we provide a detailed explanation of the code's shared structure:

Data structure
Transition(NamedTuple):
    observations: jnp.ndarray
    actions: jnp.ndarray
    rewards: jnp.ndarray
    next_observations: jnp.ndarray
    dones: jnp.ndarray

def get_dataset(...) -> Transition:
    ...
    return dataset

The code includes a Transition class, defined as a NamedTuple, which contains fields for observations, actions, rewards, next observations, and done flags. The get_dataset function is expected to output data in the Transition format, making it adaptable to any dataset that conforms to this structure.

Trainer class
class AlgoTrainState(NamedTuple):
    actor: TrainState
    critic: TrainState

class Algo(object):
    ...
    def update_actor(self, train_state: AlgoTrainState, batch: Transition, config) -> AlgoTrainState:
        ...
        return train_state

    def update_critic(self, train_state: AlgoTrainState, batch: Transition, config) -> AlgoTrainState:
        ...
        return train_state

    @partial(jax.jit, static_argnames("n_jitted_updates")
    def update_n_times(self, train_state: AlgoTrainState,  data, n_jitted_updates, config) -> AlgoTrainState:
      for _ in range(n_updates):
        batch = data.sample()
        train_state = self.update_actor(train_state, batch, config)
        agent = self.update_critic(train_state, batch, config)
      return train_state

def create_train_state(...) -> AlgoTrainState:
    # initialize models...
    return AlgoTrainState(
        acotor=actor,
        critic=critic,
    )

For all algorithms, we have TrainState class (e.g. TD3BCTrainState for TD3+BC) which encompasses all flax trainstate for models. Update logic is implemented as the method of Algo classes (e.g. TD3BC) Both TrainState and Algo classes are versatile and can be used outside of the provided files if the create_train_state function is properly implemented to meet the necessary specifications for the TrainState class. Note: So far, we have not followed the policy for CQL due to technical issues. This will be handled in the near future.

See also

Great Offline RL libraries

Implementations of offline RL algorithms in JAX

Single-file implementations

Cite JAX-CORL

@article{nishimori2024jaxcorl,
  title={JAX-CORL: Clean Sigle-file Implementations of Offline RL Algorithms in JAX},
  author={Soichiro Nishimori},
  year={2024},
  url={https://github.com/nissymori/JAX-CORL}
}

Credits