google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.17k stars 650 forks source link

NNXWrapper #4088

Open PhilipVinc opened 4 months ago

PhilipVinc commented 4 months ago

Hi,

I have a large library that we decided to build on top of flax.linen several years ago. I'd like now to begin testing nnx. However, given the size of the repo and people using it, I cannot change everything at once over to nnx, instead I would like to keep using linen-style code for a while, and allowing users to use models defined with nnx inside of our library.

In brief, the way we use modules right now is

model = LinenModel(...)
model_state, parameters = fcore.pop(model.init(jax.random.key(1), ...), "params")
...
# jit boundary
variables ={"params": parameters, **model_state}
model.apply(variables, inoputs...)

I tried to use nnx.split to this end, but the way it works, returning a special object and not a simple dictionary, makes it impossible to have this approach work fine.

By inspecting nnx.compat/bridge I see that you have several utilities to use linen layers within nnx, but it is unclear to me how to do the opposite. It seems that nnx.bridge.NNXWrapper should do that, but it is unfinished, while it is not clear to me how to use nnx.Module..

Is there anything I can use?

cgarciae commented 4 months ago

Hey @PhilipVinc, as you point out #4081 is the solution we are working on to use Linen Modules in NNX and vice versa. Should be done soon-ish. In the meantime maybe you can use something simple like:

class LinenToNNX(nnx.Module):
  def __init__(
    self,
    module: linen.Module,
    rngs: nnx.Rngs,
  ):
    self.module = module
    self.rngs = rngs
    self.initialized = False

  def __call__(
    self, *args: Any, **kwargs: Any
  ) -> Any:
    _rngs = {name: stream() for name, stream in rngs.items()}
    if 'params' not in _rngs and 'default' in _rngs:
      _rngs['params'] = _rngs.pop('default')

    if not self.initialized:
      self.initialized = True

      out, variables = self.module.init_with_output(_rngs, *args, **kwargs)
      self.params = nnx.Param(variables['params'])
    else:
      variables = {'params': self.params.value}
      out, variables = self.module.apply(variables, *args, rngs=_rngs, **kwargs)
      self.params.value = variables['params']

    return out
IvyZX commented 4 months ago

Hi! I am working on the NNXToLinen wrapper that allows you to use NNX within Linen. I likely will send out the actual PR in a few days but for now this is my draft and example of use. Note the final API might be slightly different.

class NNXToLinen(nn.Module):
  module_op: Callable[..., nnx.Module]

  def setup(self):
    if self.is_initializing():
      self.module = self.module_op(rngs=nnx.Rngs(self.make_rng()))
      self.gdef, state = nnx.split(self.module)
      self.put_variable('params', 'nnx_state', state)
      return
    self.nnx_state = self.variable('params', 'nnx_state').value

  def __call__(self, *args, **kwargs):
    if self.is_initializing():
      return self.module(*args, **kwargs)
    module = nnx.eval_shape(self.module_op, rngs=nnx.Rngs(0))  # dummy rng
    nnx.update(module, self.nnx_state)
    return module(*args, **kwargs)

class NNXInner(nnx.Module):
  def __init__(self, din, dout, rngs):
    self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout)))
    self.bn = nnx.BatchNorm(dout, use_running_average=False, rngs=rngs)

  def __call__(self, x):
    return x @ self.w.value

class LinenOuter(nn.Module):
  dout: int
  @nn.compact
  def __call__(self, x):
    linear = NNXToLinen(functools.partial(NNXInner, x.shape[-1], self.dout))
    b = self.param('b', nn.initializers.lecun_normal(), (1, self.dout))
    return linear(x) + b

x = jax.random.normal(jax.random.key(0), (2, 4))
model = LinenOuter(3)
var = model.init(jax.random.key(0), x)
print(f'{var = }')
y = model.apply(var, x)
assert y.shape == (2, 3)
print(y)