Closed DrJonnyT closed 5 months ago
Hey @DrJonnyT!
Hmmm, this is curious. I think you might be onto something here. I can tell that the difference in the encoder.
Will investigate and get back to you!
The discrepancy indeed is in the Attention
module. That's concerning.
Actually, @DrJonnyT, I'm thinking that things might be fine after all. Could you try running the following and seeing if you find that things are equal too?
import neuralprocesses.torch as nps_torch
import neuralprocesses.tensorflow as nps_tf
import lab as B
import numpy as np
import tensorflow as tf
import torch
model_tf = nps_tf.construct_gnp(
dim_x=17,
dim_y=9,
dim_embedding=128,
num_enc_layers=6,
num_dec_layers=6,
)
model_torch = nps_torch.construct_gnp(
dim_x=17,
dim_y=9,
dim_embedding=128,
num_enc_layers=6,
num_dec_layers=6,
)
model_torch(
B.randn(torch.float32, 16, 17, 10),
B.randn(torch.float32, 16, 9, 10),
B.randn(torch.float32, 16, 17, 10),
)
model_tf(
B.randn(tf.float32, 16, 17, 10),
B.randn(tf.float32, 16, 9, 10),
B.randn(tf.float32, 16, 17, 10),
)
assert len(model_tf.get_weights()) == len(list(model_torch.parameters()))
for x, y in zip(model_tf.get_weights(), model_torch.parameters()):
assert x.shape == y.shape or x.shape == (y.shape[1], y.shape[0])
print("Ok!")
model_tf = nps_tf.construct_agnp(
dim_x=17,
dim_y=9,
dim_embedding=128,
num_enc_layers=6,
num_dec_layers=6,
)
model_torch = nps_torch.construct_agnp(
dim_x=17,
dim_y=9,
dim_embedding=128,
num_enc_layers=6,
num_dec_layers=6,
)
model_torch(
B.randn(torch.float32, 16, 17, 10),
B.randn(torch.float32, 16, 9, 10),
B.randn(torch.float32, 16, 17, 10),
)
model_tf(
B.randn(tf.float32, 16, 17, 10),
B.randn(tf.float32, 16, 9, 10),
B.randn(tf.float32, 16, 17, 10),
)
assert len(model_tf.get_weights()) == len(list(model_torch.parameters()))
for x, y in zip(model_tf.get_weights(), model_torch.parameters()):
assert x.shape == y.shape or x.shape == (y.shape[1], y.shape[0])
print("Ok!")
However, the performance after n epochs is worse in torch.
Hmm, the performance may be sensitive to initialisation, which could be different between PyTorch and TF and the precise optimiser settings (how do learning rate and batch size interact?). Are you sure that those are equal?
@wesselb that script gives me OK for both gnp and agnp, so all good there I've just been testing with this script below. Very variable results, with convgnp the performance is often comparable, but with gnp and agnp tensorflow always seems to win. However it also depends on your choice of optimizer exactly how fast they learn
import tensorflow as tf
import numpy as np
import neuralprocesses.torch as nps_torch
import neuralprocesses.tensorflow as nps_tf
def x_to_y(x):
# Dummy function to make learnable y data from random x data
shape = x.shape
y = torch.randn(shape[0],2,shape[2])
y[:, 0, :] *= 2
y[:, 1, :] *= 3
y = y + torch.randn_like(y)*0.1
return y
# %%
num_batches = 32
xc_list_torch, yc_list_torch, xt_list_torch, yt_list_torch = [],[],[],[]
for batch in range(num_batches):
xc = torch.randn(16, 1, 10)
xt = torch.randn(16, 1, 15)
xc_list_torch.append(xc) # Context inputs
xt_list_torch.append(xt) # Target inputs
yc_list_torch.append(x_to_y(xc)) # Context outputs
yt_list_torch.append(x_to_y(xt)) # Target output
# Construct models
agnp_torch = nps_torch.construct_gnp(dim_x=1, dim_y=2, likelihood="het")
agnp_tf = nps_tf.construct_gnp(dim_x=1, dim_y=2, likelihood="het")
# Construct optimisers with low learning rate as the data are simple to learn
# I have tuned the learning rate so it doesn't plateau with Adam after 5 epochs
opt_torch = torch.optim.Adam(agnp_torch.parameters(), 5e-6)
opt_tf = tf.keras.optimizers.legacy.Adam(learning_rate=5e-6)
num_epochs = 5
# %%
# Training loop of 5 actual epochs, with one warmup epoch at the start to test
# the losses of the untrained models to check they are similar from the inital weights
epochs_loss_torch = []
epochs_loss_tf = []
for epoch in range(num_epochs+1):
this_epoch_loss_torch = []
this_epoch_loss_tf = []
for batch in range(num_batches):
# Torch version
xc_torch = xc_list_torch[batch]
yc_torch = yc_list_torch[batch]
xt_torch = xt_list_torch[batch]
yt_torch = yt_list_torch[batch]
if epoch > 0:
loss_torch = -torch.mean(nps_torch.loglik(agnp_torch, xc_torch, yc_torch, xt_torch, yt_torch, normalise=True))
opt_torch.zero_grad(set_to_none=True)
loss_torch.backward()
opt_torch.step()
this_epoch_loss_torch.append(loss_torch.detach().numpy())
# Tensorflow version with the same data
xc_tf = tf.convert_to_tensor(xc_torch.numpy())
yc_tf = tf.convert_to_tensor(yc_torch.numpy())
xt_tf = tf.convert_to_tensor(xt_torch.numpy())
yt_tf = tf.convert_to_tensor(yt_torch.numpy())
with tf.GradientTape() as tape:
# Compute the loss
loss_tf = -tf.reduce_mean(nps_tf.loglik(agnp_tf, xc_tf, yc_tf, xt_tf, yt_tf, normalise=True))
if epoch > 0:
gradients = tape.gradient(loss_tf, agnp_tf.trainable_variables)
opt_tf.apply_gradients(zip(gradients, agnp_tf.trainable_variables))
this_epoch_loss_tf.append(loss_tf.numpy())
# Collate the losses per epoch
epochs_loss_torch.append(np.mean(this_epoch_loss_torch).round(3))
epochs_loss_tf.append(np.mean(this_epoch_loss_tf).round(3))
print(f"Torch losses:\n{epochs_loss_torch}")
print(f"TF losses:\n{epochs_loss_tf}")
@wesselb here's a messy chatgpt script to make 2 identical relu networks, just using pytorch and tensorflow. It seems to train much quicker in tensorflow so I think it's probably just a tensorflow vs torch thing rather than an issues with neuralprocesses? I would have thought that would be more well known though?
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
# Define the model
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 10)
self.fc2 = nn.Linear(10, 10)
self.fc3 = nn.Linear(10, 10)
self.fc4 = nn.Linear(10, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.relu(self.fc3(x))
x = self.fc4(x)
return x
# Create the model
model = Net()
# Create some data
x = torch.randn(100, 10) + 1
y = torch.mean(x.pow(2) + 10 + torch.randn(100, 10) * 0.1,axis=1)
# Define loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-2)
# Training loop
num_epochs = 5
torch_losses = []
for epoch in range(num_epochs):
optimizer.zero_grad()
outputs = model(x)
loss = criterion(outputs.squeeze(), y)
loss.backward()
optimizer.step()
torch_losses.append(loss.detach().numpy().round(3))
# Tensorflow starts here
# Define the model
model = tf.keras.Sequential([
layers.Dense(10, activation='relu', input_shape=(10,)),
layers.Dense(10, activation='relu'),
layers.Dense(10, activation='relu'),
layers.Dense(1)
])
# Create some data using the same data from PyTorch
# Convert PyTorch tensor to NumPy array
x_np = x.detach().numpy()
y_np = y.detach().numpy()
x = tf.convert_to_tensor(x_np, dtype=tf.float32)
y = tf.convert_to_tensor(y_np, dtype=tf.float32)
# Compile the model
opt = tf.keras.optimizers.legacy.Adam(learning_rate=1e-2)
model.compile(optimizer=opt, loss='mse')
# Training loop
history = model.fit(x, y, epochs=num_epochs,verbose=0)
print(f"Torch losses:\n{torch_losses}")
print(f"TF losses:\n{history.history['loss']}")
@DrJonnyT Your ReLU example is a good one. I would chase that down. It should be possible to configure things so that the convergence is exactly the same between TF and PyTorch.
Do PyTorch and TF initialise the weights in the same way? That could make a big difference.
@wesselb This gets you pretty close! I get these losses for 5 epochs (the first data point is testing the untrained model).
Torch losses:
[5.579, 5.314, 4.832, 4.471, 4.173, 3.899]
TF losses:
[5.593, 5.328, 4.847, 4.49, 4.204, 3.96]
I suspect/hope that if you set the random seed the same it would come out exactly the same. I tried the optimizer in this configuration with some very basic tf/torch models and the loss was exactly the same. So I'm happy to close. Phew! Script here:
import tensorflow as tf
import torch
import numpy as np
import neuralprocesses.torch as nps_torch
import neuralprocesses.tensorflow as nps_tf
def x_to_y(x):
"""Dummy function to make learnable y data from random x data"""
shape = x.shape
y = torch.randn(shape[0],2,shape[2])
y[:, 0, :] *= 2
y[:, 1, :] *= 3
y = y + torch.randn_like(y)*0.1
return y
def copy_weights_and_biases(model_torch, model_tf):
"""Copy weights from torch model to tf model"""
weights_tf = model_tf.get_weights()
weights_torch = [param.detach().numpy() for param in model_torch.parameters()]
for i in range(len(weights_tf)):
if weights_tf[i].shape == weights_torch[i].shape:
weights_tf[i] = weights_torch[i]
elif weights_tf[i].shape == (weights_torch[i].shape[1], weights_torch[i].shape[0]):
weights_tf[i] = weights_torch[i].T
model_tf.set_weights(weights_tf)
print("Weights and biases copied successfully from PyTorch model to TensorFlow model")
def compare_models(model_torch, model_tf):
"""Check that the weights and biases in a tf and torch model are all the same"""
# Convert PyTorch model to state_dict (dictionary object)
pytorch_state_dict = model_torch.state_dict()
# Get TensorFlow model variables
tensorflow_variables = model_tf.trainable_variables
# Check if the number of layers are the same
if len(pytorch_state_dict) != len(tensorflow_variables):
print("The models have a different number of layers.")
return False
# Iterate over PyTorch model parameters
for item, ((name, param), tf_var) in enumerate(zip(pytorch_state_dict.items(), tensorflow_variables)):
# Convert PyTorch tensor to numpy array
pytorch_param = param.detach().numpy()
# Get corresponding TensorFlow variable
tensorflow_param = tf_var.numpy()
# Check if the shapes are the same
if pytorch_param.shape != tensorflow_param.shape:
pytorch_param = pytorch_param.transpose()
if pytorch_param.shape != tensorflow_param.shape:
print(f'Difference found in layer: {name}. Different shapes.')
return False
# Check if the weights are the same
if not np.allclose(pytorch_param, tensorflow_param, atol=1e-6):
print(f'Difference found in layer: {name}. Weights are not the same.')
return False
print('All layers have the same shape and weights.')
return True
# %%
# Make some data
num_batches = 8
xc_list_torch, yc_list_torch, xt_list_torch, yt_list_torch = [],[],[],[]
for batch in range(num_batches):
xc = torch.randn(16, 1, 10)
xt = torch.randn(16, 1, 15)
xc_list_torch.append(xc) # Context inputs
xt_list_torch.append(xt) # Target inputs
yc_list_torch.append(x_to_y(xc)) # Context outputs
yt_list_torch.append(x_to_y(xt)) # Target output
# Construct models
gnp_torch = nps_torch.construct_gnp(dim_x=1, dim_y=2, likelihood="het")
gnp_tf = nps_tf.construct_gnp(dim_x=1, dim_y=2, likelihood="het")
# SGD Optimizers that I have tested to be equivalent
opt_torch = torch.optim.SGD(gnp_torch.parameters(), lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov=False)
opt_tf = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.0, nesterov=False)
# %%
# Copy weights and biases
copy_weights_and_biases(gnp_torch,gnp_tf)
assert compare_models(gnp_torch,gnp_tf)
# %%
# Training loop of 5 actual epochs, with one warmup epoch at the start to test
# the losses of the untrained models to check they are similar from the inital weights
num_epochs = 5
epochs_loss_torch = []
epochs_loss_tf = []
for epoch in range(num_epochs+1):
this_epoch_loss_torch = []
this_epoch_loss_tf = []
for batch in range(num_batches):
# Torch version
xc_torch = xc_list_torch[batch]
yc_torch = yc_list_torch[batch]
xt_torch = xt_list_torch[batch]
yt_torch = yt_list_torch[batch]
loss_torch = -torch.mean(nps_torch.loglik(gnp_torch, xc_torch, yc_torch, xt_torch, yt_torch, normalise=True))
if epoch > 0:
opt_torch.zero_grad(set_to_none=True)
loss_torch.backward()
opt_torch.step()
this_epoch_loss_torch.append(loss_torch.detach().numpy())
# Tensorflow version with the same data
xc_tf = tf.convert_to_tensor(xc_torch.numpy())
yc_tf = tf.convert_to_tensor(yc_torch.numpy())
xt_tf = tf.convert_to_tensor(xt_torch.numpy())
yt_tf = tf.convert_to_tensor(yt_torch.numpy())
with tf.GradientTape() as tape:
# Compute the loss
loss_tf = -tf.reduce_mean(nps_tf.loglik(gnp_tf, xc_tf, yc_tf, xt_tf, yt_tf, normalise=True))
if epoch > 0:
gradients = tape.gradient(loss_tf, gnp_tf.trainable_variables)
opt_tf.apply_gradients(zip(gradients, gnp_tf.trainable_variables))
this_epoch_loss_tf.append(loss_tf.numpy())
# Collate the losses per epoch
epochs_loss_torch.append(np.mean(this_epoch_loss_torch).round(3))
epochs_loss_tf.append(np.mean(this_epoch_loss_tf).round(3))
print(f"Torch losses:\n{epochs_loss_torch}")
print(f"TF losses:\n{epochs_loss_tf}")
@DrJonnyT That's some impressive investigative work! :) Very nice!! Did you also check the attentive models? Perhaps its worthwhile to do that too?
@wesselb Here's a slightly updated version that works for gnp and agnp. I tried convgnp but the part where it checks if the weights are the same fails, but the training losses end up similar if you comment out the assert
line.
@DrJonnyT That's amazing. This is a super good check. :)
I think the convolutional models do not line up exactly because TF adopts a channels-last convention whether PyTorch is channels-first, so you may need to reorder the convolutional weights to get equality.
@wesselb Cool, I've made a minor tweak to that gist and now it works for convgnp as well 👍
@DrJonnyT That's super good.
How would you like it if I were to link the gist from the documentation, because I think this is a super important check?
A more ambitious plan would be to turn it into a unit test for the library, but that might not be so simple
@wesselb Sounds good!
I've been converting my code from tensorflow to pytorch and it's much easier to get it training faster. However, the performance after n epochs is worse in torch. After lots of digging, it seems like the model architectures come out different for AGNP? But it doesn't seem to be an issue for a GNP: