rlworkgroup / garage

A toolkit for reproducible reinforcement learning research.
MIT License
1.86k stars 310 forks source link

Reshape input in tf/cnn #2168

Closed yeukfu closed 3 years ago

yeukfu commented 3 years ago

This should close https://github.com/rlworkgroup/garage/issues/2154.

In tf/cnn.py, unflatten the input if it is flattened.

codecov[bot] commented 3 years ago

Codecov Report

Merging #2168 (8036875) into master (75e4c9a) will decrease coverage by 0.03%. The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #2168      +/-   ##
==========================================
- Coverage   91.24%   91.21%   -0.04%     
==========================================
  Files         201      201              
  Lines       10971    10977       +6     
  Branches     1371     1371              
==========================================
+ Hits        10011    10013       +2     
- Misses        700      702       +2     
- Partials      260      262       +2     
Impacted Files Coverage Δ
src/garage/tf/models/cnn_mlp_merge_model.py 100.00% <ø> (ø)
...garage/tf/q_functions/continuous_cnn_q_function.py 100.00% <ø> (ø)
src/garage/torch/modules/discrete_cnn_module.py 100.00% <ø> (ø)
...arage/torch/modules/discrete_dueling_cnn_module.py 100.00% <ø> (ø)
src/garage/tf/baselines/gaussian_cnn_baseline.py 97.24% <100.00%> (ø)
...garage/tf/baselines/gaussian_cnn_baseline_model.py 100.00% <100.00%> (ø)
src/garage/tf/models/categorical_cnn_model.py 100.00% <100.00%> (ø)
src/garage/tf/models/cnn.py 100.00% <100.00%> (ø)
src/garage/tf/models/cnn_model.py 100.00% <100.00%> (ø)
src/garage/tf/models/cnn_model_max_pooling.py 100.00% <100.00%> (ø)
... and 6 more

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 75e4c9a...8036875. Read the comment docs.

yeukfu commented 3 years ago

@krzentner I have a concern that if we add env_spec into the constructor of tf/cnn_model, it will be not consistent to the torch's module.

yeukfu commented 3 years ago

I think that's fine. There's already other minor differences between TF and Torch in the model code.

I pass the input_dim to the constructor of cnn models, which is more consistent the current apis I think. Because some constructors have output_dim as an argument.