juliuskunze / jaxnet

Concise deep learning for JAX
Apache License 2.0
184 stars 14 forks source link

Issues with basic RNN... #26

Closed adarob closed 4 years ago

adarob commented 4 years ago

I've simplified the code i'm running down to this minimal example:

import jax
from jax import nn
import jax.numpy as jnp
import jaxnet
import jaxnet.optimizers

cell = jaxnet.Dense(1)

@jaxnet.parametrized
def step_fn(carry, x):
  y = cell(carry)
  return y, y

@jaxnet.parametrized
def loss(init):
  carry, y = jax.lax.scan(step_fn, init=init, xs=None, length=10)
  return y.sum()

opt = jaxnet.optimizers.Adam()
opt_state = opt.init(loss.init_parameters(jnp.zeros([1, 1]), key=jax.random.PRNGKey(0)))
for i in range(5):
  opt_state = opt.update(loss.apply, opt_state, jnp.zeros([1, 1]), jit=True)

Initially I get this error:

/usr/local/lib/python3.6/dist-packages/jax/util.py in split_dict(dct, names)
     83   dct = dict(dct)
     84   lst = [dct.pop(name) for name in names]
---> 85   assert not dct
     86   return lst
     87 

AssertionError: 

This is because unroll is missing from the list of kwargs in _custom_cell_scan_impl. If I add it and pass its value through to scan_p, I then get this error:

/usr/local/lib/python3.6/dist-packages/jaxnet/core.py in next_parameters_for(self, primitive)
    557             return parameters
    558 
--> 559         parameters = self.parameters[self._index]
    560         self._index += 1
    561         self.global_parameters_by_primitive[primitive] = parameters

IndexError: tuple index out of range

It appears that the params of dense aren't being added:

> opt.get_parameters(opt_state)
loss(step_fn=step_fn())

Any idea how to fix this?

juliuskunze commented 4 years ago

Thanks for the clear report! This is now fixed.