google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.27k stars 226 forks source link

Flattening issue in predict.gradient descent? #11

Closed atishagarwala closed 4 years ago

atishagarwala commented 4 years ago

Seems to be a bug in nt.predict.gradient_descent, perhaps related to flattening of inputs. Code snippet and stacktrace below.

Code snippet:

def ntk_loss(fx,y_hat):
  return -np.mean(np.sum(jstax.logsoftmax(beta*fx) * y_hat,axis=1))

g_dd = kernel_fn(x_train, x_train, 'ntk') # kernel_fn from nt.stax.serial
g_td = kernel_fn(x_test, x_train, 'ntk') # test and train numpy arrays
ntk_loss = scaled_loss_for_ntk(beta)
ntk_loss = jit(ntk_loss)

predict_fn = nt.predict.gradient_descent(g_dd, y_train, ntk_loss, g_td)
predict_fn(0.1,fx_train_initial,fx_test_initial)

Stacktrace of error:

ValueError                                Traceback (most recent call last)
<ipython-input-97-592ef42dc248> in <module>()
      5   ntk_outputs, ntk_loss_fn, ntk_acc_fn = get_ntk_dynamics(
      6       kernel_fn,x_train,x_test,y_train,
----> 7       y_test,fx_train_initial,fx_test_initial,beta)
      8   # get results
      9   train_loss = nnp.zeros(len(ts))

25 frames
<ipython-input-96-9b76a38cb837> in get_ntk_dynamics(kernel_fn, x_train, x_test, y_train, y_test, fx_train_initial, fx_test_initial, beta)
     22   print('NTK initial loss: {}'.format(ntk_loss(fx_train_initial,y_train)))
     23   predict_fn = nt.predict.gradient_descent(g_dd, y_train, ntk_loss, g_td)
---> 24   predict_fn(0.1,fx_train_initial,fx_test_initial)
     25 
     26   ntk_outputs = functools.partial(

google3/third_party/py/neural_tangents/predict.py in predict(dt, fx_train, fx_test)
    276       train_size, output_dim = fx_train.shape
    277       r.set_initial_value(fx, 0).set_f_params(train_size * output_dim)
--> 278       r.integrate(dt)
    279       fx = ufl(r.y)
    280 

google3/third_party/py/scipy/integrate/_ode.py in integrate(self, t, step, relax)
    430             self._y, self.t = mth(self.f, self.jac or (lambda: None),
    431                                   self._y, self.t, t,
--> 432                                   self.f_params, self.jac_params)
    433         except SystemError:
    434             # f2py issue with tuple returns, see ticket 1187.

google3/third_party/py/scipy/integrate/_ode.py in run(self, f, jac, y0, t0, t1, f_params, jac_params)
   1170     def run(self, f, jac, y0, t0, t1, f_params, jac_params):
   1171         x, y, iwork, istate = self.runner(*((f, t0, y0, t1) +
-> 1172                                           tuple(self.call_args) + (f_params,)))
   1173         self.istate = istate
   1174         if istate < 0:

google3/third_party/py/neural_tangents/predict.py in dfx_dt(unused_t, fx, train_size)
    266     def dfx_dt(unused_t, fx, train_size):
    267       fx_train = fx[:train_size]
--> 268       dfx_train = -ifl(np.dot(g_dd, iufl(grad_loss(fx_train))))
    269       dfx_test = -ifl(np.dot(g_td, iufl(grad_loss(fx_train))))
    270       return np.concatenate((dfx_train, dfx_test), axis=0)

google3/third_party/py/jax/api.py in grad_f(*args, **kwargs)
    353   @wraps(fun, docstr=docstr, argnums=argnums)
    354   def grad_f(*args, **kwargs):
--> 355     _, g = value_and_grad_f(*args, **kwargs)
    356     return g
    357 

google3/third_party/py/jax/api.py in value_and_grad_f(*args, **kwargs)
    408     f_partial, dyn_args = _argnums_partial(f, argnums, args)
    409     if not has_aux:
--> 410       ans, vjp_py = vjp(f_partial, *dyn_args)
    411     else:
    412       ans, vjp_py, aux = vjp(f_partial, *dyn_args, has_aux=True)

google3/third_party/py/jax/api.py in vjp(fun, *primals, **kwargs)
   1267   if not has_aux:
   1268     flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 1269     out_primal, out_vjp = ad.vjp(flat_fun, primals_flat)
   1270     out_tree = out_tree()
   1271   else:

google3/third_party/py/jax/interpreters/ad.py in vjp(traceable, primals, has_aux)
    106 def vjp(traceable, primals, has_aux=False):
    107   if not has_aux:
--> 108     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    109   else:
    110     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

google3/third_party/py/jax/interpreters/ad.py in linearize(traceable, *primals, **kwargs)
     95   _, in_tree = tree_flatten(((primals, primals), {}))
     96   jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
---> 97   jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
     98   pval_primals, pval_tangents = tree_unflatten(out_tree(), out_pvals)
     99   aval_primals, const_primals = unzip2(pval_primals)

google3/third_party/py/jax/interpreters/partial_eval.py in trace_to_jaxpr(fun, pvals, **kwargs)
    313   with new_master(JaxprTrace) as master:
    314     fun = trace_to_subjaxpr(fun, master, instantiate)
--> 315     jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    316     assert not env
    317     del master

google3/third_party/py/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    151     gen = None
    152 
--> 153     ans = self.f(*args, **dict(self.params, **kwargs))
    154     del args
    155     while stack:

google3/third_party/py/jax/api.py in f_jitted(*args, **kwargs)
    148     _check_args(args_flat)
    149     flat_fun, out_tree = flatten_fun(f, in_tree)
--> 150     out = xla.xla_call(flat_fun, *args_flat, device=device, backend=backend)
    151     return tree_unflatten(out_tree(), out)
    152 

google3/third_party/py/jax/core.py in call_bind(primitive, f, *args, **params)
    593   else:
    594     tracers = map(top_trace.full_raise, args)
--> 595     outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
    596   return apply_todos(env_trace_todo(), outs)
    597 

google3/third_party/py/jax/interpreters/ad.py in process_call(self, call_primitive, f, tracers, params)
    324     nonzero_tangents, in_tree_def = tree_flatten(tangents)
    325     f_jvp, out_tree_def = traceable(jvp_subtrace(f, self.master), len(primals), in_tree_def)
--> 326     result = call_primitive.bind(f_jvp, *(primals + nonzero_tangents), **params)
    327     primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
    328     return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]

google3/third_party/py/jax/core.py in call_bind(primitive, f, *args, **params)
    593   else:
    594     tracers = map(top_trace.full_raise, args)
--> 595     outs = map(full_lower, top_trace.process_call(primitive, f, tracers, params))
    596   return apply_todos(env_trace_todo(), outs)
    597 

google3/third_party/py/jax/interpreters/partial_eval.py in process_call(self, call_primitive, f, tracers, params)
    113     in_pvs, in_consts = unzip2([t.pval for t in tracers])
    114     fun, aux = partial_eval(f, self, in_pvs)
--> 115     out_flat = call_primitive.bind(fun, *in_consts, **params)
    116     out_pvs, jaxpr, env = aux()
    117     out_pv_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)])

google3/third_party/py/jax/core.py in call_bind(primitive, f, *args, **params)
    590   if top_trace is None:
    591     with new_sublevel():
--> 592       outs = primitive.impl(f, *args, **params)
    593   else:
    594     tracers = map(top_trace.full_raise, args)

google3/third_party/py/jax/interpreters/xla.py in _xla_call_impl(fun, *args, **params)
    398   device = params['device']
    399   backend = params.get('backend', None)
--> 400   compiled_fun = _xla_callable(fun, device, backend, *map(abstractify, args))
    401   try:
    402     return compiled_fun(*args)

google3/third_party/py/jax/linear_util.py in memoized_fun(fun, *args)
    207       fun.populate_stores(stores)
    208     else:
--> 209       ans = call(fun, *args)
    210       cache[key] = (ans, fun.stores)
    211     return ans

google3/third_party/py/jax/interpreters/xla.py in _xla_callable(fun, device, backend, *abstract_args)
    410   pvals = [pe.PartialVal((aval, core.unit)) for aval in abstract_args]
    411   with core.new_master(pe.JaxprTrace, True) as master:
--> 412     jaxpr, (pvals, consts, env) = pe.trace_to_subjaxpr(fun, master, False).call_wrapped(pvals)
    413     assert not env  # no subtraces here
    414     del master, env

google3/third_party/py/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    151     gen = None
    152 
--> 153     ans = self.f(*args, **dict(self.params, **kwargs))
    154     del args
    155     while stack:

<ipython-input-96-9b76a38cb837> in ntk_loss(fx, y_hat)
      2 def scaled_loss_for_ntk(beta):
      3   def ntk_loss(fx,y_hat):
----> 4     return -np.mean(np.sum(jstax.logsoftmax(beta*fx) * y_hat,axis=1))
      5   return ntk_loss
      6 

google3/third_party/py/jax/numpy/lax_numpy.py in reduction(a, axis, dtype, out, keepdims)
   1184     a = a if isinstance(a, ndarray) else asarray(a)
   1185     a = preproc(a) if preproc else a
-> 1186     dims = _reduction_dims(a, axis)
   1187     result_dtype = dtype or _dtype(np_fun(onp.ones((), dtype=_dtype(a))))
   1188     if upcast_f16_for_computation and issubdtype(result_dtype, inexact):

google3/third_party/py/jax/numpy/lax_numpy.py in _reduction_dims(a, axis)
   1206     return tuple(_canonicalize_axis(x, ndim(a)) for x in axis)
   1207   elif isinstance(axis, int):
-> 1208     return (_canonicalize_axis(axis, ndim(a)),)
   1209   else:
   1210     raise TypeError("Unexpected type of axis argument: {}".format(type(axis)))

google3/third_party/py/jax/numpy/lax_numpy.py in _canonicalize_axis(axis, num_dims)
    353       raise ValueError(
    354           "axis {} is out of bounds for array of dimension {}".format(
--> 355               axis, num_dims))
    356   return axis
    357 

ValueError: axis 1 is out of bounds for array of dimension 1
romanngg commented 4 years ago

Looks like could be related indeed... Do you happen to have the full code to repro? This could helps us speed up debugging this.

sschoenholz commented 4 years ago

I debugged this a bit with @atishagarwala last night. Will followup next week. I think it should be a fairly simple fix, and we should also add a test.

romanngg commented 4 years ago

Sorry for the enormous delay - fixed in a76bbb494f19af4f8c9c1a1b0904e91b105f769e (v0.3.0)! Example using the new API: https://colab.research.google.com/gist/romanngg/b6cd8595fcd5e12ac56b7c78747851db/flatten_issue.ipynb

Please reopen if I missed anything!