google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.01k stars 633 forks source link

Using `vmap` on the loss function change results #3973

Open gduflo opened 3 months ago

gduflo commented 3 months ago

Hello, I am not sure whether this should be written as a JAX or Flax issue.

Here the simplified description of the training steps in my context (see the example below for more detail):

The reason I do so is that I am working with a graph neural network (for a node regression). I then get a prediction for all nodes, but I split nodes into batches to compute the average loss/gradient among batches.

My loss function takes as an input the indices to be used for the computation. Instead of calling the function for each batch, I wanted to use vmap on the loss function which I expect to be faster. It however seems that it leads to different results. Here is an example of the issue:

import flax.linen as nn
import jax
import jax.numpy as N
import jax.random as R
import jax.tree_util as T
import optax

key = R.PRNGKey(0)
rng1, rng2, rng3, rng4 = R.split(key, 4)

X = R.normal(rng1, (1000, 20)) # Input: <num_instances> x <num_features>
Y = R.normal(rng2, (1000, 1))  # Label: <num_instances> x 1
# Batches of indices: <num_batches> x <len_batches>
batches = R.randint(rng3, (10, 500), 0, 1000)

model = nn.Dense(1)
model_vars = model.init(rng4, N.ones((1, 20)))

# Method 1: with vmap on loss function
@jax.jit
def train_step1(variables, x, y, bs):

    def loss_fn(v, i):
        p = model.apply(v, x)
        loss = N.mean(optax.sigmoid_binary_cross_entropy(p[i], y[i]))
        return loss

    v_loss_grad_fn = jax.vmap(jax.value_and_grad(loss_fn), in_axes=(None, 0), out_axes=0)
    loss, grads = v_loss_grad_fn(variables, bs)
    loss = N.mean(loss, axis=0)
    grads = T.tree_map(lambda g: N.mean(g, axis=0), grads)
    variables = T.tree_map(lambda v, g: v - 0.01*g, variables, grads)
    return variables, loss

# Method 2: without vmap on loss function
@jax.jit
def train_step2(variables, x, y, bs):

    def loss_fn(v, i):
        p = model.apply(v, x)
        loss = N.mean(optax.sigmoid_binary_cross_entropy(p[i], y[i]))
        return loss

    loss_grad_fn = jax.value_and_grad(loss_fn)
    l_list = []
    g_list = []
    for i in range(10): # Number of batches
        idx = bs[i]
        loss, grads = loss_grad_fn(variables, idx)
        l_list.append(loss)
        g_list.append(grads)
    loss = N.mean(N.stack(l_list, axis=0), axis=0)
    grads = T.tree_map(lambda *g: N.mean(N.stack(g, axis=0), axis=0), *g_list)
    variables = T.tree_map(lambda v, g: v - 0.01*g, variables, grads)
    return variables, loss

# Warming up
train_step1(model_vars, N.ones((1, 20)), N.ones((1, 1)), N.asarray([[0]]))
train_step2(model_vars, N.ones((1, 20)), N.ones((1, 1)), N.asarray([[0]]))

v1 = model_vars
v2 = model_vars

# Comparing results before trainings
p1 = model.apply(v1, X)
p2 = model.apply(v2, X)
print("Same results before trainings:", N.all(p1 == p2))

# Training
for _ in range(1000):
    v1, l1 = train_step1(v1, X, Y, batches)
    v2, l2 = train_step2(v2, X, Y, batches)

# Comparing results after both trainings
p1 = model.apply(v1, X)
p2 = model.apply(v2, X)
print("Same results after trainings:", N.all(p1 == p2))
print("Same losses after trainings", l1 == l2)

System information

Maybe I misunderstood how vmap is working, but I think that both methods described above should have the same behaviour. So I don't understand why they both lead to different results. Are those due to some approximation during the computation (because losses are still identical at last iteration)? Thank you for your help.