jaxleyverse / jaxley

Differentiable neuron simulations with biophysical detail on CPU, GPU, or TPU.
https://jaxleyverse.github.io/jaxley/
Apache License 2.0
22 stars 1 forks source link

Potential reasons for `NaN` during training #317

Open michaeldeistler opened 5 months ago

michaeldeistler commented 5 months ago

Below are several reasons for experiencing NaN, ranked from most to least likely. If you have tried out all of the options below, we would be happy to receive a bug report (as an issue on Github) with the following information:

1. You are using float32.

In most cases where I have encountered NaN so far, NaN were resolved by switching to float64:

from jax import config
config.update("jax_enable_x64", True)

2. The mechanisms (channels or synapses) are unstable.

The most likely reason for this is that the channels you are using contain a jnp.exp() whose input gets very large (>100.0) such that the result will be inf. For example, our initial implementation of channels had given NaN when a strong negative stimulus was inserted, such that the neuron was very strongly hyperpolarized (voltages below -200mV). You can prevent this by chaning jnp.exp() to save_exp from jaxley.solver_gate.py.

3. The ParamTransform saturates:

transform = jx.ParamTransform(lowers={"x": -1.0}, uppers={"x": 1.0})
def tf(params):
    return jnp.sum(transform.forward(params)[0]["x"])

# Interestingly, only negative values return `NaN` gradient, positive values return `0`.
p = [{"x": jnp.asarray([-100.0])}]

tf_grad_fn = value_and_grad(tf)
print(tf(p))
print(tf_grad_fn(p))

To debug this, print the maximum value of the transformed params after every gradient update:

parameters = net.get_parameters()
opt_params = transform.inverse(parameters)

leaves, _ = tree_util.tree_flatten(x)
max_val = jnp.max([jnp.abs(leaf) for leaf in leaves])

If max_val > 50.0 then you are probably in trouble.

jnsbck commented 5 months ago

Do you have a hunch, where we need so much precision that float32 is not sufficient? Sounds like a suitable transform of different units might be able to take care of this and could save valuable GPU memory. If you have an example, I could look into it. I just saw NEURON uses double as well.

michaeldeistler commented 5 months ago

This notebook will give NaN if either one uses float32 or if one changes the max_value of save_exp e.g. to 100 or even more

michaeldeistler commented 5 months ago

I do not have a clear idea of why the NaN happen. The most concrete thing I observed so far is the following (and it is the reason I introduced the save_exp:

In our channels, e.g. in the sodium part of HH(), we often take an exp(). This value can become huge if the neuron is depolarized beyond the normal range of voltages. For example if v=-300mV, one can get gigantic values in this exp:

v = -300
x = -(v + 35) / 10) + 1 = 27.5
# exp(x) -> huge

Especially when using float32 these huge values are a problem and make things unstable. Such strong negative values can happen if extremely strong inhibitory synapses are being learned.

michaeldeistler commented 5 months ago

A very simple reproducible example should look sth like this (but you need to turn of the save_exp from the channels; I did not test this right now):

comp = jx.Compartment()

# Strong negative current
current = jx.step_current(0.0, 10.0, -10.0, dt, t_max)
comp.stimulate(current)

comp.insert(HH())
comp.record()
michaeldeistler commented 5 months ago

That being said, I have also observed NaN in single neurons (with morph detail) when using float32, but I don't have an example right now.

jnsbck commented 5 months ago

Interesting, might look into this. Thanks for all the hints.

jnsbck commented 5 months ago

I took a deep dive into this, using your example.

# Stimulus.
i_delay = 3.0  # ms
i_amp = 0.05  # nA
i_dur = 2.0  # ms

# Duration and step size.
dt = 0.025  # ms
t_max = 10.0  # ms

time_vec = jnp.arange(0.0, t_max+dt, dt)

comp = jx.Compartment()
comp.stimulate(current)
comp.insert(ModHH(clip_exp=clip_exp))

Leads to NaNs at some point, with higher precisions essentially just delaying this. image Save exponential dont have this problem, since they prevent the floating point format from topping out after a while.

Here is what it essentially comes down to:

def solve_gate_exponential(
    x: jnp.ndarray,
    dt: float,
    alpha: jnp.ndarray,
    beta: jnp.ndarray,
):
    tau = 1 / (alpha + beta) # <--- alpha can become inf here, hence tau becomes 0
    xinf = alpha * tau # <--- this means tau*alpha = inf*0, which is nan
    return exponential_euler(x, dt, xinf, tau)

The crux of the issue I think is alpha topping out the floating point format in the gates already, as you pointed out as well.

EDIT: Idea for better safe_exp:

def save_exp(x):
    """Clip the input below maximum value that dtype can support."""
    # floored log of maximum value that can be represented by the dtype.
    max_value = 88.0 if x.dtype == np.float32 else 709.0
    x = jnp.clip(x, a_max=max_value)
    return jnp.exp(x)

max_value can be obtained from: max_dtype_val = np.finfo(dtype).max; max_value = np.floor(np.log(max_dtype_val))

jnsbck commented 5 months ago

I spend some time looking into this

This notebook will give NaN if either one uses float32 or if one changes the max_value of save_exp e.g. to 100 or even more

I don't exactly know why one get's NaNs yet, but I was at least able to trace it to the checkpointing. Here is a minimally reproducing example that you can run in the notebook above if you have set up the network:

config.update("jax_enable_x64", False)

# expose checkpointing and stim_duration as kwargs in loss
def cross_entropy_loss(opt_params, image, label, t_dur=2.301, checkpoint_lengths=None):
    params = transform.forward(opt_params)

    def simulate(params, image):
        tau = 500.0
        i_amp = 10.0 / tau
        currents = jx.datapoint_to_step_currents(0.1, 1.0, i_amp*image, dt, t_dur)
        data_stimuli = net[range(784), 0, 0].data_stimulate(currents, None)
        return jx.integrate(net, params=params, data_stimuli=data_stimuli, tridiag_solver="thomas", checkpoint_lengths=checkpoint_lengths)

    vs = simulate(params, image)
    prediction = vs[:, -1]
    prediction += 60.0
    prediction /= 10.0
    log_prob = prediction[label] - logsumexp(prediction)
    return -log_prob

# For a set of parameters
with open(f"results/parameters/tmp_state_2.pkl", "rb") as handle:
    opt_params, batch = pickle.load(handle)

# ... on a specific training pair
image_batch, label_batch = tfds.as_numpy(batch)
image, label = image_batch[1], label_batch[1]

# computing the cross_entropy_loss works
l = cross_entropy_loss(opt_params, image, label, t_dur=2.301)

# BUT: Computing the gradient fails, raising a `FloatingPointError`
try:
    grads = grad(cross_entropy_loss, argnums=0)(opt_params, image, label)
except FloatingPointError as e:
    print(e)

# Either shortening the time or 
grads = grad(cross_entropy_loss, argnums=0)(opt_params, image, label, t_dur=2.300)
# ... using multiple checkpoints fixes this issue
grads = grad(cross_entropy_loss, argnums=0)(opt_params, image, label, t_dur=2.301, checkpoint_lengths=[103,4])

The model_checkpoint, where this happens is the following, which you also get, by running the training in the notebook. tmp_state_2.zip

I have not looked at why changing the save_exp affects this, but its somehow linked to checkpointing, as adding multiple levels seems to fix this issue. Also, the FloatingPointError originates here https://github.com/jaxleyverse/jaxley/blob/72278f8cd560b6e3885a3af7ec37ac7ec7a55df1/jaxley/utils/jax_utils.py#L65-L66 and happens somewhere in lax.scan

Looking forward to hear your thoughts on this.

jnsbck commented 4 months ago

I did even more NaN chasing. Prompted by @manuelgloeckler, who found that https://github.com/mackelab/jaxley_experiments/blob/main/nex/smc/smc_allen_experimental.py yielded NaNs in the simluations sometimes.

I looked into one particular example (see plot). image

When I running the jitted version of jx.integrate, no NaNs are returned. However, running the unjitted version, returns NaNs after some point. As pointed out above, the issue was somewhere in lax.scan . Running lax.scan on the unjitted body_fun, yields NaNs, while, running lax.scan until the last good output, and then running body_fun once, does not yield NaNs in the output, while running lax.scan or jit(body_fun) on the problematic inputs, does (thanks @manuelgloeckler for the pointer)! Thanks @michaeldeistler for the hint about tridiax solvers, because this issue is only present for the implicit solve. Have not gotten further than this yet though.

I have not checked with unroll=True since it takes ages.

you can find the notebook here if you want to have a look. https://github.com/mackelab/jaxley_experiments/blob/fix_nans/nex/smc/nan_issues.ipynb

jnsbck commented 3 months ago

From diffrax

image