rdyro / torch2jax

Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
37 stars 1 forks source link

Issue when attempting to optimise using torch2jax function #3

Closed adam-hartshorne closed 1 year ago

adam-hartshorne commented 1 year ago

I have attached minimal example code below which should allow you to reproduce the issue along with the error log.

Everything works ok unless I attempt to utilise a function that contains calls to KeOps library inside your torch2jax framework and I need gradient i.e. when running an optimiser. The same function calls works fine when simply called from JAX and obviously everything works fine in pure pytorch.

import torch
from pykeops.torch import LazyTensor
from torch2jax import torch2jax, torch2jax_with_vjp # this converts a Python function to JAX

import jax
import jax.numpy as jnp
import jax.random as jrandom
import optax
import equinox as eqx
import numpy as np
import os

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "0.1"

USE_X64 = False

if USE_X64:
    # use jax 64-bit precision
    jax.config.update("jax_enable_x64", True)
    dtype = jnp.float64
    np_dtype = np.float64
else:
    dtype = jnp.float32
    np_dtype = np.float32

def keops_rbf_loss(x, y):
    x = LazyTensor(x[:, None, :])
    y = LazyTensor(y[None, :, :])
    D2 = x.sqdist(y)
    K = (-D2).exp()
    return torch.reshape(torch.sum(K.sum(dim=1)), (1,))

def torch_rbf_loss(x, y):
    xx = torch.sum(x * x, dim=1, keepdim=True)
    zz = torch.sum(y * y, dim=1, keepdim=True)
    sq_dist = xx + torch.transpose(zz, 0, 1) - 2.0 * torch.einsum('ij,kj->ik', x, y)
    return torch.reshape(torch.sum(torch.exp(-sq_dist)), (1,))

key = jrandom.PRNGKey(0)
x_key, y_key, nn_key = jrandom.split(key, 3)
x = jrandom.normal(x_key, (100, 3))
y = jrandom.normal(y_key, (200, 3))
data = (x,y)

def jax_keops_rbf_loss(x, y):
    loss = torch2jax_with_vjp(keops_rbf_loss,
                                           jax.ShapeDtypeStruct(x.shape, dtype=x.dtype),
                                           jax.ShapeDtypeStruct(y.shape, dtype=y.dtype),
                                           output_shapes=jax.ShapeDtypeStruct((1,), dtype=x.dtype),
                                           depth=2)(x, y)
    return loss

def jax_torch_rbf_loss(x, y):
    loss = torch2jax_with_vjp(torch_rbf_loss,
                            jax.ShapeDtypeStruct(x.shape,  dtype=x.dtype),
                            jax.ShapeDtypeStruct(y.shape,  dtype=y.dtype),
                            output_shapes=jax.ShapeDtypeStruct((1,), dtype=x.dtype),
                            depth=2)(x, y)

    return loss

def jax_rbf_loss(x, y):
    xx = jnp.sum(x * x, axis=1, keepdims=True)
    zz = jnp.sum(y * y, axis=1, keepdims=True)
    sq_dist = xx + jnp.transpose(zz) - 2.0 * jnp.einsum('ij,kj->ik', x, y)
    return jnp.sum(jnp.exp(-sq_dist))

class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jax.nn.relu,
            key=key,
        )

    def __call__(self, x, args=None):
        return jax.vmap(self.mlp)(x)

model = Func(data_size=3, width_size=128, depth=2, key=nn_key)
optimiser_creation_func = optax.adabelief
opt_init, opt_update = optimiser_creation_func(learning_rate=1e-3)
opt_state = opt_init(eqx.filter(model, eqx.is_inexact_array))

@eqx.filter_value_and_grad
def compute_loss(model, data):
    x, y = data
    y_pred = model(x)

    loss = jax_keops_rbf_loss(y_pred, y) # DOESN'T WORK
    # loss = jax_torch_rbf_loss(y_pred, y) # WORKS
    # loss = jax_rbf_loss(y_pred, y) # WORKS

    return jnp.mean(loss)

@eqx.filter_jit
def make_step(model, data, opt_state):
    loss, grads = compute_loss(model, data)
    updates, opt_state = opt_update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

for step in range(10):
    loss, model, opt_state = make_step(model, data, opt_state)
    print(loss)

# -----------------------------------
# ALL PYTORCH CODE BELOW TO SHOW THAT THE LOSS FUNCTIONS WORK IN PURE PYTORCH
# -----------------------------------

from torch import nn
class MLP(nn.Module):
  '''
    Multilayer Perceptron.
  '''
  def __init__(self):
    super().__init__()
    self.layers = nn.Sequential(
      nn.Linear(3, 128),
      nn.ReLU(),
      nn.Linear(128, 128),
      nn.ReLU(),
      nn.Linear(128, 3)
    )

  def forward(self, x):
    '''Forward pass'''
    return self.layers(x)

x_tensor = torch.randn((100, 3)).to('cuda')
y_tensor = torch.randn((200, 3)).to('cuda')

mlp = MLP()
mlp.to('cuda')

optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3)

# Run the training loop
for epoch in range(100):  # 5 epochs at maximum

    # Print epoch
    print(f'Starting epoch {epoch + 1}')

    # Set current loss value
    current_loss = 0.0

    # Get inputs
    # inputs, targets = data

    # Zero the gradients
    optimizer.zero_grad()

    # Perform forward pass
    outputs = mlp(x_tensor)

    # Compute loss
    # loss = torch_rbf_loss(outputs, y_tensor) # WORKS
    loss = keops_rbf_loss(outputs, y_tensor) # WORKS

    # Perform backward pass
    loss.backward()

    # Perform optimization
    optimizer.step()

    # Print statistics
    current_loss += loss.item()
    print(f'Loss after {epoch + 1}: {current_loss:.3f}')

torch2jax_error.txt

rdyro commented 1 year ago

I managed to reproduce it, thanks for submitting it!

It seems keops is not entirely compatible with the PyTorch functional interface (the error message seems to suggest it is a missing implementation in keops).

I made changes to allow to user of torch2jax to default to the old torch.autograd.grad interface for taking gradients, in your emaple, the only change would be

def jax_keops_rbf_loss(x, y):
    loss = torch2jax_with_vjp(
        keops_rbf_loss,
        jax.ShapeDtypeStruct(x.shape, dtype=x.dtype),
        jax.ShapeDtypeStruct(y.shape, dtype=y.dtype),
        output_shapes=jax.ShapeDtypeStruct((1,), dtype=x.dtype),
        depth=2,
        use_torch_vjp=False
    )(x, y)
    return loss

Notice, the use_torch_vjp=False flag.

This switches torch2jax_with_vjp gradient definition from using torch.func.vjp to torch.autograd.grad (where the latter, importantly, is not compatible with torch.vmap, so it also instructs torch2jax to switch to dumb experimental looping instead of using torch.vmap; this should not be any slower as long as jax.vmap is not used over the wrapped function, something which your code doesn't do anyway).

Let me know if this works for you!

adam-hartshorne commented 1 year ago

Thank you for such a swift response.

Also thank you for mentioning the potential torch.vmap / jax.vmap issue, as the current structure of my real code (in which I want to replace a JAX function with a Keops function) uses this.

It might be worth adding this issue to the docs, as I wouldn't be surprised if quite a few libraries that interface with torch still use the old torch.autograd.grad method for handling gradients.

rdyro commented 1 year ago

No worries, I'm glad you're finding the package useful.

After reading through keops, they're defining a custom autograd function in the old way (using forward & backward instead of forward, backward, and setup_context). I experimented a bit with editing that code and it's possible to make it work with the new version (and hence torch.vjp), but I'm not sure if my edits don't break higher derivatives.

That's a good idea, for now, I made the code detect a failure when using torch.vjp and fallback to torch.autograd.grad automatically with a warning. I'll look into how to add a description of this pitfall in the documentation.

adam-hartshorne commented 1 year ago

They might be happy to accept a pull request for modernising their gradients.

I am currently experiencing a slight issue with intermittent NaN gradients in my code when using torch / Keops vs pure JAX. However, it is only occurring when using much larger batch sizes than possible with JAX (because of the lazy evaluation functionality of Keops). I haven't managed to narrow it down any further, so I can't say for absolute certain where the issue lies i.e. Keops, torch2jax or something in my own JAX code.

I will obviously raise an issue here if I believe it is something to do with torch2jax.

rdyro commented 1 year ago

I'll look into a pull request once I have a little more time, it's a good idea.

I recently spotted a possible synchronization bug in my CUDA implementation, I already pushed a fix, but it's in the C++ part of the code, so you can either try:

from torch2jax import compile_and_import_module

compile_and_import_module(force_recompile=True)

or delete ~/.cache/torch2jax, sorry for that.

Thanks, keep me updated!

adam-hartshorne commented 1 year ago

Ok, so I have discovered the bug that is causing the NaN, but it is very sporadic which makes debugging it hard, but I think there is some sort of weird interaction between torch2jax, torch, keops.

Basically performing the simple division inside the pytorch part of the code (even if it is 1.0 / 1.0) can cause these NaN gradients.

**THIS CAN RESULT IN NAN GRADIENTS**

def batch_keops_ls(x, y, lengthscale):
    B, N, D = x.shape  # Batch size, number of source points, features
    _, M, _ = y.shape  # Batch size, number of target points, features

    gamma = 1.0 / length_scale

    # Encode as symbolic tensors:
    x_i = LazyTensor(x.view(B, N, 1, D))  # (B, N, 1, D)
    y_j = LazyTensor(y.view(B, 1, M, D))  # (B, 1, M, D)
    D2 = x_i.sqdist(y_j)
    K = (-D2 * gamma).exp()
    return torch.reshape(torch.sum(K.sum(2), dim=1), (-1,))

length_scale = jnp.array([1.0])

 xy_cost = torch2jax_with_vjp(
      batch_keops_ls,
      jax.ShapeDtypeStruct(x.shape, x.dtype),
      jax.ShapeDtypeStruct(y.shape, x.dtype),
      jax.ShapeDtypeStruct(length_scale.shape, x.dtype),
      output_shapes=jax.ShapeDtypeStruct((x.shape[0],), x.dtype),
      depth=2,
      use_torch_vjp=False)(x, y, length_scale)
**THIS WORKS**

def batch_keops_gamma(x, y, gamma):
    B, N, D = x.shape  # Batch size, number of source points, features
    _, M, _ = y.shape  # Batch size, number of target points, features

    # Encode as symbolic tensors:
    x_i = LazyTensor(x.view(B, N, 1, D))  # (B, N, 1, D)
    y_j = LazyTensor(y.view(B, 1, M, D))  # (B, 1, M, D)
    D2 = x_i.sqdist(y_j)
    K = (-D2 * gamma).exp()
    return torch.reshape(torch.sum(K.sum(2), dim=1), (-1,))

gamma = jnp.array([1.0])

 xy_cost = torch2jax_with_vjp(
      batch_keops_gamma,
      jax.ShapeDtypeStruct(x.shape, x.dtype),
      jax.ShapeDtypeStruct(y.shape, x.dtype),
      jax.ShapeDtypeStruct(gamma.shape, x.dtype),
      output_shapes=jax.ShapeDtypeStruct((x.shape[0],), x.dtype),
      depth=2,
      use_torch_vjp=False)(x, y, gamma)
**THIS CAN RESULT IN NAN GRADIENTS**
def torch_wrapper(x, y, length_scale):
    gamma = 1.0 / length_scale
    return batch_keops_gamma(x, y, gamma)

length_scale =  jnp.array([1.0])

 xy_cost = torch2jax_with_vjp(
      torch_wrapper,
      jax.ShapeDtypeStruct(x.shape, x.dtype),
      jax.ShapeDtypeStruct(y.shape, x.dtype),
      jax.ShapeDtypeStruct(length_scale.shape, x.dtype),
      output_shapes=jax.ShapeDtypeStruct((x.shape[0],), x.dtype),
      depth=2,
      use_torch_vjp=False)(x, y, length_scale)
**THIS WORKS**
length_scale =  jnp.array([1.0])
gamma = 1.0 / length_scale

 xy_cost = torch2jax_with_vjp(
      batch_keops_gamma,
      jax.ShapeDtypeStruct(x.shape, x.dtype),
      jax.ShapeDtypeStruct(y.shape, x.dtype),
      jax.ShapeDtypeStruct(gamma.shape, x.dtype),
      output_shapes=jax.ShapeDtypeStruct((x.shape[0],), x.dtype),
      depth=2,
      use_torch_vjp=False)(x, y, gamma)
rdyro commented 1 year ago

I made two bug fixes thanks to #4, let me know if those fixes also help with this please.

adam-hartshorne commented 1 year ago

Looks good from my end!

Really appreciate how fast you are on correcting these bugs.