patrick-kidger / equinox

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/
Apache License 2.0
2.12k stars 142 forks source link

using batchnorm/dropout layers from flax.linen along with diffrax package #394

Open Negar-Erfanian opened 1 year ago

Negar-Erfanian commented 1 year ago

Hi Patrick,

I am using the diffrax ode solver in my code and will need to use batchnorm/dropout layers in the function that will be passed to the solver. However this is the error I am getting:

solution = diffrax.diffeqsolve( File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/equinox/_jit.py", line 99, in call return self._call(False, args, kwargs) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/equinox/_jit.py", line 95, in _call out = self._cached(dynamic, static) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/equinox/_jit.py", line 37, in fun_wrapped out = fun(*args, **kwargs) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/diffrax/integrate.py", line 676, in diffeqsolve solver_state = solver.init(terms, t0, tnext, y0, args) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/diffrax/solver/runge_kutta.py", line 269, in init return terms.vf(t0, y0, args) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/diffrax/term.py", line 364, in vf return self.term.vf(t, y, args) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/diffrax/term.py", line 173, in vf return self.vector_field(t, y, args) File "/data/ne12/Kuramoto/model/neuralODE.py", line 56, in fn y0 = self.batchnorm(y0, use_running_average=not training) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/flax/linen/normalization.py", line 256, in call ra_mean = self.variable('batch_stats', 'mean', File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/flax/core/tracers.py", line 36, in check_trace_level raise errors.JaxTransformError() flax.errors.JaxTransformError: Jax transforms and Flax models cannot be mixed. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.JaxTransformError)

I am not using equinox but instead using flax.linen.Module to build up my Neural Network.

`class NeuralODE(nn.Module):
act_fn: Callable actfn: Callable node_size: int hdims : str batchnorm : Callable = nn.BatchNorm() dropout : Callable = nn.Dropout(0.1) @nn.compact def call(self, ts, ys, training): y, args = ys if self.hdims!=None: hdims_list = [int(i) for i in self.hdims.split('-')] else: hdims_list = [] kernels = [] first_dim = self.node_size for i, dim in enumerate(hdims_list): kernels.append(self.param(f'kernel{i}', nn.initializers.normal(), [dim, self.node_size, first_dim])) first_dim = dim kernels.append(self.param('kernel', nn.initializers.normal(), [self.node_size, self.node_size, first_dim])) def fn(t, y, args): y0 = y bias, data_adj = args if len(y0.shape) == 2: y0 = jnp.expand_dims(y0, -1) elif len(y0.shape) == 1: y0 = jnp.expand_dims(jnp.expand_dims(y0, -1), 0) for kernel in kernels: y0 = jnp.einsum('ijk,lmj->ilm', y0, kernel) y0 = self.act_fn(y0) y0 = self.batchnorm(y0, use_running_average=not training) y0 = self.dropout(y0, deterministic=not training)

        if y0.shape[0] == 1:
            y0 = jnp.squeeze(y0, 0)
        if len(y0.shape) == 2:
            out = jnp.einsum('ij,ij->ij', data_adj, y0).sum(-1)
        elif len(y0.shape) == 3:
            out = jnp.einsum('aij,aij->aij', data_adj, y0).sum(-1)
        out = jnp.squeeze(bias, -1) - out  # B*N

        return out#, bias, data_adj

    solution = diffrax.diffeqsolve(
        diffrax.ODETerm(fn),
        diffrax.Dopri5(),
        t0=ts[0],
        t1=ts[-1],
        dt0=0.01,  # ts[1] - ts[0],
        y0=y,
        args=args,
        stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6, dtmax=0.1),
        saveat=diffrax.SaveAt(ts=ts),
        made_jump=True,

    )
    return solution.ys`

The error raises when I init the model using model.init(model_key, ts, inputs, training= False).

Any idea how to solve this? Thanks so much.

patrick-kidger commented 1 year ago

You will need to convert your Flax module into init/apply style -- don't call your Flax layers directly.

(This is basically the standard way to convert Flax models into something compatible with the rest of JAX; I expect Flax has some documentation on how to do this.)

Negar-Erfanian commented 1 year ago

Thanks for your reply. However, I don't believe this is the problem as I am using init and apply and I don't use the function directly. If you trace the error, the error raises as the solver is being initiated in

solver_state = solver.init(terms, t0, tnext, y0, args) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/diffrax/solver/runge_kutta.py", line 269, in init return terms.vf(t0, y0, args) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/diffrax/term.py", line 364, in vf return self.term.vf(t, y, args) File "/data/ne12/.conda/envs/Negar2/lib/python3.8/site-packages/diffrax/term.py", line 173, in vf return self.vector_field(t, y, args)

and I am using the batchnorm/dropout in the function that is being sent to the solver. Have you used any flax layers in the functions that you send to diffrax solvers before?

patrick-kidger commented 1 year ago

Your example is misformatted, but it looks rather like you are using the Flax layers directly inside Diffrax, on the lines

y0 = self.act_fn(y0)
y0 = self.batchnorm(y0,

.

You will need to convert to init/apply style before that. (Or just use Equinox.)

Negar-Erfanian commented 1 year ago

What you are suggesting is to init/apply once before sending fn to Diffrax, and again having my network (which has the solver in it) I do init/apply again? Would you please give me an example of how you would write the init/apply code for batchnorm before sending the function to Diffrax?

patrick-kidger commented 1 year ago

I'm afraid that's a Flax issue -- I'd have to figure out that one just like you would.

(There's a reason I wrote Equinox at the same time as Diffrax.)

Negar-Erfanian commented 1 year ago

Got it, thank you very much. It looks like there's not an easy way to switch between layers in flax and diffrax solver as seemingly diffrax is Equinox based. So initializing the params in specific layers such as batchnorm or dropout taken from flax will cause incompatibility when entering the solver. So before I switch to Equinox I would like to ask whether there's anything similar in Equinox as train_state in Flax so that we use states in Equinox too? Thank you so much for your reply.

cgarciae commented 1 year ago

Hey @Negar-Erfanian, here is an example using Diffrax from Flax:

https://github.com/google/flax/discussions/2891#discussioncomment-5066071

Using Flax + Diffrax within Equinox should also be possible as you would just be using Flax's params pytree, but I cannot read your code to figure out the problem 😅

Negar-Erfanian commented 1 year ago

Hey @Negar-Erfanian, here is an example using Diffrax from Flax:

google/flax#2891 (comment)

Using Flax + Diffrax within Equinox should also be possible as you would just be using Flax's params pytree, but I cannot read your code to figure out the problem 😅

Dear @cgarciae, thank you for your reply. I wrote to your comment in that Thread based on the example you have given there. Thanks in advance.

Negar-Erfanian commented 1 year ago

hi @patrick-kidger,

I am trying to convert my code from flax to Equinox. However I cannot find the equivalence form of Module.param of Flax in Equinox. How can I define parameters in the neural network from scratch that are trainable and how to initialize them? do you have an example you could send me? thanks!

patrick-kidger commented 1 year ago

Just initialise them as normal JAX arrays. Here's an example.

Also just take a look through the equinox.nn source code, e.g. for equinox.nn.MLP, which should prove instructive.

Negar-Erfanian commented 1 year ago

Hi @patrick-kidger,

Do you have an example using stateful operations with neural odes? (other than this) I am a bit confused about how to use the state in the call function when passing it to the solver. I am in equinox now. Thanks!

patrick-kidger commented 1 year ago

What stateful operation are you trying to do?

If you're not careful then it's not necessarily defined what you mean -- you can't arbitrarily insert stateful operations and still have an ODE.

Negar-Erfanian commented 1 year ago

Hi @patrick-kidger

Thanks for your reply. So my aim is to learn a bunch of oscillators interacting. Their characteristics is defined by an ode. However, for 101 I am defining the weights myself to simulate and aiming to learn the weights. However, in my first try without dropout, it was not learning the wights and was overfitting to the training data (the same weight tensor is defined for train and test sets). I put a dropout there, the overfitting is fine now but still the weight is not being learnt. I have not used any sort of normalization there, that's why I was wondering maybe a batchnorm might help? seemingly it's a stateful operation.

patrick-kidger commented 1 year ago

Batch norm probably won't help I'm afraid -- it's generally less useful for neural ODEs.

Moreover in non-neural-ODE contexts it's generally out-of-favour these days, and LayerNorm is more commonly used.

You'll probably want to try the usual things to combat overfitting -- get more data, regularise, smaller network, etc. etc.