google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.29k stars 227 forks source link

A type error in predict.gradient_descent #112

Closed kim-hyunsu closed 3 years ago

kim-hyunsu commented 3 years ago

I wrote simple codes with monte_carlo_kernel_fn and gradient_descent modules, but it raised an unidentifiable type error even though I've never manipulated any types in the code. Basically, I followed some examples shown in the source codes except for the fact that I used jax.experiment.stax.Tanh to build a two-layer neural network involving a hypertangent activation.

The code I ran was as follows:

import neural_tangents as nt
import jax.experimental.stax as ostax
from jax import random as jrandom
import jax.numpy as np

key = jrandom.PRNGKey(0)

def gen_key():
    global key
    key, k = jrandom.split(key, 2)
    return k

def cross_entropy(fx, y_hat):
    return -np.mean(ostax.logsoftmax(fx)*y_hat)

x_train = jrandom.normal(gen_key(), (20, 784))
x_test = jrandom.normal(gen_key(), (20, 784))
y_train = jrandom.normal(gen_key(), (20, 50))

init_fn, apply_fn = ostax.serial(
    ostax.Dense(200), ostax.Tanh, ostax.Dense(50))

_, params = init_fn(gen_key(), x_train.shape)

kernel_fn = nt.monte_carlo_kernel_fn(
    init_fn, apply_fn, key=gen_key(), n_samples=100)

k_train_train = kernel_fn(x_train, None, get='ntk')
k_test_train = kernel_fn(x_test, x_train, get='ntk')

predict_fn = nt.predict.gradient_descent(
    cross_entropy, k_train_train, y_train, 1e-2, 0.9)
fx_train_0 = apply_fn(params, x_train)
fx_test_0 = apply_fn(params, x_test)

t = 1e-7
fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, k_test_train)
print(fx_train_t)
print(fx_test_t)

The raised error was as follows:

Traceback (most recent call last):
  File "nt-practice.py", line 42, in <module>
    fx_train_t, fx_test_t = predict_fn(t, fx_train_0, fx_test_0, k_test_train)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/neural_tangents/predict.py", line 472, in predict_fn
    state_t = ode.odeint(get_dstate_dt(k_test_train), state_0, t)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/experimental/ode.py", line 173, in odeint
    return _odeint_wrapper(converted, rtol, atol, mxstep, y0, t, *args, *consts)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/api.py", line 338, in cache_miss
    donated_invars=donated_invars)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/core.py", line 1402, in bind
    return call_bind(self, fun, *args, **params)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/core.py", line 1393, in call_bind
    outs = primitive.process(top_trace, fun, tracers, params)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/core.py", line 1405, in process
    return trace.process_call(self, fun, tracers, params)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/core.py", line 600, in process_call
    return primitive.impl(f, *tracers, **params)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/interpreters/xla.py", line 577, in _xla_call_impl
    *unsafe_map(arg_spec, args))
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/linear_util.py", line 260, in memoized_fun
    ans = call(fun, *args)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/interpreters/xla.py", line 652, in _xla_callable
    jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1209, in trace_to_jaxpr_final
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/experimental/ode.py", line 179, in _odeint_wrapper
    out = _odeint(func, rtol, atol, mxstep, y0, ts, *args)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/custom_derivatives.py", line 485, in __call__
    out_trees=out_trees)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/custom_derivatives.py", line 566, in bind
    out_trees=out_trees)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1137, in process_custom_vjp_call
    fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/experimental/ode.py", line 214, in _odeint
    _, ys = lax.scan(scan_fun, init_carry, ts[1:])
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/lax/control_flow.py", line 1276, in scan
    init_flat, carry_avals, carry_avals_out, init_tree, *rest = _create_jaxpr(init)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/lax/control_flow.py", line 1263, in _create_jaxpr
    jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals, "scan")
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/util.py", line 185, in wrapper
    return cached(bool(config.x64_enabled), *args, **kwargs)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/util.py", line 178, in cached
    return f(*args, **kwargs)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/lax/control_flow.py", line 77, in _initial_style_jaxpr
    transform_name)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/util.py", line 185, in wrapper
    return cached(bool(config.x64_enabled), *args, **kwargs)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/util.py", line 178, in cached
    return f(*args, **kwargs)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/lax/control_flow.py", line 70, in _initial_style_open_jaxpr
    transform_name=transform_name)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1178, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1188, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/experimental/ode.py", line 204, in scan_fun
    _, *carry = lax.while_loop(cond_fun, body_fun, [0] + carry)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/lax/control_flow.py", line 301, in while_loop
    in_tree_children[0], init_avals)
  File "/home/hyunsu/restricted-boltzmann-machines/.venv/lib/python3.6/site-packages/jax/_src/lax/control_flow.py", line 1940, in _check_tree_and_avals
    f"{what} must have identical types, got\n"
TypeError: body_fun output and input must have identical types, got
[ShapedArray(int64[], weak_type=True), ShapedArray(float64[4000]), ShapedArray(float64[4000]), ShapedArray(float64[]), ShapedArray(float64[]), ShapedArray(float32[]), ShapedArray(float64[5,4000])]
and
[ShapedArray(int64[], weak_type=True), ShapedArray(float64[4000]), ShapedArray(float64[4000]), ShapedArray(float32[]), ShapedArray(float64[]), ShapedArray(float32[]), ShapedArray(float64[5,4000])].

Is there any idea to address this problem?

romanngg commented 3 years ago

Thanks for the report! I can't seem to reproduce it Colab: https://colab.research.google.com/gist/romanngg/693d46b3c4a89649ab23a37542319eef/https-github-com-google-neural-tangents-issues-112.ipynb

Could you double-check you're using the latest version of NT/JAX, as in the colab above? Also, do you by any chance run this on a machine with multiple GPUs? If so, how many? (there could be some issues with parallel execution, although I'd expect different error messages)

kim-hyunsu commented 3 years ago

Thank you for the quick reply. I'm using the following versions of the packages: jax 0.2.12 jaxlib 0.1.65+cuda110 neural-tangents 0.3.6 numpy 1.19.4 and using 4 GPUs with CUDA 11.0 and Nvidia driver 450.51.05. Just in case, Python is 3.6 version and the OS is Ubuntu 18.04LTS. The packages seem the latest in my view.

romanngg commented 3 years ago

Hm, so the issue seems to be that you have 64-bit precision enabled, but fx_train_0, fx_test_0 have different types from k_train_train and k_test_train. A quick fix should be to cast them all to the same type before creating the predictor function predict_fn, e.g.

k_test_train = k_test_train.astype(fx_train_0.dtype)
k_train_train = k_train_train.astype(fx_test_0.dtype)

This appears to stem from the fact that for x64 x_train and x32 params, apply_fn(params, x) is x64, but jacobian(apply_fn)(params, x) is x32. Filed https://github.com/google/jax/issues/6638

IIUC this might also mean that k_train_train and k_test_train are computed in low, 32-bit precision, and in this case you might as well disable x64 precision alltogether for faster x32 performance.

kim-hyunsu commented 3 years ago

I tried each

k_test_train = k_test_train.astype(fx_train_0.dtype)
k_train_train = k_train_train.astype(fx_test_0.dtype)

and

from jax.config import config
config.update("jax_enable_x64", False)

Now both ways work. Thank you for the kind answer. I didn't know 64-bit precision enabled. Is that a default setting?

romanngg commented 3 years ago

I don't think so. Perhaps you had it enabled once in an ipython/colab runtime, and then ran this code in that runtime? Restarting the runtime should also reset this flag to False AFAIK.

kim-hyunsu commented 3 years ago

Ah, it turns out that I enabled it by an environment variable,

export JAX_ENABLE_X64=True

Now everything is understandable. I appreciate your help.

romanngg commented 3 years ago

I believe after https://github.com/google/jax/commit/693d2e20cf40e17b567c4a252f37a4d6b9366e5d there should be no further type mismatches like this one - stax networks will initialize weights with the type corresponding to JAX_ENABLE_X64 vs always defaulting to jnp.float32.

Here's a repro with JAX installed from head where it now works: https://colab.research.google.com/gist/romanngg/38f635cc20ba1ba667d34408728c1512/issue_112_fixed.ipynb

Thanks for noticing this!