OS Platform and Distribution (e.g., Linux Ubuntu 16.04):
Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib:
flax version: 0.8.4
jax version: 0.4.30
jaxlib: 0.4.30
Python version: 3.10.12
GPU/TPU model and memory: GeForce RTX 3080 Ti Mobile
CUDA version (if applicable): cuda_12.5.r12.5/compiler.34177558_0
Problem you have encountered:
XLA crashes with a very opaque error when I try to initialize a model with a dummy input with a large batch dim.
What you expected to happen:
I'm new to JAX (and ML, in general), so I'm not sure I should even be initializing with data that has a batch_dim greater than 1, but a more descriptive error may be useful for other people that make the same mistake.
Steps to reproduce:
Run the following:
import flax
import jax
import jax.numpy as jnp
import flax.linen as nn
class MyNN(nn.Module):
@nn.compact
def __call__(self, x):
return nn.Dense(128)(x)
batch_dim = 32
model = MyNN()
key = jax.random.key(0)
x = jnp.zeros((batch_dim, 16**2))
params = model.init(key, x)
Error
Exception has occurred: XlaRuntimeError
INTERNAL: the library was not initialized
System information
pip show flax jax jaxlib
: flax version: 0.8.4 jax version: 0.4.30 jaxlib: 0.4.30Problem you have encountered:
XLA crashes with a very opaque error when I try to initialize a model with a dummy input with a large batch dim.
What you expected to happen:
I'm new to JAX (and ML, in general), so I'm not sure I should even be initializing with data that has a batch_dim greater than 1, but a more descriptive error may be useful for other people that make the same mistake.
Steps to reproduce:
Run the following:
Error