Open adam-hartshorne opened 3 weeks ago
Cool experiment, I think donate args should work, but I haven't tested it yet, it's probably good to introduce a test for it.
This warning comes from my attempts (unsuccessful so far) to marry pytorch's and jax's batching transforms. I'm hoping to figure out how to do this properly.
It's weird this warning is triggering like this, it look like it's too informative in this context.
I just tried a simplified example, but I don't get the warning:
import time
from collections import OrderedDict
import functools
import jax
import torch
import torch2jax
from torch2jax import t2j
import torchopt
if __name__ == "__main__":
X = torch.randn(100, 10)
y = torch.randint(0, 2, (100,))
model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(),
torch.nn.Linear(5, 2))
loss_fn = torch.nn.CrossEntropyLoss()
state = jax.tree.map(torch2jax.t2j, model.state_dict())
torch_state = jax.tree.map(torch2jax.j2t, state)
optimizer = torchopt.adam(lr=1e-3)
optimizer_state = optimizer.init(list(torch_state.values()))
def train_step(params, X, y_target):
global optimizer_state
params_ = [x for x in params.values()]
for param in params_:
param.requires_grad = True
y_pred = torch.func.functional_call(model, params, X)
loss = loss_fn(y_pred, y_target)
loss.backward()
grads_ = [param_.grad for param_ in params_]
updates_, optimizer_state = optimizer.update(grads_, optimizer_state)
old_params_ = params_
# watch out! this modifies parameters in place
new_params_ = torchopt.apply_updates(old_params_, updates_)
# without OrderedDict, the pytree mechanism will sort keys
new_params = OrderedDict(zip(params.keys(), new_params_))
return new_params
train_step_jax = jax.jit(
torch2jax.torch2jax(train_step, torch_state, X, y, output_shapes=state),
donate_argnums=(0,))
@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(3,))
def train_steps(params, X, y_target, num_steps):
return jax.lax.scan(lambda params, _: (train_step_jax(params, X, y_target),
None), params, length=num_steps)
num_steps = 10000
X_, y_ = t2j(X), t2j(y)
train_steps_compiled = train_steps.lower(state, X_, y_, num_steps).compile()
t = time.time()
params_ = train_steps_compiled(state, X_, y_) # num_steps compiled in
#params_ = train_steps(state, X_, y_, num_steps)
t = time.time() - t
print(f"Time per step: {t/num_steps:.4e}")
Do you have a small repro by any chance?
Sorry for the slow response. I have a paper deadline coming up. With give reply ASAP.
Here is a some example code that spawns this issue. It is just this example, https://docs.kidger.site/diffrax/examples/neural_ode/ , with MSE loss function in jax swapped out to a call to PyTorch.
import time
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import torch
import matplotlib.pyplot as plt
import optax # https://github.com/deepmind/optax
from torch2jax import torch2jax, torch2jax_with_vjp # this converts a Python function to JAX
from torch2jax import Size, dtype_t2j
class Func(eqx.Module):
mlp: eqx.nn.MLP
def __init__(self, data_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.mlp = eqx.nn.MLP(
in_size=data_size,
out_size=data_size,
width_size=width_size,
depth=depth,
activation=jnn.softplus,
key=key,
)
def __call__(self, t, y, args=None):
return self.mlp(y)
class NeuralODE(eqx.Module):
func: Func
def __init__(self, data_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.func = Func(data_size, width_size, depth, key=key)
def __call__(self, ts, y0):
solution = diffrax.diffeqsolve(
diffrax.ODETerm(self.func),
diffrax.Tsit5(),
t0=ts[0],
t1=ts[-1],
dt0=ts[1] - ts[0],
y0=y0,
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
saveat=diffrax.SaveAt(ts=ts),
)
return solution.ys
def torch_fn(x, y):
return torch.mean(torch.square(x - y)).reshape(1)
def test_mse(x, y):
torch_cost_fn = torch2jax_with_vjp(
torch_fn,
jax.ShapeDtypeStruct(x.shape, dtype_t2j(x.dtype)),
jax.ShapeDtypeStruct(y.shape, dtype_t2j(y.dtype)),
output_shapes=jax.ShapeDtypeStruct((1,), dtype_t2j(x.dtype)),
depth=2,
use_torch_vjp=False,
)
return torch_cost_fn(x, y)
def _get_data(ts, *, key):
y0 = jr.uniform(key, (2,), minval=-0.6, maxval=1)
def f(t, y, args):
x = y / (1 + y)
return jnp.stack([x[1], -x[0]], axis=-1)
solver = diffrax.Tsit5()
dt0 = 0.1
saveat = diffrax.SaveAt(ts=ts)
sol = diffrax.diffeqsolve(
diffrax.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat
)
ys = sol.ys
return ys
def get_data(dataset_size, *, key):
ts = jnp.linspace(0, 10, 100)
key = jr.split(key, dataset_size)
ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
return ts, ys
def dataloader(arrays, batch_size, *, key):
dataset_size = arrays[0].shape[0]
assert all(array.shape[0] == dataset_size for array in arrays)
indices = jnp.arange(dataset_size)
while True:
perm = jr.permutation(key, indices)
(key,) = jr.split(key, 1)
start = 0
end = batch_size
while end < dataset_size:
batch_perm = perm[start:end]
yield tuple(array[batch_perm] for array in arrays)
start = end
end = start + batch_size
def main(
dataset_size=256,
batch_size=32,
lr_strategy=(3e-3, 3e-3),
steps_strategy=(500, 500),
length_strategy=(0.1, 1),
width_size=64,
depth=2,
seed=5678,
plot=True,
print_every=100,
):
key = jr.PRNGKey(seed)
data_key, model_key, loader_key = jr.split(key, 3)
ts, ys = get_data(dataset_size, key=data_key)
_, length_size, data_size = ys.shape
model = NeuralODE(data_size, width_size, depth, key=model_key)
# Training loop like normal.
#
# Only thing to notice is that up until step 500 we train on only the first 10% of
# each time series. This is a standard trick to avoid getting caught in a local
# minimum.
@eqx.filter_value_and_grad
def grad_loss(model, ti, yi):
y_pred = jax.vmap(model, in_axes=(None, 0))(ti, yi[:, 0])
mse = test_mse(yi, y_pred)
return jnp.mean(mse)
@eqx.filter_jit(donate='all')
def make_step(ti, yi, model, opt_state):
loss, grads = grad_loss(model, ti, yi)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):
optim = optax.adabelief(lr)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
_ts = ts[: int(length_size * length)]
_ys = ys[:, : int(length_size * length)]
for step, (yi,) in zip(
range(steps), dataloader((_ys,), batch_size, key=loader_key)
):
start = time.time()
loss, model, opt_state = make_step(_ts, yi, model, opt_state)
end = time.time()
if (step % print_every) == 0 or step == steps - 1:
print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")
if plot:
plt.plot(ts, ys[0, :, 0], c="dodgerblue", label="Real")
plt.plot(ts, ys[0, :, 1], c="dodgerblue")
model_y = model(ts, ys[0, 0])
plt.plot(ts, model_y[:, 0], c="crimson", label="Model")
plt.plot(ts, model_y[:, 1], c="crimson")
plt.legend()
plt.tight_layout()
plt.savefig("neural_ode.png")
plt.show()
return ts, ys, model
# __main__
if __name__ == "__main__":
ts, ys, model = main()
I have recently discovered that when you jit your model, you can use donate_args,
https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html
functionality to reuse the memory, as reading the following,
First, we're going to arrange to "donate" memory, which specifies that we can re-use the memory for our input arrays (e.g. model parameters) to store the output arrays (e.g. updated model parameters). (This isn't technically related to autoparallelism, but it's good practice so you should do it anyway :)
https://docs.kidger.site/equinox/examples/parallelism/
Everything appears to work as expected, except if I use donate args functionality, the following warnings spawns every iteration (usually normally I just get this on first JITing).
Speed of optimisation and progress in minimising the loss are working as before / as expected. I can obviously suppress the warning, but so I think it's more an annoyance that shouldn't be happening than anything else.