google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.8k stars 613 forks source link

Opaque XLA crash when initializing model #4054

Closed fernandopalafox closed 5 days ago

fernandopalafox commented 5 days ago

System information

Problem you have encountered:

XLA crashes with a very opaque error when I try to initialize a model with a dummy input with a large batch dim.

What you expected to happen:

I'm new to JAX (and ML, in general), so I'm not sure I should even be initializing with data that has a batch_dim greater than 1, but a more descriptive error may be useful for other people that make the same mistake.

Steps to reproduce:

Run the following:

import flax
import jax
import jax.numpy as jnp
import flax.linen as nn

class MyNN(nn.Module):

    @nn.compact
    def __call__(self, x):
        return nn.Dense(128)(x)

batch_dim = 32
model = MyNN()
key = jax.random.key(0)
x = jnp.zeros((batch_dim, 16**2))
params = model.init(key, x)

Error

Exception has occurred: XlaRuntimeError
INTERNAL: the library was not initialized
fernandopalafox commented 5 days ago

Ended up reinstalling the entire venv and this works now