I've simplified the code i'm running down to this minimal example:
import jax
from jax import nn
import jax.numpy as jnp
import jaxnet
import jaxnet.optimizers
cell = jaxnet.Dense(1)
@jaxnet.parametrized
def step_fn(carry, x):
y = cell(carry)
return y, y
@jaxnet.parametrized
def loss(init):
carry, y = jax.lax.scan(step_fn, init=init, xs=None, length=10)
return y.sum()
opt = jaxnet.optimizers.Adam()
opt_state = opt.init(loss.init_parameters(jnp.zeros([1, 1]), key=jax.random.PRNGKey(0)))
for i in range(5):
opt_state = opt.update(loss.apply, opt_state, jnp.zeros([1, 1]), jit=True)
Initially I get this error:
/usr/local/lib/python3.6/dist-packages/jax/util.py in split_dict(dct, names)
83 dct = dict(dct)
84 lst = [dct.pop(name) for name in names]
---> 85 assert not dct
86 return lst
87
AssertionError:
This is because unroll is missing from the list of kwargs in _custom_cell_scan_impl. If I add it and pass its value through to scan_p, I then get this error:
/usr/local/lib/python3.6/dist-packages/jaxnet/core.py in next_parameters_for(self, primitive)
557 return parameters
558
--> 559 parameters = self.parameters[self._index]
560 self._index += 1
561 self.global_parameters_by_primitive[primitive] = parameters
IndexError: tuple index out of range
It appears that the params of dense aren't being added:
I've simplified the code i'm running down to this minimal example:
Initially I get this error:
This is because
unroll
is missing from the list of kwargs in_custom_cell_scan_impl
. If I add it and pass its value through toscan_p
, I then get this error:It appears that the params of dense aren't being added:
Any idea how to fix this?