google / evojax

Apache License 2.0
847 stars 87 forks source link

Reproducing benchmark scores #45

Closed EdoardoPona closed 2 years ago

EdoardoPona commented 2 years ago

Hello everyone.

I am currently currently trying to reproduce scores from the benchmarks, specifically for ARS, as I am implementing my own version native in jax, and wanted to compare with the wrapper already implemented.

For example, I cannot achieve the score posted in the benchmark table (902.107) for ARS on cartpole_easy.

running python train.py -config configs/ARS/cartpole_easy.yaml yields the following training logs

cartpole_easy: 2022-09-25 22:45:55,777 [INFO] EvoJAX cartpole_easy
cartpole_easy: 2022-09-25 22:45:55,777 [INFO] ==============================
absl: 2022-09-25 22:45:55,791 [INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
absl: 2022-09-25 22:45:57,247 [INFO] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA Host
absl: 2022-09-25 22:45:57,247 [INFO] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
MLPPolicy: 2022-09-25 22:45:59,165 [INFO] MLPPolicy.num_params = 4609
cartpole_easy: 2022-09-25 22:45:59,429 [INFO] use_for_loop=False
cartpole_easy: 2022-09-25 22:45:59,496 [INFO] Start to train for 1000 iterations.
cartpole_easy: 2022-09-25 22:46:10,527 [INFO] Iter=50, size=100, max=399.5886, avg=207.9111, min=0.5843, std=99.0207
cartpole_easy: 2022-09-25 22:46:19,916 [INFO] Iter=100, size=100, max=543.8907, avg=364.9780, min=28.8478, std=141.8982
cartpole_easy: 2022-09-25 22:46:21,143 [INFO] [TEST] Iter=100, #tests=100, max=553.4018 avg=510.5583, min=462.4243, std=15.6930
cartpole_easy: 2022-09-25 22:46:30,627 [INFO] Iter=150, size=100, max=558.2020, avg=314.9279, min=89.8001, std=153.6488
cartpole_easy: 2022-09-25 22:46:40,068 [INFO] Iter=200, size=100, max=562.4118, avg=354.9529, min=47.0048, std=154.1567
cartpole_easy: 2022-09-25 22:46:40,114 [INFO] [TEST] Iter=200, #tests=100, max=570.1135 avg=547.5375, min=508.5795, std=10.0840
cartpole_easy: 2022-09-25 22:46:49,579 [INFO] Iter=250, size=100, max=562.1505, avg=325.3990, min=73.3733, std=161.9460
cartpole_easy: 2022-09-25 22:46:59,073 [INFO] Iter=300, size=100, max=569.5461, avg=370.2641, min=83.7473, std=166.8020
cartpole_easy: 2022-09-25 22:46:59,129 [INFO] [TEST] Iter=300, #tests=100, max=573.5941 avg=545.0388, min=505.8637, std=11.3853
cartpole_easy: 2022-09-25 22:47:08,623 [INFO] Iter=350, size=100, max=579.3894, avg=425.6462, min=82.4907, std=126.6614
cartpole_easy: 2022-09-25 22:47:18,109 [INFO] Iter=400, size=100, max=627.6509, avg=530.2781, min=156.4797, std=76.0956
cartpole_easy: 2022-09-25 22:47:18,160 [INFO] [TEST] Iter=400, #tests=100, max=639.7323 avg=600.9105, min=573.7767, std=10.7564
cartpole_easy: 2022-09-25 22:47:27,653 [INFO] Iter=450, size=100, max=668.2064, avg=546.0261, min=418.5385, std=60.5854
cartpole_easy: 2022-09-25 22:47:37,149 [INFO] Iter=500, size=100, max=684.4142, avg=574.4891, min=446.3126, std=62.5338
cartpole_easy: 2022-09-25 22:47:37,202 [INFO] [TEST] Iter=500, #tests=100, max=693.1522 avg=682.7945, min=638.0387, std=12.1575
cartpole_easy: 2022-09-25 22:47:46,708 [INFO] Iter=550, size=100, max=708.9561, avg=591.0547, min=295.5651, std=73.6026
cartpole_easy: 2022-09-25 22:47:56,212 [INFO] Iter=600, size=100, max=706.8138, avg=599.4783, min=348.7581, std=55.6310
cartpole_easy: 2022-09-25 22:47:56,263 [INFO] [TEST] Iter=600, #tests=100, max=691.0123 avg=680.4677, min=630.2983, std=6.1448
cartpole_easy: 2022-09-25 22:48:05,770 [INFO] Iter=650, size=100, max=707.0887, avg=581.3851, min=418.2251, std=75.9066
cartpole_easy: 2022-09-25 22:48:15,275 [INFO] Iter=700, size=100, max=712.7586, avg=586.4597, min=362.7628, std=71.5669
cartpole_easy: 2022-09-25 22:48:15,326 [INFO] [TEST] Iter=700, #tests=100, max=725.2336 avg=714.1309, min=635.7863, std=9.3471
cartpole_easy: 2022-09-25 22:48:24,849 [INFO] Iter=750, size=100, max=716.1056, avg=602.7747, min=458.0401, std=63.1697
cartpole_easy: 2022-09-25 22:48:34,365 [INFO] Iter=800, size=100, max=709.3475, avg=587.9896, min=393.0367, std=69.2385
cartpole_easy: 2022-09-25 22:48:34,418 [INFO] [TEST] Iter=800, #tests=100, max=732.5553 avg=720.5952, min=648.5032, std=8.3936
cartpole_easy: 2022-09-25 22:48:43,945 [INFO] Iter=850, size=100, max=706.8488, avg=598.3582, min=321.8640, std=75.2542
cartpole_easy: 2022-09-25 22:48:53,482 [INFO] Iter=900, size=100, max=720.0320, avg=596.1929, min=370.6555, std=77.2801
cartpole_easy: 2022-09-25 22:48:53,536 [INFO] [TEST] Iter=900, #tests=100, max=703.5345 avg=692.9500, min=677.6909, std=5.9381
cartpole_easy: 2022-09-25 22:49:03,068 [INFO] Iter=950, size=100, max=716.2341, avg=598.3802, min=422.7760, std=71.7756
cartpole_easy: 2022-09-25 22:49:12,455 [INFO] [TEST] Iter=1000, #tests=100, max=726.0114, avg=719.0803, min=698.4325, std=4.7247
cartpole_easy: 2022-09-25 22:49:12,457 [INFO] Training done, best_score=720.5952
cartpole_easy: 2022-09-25 22:49:12,458 [INFO] Loaded model parameters from ./log/ARS/cartpole_easy/default.
cartpole_easy: 2022-09-25 22:49:12,459 [INFO] Start to test the parameters.
cartpole_easy: 2022-09-25 22:49:12,509 [INFO] [TEST] #tests=100, max=728.9848, avg=720.6152, min=698.9832, std=5.0566

I am not entirely sure if the result on the benchmark table is intended to be 720.5952 from cartpole_easy: 2022-09-25 22:49:12,457 [INFO] Training done, best_score=720.5952

or the max score from the final test. Regardless, neither of these match the one posted on the benchmark table.

Am I doing something wrong to reproduce these scores? This makes me unable to compare my own implementation of the algorithm.

Thank you

alantian commented 2 years ago

Hey @EdoardoPona thanks for your interest in EvoJAX!

Many algorithms are sensitive to hyper-parameters, so hyper-parameter searches have been conducted and recorded in scripts/benchmarks/Readme.md. For example scripts/benchmarks/figures/ARS/cartpole_easy.png shows that ARS on Cartpole (Easy) task would have a good performance with init_stdev = 0.1 and lrate_init=0.1.

Thus it is possible to reproduce such result by setting init_stdev: 0.1 and lrate_init: 0.1 in scripts/benchmarks/configs/ARS/cartpole_easy.yaml and rerunning python train.py -config configs/ARS/cartpole_easy.yaml. In doing so, I got the following result that is consistent with the table:

cartpole_easy: 2022-09-26 02:21:50,723 [INFO] [TEST] #tests=100, max=933.8828, avg=913.0647, min=168.9905, std=76.7758

Please let me know how it works for you and if there is any further questions!

EdoardoPona commented 2 years ago

Thank you for the clarification! All works correctly now.

My misunderstanding was due to the fact I thought the hyper-params contained in the .yaml files linked in the table such as scripts/benchmarks/configs/ARS/cartpole_easy.yaml were the final optimised ones, instead it reports init_stdev: 0.03 and lrate_init: 0.01. I did not read the heatmaps correctly.