google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.8k stars 613 forks source link

[nnx] fix Variable overloads and add shape/dtype properties #4049

Closed cgarciae closed 1 week ago

cgarciae commented 1 week ago

What does this PR do?

After these changes you can now correctly do in-place operations of Variables:

class Count(nnx.Variable): ...

class Counter(nnx.Module):
  def __init__(self):
    self.count = Count(jnp.array(0, dtype=jnp.uint32))

  def increment(self):
    self.count += 1

Previously you have do this on the value: self.count.value += 1 .

review-notebook-app[bot] commented 1 week ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB