Open jhn-nt opened 3 weeks ago
Same behavior when updating to flax==0.10.0 and jax[cuda12_local]==0.4.35
Hey @jhn-nt, thanks for reporting this! Very curious why our CI is not failing. Easiest fix is to split the keys for the Rngs:
keys = jax.random.split(jax.random.key(0), 5)
model = create_model(nnx.Rngs(keys))
Will fix this quickly.
Thanks a lot again for the prompt help!
Giovanni
Oh wait, the link you posted is for the old experimental docs in the 0.8.3 version of the site, this is fixed in the new version: https://flax.readthedocs.io/en/latest/nnx_basics.html#scan-over-layers . Did you find this via Google?
Uh I see, that explaines it then, apologies for opening the issue, I should have checked in more detail
But, yes, I find it through google, searching for "flax nnx"
Giovanni
I am encountering an index error when running this example in the documentation
I am running the code in a docker environment using an NVIDIA image for jax.
Best Giovanni
System information
Problem you have encountered:
Error while going through nnx tutorial
What you expected to happen:
Logs, error messages, etc:
IndexError Traceback (most recent call last) Cell In[8], line 23 19 @partial(nnx.vmap, axis_size=5) 20 def create_model(rngs: nnx.Rngs): 21 return MLP(10, 32, 10, rngs=rngs) ---> 23 model = create_model(nnx.Rngs(0))
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/graph.py:1158, in UpdateContextManager.call..update_context_manager_wrapper(*args, kwargs)
1155 @functools.wraps(f)
1156 def update_context_manager_wrapper(*args, *kwargs):
1157 with self:
-> 1158 return f(args, kwargs)
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/iteration.py:339, in vmap..vmap_wrapper(*args, *kwargs)
335 args = resolve_kwargs(f, args, kwargs)
336 pure_args = extract.to_tree(
337 args, prefix=in_axes, split_fn=_vmap_split_fn, ctxtag='vmap'
338 )
--> 339 pure_args_out, pure_out = vmapped_fn(pure_args)
340 _args_out, out = extract.from_tree((pure_args_out, pure_out), ctxtag='vmap')
341 return out
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/transforms/iteration.py:164, in VmapFn.call(self, pure_args) 159 pure_args = _update_variable_sharding_metadata( 160 pure_args, self.transform_metadata, spmd.remove_axis 161 ) 162 args = extract.from_tree(pure_args, ctxtag='vmap') --> 164 out = self.f(args) 166 args_out = extract.clear_non_graph_nodes(args) 167 pure_args_out, pure_out = extract.to_tree( 168 (args_out, out), 169 prefix=(self.in_axes, self.out_axes), 170 split_fn=_vmap_split_fn, 171 ctxtag='vmap', 172 )
Cell In[8], line 21 19 @partial(nnx.vmap, axis_size=5) 20 def create_model(rngs: nnx.Rngs): ---> 21 return MLP(10, 32, 10, rngs=rngs)
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:79, in ObjectMeta.call(cls, *args, kwargs) 78 def call(cls, *args: Any, *kwargs: Any) -> Any: ---> 79 return _graph_node_meta_call(cls, args, kwargs)
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:88, in _graph_node_meta_call(cls, *args, kwargs) 86 node = cls.new(cls, *args, *kwargs) 87 vars(node)['_object__state'] = ObjectState() ---> 88 cls._object_meta_construct(node, args, kwargs) 90 return node
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:82, in ObjectMeta._object_meta_construct(cls, self, *args, kwargs) 81 def _object_meta_construct(cls, self, *args, *kwargs): ---> 82 self.init(args, kwargs)
Cell In[8], line 8 7 def init(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): ----> 8 self.linear1 = nnx.Linear(din, dmid, rngs=rngs) 9 self.dropout = nnx.Dropout(rate=0.1, rngs=rngs) 10 self.bn = nnx.BatchNorm(dmid, rngs=rngs)
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:79, in ObjectMeta.call(cls, *args, kwargs) 78 def call(cls, *args: Any, *kwargs: Any) -> Any: ---> 79 return _graph_node_meta_call(cls, args, kwargs)
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:88, in _graph_node_meta_call(cls, *args, kwargs) 86 node = cls.new(cls, *args, *kwargs) 87 vars(node)['_object__state'] = ObjectState() ---> 88 cls._object_meta_construct(node, args, kwargs) 90 return node
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/object.py:82, in ObjectMeta._object_meta_construct(cls, self, *args, kwargs) 81 def _object_meta_construct(cls, self, *args, *kwargs): ---> 82 self.init(args, kwargs)
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/nn/linear.py:346, in Linear.init(self, in_features, out_features, use_bias, dtype, param_dtype, precision, kernel_init, bias_init, dot_general, rngs) 332 def init( 333 self, 334 in_features: int, (...) 344 rngs: rnglib.Rngs, 345 ): --> 346 kernel_key = rngs.params() 347 self.kernel = nnx.Param( 348 kernel_init(kernel_key, (in_features, out_features), param_dtype) 349 ) 350 if use_bias:
File /usr/local/lib/python3.10/dist-packages/flax/nnx/nnx/rnglib.py:84, in RngStream.call(self) 80 def call(self) -> jax.Array: 81 self.check_valid_context( 82 lambda: 'Cannot call RngStream from a different trace level' 83 ) ---> 84 key = jax.random.fold_in(self.key.value, self.count.value) 85 self.count.value += 1 86 return key
File /usr/local/lib/python3.10/dist-packages/jax/_src/random.py:262, in fold_in(key, data) 251 def fold_in(key: KeyArrayLike, data: IntegerArray) -> KeyArray: 252 """Folds in data to a PRNG key to form a new PRNG key. 253 254 Args: (...) 260 statistically safe for producing a stream of new pseudo-random values. 261 """ --> 262 key, wrapped = _check_prng_key("fold_in", key) 263 if np.ndim(data): 264 raise TypeError("fold_in accepts a scalar, but was given an array of" 265 f"shape {np.shape(data)} != (). Use jax.vmap for batching.")
File /usr/local/lib/python3.10/dist-packages/jax/_src/random.py:74, in _check_prng_key(name, key, allow_batched) 72 def _check_prng_key(name: str, key: KeyArrayLike, *, 73 allow_batched: bool = False) -> tuple[KeyArray, bool]: ---> 74 if isinstance(key, Array) and dtypes.issubdtype(key.dtype, dtypes.prng_key): 75 wrapped_key = key 76 wrapped = False
File /usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/batching.py:346, in BatchTracer.aval(self) 344 return aval 345 elif type(self.batch_dim) is int: --> 346 return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval) 347 elif type(self.batch_dim) is RaggedAxis: 348 new_aval = core.mapped_aval( 349 aval.shape[self.batch_dim.stacked_axis], self.batch_dim.stacked_axis, aval)
IndexError: tuple index out of range
Steps to reproduce: