google / evojax

Apache License 2.0
826 stars 78 forks source link

Evaluating brax environments other than brax-ant. Terminates with error. #38

Closed Surya-77 closed 2 years ago

Surya-77 commented 2 years ago

Information

Issue is with running brax environments other brax-ant. The included humanoid, half cheetah and fetch environments are affected.

Couldn't find any references to this issue in the repo. I could have missed something.

Expected Behavior

/home/<USER>/anaconda3/envs/evojax/bin/python /home/<USER>/evojax/scripts/benchmarks/train.py -config configs/PGPE/brax_halfcheetah.yaml
brax: 2022-06-16 20:41:01,954 [INFO] EvoJAX brax
brax: 2022-06-16 20:41:01,954 [INFO] ==============================
absl: 2022-06-16 20:41:02,137 [INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
absl: 2022-06-16 20:41:02,221 [INFO] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
MLPPolicy: 2022-06-16 20:41:03,747 [INFO] MLPPolicy.num_params = 3974
brax: 2022-06-16 20:41:03,787 [INFO] use_for_loop=False
brax: 2022-06-16 20:41:03,825 [INFO] Start to train for 1 iterations.
brax: 2022-06-16 20:41:56,024 [INFO] [TEST] Iter=1, #tests=1, max=-9.7476, avg=-9.7476, min=-9.7476, std=0.0000
brax: 2022-06-16 20:41:56,087 [INFO] Training done, best_score=-9.7476
brax: 2022-06-16 20:41:56,093 [INFO] Loaded model parameters from ./log/PGPE/brax/default.
brax: 2022-06-16 20:41:56,093 [INFO] Start to test the parameters.
brax: 2022-06-16 20:42:03,478 [INFO] [TEST] #tests=1, max=-9.9009, avg=-9.9009, min=-9.9009, std=0.0000

Current Behavior

brax: 2022-06-16 20:26:04,657 [INFO] EvoJAX brax
brax: 2022-06-16 20:26:04,657 [INFO] ==============================
absl: 2022-06-16 20:26:04,833 [INFO] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
absl: 2022-06-16 20:26:04,920 [INFO] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available.
MLPPolicy: 2022-06-16 20:26:06,465 [INFO] MLPPolicy.num_params = 3974
brax: 2022-06-16 20:26:06,504 [INFO] use_for_loop=False
brax: 2022-06-16 20:26:06,541 [INFO] Start to train for 10 iterations.
Traceback (most recent call last):
  File "/home/<USER>/evojax/scripts/benchmarks/train.py", line 88, in <module>
    main(config)
  File "/home/<USER>/evojax/scripts/benchmarks/train.py", line 64, in main
    trainer.run(demo_mode=False)
  File "/home/<USER>/evojax/evojax/trainer.py", line 152, in run
    scores, bds = self.sim_mgr.eval_params(
  File "/home/<USER>/evojax/evojax/sim_mgr.py", line 258, in eval_params
    return self._scan_loop_eval(params, test)
  File "/home/<USER>/evojax/evojax/sim_mgr.py", line 355, in _scan_loop_eval
    scores, all_obs, masks, final_states = rollout_func(
  File "/home/<USER>/evojax/evojax/sim_mgr.py", line 202, in rollout
    (obs_set, obs_mask)) = jax.lax.scan(
  File "/home/<USER>/anaconda3/envs/evojax/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 165, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/<USER>/anaconda3/envs/evojax/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 1630, in scan
    _check_tree_and_avals("scan carry output and input",
  File "/home/<USER>/anaconda3/envs/evojax/lib/python3.9/site-packages/jax/_src/lax/control_flow.py", line 2316, in _check_tree_and_avals
    raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: scan carry output and input must have identical types, got
(State(state=State(qp=QP(pos='ShapedArray(float32[16384,8,3])', rot='ShapedArray(float32[16384,8,4])', vel='ShapedArray(float32[16384,8,3])', ang='ShapedArray(float32[16384,8,3])'), obs='ShapedArray(float32[16384,18])', reward='ShapedArray(float32[16384])', done='ShapedArray(float32[16384])', metrics={'reward_ctrl_cost': 'ShapedArray(float32[16384])', 'reward_forward': 'ShapedArray(float32[16384])'}, info={'first_obs': 'ShapedArray(float32[16384,18])', 'first_qp': QP(pos='ShapedArray(float32[16384,8,3])', rot='ShapedArray(float32[16384,8,4])', vel='ShapedArray(float32[16384,8,3])', ang='ShapedArray(float32[16384,8,3])'), 'steps': 'ShapedArray(float32[16384])', 'truncation': 'ShapedArray(float32[16384])'}), obs='ShapedArray(float32[16384,18])', feet_contact='DIFFERENT ShapedArray(int32[16384,3]) vs. ShapedArray(int32[16384,4])'), PolicyState(keys='ShapedArray(uint32[16384,2])'), 'ShapedArray(float32[16384,3974])', 'ShapedArray(float32[37])', 'ShapedArray(float32[16384])', 'ShapedArray(float32[16384])').

Exact Error:

feet_contact='DIFFERENT ShapedArray(int32[16384,3]) vs. ShapedArray(int32[16384,4])')

Failure Information

Context

Based on commit history, this appears to be due to the changes introduced in #33 . Manually altering variable feet_contact variable from method reset_fn in file evojax/evojax/task/brax_task.py allows for the other environments to be run.

Setup details related to the hardware are irrelevant since error occurs on the hosted colab notebook as well.

brax                         0.0.13
evojax                       0.2.11               
flax                         0.4.0
jax                          0.3.1
jaxlib                       0.3.0+cuda11.cudnn82

Steps to Reproduce

Please provide detailed steps for reproducing the issue.

  1. Run evojax/scripts/benchmarks/train.py using a modified evojax/scripts/benchmarks/configs/<ES> file using non-ant brax environment.
  2. Modify feet_contact array size and test.
lerrytang commented 2 years ago

Thank you very much for raising the issue, let me confirm the problem and get back to you.

Surya-77 commented 2 years ago

Hi @lerrytang any updates on this issue ?.

lerrytang commented 2 years ago

Sorry for the delay, I've confirmed the error, thanks for the detailed report. The error is due to the fixed number of legs for contact detection. I should have extended the state to include feet_contact ONLY when the task is Ant and the user chooses to use map-elites. Since you discovered the bug and the fix is straightforward, will you be interested in sending a PR to us?

Surya-77 commented 2 years ago

Thanks for the update. I'll get a PR ready and send it once it's done.

lerrytang commented 2 years ago

bug fixed in #39