While developing #2316, Scope.paths where capture via a new mechanism made for tabulate, such mechanism assumes that there is a 1:1 correspondence between path names and variable structure, however a couple of workarounds had to be made as it was found that this breaks under certain circumstances. Here are some minimal reproducible examples showing this unexpected behavior.
1. Lifted Modules
When using lifted Modules, Scope.path names use the following notation:
<transformation_name>(<module_name>)
So you get names like scan(ScanLSTMCell_0), whereas in the variable structure only ScanLSTMCell_0 appears. Code bellow shows the difference between a different paths names and the corresponding variable structure for nn.scan + LSTMCell example:
When reusing a module (calling the same module more that once) there is some weird behavior where path path might be wrong after the first call.
Using Setup
Code bellow creates a CNN module, that has a ConvBlock submodule that calls Conv. Here submodules are created during setup and CNN calls self.block twice.
Code
```python
import jax
import jax.numpy as jnp
from flax import linen as nn
PATHS = []
class Conv(nn.Conv):
def __call__(self, *args, **kwargs):
PATHS.append(("inside conv", self.scope.path))
return super().__call__(*args, **kwargs)
class ConvBlock(nn.Module):
def setup(self) -> None:
self.conv = Conv(32, [3, 3])
def __call__(self, x):
PATHS.append(("ConvBlock start", self.scope.path))
x = self.conv(x)
PATHS.append(("after conv", self.scope.path))
return x
class CNN(nn.Module):
def setup(self):
self.block = ConvBlock()
def __call__(self, x):
x = self.block(x)
x = self.block(x)
return x
x = jnp.ones((4, 28, 28, 32))
variables = CNN().init(jax.random.PRNGKey(0), x)
for p in PATHS:
print(p)
```
```
('ConvBlock start', ('block',)) # correct
('inside conv', ('block', 'conv')) # correct
('ConvBlock end', ('block',)) # correct
('ConvBlock start', ('block',)) # correct
('inside conv', ()) # wrong
('ConvBlock end', ('block',)) # correct
```
Notice that the 'inside conv' path is empty the second time self.block is called.
Using nn.compact
If you use nn.compact instead of setup path names are wrong in a different way. Here block is instantiated inside CNN.__call__ at the beginning and used twice.
Code
```python
import jax
import jax.numpy as jnp
from flax import linen as nn
PATHS = []
class Conv(nn.Conv):
def __call__(self, *args, **kwargs):
PATHS.append(("inside conv", self.scope.path))
return super().__call__(*args, **kwargs)
class ConvBlock(nn.Module):
@nn.compact
def __call__(self, x):
PATHS.append(("ConvBlock start", self.scope.path))
x = Conv(32, [3, 3])(x)
PATHS.append(("ConvBlock end", self.scope.path))
return x
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
block = ConvBlock()
x = block(x)
x = block(x)
return x
x = jnp.ones((4, 28, 28, 32))
variables = CNN().init(jax.random.PRNGKey(0), x)
for p in PATHS:
print(p)
```
```
('ConvBlock start', ('ConvBlock_0',)) # correct
('inside conv', ('ConvBlock_0', 'Conv_0')) # correct
('ConvBlock end', ('ConvBlock_0',)) # correct
('ConvBlock start', ()) # wrong
('inside conv', ('Conv_0',)) # wrong
('ConvBlock end', ()) # wrong
```
Notice that now all path are wrong the second time around.
Case 1 is part of the functional core API so its not really an error, just a mismatch between the functional API and the Module API.
Case 2 was fixed by #2360.
While developing #2316,
Scope.path
s where capture via a new mechanism made fortabulate
, such mechanism assumes that there is a 1:1 correspondence between path names and variable structure, however a couple of workarounds had to be made as it was found that this breaks under certain circumstances. Here are some minimal reproducible examples showing this unexpected behavior.1. Lifted Modules
When using lifted Modules,
Scope.path
names use the following notation:So you get names like
scan(ScanLSTMCell_0)
, whereas in the variable structure onlyScanLSTMCell_0
appears. Code bellow shows the difference between a different paths names and the corresponding variable structure fornn.scan
+LSTMCell
example:Code
```python import flax.linen as nn import jax import jax.numpy as jnp from jax import random PATHS = [] class LSTMCell(nn.LSTMCell): def __call__(self, carry, inputs): PATHS.append(self.scope.path) return super().__call__(carry, inputs) class LSTM(nn.Module): out_feat: int @nn.compact def __call__(self, x): PATHS.append(self.scope.path) carry = nn.LSTMCell.initialize_carry( random.PRNGKey(0), x.shape[:1], self.out_feat ) Cell = nn.scan( LSTMCell, variable_broadcast="params", split_rngs={"params": False}, variable_axes={"intermediates": 1}, in_axes=1, out_axes=1, ) return Cell()(carry, x) lstm = LSTM(out_feat=128) variables = lstm.init(random.PRNGKey(0), jnp.ones((32, 128, 64))) print(PATHS, "\n") print(jax.tree_map(lambda x: x.shape, variables)) ``` ``` [(), ('scan(ScanLSTMCell_0)',), ('scan(ScanLSTMCell_0)',)] FrozenDict({ params: { ScanLSTMCell_0: { hf: {...}, hg: {...}, hi: {...}, ho: {...}, if: {...}, ig: {...}, ii: {...}, io: {...}, }, }, }) ```2. Module reuse
When reusing a module (calling the same module more that once) there is some weird behavior where path path might be wrong after the first call.
Using Setup
Code bellow creates a
CNN
module, that has aConvBlock
submodule that callsConv
. Here submodules are created during setup andCNN
callsself.block
twice.Code
```python import jax import jax.numpy as jnp from flax import linen as nn PATHS = [] class Conv(nn.Conv): def __call__(self, *args, **kwargs): PATHS.append(("inside conv", self.scope.path)) return super().__call__(*args, **kwargs) class ConvBlock(nn.Module): def setup(self) -> None: self.conv = Conv(32, [3, 3]) def __call__(self, x): PATHS.append(("ConvBlock start", self.scope.path)) x = self.conv(x) PATHS.append(("after conv", self.scope.path)) return x class CNN(nn.Module): def setup(self): self.block = ConvBlock() def __call__(self, x): x = self.block(x) x = self.block(x) return x x = jnp.ones((4, 28, 28, 32)) variables = CNN().init(jax.random.PRNGKey(0), x) for p in PATHS: print(p) ``` ``` ('ConvBlock start', ('block',)) # correct ('inside conv', ('block', 'conv')) # correct ('ConvBlock end', ('block',)) # correct ('ConvBlock start', ('block',)) # correct ('inside conv', ()) # wrong ('ConvBlock end', ('block',)) # correct ```Notice that the
'inside conv'
path is empty the second timeself.block
is called.Using nn.compact
If you use
nn.compact
instead of setup path names are wrong in a different way. Hereblock
is instantiated insideCNN.__call__
at the beginning and used twice.Code
```python import jax import jax.numpy as jnp from flax import linen as nn PATHS = [] class Conv(nn.Conv): def __call__(self, *args, **kwargs): PATHS.append(("inside conv", self.scope.path)) return super().__call__(*args, **kwargs) class ConvBlock(nn.Module): @nn.compact def __call__(self, x): PATHS.append(("ConvBlock start", self.scope.path)) x = Conv(32, [3, 3])(x) PATHS.append(("ConvBlock end", self.scope.path)) return x class CNN(nn.Module): @nn.compact def __call__(self, x): block = ConvBlock() x = block(x) x = block(x) return x x = jnp.ones((4, 28, 28, 32)) variables = CNN().init(jax.random.PRNGKey(0), x) for p in PATHS: print(p) ``` ``` ('ConvBlock start', ('ConvBlock_0',)) # correct ('inside conv', ('ConvBlock_0', 'Conv_0')) # correct ('ConvBlock end', ('ConvBlock_0',)) # correct ('ConvBlock start', ()) # wrong ('inside conv', ('Conv_0',)) # wrong ('ConvBlock end', ()) # wrong ```Notice that now all path are wrong the second time around.