rlworkgroup / garage

A toolkit for reproducible reinforcement learning research.
MIT License
1.84k stars 309 forks source link

garage.torch.CNNModule backport to 2021.03 #2259

Closed krzentner closed 3 years ago

krzentner commented 3 years ago

It now takes an env spec and automatically computes the output size. It also handles NHWC and NCHW instead of requiring an environment wrapper.

This change also fixes several issues with the existing usage of CNNs in pytorch.

This PR is a backport to 2021.03 (which I believe has not been pushed to PYPI yet).

krzentner commented 3 years ago

Looks good. I wish our version of Python had enum support though, so we could have enums for NCHW rather than strings.

Ah, we actually had a discussion about this about two years ago. We used enums for some things for a while, but eventually found they weren't flexible enough to be worth it. In particular, we often encountered scenarios where we wanted "sub-enums" (enums with a subset of the options of the original enum, or, alternatively, super-enums with new entries added), but python enums don't have that capability. They also turned out to be less ergonomic in python than using strings, so we decided to just use strings to keep the API consistent.

codecov[bot] commented 3 years ago

Codecov Report

Merging #2259 (87fecdc) into release-2021.03 (4e32ab2) will increase coverage by 0.06%. The diff coverage is 97.74%.

Impacted file tree graph

@@                 Coverage Diff                 @@
##           release-2021.03    #2259      +/-   ##
===================================================
+ Coverage            91.19%   91.26%   +0.06%     
===================================================
  Files                  201      199       -2     
  Lines                10977    10932      -45     
  Branches              1371     1376       +5     
===================================================
- Hits                 10011     9977      -34     
+ Misses                 703      696       -7     
+ Partials               263      259       -4     
Impacted Files Coverage Δ
...rch/q_functions/discrete_dueling_cnn_q_function.py 91.30% <88.88%> (-1.56%) :arrow_down:
src/garage/torch/modules/cnn_module.py 97.53% <97.33%> (+17.84%) :arrow_up:
src/garage/torch/_functions.py 95.57% <100.00%> (+0.89%) :arrow_up:
src/garage/torch/algos/bc.py 89.61% <100.00%> (ø)
src/garage/torch/algos/ddpg.py 96.63% <100.00%> (ø)
src/garage/torch/algos/dqn.py 93.02% <100.00%> (ø)
src/garage/torch/algos/sac.py 97.05% <100.00%> (ø)
src/garage/torch/algos/td3.py 94.44% <100.00%> (ø)
src/garage/torch/modules/discrete_cnn_module.py 100.00% <100.00%> (ø)
...rc/garage/torch/policies/categorical_cnn_policy.py 100.00% <100.00%> (ø)
... and 5 more

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 4e32ab2...87fecdc. Read the comment docs.