rdyro / torch2jax

Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff.
https://rdyro.github.io/torch2jax/
MIT License
39 stars 1 forks source link

New bug when optimising function #4

Closed adam-hartshorne closed 1 year ago

adam-hartshorne commented 1 year ago

Find attached a minimum example and the resultant issue. Definitely looks like a synchronisation issue.

Furthermore, in my proper usage which is in x64 mode, I also found as loss / gradients get small, it seems like the optimisation becomes stuck..It appears like there is some sort of issue with numerical precision of gradients (I think that is what is going on with my x64 result shown below).

import torch
from torch2jax import torch2jax, torch2jax_with_vjp # this converts a Python function to JAX
import jax
import jax.numpy as jnp
import jax.random as jrandom
import optax
import equinox as eqx
import numpy as np
import os

key = jrandom.PRNGKey(0)
x_key, y_key, nn_key = jrandom.split(key, 3)
x = jrandom.normal(x_key, (1, 5000, 3))
y = jrandom.normal(y_key, (1, 5000, 3))
data = (x, y)

def jax_chamfer_distance(normalised_data, output_locations):
    difference = (jnp.expand_dims(normalised_data, axis=-2) - jnp.expand_dims(output_locations, axis=-3))
    # Calculate the square distances between each two points: |ai - bj|^2.
    square_distances = jnp.einsum("...i,...i->...", difference, difference)

    minimum_square_distance_a_to_b = jnp.min(square_distances, axis=-1)
    minimum_square_distance_b_to_a = jnp.min(square_distances, axis=-2)
    return 0.5 * jnp.sum(jnp.mean(minimum_square_distance_a_to_b, axis=-1) + jnp.mean(minimum_square_distance_b_to_a, axis=-1))

def torch_chamfer_distance(normalised_data, output_locations):
    difference = (normalised_data.unsqueeze(-2) - output_locations.unsqueeze(-3))
    # Calculate the square distances between each two points: |ai - bj|^2.
    square_distances = torch.einsum("...i,...i->...", difference, difference)

    minimum_square_distance_a_to_b = torch.min(square_distances, dim=-1)[0]
    minimum_square_distance_b_to_a = torch.min(square_distances, dim=-2)[0]
    return 0.5 * torch.sum(torch.mean(minimum_square_distance_a_to_b, dim=-1) + torch.mean(minimum_square_distance_b_to_a, dim=-1))

def keops_calc_chamfer_loss(x, y):
    xy_cost = torch2jax_with_vjp(
        torch_chamfer_distance,
        jax.ShapeDtypeStruct(x.shape, x.dtype),
        jax.ShapeDtypeStruct(y.shape, x.dtype),
        output_shapes=jax.ShapeDtypeStruct((x.shape[0],), x.dtype),
        depth=2,
        use_torch_vjp=False
    )(x, y)
    return xy_cost

class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jax.nn.relu,
            key=key,
        )

    def __call__(self, x, args=None):
        return jax.vmap(self.mlp)(x)

model = Func(data_size=3, width_size=128, depth=2, key=nn_key)
optimiser_creation_func = optax.adabelief
opt_init, opt_update = optimiser_creation_func(learning_rate=1e-3)
opt_state = opt_init(eqx.filter(model, eqx.is_inexact_array))

@eqx.filter_value_and_grad
def compute_loss(model, data):
    x, y = data
    y_pred = jax.vmap(model, (0))(x)
    loss = keops_calc_chamfer_loss(y_pred, y) #DOESN'T WORK
    # loss = jax_chamfer_distance(y_pred, y) #WORKS

    return jnp.mean(loss)

@eqx.filter_jit
def make_step(model, data, opt_state):
    loss, grads = compute_loss(model, data)
    updates, opt_state = opt_update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

for step in range(25):
    loss, model, opt_state = make_step(model, data, opt_state)
    print(loss)

Running pure JAX

1.1784207 1.0076947 0.8537498 0.7236823 0.6130994 0.5178365 0.4353472 0.36305848 0.2981228 0.23942232 0.18905437 0.14701995 0.11380868 0.088957496 0.07115974 0.05860029 0.050203666 0.044576604 0.040941507 0.03873283 0.037716232 0.03754989 0.03794671 0.038578086 0.039351515

Running torch2jax

1.1784207 1.0076947 0.8537498 0.72368234 0.61309934 0.51783663 0.43534735 0.36305854 0.29812282 0.23942222 0.18905428 0.14701991 0.113808714 0.088957384 1.4932078 <--------------- 0.058600325 0.050203983 0.04457575 1.5029242 <--------------- 0.03873186 0.037716057 0.037549183 0.037946653 0.03857988 1.513658 <---------------

Running same torch_chamfer_distance in pure PyTorch training loop,

1.176 1.020 0.879 0.749 0.632 0.527 0.434 0.353 0.284 0.227 0.181 0.144 0.115 0.092 0.074 0.062 0.052 0.045 0.041 0.037 0.035 0.034 0.033 0.032 0.032

adam-hartshorne commented 1 year ago

Setting things in x64 bit mode is even more of an issue when using torch2jax_with_vjp,

1.1318159821788687 1.4969821824080587 1.4862239786651517 1.4780943113719986 1.472037333444222 1.4678562092608634 ....

rdyro commented 1 year ago

Thanks for submitting this!

Thanks to your code example I was able to catch and fix two errors:

Make sure to delete ~/.cache/torch2jax or call

from torch2jax import compile_and_import_module

compile_and_import_module(force_recompile=True)

at least once after pulling the new changes.

adam-hartshorne commented 1 year ago

Looks good from my end!

Really appreciate how fast you are on correcting these bugs.

rdyro commented 1 year ago

Great, no worries, your example codes make for much better testing than I managed to write by myself, I really appreciate your help.