google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.66k stars 184 forks source link

LBFGS not working for custom classes #1017

Closed Michaelhess17 closed 2 months ago

Michaelhess17 commented 2 months ago

Hello!

I am trying to fit a neural ODE network using Diffrax and Optax, but when I try to change the optimizer in the demo code (https://docs.kidger.site/diffrax/examples/latent_ode/) from optax.adam to optax.lbfgs I get an error that seems to be related to the fact that the neural network is wrapped in a custom class. Here is the stacktrace for the error:

Cell In[11], line 1
----> 1 main()

Cell In[10], line 56, in main(dataset_size, batch_size, lr, steps, save_every, hidden_size, latent_size, width_size, depth, seed)
     52 for step, (ts_i, ys_i) in zip(
     53     range(steps), dataloader((ts, ys), batch_size, key=loader_key)
     54 ):
     55     start = time.time()
---> 56     value, model, opt_state, train_key = make_step(
     57         model, opt_state, ts_i, ys_i, train_key
     58     )
     59     end = time.time()
     60     print(f"Step: {step}, Loss: {value}, Computation time: {end - start}")

    [... skipping hidden 15 frame]

Cell In[10], line 38, in main.<locals>.make_step(model, opt_state, ts_i, ys_i, key_i)
     36 value, grads = loss(model, ts_i, ys_i, key_i)
     37 key_i = jr.split(key_i, 1)[0]
---> 38 updates, opt_state = optim.update(grads, opt_state)
     39 model = eqx.apply_updates(model, updates)
     40 return value, model, opt_state, key_i

File ~/Code/optax/optax/transforms/_combining.py:73, in chain.<locals>.update_fn(updates, state, params, **extra_args)
     71 new_state = []
     72 for s, fn in zip(state, update_fns):
---> 73   updates, new_s = fn(updates, s, params, **extra_args)
     74   new_state.append(new_s)
     75 return updates, tuple(new_state)

File ~/Code/optax/optax/_src/base.py:330, in with_extra_args_support.<locals>.update(***failed resolving arguments***)
    328 def update(updates, state, params=None, **extra_args):
    329   del extra_args
--> 330   return tx.update(updates, state, params)

File ~/Code/optax/optax/_src/transform.py:1440, in scale_by_lbfgs.<locals>.update_fn(updates, state, params)
   1438 diff_params = otu.tree_sub(params, state.params)
   1439 diff_updates = otu.tree_sub(updates, state.updates)
-> 1440 vdot_diff_params_updates = otu.tree_vdot(diff_updates, diff_params)
   1441 weight = jnp.where(
   1442     vdot_diff_params_updates == 0.0, 0.0, 1./vdot_diff_params_updates
   1443 )
   1444 # params_diff, updates_diff, weight depend on differences of parameters
   1445 # that are not defined at the first iteration. Hence we keep them at 0 if
   1446 # state.count = 0.

File ~/Code/optax/optax/tree_utils/_tree_math.py:154, in tree_vdot(tree_x, tree_y)
    133 def tree_vdot(tree_x: Any, tree_y: Any) -> chex.Numeric:
    134   r"""Compute the inner product between two pytrees.
    135 
    136   Examples:
   (...)
    152   numerical issues.
    153   """
--> 154   vdots = jtu.tree_map(_vdot_safe, tree_x, tree_y)
    155   return jtu.tree_reduce(operator.add, vdots)

    [... skipping hidden 1 frame]

File ~/.conda/envs/jax/lib/python3.11/site-packages/jax/_src/tree_util.py:320, in <listcomp>(.0)
    283 """Maps a multi-input function over pytree args to produce a new pytree.
    284 
    285 Args:
   (...)
    317   - :func:`jax.tree.reduce`
    318 """
    319 leaves, treedef = tree_flatten(tree, is_leaf)
--> 320 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
    321 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

ValueError: Custom node type mismatch: expected type: <class '__main__.LatentODE'>, value: None.

I'm pretty new to JAX, so I'm not entirely sure what's going on here. It's worth noting that I was unable to install using the pip+git command (due to a timeout error connecting to GitHub), so I cloned the repo separately and ran pip install . from within the repo folder.

Python version: 3.11 JAX version: 0.4.26 Diffrax version: 0.5.0 Optax version: 0.2.4.dev

vroulet commented 2 months ago

Hello @Michaelhess17,

Thanks for pointing this out. The issue stems from the fact that the optax primitives may not handle properly the modules of equinox (e.g. optax.tree_utils.tree_sub(model, model) for model some equinox module will throw an error as it is trying to substract elements of the pytrees that are not arrays (or anything supporting a sub operation).

There is a workaround though: use equinox to filter out the elements of the module that are not arrays etc...

So here is the equinox example with lbfgs running (only difference is in the make_step function where I filter out elements of model, state or updates that are not arrays).

If that solves your issue, feel free to close the issue.

import time

import diffrax
import equinox as eqx
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import optax

matplotlib.rcParams.update({"font.size": 30})

class Func(eqx.Module):
    scale: jnp.ndarray
    mlp: eqx.nn.MLP

    def __call__(self, t, y, args):
        return self.scale * self.mlp(y)

class LatentODE(eqx.Module):
    func: Func
    rnn_cell: eqx.nn.GRUCell

    hidden_to_latent: eqx.nn.Linear
    latent_to_hidden: eqx.nn.MLP
    hidden_to_data: eqx.nn.Linear

    hidden_size: int
    latent_size: int

    def __init__(
        self, *, data_size, hidden_size, latent_size, width_size, depth, key, **kwargs
    ):
        super().__init__(**kwargs)

        mkey, gkey, hlkey, lhkey, hdkey = jr.split(key, 5)

        scale = jnp.ones(())
        mlp = eqx.nn.MLP(
            in_size=hidden_size,
            out_size=hidden_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            final_activation=jnn.tanh,
            key=mkey,
        )
        self.func = Func(scale, mlp)
        self.rnn_cell = eqx.nn.GRUCell(data_size + 1, hidden_size, key=gkey)

        self.hidden_to_latent = eqx.nn.Linear(hidden_size, 2 * latent_size, key=hlkey)
        self.latent_to_hidden = eqx.nn.MLP(
            latent_size, hidden_size, width_size=width_size, depth=depth, key=lhkey
        )
        self.hidden_to_data = eqx.nn.Linear(hidden_size, data_size, key=hdkey)

        self.hidden_size = hidden_size
        self.latent_size = latent_size

    # Encoder of the VAE
    def _latent(self, ts, ys, key):
        data = jnp.concatenate([ts[:, None], ys], axis=1)
        hidden = jnp.zeros((self.hidden_size,))
        for data_i in reversed(data):
            hidden = self.rnn_cell(data_i, hidden)
        context = self.hidden_to_latent(hidden)
        mean, logstd = context[: self.latent_size], context[self.latent_size :]
        std = jnp.exp(logstd)
        latent = mean + jr.normal(key, (self.latent_size,)) * std
        return latent, mean, std

    # Decoder of the VAE
    def _sample(self, ts, latent):
        dt0 = 0.4  # selected as a reasonable choice for this problem
        y0 = self.latent_to_hidden(latent)
        sol = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            ts[0],
            ts[-1],
            dt0,
            y0,
            saveat=diffrax.SaveAt(ts=ts),
        )
        return jax.vmap(self.hidden_to_data)(sol.ys)

    @staticmethod
    def _loss(ys, pred_ys, mean, std):
        # -log p_θ with Gaussian p_θ
        reconstruction_loss = 0.5 * jnp.sum((ys - pred_ys) ** 2)
        # KL(N(mean, std^2) || N(0, 1))
        variational_loss = 0.5 * jnp.sum(mean**2 + std**2 - 2 * jnp.log(std) - 1)
        return reconstruction_loss + variational_loss

    # Run both encoder and decoder during training.
    def train(self, ts, ys, *, key):
        latent, mean, std = self._latent(ts, ys, key)
        pred_ys = self._sample(ts, latent)
        return self._loss(ys, pred_ys, mean, std)

    # Run just the decoder during inference.
    def sample(self, ts, *, key):
        latent = jr.normal(key, (self.latent_size,))
        return self._sample(ts, latent)

def get_data(dataset_size, *, key):
    ykey, tkey1, tkey2 = jr.split(key, 3)

    y0 = jr.normal(ykey, (dataset_size, 2))

    t0 = 0
    t1 = 2 + jr.uniform(tkey1, (dataset_size,))
    ts = jr.uniform(tkey2, (dataset_size, 20)) * (t1[:, None] - t0) + t0
    ts = jnp.sort(ts)
    dt0 = 0.1

    def func(t, y, args):
        return jnp.array([[-0.1, 1.3], [-1, -0.1]]) @ y

    def solve(ts, y0):
        sol = diffrax.diffeqsolve(
            diffrax.ODETerm(func),
            diffrax.Tsit5(),
            ts[0],
            ts[-1],
            dt0,
            y0,
            saveat=diffrax.SaveAt(ts=ts),
        )
        return sol.ys

    ys = jax.vmap(solve)(ts, y0)

    return ts, ys

def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jr.permutation(key, indices)
        (key,) = jr.split(key, 1)
        start = 0
        end = batch_size
        while start < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

def main(
    dataset_size=10000,
    batch_size=256,
    lr=1e-2,
    steps=250,
    save_every=50,
    hidden_size=16,
    latent_size=16,
    width_size=16,
    depth=2,
    seed=5678,
):
    key = jr.PRNGKey(seed)
    data_key, model_key, loader_key, train_key, sample_key = jr.split(key, 5)

    ts, ys = get_data(dataset_size, key=data_key)

    model = LatentODE(
        data_size=ys.shape[-1],
        hidden_size=hidden_size,
        latent_size=latent_size,
        width_size=width_size,
        depth=depth,
        key=model_key,
    )

    @eqx.filter_value_and_grad
    def loss(model, ts_i, ys_i, key_i):
        batch_size, _ = ts_i.shape
        key_i = jr.split(key_i, batch_size)
        loss = jax.vmap(model.train)(ts_i, ys_i, key=key_i)
        return jnp.mean(loss)

    @eqx.filter_jit
    def make_step(model, opt_state, ts_i, ys_i, key_i):
        value, grads = loss(model, ts_i, ys_i, key_i)
        key_i = jr.split(key_i, 1)[0]
        grads = eqx.filter(grads, eqx.is_array)
        opt_state = eqx.filter(opt_state, eqx.is_array)
        model_ = eqx.filter(model, eqx.is_array)
        updates, opt_state = optim.update(grads, opt_state, model_)
        model = eqx.apply_updates(model, updates)
        return value, model, opt_state, key_i

    optim = optax.lbfgs(learning_rate=1e-3, linesearch=None)
    opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

    # Plot results
    num_plots = 1 + (steps - 1) // save_every
    if ((steps - 1) % save_every) != 0:
        num_plots += 1
    fig, axs = plt.subplots(1, num_plots, figsize=(num_plots * 8, 8))
    axs[0].set_ylabel("x")
    axs = iter(axs)
    for step, (ts_i, ys_i) in zip(
        range(steps), dataloader((ts, ys), batch_size, key=loader_key)
    ):
        start = time.time()
        value, model, opt_state, train_key = make_step(
            model, opt_state, ts_i, ys_i, train_key
        )
        end = time.time()
        print(f"Step: {step}, Loss: {value}, Computation time: {end - start}")

        if (step % save_every) == 0 or step == steps - 1:
            ax = next(axs)
            # Sample over a longer time interval than we trained on. The model will be
            # sufficiently good that it will correctly extrapolate!
            sample_t = jnp.linspace(0, 12, 300)
            sample_y = model.sample(sample_t, key=sample_key)
            sample_t = np.asarray(sample_t)
            sample_y = np.asarray(sample_y)
            ax.plot(sample_t, sample_y[:, 0])
            ax.plot(sample_t, sample_y[:, 1])
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_xlabel("t")

    plt.savefig("latent_ode.png")
    plt.show()

main()