Open kevinzakka opened 2 years ago
Could you provide a self-contained example that reproduces the bug?
There's a known issue where you will get a hang if you donate the same argument twice. Confirm you aren't doing that?
(I'm going to look into fixing that sometime soon.)
Hi @jakevdp and @hawkinsp, apologies for not providing code, didn't want to paste dump something big and didn't have a self-contained example. I don't see any double donation although I don't trust myself since I'm still fairly new to jax.
The code looks something like this:
Params = flax.core.frozen_dict.FrozenDict
Config = # a dataclass with config fields.
@struct.dataclass
class TrainState:
policy_params: Params
config: Config = struct.field(pytree_node=False)
@functools.partial(jax.jit, donate_argnums=0)
def learner_step(
self,
transitions: Transition,
rng_key: jax.random.KeyArray,
) -> Tuple["TrainState", Dict[str, jnp.ndarray]]:
new_policy_params, policy_loss = # do something with transitions
new_state = TrainState(
config=self.config,
policy_params=new_policy_params,
)
metrics = {
"policy_loss": policy_loss,
}
return new_state, metrics
Appreciate the fast response btw :)
Do you call learner_step
more than once? If so you're probably hitting the bug that @hawkinsp mentioned.
BTW by "twice" I mean "if you pass the same argument to the same call to a jitted function more than once".
It's hard to say without a reproduction we can run, though!
It's only called one to my knowledge. I've made the repo public to make debugging a little easier, here's the relevant line of code. The codebase is pretty lightweight, should be easy to parse!
I also had the problem (https://github.com/google/jax/issues/12627). Check if there are same leaves in pytree?
As the title suggests, my program will hang when I try to donate
self
in a method of a@struct.dataclass
.