Hello, I am not sure whether this should be written as a JAX or Flax issue.
Here the simplified description of the training steps in my context (see the example below for more detail):
I apply the model on the whole input data
I select some indices of the obtained predictions
These predictions are used along with the corresponding label for the loss computation
The reason I do so is that I am working with a graph neural network (for a node regression). I then get a prediction for all nodes, but I split nodes into batches to compute the average loss/gradient among batches.
My loss function takes as an input the indices to be used for the computation. Instead of calling the function for each batch, I wanted to use vmap on the loss function which I expect to be faster. It however seems that it leads to different results. Here is an example of the issue:
import flax.linen as nn
import jax
import jax.numpy as N
import jax.random as R
import jax.tree_util as T
import optax
key = R.PRNGKey(0)
rng1, rng2, rng3, rng4 = R.split(key, 4)
X = R.normal(rng1, (1000, 20)) # Input: <num_instances> x <num_features>
Y = R.normal(rng2, (1000, 1)) # Label: <num_instances> x 1
# Batches of indices: <num_batches> x <len_batches>
batches = R.randint(rng3, (10, 500), 0, 1000)
model = nn.Dense(1)
model_vars = model.init(rng4, N.ones((1, 20)))
# Method 1: with vmap on loss function
@jax.jit
def train_step1(variables, x, y, bs):
def loss_fn(v, i):
p = model.apply(v, x)
loss = N.mean(optax.sigmoid_binary_cross_entropy(p[i], y[i]))
return loss
v_loss_grad_fn = jax.vmap(jax.value_and_grad(loss_fn), in_axes=(None, 0), out_axes=0)
loss, grads = v_loss_grad_fn(variables, bs)
loss = N.mean(loss, axis=0)
grads = T.tree_map(lambda g: N.mean(g, axis=0), grads)
variables = T.tree_map(lambda v, g: v - 0.01*g, variables, grads)
return variables, loss
# Method 2: without vmap on loss function
@jax.jit
def train_step2(variables, x, y, bs):
def loss_fn(v, i):
p = model.apply(v, x)
loss = N.mean(optax.sigmoid_binary_cross_entropy(p[i], y[i]))
return loss
loss_grad_fn = jax.value_and_grad(loss_fn)
l_list = []
g_list = []
for i in range(10): # Number of batches
idx = bs[i]
loss, grads = loss_grad_fn(variables, idx)
l_list.append(loss)
g_list.append(grads)
loss = N.mean(N.stack(l_list, axis=0), axis=0)
grads = T.tree_map(lambda *g: N.mean(N.stack(g, axis=0), axis=0), *g_list)
variables = T.tree_map(lambda v, g: v - 0.01*g, variables, grads)
return variables, loss
# Warming up
train_step1(model_vars, N.ones((1, 20)), N.ones((1, 1)), N.asarray([[0]]))
train_step2(model_vars, N.ones((1, 20)), N.ones((1, 1)), N.asarray([[0]]))
v1 = model_vars
v2 = model_vars
# Comparing results before trainings
p1 = model.apply(v1, X)
p2 = model.apply(v2, X)
print("Same results before trainings:", N.all(p1 == p2))
# Training
for _ in range(1000):
v1, l1 = train_step1(v1, X, Y, batches)
v2, l2 = train_step2(v2, X, Y, batches)
# Comparing results after both trainings
p1 = model.apply(v1, X)
p2 = model.apply(v2, X)
print("Same results after trainings:", N.all(p1 == p2))
print("Same losses after trainings", l1 == l2)
Maybe I misunderstood how vmap is working, but I think that both methods described above should have the same behaviour. So I don't understand why they both lead to different results. Are those due to some approximation during the computation (because losses are still identical at last iteration)?
Thank you for your help.
Hello, I am not sure whether this should be written as a JAX or Flax issue.
Here the simplified description of the training steps in my context (see the example below for more detail):
The reason I do so is that I am working with a graph neural network (for a node regression). I then get a prediction for all nodes, but I split nodes into batches to compute the average loss/gradient among batches.
My loss function takes as an input the indices to be used for the computation. Instead of calling the function for each batch, I wanted to use
vmap
on the loss function which I expect to be faster. It however seems that it leads to different results. Here is an example of the issue:System information
Maybe I misunderstood how
vmap
is working, but I think that both methods described above should have the same behaviour. So I don't understand why they both lead to different results. Are those due to some approximation during the computation (because losses are still identical at last iteration)? Thank you for your help.