Open PhilipVinc opened 4 months ago
Hey @PhilipVinc, as you point out #4081 is the solution we are working on to use Linen Modules in NNX and vice versa. Should be done soon-ish. In the meantime maybe you can use something simple like:
class LinenToNNX(nnx.Module):
def __init__(
self,
module: linen.Module,
rngs: nnx.Rngs,
):
self.module = module
self.rngs = rngs
self.initialized = False
def __call__(
self, *args: Any, **kwargs: Any
) -> Any:
_rngs = {name: stream() for name, stream in rngs.items()}
if 'params' not in _rngs and 'default' in _rngs:
_rngs['params'] = _rngs.pop('default')
if not self.initialized:
self.initialized = True
out, variables = self.module.init_with_output(_rngs, *args, **kwargs)
self.params = nnx.Param(variables['params'])
else:
variables = {'params': self.params.value}
out, variables = self.module.apply(variables, *args, rngs=_rngs, **kwargs)
self.params.value = variables['params']
return out
Hi! I am working on the NNXToLinen
wrapper that allows you to use NNX within Linen. I likely will send out the actual PR in a few days but for now this is my draft and example of use. Note the final API might be slightly different.
class NNXToLinen(nn.Module):
module_op: Callable[..., nnx.Module]
def setup(self):
if self.is_initializing():
self.module = self.module_op(rngs=nnx.Rngs(self.make_rng()))
self.gdef, state = nnx.split(self.module)
self.put_variable('params', 'nnx_state', state)
return
self.nnx_state = self.variable('params', 'nnx_state').value
def __call__(self, *args, **kwargs):
if self.is_initializing():
return self.module(*args, **kwargs)
module = nnx.eval_shape(self.module_op, rngs=nnx.Rngs(0)) # dummy rng
nnx.update(module, self.nnx_state)
return module(*args, **kwargs)
class NNXInner(nnx.Module):
def __init__(self, din, dout, rngs):
self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout)))
self.bn = nnx.BatchNorm(dout, use_running_average=False, rngs=rngs)
def __call__(self, x):
return x @ self.w.value
class LinenOuter(nn.Module):
dout: int
@nn.compact
def __call__(self, x):
linear = NNXToLinen(functools.partial(NNXInner, x.shape[-1], self.dout))
b = self.param('b', nn.initializers.lecun_normal(), (1, self.dout))
return linear(x) + b
x = jax.random.normal(jax.random.key(0), (2, 4))
model = LinenOuter(3)
var = model.init(jax.random.key(0), x)
print(f'{var = }')
y = model.apply(var, x)
assert y.shape == (2, 3)
print(y)
Hi,
I have a large library that we decided to build on top of
flax.linen
several years ago. I'd like now to begin testingnnx
. However, given the size of the repo and people using it, I cannot change everything at once over to nnx, instead I would like to keep usinglinen-style
code for a while, and allowing users to use models defined with nnx inside of our library.In brief, the way we use modules right now is
I tried to use
nnx.split
to this end, but the way it works, returning a special object and not a simple dictionary, makes it impossible to have this approach work fine.By inspecting
nnx.compat/bridge
I see that you have several utilities to use linen layers within nnx, but it is unclear to me how to do the opposite. It seems thatnnx.bridge.NNXWrapper
should do that, but it is unfinished, while it is not clear to me how to usennx.Module
..Is there anything I can use?