Open michaeldeistler opened 5 months ago
Do you have a hunch, where we need so much precision that float32 is not sufficient? Sounds like a suitable transform of different units might be able to take care of this and could save valuable GPU memory. If you have an example, I could look into it. I just saw NEURON uses double as well.
This notebook will give NaN if either one uses float32 or if one changes the max_value
of save_exp
e.g. to 100 or even more
I do not have a clear idea of why the NaN
happen. The most concrete thing I observed so far is the following (and it is the reason I introduced the save_exp
:
In our channels, e.g. in the sodium part of HH()
, we often take an exp()
. This value can become huge if the neuron is depolarized beyond the normal range of voltages. For example if v=-300mV
, one can get gigantic values in this exp:
v = -300
x = -(v + 35) / 10) + 1 = 27.5
# exp(x) -> huge
Especially when using float32
these huge values are a problem and make things unstable. Such strong negative values can happen if extremely strong inhibitory synapses are being learned.
A very simple reproducible example should look sth like this (but you need to turn of the save_exp
from the channels; I did not test this right now):
comp = jx.Compartment()
# Strong negative current
current = jx.step_current(0.0, 10.0, -10.0, dt, t_max)
comp.stimulate(current)
comp.insert(HH())
comp.record()
That being said, I have also observed NaN
in single neurons (with morph detail) when using float32, but I don't have an example right now.
Interesting, might look into this. Thanks for all the hints.
I took a deep dive into this, using your example.
# Stimulus.
i_delay = 3.0 # ms
i_amp = 0.05 # nA
i_dur = 2.0 # ms
# Duration and step size.
dt = 0.025 # ms
t_max = 10.0 # ms
time_vec = jnp.arange(0.0, t_max+dt, dt)
comp = jx.Compartment()
comp.stimulate(current)
comp.insert(ModHH(clip_exp=clip_exp))
Leads to NaNs at some point, with higher precisions essentially just delaying this. Save exponential dont have this problem, since they prevent the floating point format from topping out after a while.
Here is what it essentially comes down to:
def solve_gate_exponential(
x: jnp.ndarray,
dt: float,
alpha: jnp.ndarray,
beta: jnp.ndarray,
):
tau = 1 / (alpha + beta) # <--- alpha can become inf here, hence tau becomes 0
xinf = alpha * tau # <--- this means tau*alpha = inf*0, which is nan
return exponential_euler(x, dt, xinf, tau)
The crux of the issue I think is alpha topping out the floating point format in the gates already, as you pointed out as well.
EDIT: Idea for better safe_exp
:
def save_exp(x):
"""Clip the input below maximum value that dtype can support."""
# floored log of maximum value that can be represented by the dtype.
max_value = 88.0 if x.dtype == np.float32 else 709.0
x = jnp.clip(x, a_max=max_value)
return jnp.exp(x)
max_value can be obtained from: max_dtype_val = np.finfo(dtype).max; max_value = np.floor(np.log(max_dtype_val))
I spend some time looking into this
This notebook will give NaN if either one uses float32 or if one changes the
max_value
ofsave_exp
e.g. to 100 or even more
I don't exactly know why one get's NaNs yet, but I was at least able to trace it to the checkpointing. Here is a minimally reproducing example that you can run in the notebook above if you have set up the network:
config.update("jax_enable_x64", False)
# expose checkpointing and stim_duration as kwargs in loss
def cross_entropy_loss(opt_params, image, label, t_dur=2.301, checkpoint_lengths=None):
params = transform.forward(opt_params)
def simulate(params, image):
tau = 500.0
i_amp = 10.0 / tau
currents = jx.datapoint_to_step_currents(0.1, 1.0, i_amp*image, dt, t_dur)
data_stimuli = net[range(784), 0, 0].data_stimulate(currents, None)
return jx.integrate(net, params=params, data_stimuli=data_stimuli, tridiag_solver="thomas", checkpoint_lengths=checkpoint_lengths)
vs = simulate(params, image)
prediction = vs[:, -1]
prediction += 60.0
prediction /= 10.0
log_prob = prediction[label] - logsumexp(prediction)
return -log_prob
# For a set of parameters
with open(f"results/parameters/tmp_state_2.pkl", "rb") as handle:
opt_params, batch = pickle.load(handle)
# ... on a specific training pair
image_batch, label_batch = tfds.as_numpy(batch)
image, label = image_batch[1], label_batch[1]
# computing the cross_entropy_loss works
l = cross_entropy_loss(opt_params, image, label, t_dur=2.301)
# BUT: Computing the gradient fails, raising a `FloatingPointError`
try:
grads = grad(cross_entropy_loss, argnums=0)(opt_params, image, label)
except FloatingPointError as e:
print(e)
# Either shortening the time or
grads = grad(cross_entropy_loss, argnums=0)(opt_params, image, label, t_dur=2.300)
# ... using multiple checkpoints fixes this issue
grads = grad(cross_entropy_loss, argnums=0)(opt_params, image, label, t_dur=2.301, checkpoint_lengths=[103,4])
The model_checkpoint, where this happens is the following, which you also get, by running the training in the notebook. tmp_state_2.zip
I have not looked at why changing the save_exp
affects this, but its somehow linked to checkpointing, as adding multiple levels seems to fix this issue. Also, the FloatingPointError
originates here https://github.com/jaxleyverse/jaxley/blob/72278f8cd560b6e3885a3af7ec37ac7ec7a55df1/jaxley/utils/jax_utils.py#L65-L66 and happens somewhere in lax.scan
Looking forward to hear your thoughts on this.
I did even more NaN
chasing. Prompted by @manuelgloeckler, who found that https://github.com/mackelab/jaxley_experiments/blob/main/nex/smc/smc_allen_experimental.py yielded NaN
s in the simluations sometimes.
I looked into one particular example (see plot).
When I running the jitted version of jx.integrate, no NaNs are returned. However, running the unjitted version, returns NaNs after some point. As pointed out above, the issue was somewhere in lax.scan . Running lax.scan on the unjitted body_fun, yields NaNs, while, running lax.scan until the last good output, and then running body_fun once, does not yield NaNs in the output, while running lax.scan or jit(body_fun) on the problematic inputs, does (thanks @manuelgloeckler for the pointer)! Thanks @michaeldeistler for the hint about tridiax solvers, because this issue is only present for the implicit solve. Have not gotten further than this yet though.
I have not checked with unroll=True since it takes ages.
you can find the notebook here if you want to have a look. https://github.com/mackelab/jaxley_experiments/blob/fix_nans/nex/smc/nan_issues.ipynb
Below are several reasons for experiencing
NaN
, ranked from most to least likely. If you have tried out all of the options below, we would be happy to receive a bug report (as an issue on Github) with the following information:NaN
occur during training or during simulation?1. You are using
float32
.In most cases where I have encountered
NaN
so far,NaN
were resolved by switching tofloat64
:2. The mechanisms (channels or synapses) are unstable.
The most likely reason for this is that the channels you are using contain a
jnp.exp()
whose input gets very large (>100.0) such that the result will beinf
. For example, our initial implementation of channels had givenNaN
when a strong negative stimulus was inserted, such that the neuron was very strongly hyperpolarized (voltages below -200mV). You can prevent this by chaningjnp.exp()
tosave_exp
fromjaxley.solver_gate.py
.3. The
ParamTransform
saturates:To debug this, print the maximum value of the transformed params after every gradient update:
If
max_val > 50.0
then you are probably in trouble.