google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k stars 648 forks source link

nnx.vmap example from documentation raise an index error #4355

Open jhn-nt opened 3 weeks ago

jhn-nt commented 3 weeks ago

I am encountering an index error when running this example in the documentation

I am running the code in a docker environment using an NVIDIA image for jax.

Best Giovanni

System information

Problem you have encountered:

Error while going through nnx tutorial

What you expected to happen:

Logs, error messages, etc:


IndexError Traceback (most recent call last) Cell In[8], line 23 19 @partial(nnx.vmap, axis_size=5) 20 def create_model(rngs: nnx.Rngs): 21 return MLP(10, 32, 10, rngs=rngs) ---> 23 model = create_model(nnx.Rngs(0))

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/graph.py:1158, in UpdateContextManager.call..update_context_manager_wrapper(*args, kwargs) 1155 @functools.wraps(f) 1156 def update_context_manager_wrapper(*args, *kwargs): 1157 with self: -> 1158 return f(args, kwargs)

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/iteration.py:339, in vmap..vmap_wrapper(*args, *kwargs) 335 args = resolve_kwargs(f, args, kwargs) 336 pure_args = extract.to_tree( 337 args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='vmap' 338 ) --> 339 pure_args_out, pure_out = vmapped_fn(pure_args) 340 _args_out, out = extract.from_tree((pure_args_out, pure_out), ctxtag='vmap') 341 return out

[... skipping hidden 3 frame]

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/iteration.py:164, in VmapFn.call(self, pure_args) 159 pure_args = _update_variable_sharding_metadata( 160 pure_args, self.transform_metadata, spmd.remove_axis 161 ) 162 args = extract.from_tree(pure_args, ctxtag='vmap') --> 164 out = self.f(args) 166 args_out = extract.clear_non_graph_nodes(args) 167 pure_args_out, pure_out = extract.to_tree( 168 (args_out, out), 169 prefix=(self.in_axes, self.out_axes), 170 split_fn=_vmap_split_fn, 171 ctxtag='vmap', 172 )

Cell In[8], line 21 19 @partial(nnx.vmap, axis_size=5) 20 def create_model(rngs: nnx.Rngs): ---> 21 return MLP(10, 32, 10, rngs=rngs)

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:79, in ObjectMeta.call(cls, *args, kwargs) 78 def call(cls, *args: Any, *kwargs: Any) -> Any: ---> 79 return _graph_node_meta_call(cls, args, kwargs)

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:88, in _graph_node_meta_call(cls, *args, kwargs) 86 node = cls.new(cls, *args, *kwargs) 87 vars(node)['_object__state'] = ObjectState() ---> 88 cls._object_meta_construct(node, args, kwargs) 90 return node

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:82, in ObjectMeta._object_meta_construct(cls, self, *args, kwargs) 81 def _object_meta_construct(cls, self, *args, *kwargs): ---> 82 self.init(args, kwargs)

Cell In[8], line 8 7 def init(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): ----> 8 self.linear1 = nnx.Linear(din, dmid, rngs=rngs) 9 self.dropout = nnx.Dropout(rate=0.1, rngs=rngs) 10 self.bn = nnx.BatchNorm(dmid, rngs=rngs)

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:79, in ObjectMeta.call(cls, *args, kwargs) 78 def call(cls, *args: Any, *kwargs: Any) -> Any: ---> 79 return _graph_node_meta_call(cls, args, kwargs)

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:88, in _graph_node_meta_call(cls, *args, kwargs) 86 node = cls.new(cls, *args, *kwargs) 87 vars(node)['_object__state'] = ObjectState() ---> 88 cls._object_meta_construct(node, args, kwargs) 90 return node

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:82, in ObjectMeta._object_meta_construct(cls, self, *args, kwargs) 81 def _object_meta_construct(cls, self, *args, *kwargs): ---> 82 self.init(args, kwargs)

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/nn/linear.py:346, in Linear.init(self, in_features, out_features, use_bias, dtype, param_dtype, precision, kernel_init, bias_init, dot_general, rngs) 332 def init( 333 self, 334 in_features: int, (...) 344 rngs: rnglib.Rngs, 345 ): --> 346 kernel_key = rngs.params() 347 self.kernel = nnx.Param( 348 kernel_init(kernel_key, (in_features, out_features), param_dtype) 349 ) 350 if use_bias:

File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/rnglib.py:84, in RngStream.call(self) 80 def call(self) -> jax.Array: 81 self.check_valid_context( 82 lambda: 'Cannot call RngStream from a different trace level' 83 ) ---> 84 key = jax.random.fold_in(self.key.value, self.count.value) 85 self.count.value += 1 86 return key

File /usr/local/lib/python3.10/dist-packages/jax/_src/random.py:262, in fold_in(key, data) 251 def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray: 252 """Folds in data to a PRNG key to form a new PRNG key. 253 254 Args: (...) 260 statistically safe for producing a stream of new pseudo-random values. 261 """ --> 262 key, wrapped = _check_prng_key("fold_in", key) 263 if np.ndim(data): 264 raise TypeError("fold_in accepts a scalar, but was given an array of" 265 f"shape {np.shape(data)} != (). Use jax.vmap for batching.")

File /usr/local/lib/python3.10/dist-packages/jax/_src/random.py:74, in _check_prng_key(name, key, allow_batched) 72 def _check_prng_key(name: str, key: KeyArrayLike, *, 73 allow_batched: bool = False) -> tuple[KeyArray, bool]: ---> 74 if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key): 75 wrapped_key = key 76 wrapped = False

[... skipping hidden 1 frame]

File /usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/batching.py:346, in BatchTracer.aval(self) 344 return aval 345 elif type(self.batch_dim) is int: --> 346 return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval) 347 elif type(self.batch_dim) is RaggedAxis: 348 new_aval = core.mapped_aval( 349 aval.shape[self.batch_dim.stacked_axis], self.batch_dim.stacked_axis, aval)

IndexError: tuple index out of range

Steps to reproduce:

from flax import nnx
import jax
from functools import partial

class MLP(nnx.Module):
  def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
    self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
    self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x: jax.Array):
    x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
    return self.linear2(x)

@partial(nnx.vmap, axis_size=5)
def create_model(rngs: nnx.Rngs):
  return MLP(10, 32, 10, rngs=rngs)

model = create_model(nnx.Rngs(0))
jhn-nt commented 3 weeks ago

Same behavior when updating to flax==0.10.0 and jax[cuda12_local]==0.4.35

cgarciae commented 3 weeks ago

Hey @jhn-nt, thanks for reporting this! Very curious why our CI is not failing. Easiest fix is to split the keys for the Rngs:

keys = jax.random.split(jax.random.key(0), 5)
model = create_model(nnx.Rngs(keys))

Will fix this quickly.

jhn-nt commented 3 weeks ago

Thanks a lot again for the prompt help!

Giovanni

cgarciae commented 2 weeks ago

Oh wait, the link you posted is for the old experimental docs in the 0.8.3 version of the site, this is fixed in the new version: https://flax.readthedocs.io/en/latest/nnx_basics.html#scan-over-layers . Did you find this via Google?

jhn-nt commented 2 weeks ago

Uh I see, that explaines it then, apologies for opening the issue, I should have checked in more detail

But, yes, I find it through google, searching for "flax nnx"

Giovanni