ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

[BUG] cannot replicate a keras model into mlx when I reuse keras pretrained weights #1057

Open thegodone opened 2 weeks ago

thegodone commented 2 weeks ago

Describe the bug

I have exactly the same model in keras and mlx and using default keras random initialisation weights, the two models output are identical. But If I reload a pretrained keras weights, there is a clear difference between the two models outputs. I don't understand where the error is coming from.

To Reproduce

import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import LSTM, Dense, Embedding, LeakyReLU, Dropout, Bidirectional, TimeDistributed, LayerNormalization
import math
from typing import Any
import mlx.nn as nn
from mlx.utils import tree_flatten
import numpy as np
import mlx.core as mx

from mlx.nn.layers.base import Module

# modified from https://github.com/sujitpal/eeap-examples
@tf.keras.saving.register_keras_serializable()
class AttentionM(Layer):
    """
    Keras layer to compute an attention vector on an incoming matrix.
    # Input
        enc - 3D Tensor of shape (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
    # Output
        2D Tensor of shape (BATCH_SIZE, EMBED_SIZE)
    # Usage
        enc = LSTM(EMBED_SIZE, return_sequences=True)(...)
        att = AttentionM()(enc)
    """
    def __init__(self, axis=-1, return_probabilities = False, **kwargs):
        self.return_probabilities = return_probabilities
        self.axis = axis
        super(AttentionM, self).__init__(**kwargs)

    def build(self, input_shape):
        # W: (EMBED_SIZE, 1)
        # b: (MAX_TIMESTEPS,)
        self.W = self.add_weight(name="W_{:s}".format(self.name),
                                 shape=(input_shape[-1], 1),
                                 initializer="normal")
        self.b = self.add_weight(name="b_{:s}".format(self.name),
                                 shape=(input_shape[1], 1),
                                 initializer="zeros")
        super(AttentionM, self).build(input_shape)

    def call(self, x, mask=None):
        # input: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
        # et: (BATCH_SIZE, MAX_TIMESTEPS)
        et = K.squeeze(K.tanh(K.dot(x, self.W) + self.b), axis=self.axis)
        # at: (BATCH_SIZE, MAX_TIMESTEPS)
        at = K.softmax(et)
        if mask is not None:
            at *= K.cast(mask, K.floatx())
        # atx: (BATCH_SIZE, MAX_TIMESTEPS, 1)
        atx = K.expand_dims(at, axis=-1)
        # ot: (BATCH_SIZE, MAX_TIMESTEPS, EMBED_SIZE)
        ot = x * atx
        # output: (BATCH_SIZE, EMBED_SIZE)
        if self.return_probabilities:
            return atx # for visualization of the attention weights
        else:
            return K.sum(ot, axis=1) # for prediction

    def compute_mask(self, input, input_mask=None):
        # do not pass the mask to the next layers
        return None

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[-1])

    def get_config(self):
        return super(AttentionM, self).get_config()

# checking all the layers Dense is working!

def exposekerasweights(model):
    w = {}
    j =0
    for layer in model.layers:
        weights = layer.get_weights()  # returns a list of all weight tensors in the layer
        print(layer.name)
        for i, weight in enumerate(weights):
            # Any Dense / linear weights need to be Transposed 
            if layer.name+"."+str(i) in ['Output.0','Proj.0','TimeDistributed.0'] :
                w[layer.name+"."+str(i)]=weight.T
            else:
                w[layer.name+"."+str(i)]=weight
            j+=1
    return w

# Define the model
model = Sequential([
    Embedding(input_dim=42,  output_dim=32, input_length=128,name="Embedding"),
    Dropout(0.1, name='Dropout'),
    Bidirectional(LSTM(32, return_sequences=True, use_bias=True),name="Image"),
    TimeDistributed(Dense(64),name="TimeDistributed"),
    AttentionM(return_probabilities=False,name="AttentionM"),
    LayerNormalization(name='LN1'),
    Dense(64, name='Proj'),  
    LeakyReLU(0.1, name='LK'),
    LayerNormalization(name='LN2'),
    Dense(1, name='Output') 
])

model.compile(
    optimizer='adam',  
    loss='mse', 
    metrics=['mae'] 
)
model.summary()

# if we use predefined weights this is not working properly ???? ie decomment this line
# model.load_weights('model_keras_weights.h5')

class AttentionM_(Module):
    def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
        super().__init__()
        self.output_dims = output_dims
        scale = math.sqrt(1.0 / input_dims)
        self.weight = mx.random.uniform(
            low=-scale,
            high=scale,
            shape=(input_dims, 1),
        )
        if bias:
            self.bias = mx.zeros(shape=(output_dims, 1))

    def _extra_repr(self) -> str:
        return f"input_dims={self.weight.shape[0]}, output_dims={self.output_dims}, bias={'bias' in self}"

    def __call__(self, x: mx.array) -> mx.array:
        if "bias" in self:
            x_ = mx.addmm(self["bias"], x, self["weight"])
        else:
            x_ = x @ self["weight"]
        x_ = mx.tanh(x_)
        x_ = mx.expand_dims(mx.softmax(mx.squeeze(x_,axis=-1), axis=-1),axis=-1)
        x = mx.sum(x*x_,axis=1)
        return x

class TimeDistributed_(nn.Module):
    def __init__(
        self,
        func : nn.Module):
        super().__init__()
        self.func = func

    def __call__(self, x):
        b_, t_ = x.shape[:2]
        c_ = self.func(x.flatten(0,1))
        return c_.reshape(b_, t_, *c_.shape[1:])

class Bidirectionnal_(nn.Module):
    def __init__(
        self,
        func1 : nn.Module,
        func2 : nn.Module):

        super().__init__()
        self.func1 = func1
        self.func2 = func2

    def __call__(self, x):
        h_f, h_b = self.func1(x), self.func2(x[:, ::-1, :])         
        return  mx.concatenate([h_f[0], h_b[0]], axis=-1)

w = exposekerasweights(model)
np.savez('w.npz', **w)

class mlx_copy(nn.Module):
    def __init__(
        self,  input_dim: int, hidden_dim: int, output_dim: int
    ):
        super().__init__()
        self.Embedding = nn.Embedding(num_embeddings=42, dims=32)
        self.Dropout = nn.Dropout(0.1)
        self.Image = Bidirectionnal_(nn.LSTM(32, 32, bias=True),
                                    nn.LSTM(32, 32, bias=True))
        self.TimeDistributed = TimeDistributed_(nn.Linear(64,64))
        self.AttentionM = AttentionM_(64,128, bias=True)
        self.LN1 = nn.LayerNorm(64,eps=0.001)
        self.Proj = nn.Linear(64, hidden_dim)
        self.LK = nn.LeakyReLU(0.1)
        self.LN2 = nn.LayerNorm(hidden_dim,eps=0.001)
        self.Output = nn.Linear(hidden_dim, 1)

    def __call__(self, x):
        x = self.Embedding(x)
        x = self.Dropout(x)
        x = self.Image(x)
        x = self.TimeDistributed(x)
        x = self.AttentionM(x)
        x = self.LN1(x)
        x = self.Proj(x)
        x = self.LK(x)
        x = self.LN2(x)
        x = self.Output(x)
        return x 

model_mlx = mlx_copy(128, 64, 1)
mx.eval(model_mlx.parameters())

we = 0
for k, x in tree_flatten(model_mlx.parameters()):
    we+=x.size
    print(x.size,k)
print(we)

# load the saved keras features

w_loaded = np.load('w.npz')

# convert to mlx names and saved it

def replace_key(key: str) -> str:
    key = key.replace("Output.0", "Output.weight")
    key = key.replace("Output.1", "Output.bias")
    key = key.replace("Proj.0", "Proj.weight")
    key = key.replace("Proj.1", "Proj.bias")
    key = key.replace("Embedding.0", "Embedding.weight")
    key = key.replace("LSTM.0", "LSTM.Wx")
    key = key.replace("LSTM.1", "LSTM.Wh")
    key = key.replace("LSTM.2", "LSTM.bias")
    key = key.replace("Image.0", "Image.func1.Wx")
    key = key.replace("Image.1", "Image.func1.Wh")
    key = key.replace("Image.2", "Image.func1.bias")
    key = key.replace("Image.3", "Image.func2.Wx")
    key = key.replace("Image.4", "Image.func2.Wh")
    key = key.replace("Image.5", "Image.func2.bias")
    key = key.replace("AttentionM.0", "AttentionM.weight")
    key = key.replace("AttentionM.1", "AttentionM.bias")
    key = key.replace("TimeDistributed.0", "TimeDistributed.func.weight")
    key = key.replace("TimeDistributed.1", "TimeDistributed.func.bias")
    key = key.replace("LN1.0", "LN1.weight")
    key = key.replace("LN1.1", "LN1.bias")
    key = key.replace("LN2.0", "LN2.weight")
    key = key.replace("LN2.1", "LN2.bias")
    return key

# switch layer names of saved keras tensors
tensors = {
    replace_key(key): tensor for key, tensor in w_loaded.items()
}

for (k,v),(kt,vt) in zip(w_loaded.items(),tensors.items()):
    print(k,v.shape, kt, vt.shape, np.max(np.abs(v-vt)))

np.savez('wconvert.npz', **tensors)

# load weights for mlx model

model_mlx.load_weights('wconvert.npz')
model_mlx.freeze()

x_train = np.random.randint(0,41, (10,128))

keras_ = model.predict(x_train)
print(keras_.shape, keras_[:10])

model_mlx.train(False)
model_mlx.eval()
h = model_mlx(mx.array(x_train))
print(h.shape, h[:10])

assert np.max(np.abs(h-mx.array(keras_))) < 1e-1

Expected behavior if I decomment this line

# model.load_weights('model_keras_weights.h5')
image

otherwise both models output are matching even if the similarity if not 1e-6 but around 1e-2 ! :

image

Desktop (please complete the following information):

Additional context Add any other context about the problem here.