google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.93k stars 628 forks source link

Dropout seems not compatible with jax.jit #4085

Open richardmkit opened 1 month ago

richardmkit commented 1 month ago
class DNN(nn.Module):
    num_hidden_units1:int
    num_hidden_units2:int
    num_outputs:int
    dropout_rate:float 

    @nn.compact
    def __call__(self,x,training):
        x=nn.Dense(self.num_hidden_units1)(x)
        x=nn.relu(x)
        x=nn.Dropout(rate=self.dropout_rate,deterministic=not training)(x)

        x=nn.Dense(self.num_hidden_units2)(x)
        x=nn.relu(x)
        x=nn.Dropout(rate=self.dropout_rate,deterministic=not training)(x)

        x=nn.Dense(self.num_outputs)(x)
        return x

def mse(params,X,y,training):
    def squared_error(X,y):
        y_pred=model.apply(params,X,training,rngs={'dropout':jax.random.PRNGKey(114)})
        diff=y_pred-y
        return jnp.inner(diff,diff)
    return jnp.mean(jax.vmap(squared_error)(X,y),axis=0)

@jax.jit
def train_setp(params,opt_state,X,y,training):

    loss,grads=jax.value_and_grad(mse)(params,X,y,training)
    updates,opt_state=optimizer.update(grads,opt_state)
    params=optax.apply_updates(params,updates)

    return params,opt_state,loss

@jax.jit
def fit(params,opt_state,X,y,training):
    for i in tqdm(range(1000)):
        params,opt_state,loss=train(params,opt_state,X,y,training)

        if i%100==0:
            print(loss)

    return params

# Initialization
params=model.init(jax.random.PRNGKey(114),X,False)

learning_rate=0.1
optimizer=adam(learning_rate)
opt_state=optimizer.init(params)

train(params,opt_state,X,y,True)
The model could be successfully trained, when I add two dropout layers and don't use jax.jit. However, as long as I try to accelerate the training by jax.jit, it prompts the error
---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[300], line 1
----> 1 train(params,opt_state,X,y,True)

    [... skipping hidden 11 frame]

Cell In[264], line 4, in train(params, opt_state, X, y, training)
      1 @jax.jit
      2 def train(params,opt_state,X,y,training):
----> 4     loss,grads=jax.value_and_grad(mse)(params,X,y,training)
      5     updates,opt_state=optimizer.update(grads,opt_state)
      6     params=optax.apply_updates(params,updates)

    [... skipping hidden 8 frame]

Cell In[296], line 6, in mse(params, X, y, training)
      4     diff=y_pred-y
      5     return jnp.inner(diff,diff)
----> 6 return jnp.mean(jax.vmap(squared_error)(X,y),axis=0)

    [... skipping hidden 3 frame]

Cell In[296], line 3, in mse.<locals>.squared_error(X, y)
      2 def squared_error(X,y):
----> 3     y_pred=model.apply(params,X,training,rngs={'dropout':jax.random.PRNGKey(114)})
      4     diff=y_pred-y
      5     return jnp.inner(diff,diff)

    [... skipping hidden 6 frame]

Cell In[293], line 11, in DNN.__call__(self, x, training)
      9 x=nn.Dense(self.num_hidden_units1)(x)
     10 x=nn.relu(x)
---> 11 x=nn.Dropout(rate=self.dropout_rate,deterministic=not training)(x)
     13 x=nn.Dense(self.num_hidden_units2)(x)
     14 x=nn.relu(x)

    [... skipping hidden 1 frame]

File /opt/conda/lib/python3.10/site-packages/jax/_src/core.py:1492, in concretization_function_error.<locals>.error(self, arg)
   1491 def error(self, arg):
-> 1492   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function train at /tmp/ipykernel_34/1452843710.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument training.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Seems something wrong with the flag training. How could I solve this? Thx.

jonasguan commented 1 month ago

You have to mark training as a static argument when you jit your functions, so the compiler knows that you're ok with recompiling the function if its value changes. See: https://jax.readthedocs.io/en/latest/jit-compilation.html#marking-arguments-as-static

In short, change your @jax.jit decorators to @partial(jax.jit, static_argnames=['training']) should do the trick. I know it's a bit confusing because the flax dropout guide neglects to mention this.