I have exactly the same model in keras and mlx and using default keras random initialisation weights, the two models output are identical.
But If I reload a pretrained keras weights, there is a clear difference between the two models outputs. I don't understand where the error is coming from.
To Reproduce
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import LSTM, Dense, Embedding, LeakyReLU, Dropout, Bidirectional, TimeDistributed, LayerNormalization
import math
from typing import Any
import mlx.nn as nn
from mlx.utils import tree_flatten
import numpy as np
import mlx.core as mx
from mlx.nn.layers.base import Module
# modified from https://github.com/sujitpal/eeap-examples
@tf.keras.saving.register_keras_serializable()
class AttentionM(Layer):
"""
Keras layer to compute an attention vector on an incoming matrix.
# Input
enc - 3D Tensor of shape (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
# Output
2D Tensor of shape (BATCH_SIZE, EMBED_SIZE)
# Usage
enc = LSTM(EMBED_SIZE, return_sequences=True)(...)
att = AttentionM()(enc)
"""
def __init__(self, axis=-1, return_probabilities = False, **kwargs):
self.return_probabilities = return_probabilities
self.axis = axis
super(AttentionM, self).__init__(**kwargs)
def build(self, input_shape):
# W: (EMBED_SIZE, 1)
# b: (MAX_TIMESTEPS,)
self.W = self.add_weight(name="W_{:s}".format(self.name),
shape=(input_shape[-1], 1),
initializer="normal")
self.b = self.add_weight(name="b_{:s}".format(self.name),
shape=(input_shape[1], 1),
initializer="zeros")
super(AttentionM, self).build(input_shape)
def call(self, x, mask=None):
# input: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
# et: (BATCH_SIZE, MAX_TIMESTEPS)
et = K.squeeze(K.tanh(K.dot(x, self.W) + self.b), axis=self.axis)
# at: (BATCH_SIZE, MAX_TIMESTEPS)
at = K.softmax(et)
if mask is not None:
at *= K.cast(mask, K.floatx())
# atx: (BATCH_SIZE, MAX_TIMESTEPS, 1)
atx = K.expand_dims(at, axis=-1)
# ot: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
ot = x * atx
# output: (BATCH_SIZE, EMBED_SIZE)
if self.return_probabilities:
return atx # for visualization of the attention weights
else:
return K.sum(ot, axis=1) # for prediction
def compute_mask(self, input, input_mask=None):
# do not pass the mask to the next layers
return None
def compute_output_shape(self, input_shape):
return (input_shape[0], input_shape[-1])
def get_config(self):
return super(AttentionM, self).get_config()
# checking all the layers Dense is working!
def exposekerasweights(model):
w = {}
j =0
for layer in model.layers:
weights = layer.get_weights() # returns a list of all weight tensors in the layer
print(layer.name)
for i, weight in enumerate(weights):
# Any Dense / linear weights need to be Transposed
if layer.name+"."+str(i) in ['Output.0','Proj.0','TimeDistributed.0'] :
w[layer.name+"."+str(i)]=weight.T
else:
w[layer.name+"."+str(i)]=weight
j+=1
return w
# Define the model
model = Sequential([
Embedding(input_dim=42, output_dim=32, input_length=128,name="Embedding"),
Dropout(0.1, name='Dropout'),
Bidirectional(LSTM(32, return_sequences=True, use_bias=True),name="Image"),
TimeDistributed(Dense(64),name="TimeDistributed"),
AttentionM(return_probabilities=False,name="AttentionM"),
LayerNormalization(name='LN1'),
Dense(64, name='Proj'),
LeakyReLU(0.1, name='LK'),
LayerNormalization(name='LN2'),
Dense(1, name='Output')
])
model.compile(
optimizer='adam',
loss='mse',
metrics=['mae']
)
model.summary()
# if we use predefined weights this is not working properly ???? ie decomment this line
# model.load_weights('model_keras_weights.h5')
class AttentionM_(Module):
def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
super().__init__()
self.output_dims = output_dims
scale = math.sqrt(1.0 / input_dims)
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, 1),
)
if bias:
self.bias = mx.zeros(shape=(output_dims, 1))
def _extra_repr(self) -> str:
return f"input_dims={self.weight.shape[0]}, output_dims={self.output_dims}, bias={'bias' in self}"
def __call__(self, x: mx.array) -> mx.array:
if "bias" in self:
x_ = mx.addmm(self["bias"], x, self["weight"])
else:
x_ = x @ self["weight"]
x_ = mx.tanh(x_)
x_ = mx.expand_dims(mx.softmax(mx.squeeze(x_,axis=-1), axis=-1),axis=-1)
x = mx.sum(x*x_,axis=1)
return x
class TimeDistributed_(nn.Module):
def __init__(
self,
func : nn.Module):
super().__init__()
self.func = func
def __call__(self, x):
b_, t_ = x.shape[:2]
c_ = self.func(x.flatten(0,1))
return c_.reshape(b_, t_, *c_.shape[1:])
class Bidirectionnal_(nn.Module):
def __init__(
self,
func1 : nn.Module,
func2 : nn.Module):
super().__init__()
self.func1 = func1
self.func2 = func2
def __call__(self, x):
h_f, h_b = self.func1(x), self.func2(x[:, ::-1, :])
return mx.concatenate([h_f[0], h_b[0]], axis=-1)
w = exposekerasweights(model)
np.savez('w.npz', **w)
class mlx_copy(nn.Module):
def __init__(
self, input_dim: int, hidden_dim: int, output_dim: int
):
super().__init__()
self.Embedding = nn.Embedding(num_embeddings=42, dims=32)
self.Dropout = nn.Dropout(0.1)
self.Image = Bidirectionnal_(nn.LSTM(32, 32, bias=True),
nn.LSTM(32, 32, bias=True))
self.TimeDistributed = TimeDistributed_(nn.Linear(64,64))
self.AttentionM = AttentionM_(64,128, bias=True)
self.LN1 = nn.LayerNorm(64,eps=0.001)
self.Proj = nn.Linear(64, hidden_dim)
self.LK = nn.LeakyReLU(0.1)
self.LN2 = nn.LayerNorm(hidden_dim,eps=0.001)
self.Output = nn.Linear(hidden_dim, 1)
def __call__(self, x):
x = self.Embedding(x)
x = self.Dropout(x)
x = self.Image(x)
x = self.TimeDistributed(x)
x = self.AttentionM(x)
x = self.LN1(x)
x = self.Proj(x)
x = self.LK(x)
x = self.LN2(x)
x = self.Output(x)
return x
model_mlx = mlx_copy(128, 64, 1)
mx.eval(model_mlx.parameters())
we = 0
for k, x in tree_flatten(model_mlx.parameters()):
we+=x.size
print(x.size,k)
print(we)
# load the saved keras features
w_loaded = np.load('w.npz')
# convert to mlx names and saved it
def replace_key(key: str) -> str:
key = key.replace("Output.0", "Output.weight")
key = key.replace("Output.1", "Output.bias")
key = key.replace("Proj.0", "Proj.weight")
key = key.replace("Proj.1", "Proj.bias")
key = key.replace("Embedding.0", "Embedding.weight")
key = key.replace("LSTM.0", "LSTM.Wx")
key = key.replace("LSTM.1", "LSTM.Wh")
key = key.replace("LSTM.2", "LSTM.bias")
key = key.replace("Image.0", "Image.func1.Wx")
key = key.replace("Image.1", "Image.func1.Wh")
key = key.replace("Image.2", "Image.func1.bias")
key = key.replace("Image.3", "Image.func2.Wx")
key = key.replace("Image.4", "Image.func2.Wh")
key = key.replace("Image.5", "Image.func2.bias")
key = key.replace("AttentionM.0", "AttentionM.weight")
key = key.replace("AttentionM.1", "AttentionM.bias")
key = key.replace("TimeDistributed.0", "TimeDistributed.func.weight")
key = key.replace("TimeDistributed.1", "TimeDistributed.func.bias")
key = key.replace("LN1.0", "LN1.weight")
key = key.replace("LN1.1", "LN1.bias")
key = key.replace("LN2.0", "LN2.weight")
key = key.replace("LN2.1", "LN2.bias")
return key
# switch layer names of saved keras tensors
tensors = {
replace_key(key): tensor for key, tensor in w_loaded.items()
}
for (k,v),(kt,vt) in zip(w_loaded.items(),tensors.items()):
print(k,v.shape, kt, vt.shape, np.max(np.abs(v-vt)))
np.savez('wconvert.npz', **tensors)
# load weights for mlx model
model_mlx.load_weights('wconvert.npz')
model_mlx.freeze()
x_train = np.random.randint(0,41, (10,128))
keras_ = model.predict(x_train)
print(keras_.shape, keras_[:10])
model_mlx.train(False)
model_mlx.eval()
h = model_mlx(mx.array(x_train))
print(h.shape, h[:10])
assert np.max(np.abs(h-mx.array(keras_))) < 1e-1
Expected behavior
if I decomment this line
# model.load_weights('model_keras_weights.h5')
otherwise both models output are matching even if the similarity if not 1e-6 but around 1e-2 ! :
Desktop (please complete the following information):
OS Version: MacOS 14.4
Version 0.7.0
Additional context
Add any other context about the problem here.
Describe the bug
I have exactly the same model in keras and mlx and using default keras random initialisation weights, the two models output are identical. But If I reload a pretrained keras weights, there is a clear difference between the two models outputs. I don't understand where the error is coming from.
To Reproduce
Expected behavior if I decomment this line
otherwise both models output are matching even if the similarity if not 1e-6 but around 1e-2 ! :
Desktop (please complete the following information):
Additional context Add any other context about the problem here.