rdyro / torch2jax

Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
39 stars 1 forks source link

Donate Args Causing Repeated Spawning of Warning #19

Open adam-hartshorne opened 3 weeks ago

adam-hartshorne commented 3 weeks ago

I have recently discovered that when you jit your model, you can use donate_args,

https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html

functionality to reuse the memory, as reading the following,

First, we're going to arrange to "donate" memory, which specifies that we can re-use the memory for our input arrays (e.g. model parameters) to store the output arrays (e.g. updated model parameters). (This isn't technically related to autoparallelism, but it's good practice so you should do it anyway :)

https://docs.kidger.site/equinox/examples/parallelism/

Everything appears to work as expected, except if I use donate args functionality, the following warnings spawns every iteration (usually normally I just get this on first JITing).

/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/wrap_torch2jax/gradients.py:172: UserWarning: You are NOT using PyTorch's functional VJP. This is highly experimental.
  warn("You are NOT using PyTorch's functional VJP. This is highly experimental.")

Speed of optimisation and progress in minimising the loss are working as before / as expected. I can obviously suppress the warning, but so I think it's more an annoyance that shouldn't be happening than anything else.

rdyro commented 3 weeks ago

Cool experiment, I think donate args should work, but I haven't tested it yet, it's probably good to introduce a test for it.

This warning comes from my attempts (unsuccessful so far) to marry pytorch's and jax's batching transforms. I'm hoping to figure out how to do this properly.

It's weird this warning is triggering like this, it look like it's too informative in this context.

rdyro commented 2 weeks ago

I just tried a simplified example, but I don't get the warning:

import time
from collections import OrderedDict
import functools

import jax
import torch
import torch2jax
from torch2jax import t2j
import torchopt

if __name__ == "__main__":
  X = torch.randn(100, 10)
  y = torch.randint(0, 2, (100,))

  model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), 
                              torch.nn.Linear(5, 2))
  loss_fn = torch.nn.CrossEntropyLoss()
  state = jax.tree.map(torch2jax.t2j, model.state_dict())
  torch_state = jax.tree.map(torch2jax.j2t, state)
  optimizer = torchopt.adam(lr=1e-3)
  optimizer_state = optimizer.init(list(torch_state.values()))

  def train_step(params, X, y_target):
    global optimizer_state
    params_ = [x for x in params.values()]
    for param in params_:
      param.requires_grad = True
    y_pred = torch.func.functional_call(model, params, X)
    loss = loss_fn(y_pred, y_target)
    loss.backward()
    grads_ = [param_.grad for param_ in params_]
    updates_, optimizer_state = optimizer.update(grads_, optimizer_state)
    old_params_ = params_ 
    # watch out! this modifies parameters in place
    new_params_ = torchopt.apply_updates(old_params_, updates_)
    # without OrderedDict, the pytree mechanism will sort keys
    new_params = OrderedDict(zip(params.keys(), new_params_))
    return new_params

  train_step_jax = jax.jit(
    torch2jax.torch2jax(train_step, torch_state, X, y, output_shapes=state), 
    donate_argnums=(0,))

  @functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(3,))
  def train_steps(params, X, y_target, num_steps):
    return jax.lax.scan(lambda params, _: (train_step_jax(params, X, y_target), 
                                           None), params, length=num_steps)

  num_steps = 10000
  X_, y_ = t2j(X), t2j(y)
  train_steps_compiled = train_steps.lower(state, X_, y_, num_steps).compile()

  t = time.time()
  params_ = train_steps_compiled(state, X_, y_)  # num_steps compiled in
  #params_ = train_steps(state, X_, y_, num_steps)
  t = time.time() - t

  print(f"Time per step: {t/num_steps:.4e}")

Do you have a small repro by any chance?

adam-hartshorne commented 2 weeks ago

Sorry for the slow response. I have a paper deadline coming up. With give reply ASAP.

adam-hartshorne commented 1 week ago

Here is a some example code that spawns this issue. It is just this example, https://docs.kidger.site/diffrax/examples/neural_ode/ , with MSE loss function in jax swapped out to a call to PyTorch.

import time
import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr

import torch

import matplotlib.pyplot as plt
import optax  # https://github.com/deepmind/optax

from torch2jax import torch2jax, torch2jax_with_vjp  # this converts a Python function to JAX
from torch2jax import Size, dtype_t2j

class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.softplus,
            key=key,
        )

    def __call__(self, t, y, args=None):
        return self.mlp(y)

class NeuralODE(eqx.Module):
    func: Func

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(data_size, width_size, depth, key=key)

    def __call__(self, ts, y0):
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            t0=ts[0],
            t1=ts[-1],
            dt0=ts[1] - ts[0],
            y0=y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts),
        )
        return solution.ys

def torch_fn(x, y):
    return torch.mean(torch.square(x - y)).reshape(1)

def test_mse(x, y):

    torch_cost_fn = torch2jax_with_vjp(
        torch_fn,
        jax.ShapeDtypeStruct(x.shape, dtype_t2j(x.dtype)),
        jax.ShapeDtypeStruct(y.shape, dtype_t2j(y.dtype)),
        output_shapes=jax.ShapeDtypeStruct((1,), dtype_t2j(x.dtype)),
        depth=2,
        use_torch_vjp=False,
    )

    return torch_cost_fn(x, y)

def _get_data(ts, *, key):
    y0 = jr.uniform(key, (2,), minval=-0.6, maxval=1)

    def f(t, y, args):
        x = y / (1 + y)
        return jnp.stack([x[1], -x[0]], axis=-1)

    solver = diffrax.Tsit5()
    dt0 = 0.1
    saveat = diffrax.SaveAt(ts=ts)
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat
    )
    ys = sol.ys
    return ys

def get_data(dataset_size, *, key):
    ts = jnp.linspace(0, 10, 100)
    key = jr.split(key, dataset_size)
    ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
    return ts, ys

def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jr.permutation(key, indices)
        (key,) = jr.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

def main(
    dataset_size=256,
    batch_size=32,
    lr_strategy=(3e-3, 3e-3),
    steps_strategy=(500, 500),
    length_strategy=(0.1, 1),
    width_size=64,
    depth=2,
    seed=5678,
    plot=True,
    print_every=100,
):
    key = jr.PRNGKey(seed)
    data_key, model_key, loader_key = jr.split(key, 3)

    ts, ys = get_data(dataset_size, key=data_key)
    _, length_size, data_size = ys.shape

    model = NeuralODE(data_size, width_size, depth, key=model_key)

    # Training loop like normal.
    #
    # Only thing to notice is that up until step 500 we train on only the first 10% of
    # each time series. This is a standard trick to avoid getting caught in a local
    # minimum.

    @eqx.filter_value_and_grad
    def grad_loss(model, ti, yi):
        y_pred = jax.vmap(model, in_axes=(None, 0))(ti, yi[:, 0])
        mse = test_mse(yi, y_pred)
        return jnp.mean(mse)

    @eqx.filter_jit(donate='all')
    def make_step(ti, yi, model, opt_state):
        loss, grads = grad_loss(model, ti, yi)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):
        optim = optax.adabelief(lr)
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
        _ts = ts[: int(length_size * length)]
        _ys = ys[:, : int(length_size * length)]
        for step, (yi,) in zip(
            range(steps), dataloader((_ys,), batch_size, key=loader_key)
        ):
            start = time.time()
            loss, model, opt_state = make_step(_ts, yi, model, opt_state)
            end = time.time()
            if (step % print_every) == 0 or step == steps - 1:
                print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")

    if plot:
        plt.plot(ts, ys[0, :, 0], c="dodgerblue", label="Real")
        plt.plot(ts, ys[0, :, 1], c="dodgerblue")
        model_y = model(ts, ys[0, 0])
        plt.plot(ts, model_y[:, 0], c="crimson", label="Model")
        plt.plot(ts, model_y[:, 1], c="crimson")
        plt.legend()
        plt.tight_layout()
        plt.savefig("neural_ode.png")
        plt.show()

    return ts, ys, model

# __main__
if __name__ == "__main__":
    ts, ys, model = main()