jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.31k stars 2.78k forks source link

Program Hangs When Using `donate_argnums` #10737

Open kevinzakka opened 2 years ago

kevinzakka commented 2 years ago

As the title suggests, my program will hang when I try to donate self in a method of a @struct.dataclass.

jakevdp commented 2 years ago

Could you provide a self-contained example that reproduces the bug?

hawkinsp commented 2 years ago

There's a known issue where you will get a hang if you donate the same argument twice. Confirm you aren't doing that?

(I'm going to look into fixing that sometime soon.)

kevinzakka commented 2 years ago

Hi @jakevdp and @hawkinsp, apologies for not providing code, didn't want to paste dump something big and didn't have a self-contained example. I don't see any double donation although I don't trust myself since I'm still fairly new to jax.

The code looks something like this:

Params = flax.core.frozen_dict.FrozenDict
Config = # a dataclass with config fields.

@struct.dataclass
class TrainState:
    policy_params: Params
    config: Config = struct.field(pytree_node=False)

    @functools.partial(jax.jit, donate_argnums=0)
    def learner_step(
        self,
        transitions: Transition,
        rng_key: jax.random.KeyArray,
    ) -> Tuple["TrainState", Dict[str, jnp.ndarray]]:

        new_policy_params, policy_loss = # do something with transitions

        new_state = TrainState(
            config=self.config,
            policy_params=new_policy_params,
        )

        metrics = {
            "policy_loss": policy_loss,
        }

        return new_state, metrics

Appreciate the fast response btw :)

jakevdp commented 2 years ago

Do you call learner_step more than once? If so you're probably hitting the bug that @hawkinsp mentioned.

hawkinsp commented 2 years ago

BTW by "twice" I mean "if you pass the same argument to the same call to a jitted function more than once".

It's hard to say without a reproduction we can run, though!

kevinzakka commented 2 years ago

It's only called one to my knowledge. I've made the repo public to make debugging a little easier, here's the relevant line of code. The codebase is pretty lightweight, should be easy to parse!

imoneoi commented 2 years ago

I also had the problem (https://github.com/google/jax/issues/12627). Check if there are same leaves in pytree?