I wrote some tests for different observation types. Wasn't sure if these belonged in the test folder, since they kind of just demonstrate the functionality.
The tree.map_structure function is about 10us of overhead.
import gym, tree, numpy as np
o = gym.spaces.Box(0, 255, (64, 64))
x = [o.sample() for _ in range(10)]
%%timeit
o=np.stack(x)
# 15.8 µs ± 491 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
%%timeit
o=tree.map_structure(lambda *x: np.stack(x), *x)
# 26.5 µs ± 670 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
I (surprisingly?) got a slight speed increase when running BreakoutNoFrameskip. Locally I got about 470 SPS with complex tree_map vs 460 SPS on the original.
Questions
Maybe I should put the complex obs in the ppo.py file directly, instead of a new file?
[x] I have ensured pre-commit run --all-files passes (required).
[ ] I have updated the documentation and previewed the changes via mkdocs serve.
[ ] I have updated the tests accordingly (if applicable).
If you are adding new algorithm variants or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.
Description
Added handling of complex observations to
atari_ppo.py
. Closes #353I also wrote a jax version for the #338 branch (I can put it in another PR when #338 is ready?) There are only 2 changes that use jax's
tree_map
. https://gist.github.com/ttumiel/ee746d6292cecb47d390fb97c3ccfa5eTests
I wrote some tests for different observation types. Wasn't sure if these belonged in the test folder, since they kind of just demonstrate the functionality.
I also wrote a dummy complex observation wrapper to demonstrate handling a dict spact in atari: https://gist.github.com/ttumiel/c2132b424c49b76a62bafe7efef9923d
Speed
The
tree.map_structure
function is about 10us of overhead.I (surprisingly?) got a slight speed increase when running
BreakoutNoFrameskip
. Locally I got about 470 SPS with complex tree_map vs 460 SPS on the original.Questions
ppo.py
file directly, instead of a new file?Types of changes
Checklist:
pre-commit run --all-files
passes (required).mkdocs serve
.If you are adding new algorithm variants or your change could result in performance difference, you may need to (re-)run tracked experiments. See https://github.com/vwxyzjn/cleanrl/pull/137 as an example PR.
--capture-video
flag toggled on (required).mkdocs serve
.