uzh-rpg / rpg_flightning

Learning Quadrotor Control From Visual Features Using Differentiable Simulation
GNU General Public License v3.0
35 stars 1 forks source link

error for example "Train the Policy Using BPTT" #6

Open zhw970623 opened 3 weeks ago

zhw970623 commented 3 weeks ago

An error occurs when executing the following code in the example train_bptt_state.ipynb

time_start = time.time()
res_dict = bptt.train(
    env,
    train_state,
    num_epochs=100,
    num_steps_per_epoch=env.max_steps_in_episode,
    num_envs=100,
    key=key_bptt,
)
time_train = time.time() - time_start
print(f"Training time: {time_train}")

JaxStackTraceBeforeTransformation Traceback (most recent call last) File :198, in _run_module_as_main()

File :88, in _run_code()

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel_launcher.py:18 16 from ipykernel import kernelapp as app ---> 18 app.launch_new_instance()

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/traitlets/config/application.py:1075, in launch_instance() 1074 app.initialize(argv) -> 1075 app.start()

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelapp.py:739, in start() 738 try: --> 739 self.io_loop.start() 740 except KeyboardInterrupt:

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/tornado/platform/asyncio.py:205, in start() 204 def start(self) -> None: --> 205 self.asyncio_loop.run_forever()

File ~/anaconda3/envs/flightning/lib/python3.11/asyncio/base_events.py:608, in run_forever() 607 while True: --> 608 self._run_once() 609 if self._stopping:

File ~/anaconda3/envs/flightning/lib/python3.11/asyncio/base_events.py:1936, in _run_once() 1935 else: -> 1936 handle._run() 1937 handle = None

File ~/anaconda3/envs/flightning/lib/python3.11/asyncio/events.py:84, in _run() 83 try: ---> 84 self._context.run(self._callback, *self._args) 85 except (SystemExit, KeyboardInterrupt):

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelbase.py:545, in dispatch_queue() 544 try: --> 545 await self.process_one() 546 except Exception:

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelbase.py:534, in process_one() 533 return --> 534 await dispatch(*args)

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelbase.py:437, in dispatch_shell() 436 if inspect.isawaitable(result): --> 437 await result 438 except Exception:

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/ipkernel.py:362, in execute_request() 361 self._associate_new_top_level_threads_with(parent_header) --> 362 await super().execute_request(stream, ident, parent)

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelbase.py:778, in execute_request() 777 if inspect.isawaitable(reply_content): --> 778 reply_content = await reply_content 780 # Flush output before sending the reply.

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/ipkernel.py:449, in do_execute() 448 if accepts_params["cell_id"]: --> 449 res = shell.run_cell( 450 code, 451 store_history=store_history, 452 silent=silent, 453 cell_id=cell_id, 454 ) 455 else:

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/zmqshell.py:549, in run_cell() 548 self._last_traceback = None --> 549 return super().run_cell(*args, **kwargs)

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3075, in run_cell() 3074 try: -> 3075 result = self._run_cell( 3076 raw_cell, store_history, silent, shell_futures, cell_id 3077 ) 3078 finally:

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3130, in _run_cell() 3129 try: -> 3130 result = runner(coro) 3131 except BaseException as e:

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/async_helpers.py:128, in _pseudo_sync_runner() 127 try: --> 128 coro.send(None) 129 except StopIteration as exc:

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3334, in run_cell_async() 3331 interactivity = "none" if silent else self.ast_node_interactivity -> 3334 has_raised = await self.run_ast_nodes(code_ast.body, cell_name, 3335 interactivity=interactivity, compiler=compiler, result=result) 3337 self.last_execution_succeeded = not has_raised

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3517, in run_ast_nodes() 3516 asy = compare(code) -> 3517 if await self.runcode(code, result, async=asy): 3518 return True

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3577, in run_code() 3576 else: -> 3577 exec(code_obj, self.user_global_ns, self.user_ns) 3578 finally: 3579 # Reset our crash handler in place

Cell In[9], line 11 9 return train_state.apply_fn(train_state.params, obs) ---> 11 transitions = get_rollouts(env, policy, 10, jax.random.key(3))

Cell In[9], line 4, in get_rollouts() 3 rollout_keys = jax.random.split(key, num_rollouts) ----> 4 transitions = parallel_rollout(env, rollout_keys, policy) 5 return transitions

File ~/reach/rpg_flightning/flightning/envs/env_base.py:135, in rollout() 134 keys_steps = jax.random.split(key, numsteps) --> 135 , transitions = jax.lax.scan(step_fn, (state, obs), keys_steps) 136 # concatenate all transitions

File ~/reach/rpg_flightning/flightning/envs/env_base.py:131, in step_fn() 130 else: --> 131 trans = env._step(env_state, action, key_step) 132 return (trans.state, trans.obs), trans

File ~/reach/rpg_flightning/flightning/envs/wrappers.py:212, in _step() 210 @partial(jax.jit, static_argnums=(0,)) 211 def _step(self, state, action, key) -> EnvTransition: --> 212 transition = self._env._step(state, action, key) 213 obs = normalize(transition.obs, self._obs_min, self._obs_max)

File ~/reach/rpg_flightning/flightning/envs/hovering_state_env.py:163, in _step() 162 f_1, omega_1 = action_1[0], action_1[1:] --> 163 quadrotor_state = self.quadrotor.step( 164 state.quadrotor_state, f_1, omega_1, dt_1 165 ) 167 if self.delay > 0: 168 # 2 step

File ~/reach/rpg_flightning/flightning/objects/quadrotor_obj.py:309, in step() 307 return state_new, state_dot_new --> 309 return _step(state, f_d, omega_d, dt)

JaxStackTraceBeforeTransformation: TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got: primal key[] with tangent key[], expecting tangent ShapedArray(float0[])

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.


The above exception was the direct cause of the following exception:

TypeError Traceback (most recent call last) Cell In[11], line 2 1 time_start = time.time() ----> 2 res_dict = bptt.train( 3 env, 4 train_state, 5 num_epochs=100, 6 num_steps_per_epoch=env.max_steps_in_episode, 7 num_envs=100, 8 key=key_bptt, 9 ) 10 time_train = time.time() - time_start 11 print(f"Training time: {time_train}")

File ~/reach/rpg_flightning/flightning/algos/bptt.py:155, in train(env, train_state, num_epochs, num_steps_per_epoch, num_envs, key) 152 env_state, obs = env.reset(key_reset, None) 153 runner_state = RunnerState(train_state, env_state, obs, key, epoch_idx=0) --> 155 return jax.jit(_train)(runner_state)

[... skipping hidden 11 frame]

File ~/reach/rpg_flightning/flightning/algos/bptt.py:143, in train.._train(runner_state) 140 return epoch_state, loss 142 # run epochs --> 143 runner_state_final, losses = jax.lax.scan( 144 epoch_fn, runner_state, None, num_epochs 145 ) 147 return {"runner_state": runner_state_final, "metrics": losses}

[... skipping hidden 9 frame]

File ~/reach/rpg_flightning/flightning/algos/bptt.py:122, in train.._train..epoch_fn(epoch_state, _unused) 120 # compute reward 121 train_state = epoch_state.train_state --> 122 (loss, epoch_state), grad = loss_fn( 123 train_state.params, epoch_state 124 ) 125 # update params 126 train_state = train_state.apply_gradients(grads=grad)

[... skipping hidden 8 frame]

File ~/reach/rpg_flightning/flightning/algos/bptt.py:116, in train.._train..epoch_fn..loss_fn(params, runner_state) 113 return runner_state, trajectory 115 # collect data --> 116 runner_state, trajectory = rollout(runner_state) 117 loss = -trajectory.reward.sum() / num_envs 118 return loss, runner_state

File ~/reach/rpg_flightning/flightning/algos/bptt.py:110, in train.._train..epoch_fn..loss_fn..rollout(runner_state) 101 runner_state = RunnerState( 102 train_state, env_state, obs, key, epoch_idx 103 ) 105 return ( 106 runner_state, 107 TrajectoryState(reward=reward), 108 ) --> 110 runner_state, trajectory = jax.lax.scan( 111 step_fn, runner_state, None, num_steps_per_epoch 112 ) 113 return runner_state, trajectory

[... skipping hidden 31 frame]

[... skipping similar frames: _jvp_jaxpr at line 685 (2 times), WrappedFun.call_wrapped at line 193 (2 times), eval_jaxpr at line 508 (2 times), jaxpr_as_fun at line 260 (2 times), jvp_jaxpr at line 675 (2 times), trace_to_jaxpr_dynamic at line 2278 (2 times), trace_to_subjaxpr_dynamic at line 2301 (2 times), annotate_function.<locals>.wrapper at line 333 (2 times), _pjit_jvp at line 2045 (1 times), AxisPrimitive.bind at line 2803 (1 times), Primitive.bind_with_trace at line 442 (1 times), JVPTrace.process_primitive at line 302 (1 times)]

[... skipping hidden 25 frame]

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/jax/_src/custom_derivatives.py:351, in _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args) 344 msg = ("Custom JVP rule must produce primal and tangent outputs with " 345 "corresponding shapes and dtypes, but got:\n{}") 346 disagreements = ( 347 f" primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}" 348 for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out) 349 if av_et != av_t) --> 351 raise TypeError(msg.format('\n'.join(disagreements))) 352 yield primals_out + tangents_out, (out_tree, primal_avals)

TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got: primal key[] with tangent key[], expecting tangent ShapedArray(float0[])

zhw970623 commented 2 weeks ago

JaxStackTraceBeforeTransformation Traceback (most recent call last) File :198, in _run_module_as_main()

File :88, in _run_code()

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel_launcher.py:18 16 from ipykernel import kernelapp as app ---> 18 app.launch_new_instance()

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/traitlets/config/application.py:1075, in launch_instance() 1074 app.initialize(argv) -> 1075 app.start()

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelapp.py:739, in start() 738 try: --> 739 self.io_loop.start() 740 except KeyboardInterrupt:

File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/tornado/platform/asyncio.py:205, in start() 204 def start(self) -> None: --> 205 self.asyncio_loop.run_forever()

File ~/anaconda3/envs/flightning/lib/python3.11/asyncio/base_events.py:608, in run_forever() 607 while True: ... --> 351 raise TypeError(msg.format('\n'.join(disagreements))) 352 yield primals_out + tangents_out, (out_tree, primal_avals)

TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got: primal uint32[2] with tangent uint32[2], expecting tangent ShapedArray(float0[2])

patricksharlow commented 1 week ago

Getting the same error

joheeg commented 1 week ago

This specific error is do to a change in JAX. It is related to https://github.com/jax-ml/jax/discussions/24262 . If you are using jax 0.4.34 or later, try using 0.4.33.