vwxyzjn / cleanrl

High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG)
http://docs.cleanrl.dev
Other
5.02k stars 575 forks source link

add complex observation atari ppo #359

Open ttumiel opened 1 year ago

ttumiel commented 1 year ago

Description

Added handling of complex observations to atari_ppo.py. Closes #353

I also wrote a jax version for the #338 branch (I can put it in another PR when #338 is ready?) There are only 2 changes that use jax's tree_map. https://gist.github.com/ttumiel/ee746d6292cecb47d390fb97c3ccfa5e

Tests

I wrote some tests for different observation types. Wasn't sure if these belonged in the test folder, since they kind of just demonstrate the functionality.

I also wrote a dummy complex observation wrapper to demonstrate handling a dict spact in atari: https://gist.github.com/ttumiel/c2132b424c49b76a62bafe7efef9923d

Speed

The tree.map_structure function is about 10us of overhead.

import gym, tree, numpy as np

o = gym.spaces.Box(0, 255, (64, 64))
x = [o.sample() for _ in range(10)]

%%timeit
o=np.stack(x)
# 15.8 µs ± 491 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

%%timeit
o=tree.map_structure(lambda *x: np.stack(x), *x)
# 26.5 µs ± 670 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

I (surprisingly?) got a slight speed increase when running BreakoutNoFrameskip. Locally I got about 470 SPS with complex tree_map vs 460 SPS on the original.

Questions

Types of changes

Checklist:

If you are adding new algorithm variants or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.

vercel[bot] commented 1 year ago

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Comments Updated
cleanrl ✅ Ready (Inspect) Visit Preview 💬 Add your feedback Feb 15, 2023 at 11:25PM (UTC)