Open zhw970623 opened 3 weeks ago
JaxStackTraceBeforeTransformation Traceback (most recent call last)
File
File
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel_launcher.py:18 16 from ipykernel import kernelapp as app ---> 18 app.launch_new_instance()
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/traitlets/config/application.py:1075, in launch_instance() 1074 app.initialize(argv) -> 1075 app.start()
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelapp.py:739, in start() 738 try: --> 739 self.io_loop.start() 740 except KeyboardInterrupt:
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/tornado/platform/asyncio.py:205, in start() 204 def start(self) -> None: --> 205 self.asyncio_loop.run_forever()
File ~/anaconda3/envs/flightning/lib/python3.11/asyncio/base_events.py:608, in run_forever() 607 while True: ... --> 351 raise TypeError(msg.format('\n'.join(disagreements))) 352 yield primals_out + tangents_out, (out_tree, primal_avals)
TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got: primal uint32[2] with tangent uint32[2], expecting tangent ShapedArray(float0[2])
Getting the same error
This specific error is do to a change in JAX. It is related to https://github.com/jax-ml/jax/discussions/24262 . If you are using jax 0.4.34 or later, try using 0.4.33.
An error occurs when executing the following code in the example train_bptt_state.ipynb
JaxStackTraceBeforeTransformation Traceback (most recent call last) File:198, in _run_module_as_main()
File:88, in _run_code()
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel_launcher.py:18 16 from ipykernel import kernelapp as app ---> 18 app.launch_new_instance()
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/traitlets/config/application.py:1075, in launch_instance() 1074 app.initialize(argv) -> 1075 app.start()
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelapp.py:739, in start() 738 try: --> 739 self.io_loop.start() 740 except KeyboardInterrupt:
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/tornado/platform/asyncio.py:205, in start() 204 def start(self) -> None: --> 205 self.asyncio_loop.run_forever()
File ~/anaconda3/envs/flightning/lib/python3.11/asyncio/base_events.py:608, in run_forever() 607 while True: --> 608 self._run_once() 609 if self._stopping:
File ~/anaconda3/envs/flightning/lib/python3.11/asyncio/base_events.py:1936, in _run_once() 1935 else: -> 1936 handle._run() 1937 handle = None
File ~/anaconda3/envs/flightning/lib/python3.11/asyncio/events.py:84, in _run() 83 try: ---> 84 self._context.run(self._callback, *self._args) 85 except (SystemExit, KeyboardInterrupt):
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelbase.py:545, in dispatch_queue() 544 try: --> 545 await self.process_one() 546 except Exception:
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelbase.py:534, in process_one() 533 return --> 534 await dispatch(*args)
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelbase.py:437, in dispatch_shell() 436 if inspect.isawaitable(result): --> 437 await result 438 except Exception:
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/ipkernel.py:362, in execute_request() 361 self._associate_new_top_level_threads_with(parent_header) --> 362 await super().execute_request(stream, ident, parent)
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/kernelbase.py:778, in execute_request() 777 if inspect.isawaitable(reply_content): --> 778 reply_content = await reply_content 780 # Flush output before sending the reply.
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/ipkernel.py:449, in do_execute() 448 if accepts_params["cell_id"]: --> 449 res = shell.run_cell( 450 code, 451 store_history=store_history, 452 silent=silent, 453 cell_id=cell_id, 454 ) 455 else:
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/ipykernel/zmqshell.py:549, in run_cell() 548 self._last_traceback = None --> 549 return super().run_cell(*args, **kwargs)
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3075, in run_cell() 3074 try: -> 3075 result = self._run_cell( 3076 raw_cell, store_history, silent, shell_futures, cell_id 3077 ) 3078 finally:
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3130, in _run_cell() 3129 try: -> 3130 result = runner(coro) 3131 except BaseException as e:
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/async_helpers.py:128, in _pseudo_sync_runner() 127 try: --> 128 coro.send(None) 129 except StopIteration as exc:
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3334, in run_cell_async() 3331 interactivity = "none" if silent else self.ast_node_interactivity -> 3334 has_raised = await self.run_ast_nodes(code_ast.body, cell_name, 3335 interactivity=interactivity, compiler=compiler, result=result) 3337 self.last_execution_succeeded = not has_raised
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3517, in run_ast_nodes() 3516 asy = compare(code) -> 3517 if await self.runcode(code, result, async=asy): 3518 return True
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/IPython/core/interactiveshell.py:3577, in run_code() 3576 else: -> 3577 exec(code_obj, self.user_global_ns, self.user_ns) 3578 finally: 3579 # Reset our crash handler in place
Cell In[9], line 11 9 return train_state.apply_fn(train_state.params, obs) ---> 11 transitions = get_rollouts(env, policy, 10, jax.random.key(3))
Cell In[9], line 4, in get_rollouts() 3 rollout_keys = jax.random.split(key, num_rollouts) ----> 4 transitions = parallel_rollout(env, rollout_keys, policy) 5 return transitions
File ~/reach/rpg_flightning/flightning/envs/env_base.py:135, in rollout() 134 keys_steps = jax.random.split(key, numsteps) --> 135 , transitions = jax.lax.scan(step_fn, (state, obs), keys_steps) 136 # concatenate all transitions
File ~/reach/rpg_flightning/flightning/envs/env_base.py:131, in step_fn() 130 else: --> 131 trans = env._step(env_state, action, key_step) 132 return (trans.state, trans.obs), trans
File ~/reach/rpg_flightning/flightning/envs/wrappers.py:212, in _step() 210 @partial(jax.jit, static_argnums=(0,)) 211 def _step(self, state, action, key) -> EnvTransition: --> 212 transition = self._env._step(state, action, key) 213 obs = normalize(transition.obs, self._obs_min, self._obs_max)
File ~/reach/rpg_flightning/flightning/envs/hovering_state_env.py:163, in _step() 162 f_1, omega_1 = action_1[0], action_1[1:] --> 163 quadrotor_state = self.quadrotor.step( 164 state.quadrotor_state, f_1, omega_1, dt_1 165 ) 167 if self.delay > 0: 168 # 2 step
File ~/reach/rpg_flightning/flightning/objects/quadrotor_obj.py:309, in step() 307 return state_new, state_dot_new --> 309 return _step(state, f_d, omega_d, dt)
JaxStackTraceBeforeTransformation: TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got: primal key[] with tangent key[], expecting tangent ShapedArray(float0[])
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
The above exception was the direct cause of the following exception:
TypeError Traceback (most recent call last) Cell In[11], line 2 1 time_start = time.time() ----> 2 res_dict = bptt.train( 3 env, 4 train_state, 5 num_epochs=100, 6 num_steps_per_epoch=env.max_steps_in_episode, 7 num_envs=100, 8 key=key_bptt, 9 ) 10 time_train = time.time() - time_start 11 print(f"Training time: {time_train}")
File ~/reach/rpg_flightning/flightning/algos/bptt.py:155, in train(env, train_state, num_epochs, num_steps_per_epoch, num_envs, key) 152 env_state, obs = env.reset(key_reset, None) 153 runner_state = RunnerState(train_state, env_state, obs, key, epoch_idx=0) --> 155 return jax.jit(_train)(runner_state)
File ~/reach/rpg_flightning/flightning/algos/bptt.py:143, in train.._train(runner_state)
140 return epoch_state, loss
142 # run epochs
--> 143 runner_state_final, losses = jax.lax.scan(
144 epoch_fn, runner_state, None, num_epochs
145 )
147 return {"runner_state": runner_state_final, "metrics": losses}
File ~/reach/rpg_flightning/flightning/algos/bptt.py:122, in train.._train..epoch_fn(epoch_state, _unused)
120 # compute reward
121 train_state = epoch_state.train_state
--> 122 (loss, epoch_state), grad = loss_fn(
123 train_state.params, epoch_state
124 )
125 # update params
126 train_state = train_state.apply_gradients(grads=grad)
File ~/reach/rpg_flightning/flightning/algos/bptt.py:116, in train.._train..epoch_fn..loss_fn(params, runner_state)
113 return runner_state, trajectory
115 # collect data
--> 116 runner_state, trajectory = rollout(runner_state)
117 loss = -trajectory.reward.sum() / num_envs
118 return loss, runner_state
File ~/reach/rpg_flightning/flightning/algos/bptt.py:110, in train.._train..epoch_fn..loss_fn..rollout(runner_state)
101 runner_state = RunnerState(
102 train_state, env_state, obs, key, epoch_idx
103 )
105 return (
106 runner_state,
107 TrajectoryState(reward=reward),
108 )
--> 110 runner_state, trajectory = jax.lax.scan(
111 step_fn, runner_state, None, num_steps_per_epoch
112 )
113 return runner_state, trajectory
File ~/anaconda3/envs/flightning/lib/python3.11/site-packages/jax/_src/custom_derivatives.py:351, in _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args) 344 msg = ("Custom JVP rule must produce primal and tangent outputs with " 345 "corresponding shapes and dtypes, but got:\n{}") 346 disagreements = ( 347 f" primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}" 348 for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out) 349 if av_et != av_t) --> 351 raise TypeError(msg.format('\n'.join(disagreements))) 352 yield primals_out + tangents_out, (out_tree, primal_avals)
TypeError: Custom JVP rule must produce primal and tangent outputs with corresponding shapes and dtypes, but got: primal key[] with tangent key[], expecting tangent ShapedArray(float0[])