ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

[FEATURE] in keras LayerNorm by default is apply to last dimension only #1049

Closed thegodone closed 2 weeks ago

thegodone commented 2 weeks ago

Describe the bug Look like mlx.core.LayerNorm applies on all dimensions. Can we add an axis parameter to set it like in keras ?

Expected behavior I am trying to clone a keras code and the layernorm behaviour is not identical to keras. Can we add a feature to apply only on the axis we want to.

Desktop (please complete the following information):

Additional context Add any other context about the problem here.

awni commented 2 weeks ago

Actually, that's not true. LayerNorm only and always normalizes normalizes over the last axis.

import mlx.nn as nn
import mlx.core as mx

ln = nn.LayerNorm(32)
x = mx.random.uniform(shape=(10, 32))
print(ln(x).sum(axis=-1)) # close to 0
print(ln(x).sum(axis=0)) # not close to 0

Since MLX NN standardizes on the feature dimension being last, we don't have plans to include an axis parameter in our LayerNorm. It sounds like the current behavior works for what you want since it is consistent with Keras?

If there is something more here, let me know and we can reopen/discuss further.

thegodone commented 2 weeks ago

But If I use the weights and bias from Keras in LayerNorm I don't get the same result , why ?

If I comment remove LayerNorm on both models assert is working.

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Embedding, LayerNormalization
import numpy as np
import mlx.nn as nn
from mlx.utils import tree_flatten
import mlx.core as mx

# Define the model
model = Sequential([
    Embedding(input_dim=40,  output_dim=32, input_length=20,name="Embedding"),
    LayerNormalization(name='LN1'),

])

# Here input_dim is the number of input features, and output_dim is the number of output features.
model.compile(
    optimizer='adam',  # Optimizer
    loss='mse',  # Mean Squared Error for regression tasks
    metrics=['mae']  # Mean Absolute Error for regression metrics
)
model.summary()

def exposeweights(model):
    w = {}
    j =0
    for layer in model.layers:
        weights = layer.get_weights()  # returns a list of all weight tensors in the layer
        print(layer.name)
        for i, weight in enumerate(weights):
            # is the Dense / linear are opposite array (ie Transpose) ?
            if layer.name+"."+str(i) in ['Output.0','Proj.0','TimeDistributed.0'] :
                w[layer.name+"."+str(i)]=weight.T
            else:
                w[layer.name+"."+str(i)]=weight
            j+=1
    return w

w = exposeweights(model)
np.savez('w.npz', **w)

class mlxduplicate(nn.Module):
    def __init__(
        self):
        super().__init__()
        self.Embedding = nn.Embedding(num_embeddings=40, dims=32)
        self.LN1 = nn.LayerNorm(32)

    def __call__(self, x):
        x = self.Embedding (x)
        x = self.LN1(x)
        return x 

model_mlx = mlxduplicate()
we = 0
for k, x in tree_flatten(model_mlx.parameters()):
    we+=x.size
    print(x.size,k)
print(we)

tensor_loaded = np.load('w.npz')

def replace_key(key: str) -> str:
    key = key.replace("Embedding.0", "Embedding.weight")
    key = key.replace("LN1.0", "LN1.weight")
    key = key.replace("LN1.1", "LN1.bias")
    return key

# switch layer names of saved keras tensors
tensors_mlx = {
    replace_key(key): tensor for key, tensor in tensor_loaded.items()
}

for k,v in tensor_loaded.items():
    print(k,v.shape)

for k,v in tensors_mlx.items():
    print(k,v.shape)

np.savez('w_convert_to_mlx.npz', **tensors_mlx)

model_mlx.load_weights('w_convert_to_mlx.npz')

x_train = np.random.randint(0,39, (2,20))
keras_output = model.predict(x_train)
keras_output.shape

mlx_output = model_mlx(mx.array(x_train))
assert mlx_output.shape == keras_output.shape

assert np.max(np.abs(mlx_output-mx.array(keras_output))) < 1e-6

2024-04-28 20:12:54.329876: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3 Max 2024-04-28 20:12:54.329898: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 128.00 GB 2024-04-28 20:12:54.329900: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 48.00 GB 2024-04-28 20:12:54.329932: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support. 2024-04-28 20:12:54.329949: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: ) Model: "sequential"


Layer (type) Output Shape Param #

Embedding (Embedding) (None, 20, 32) 1280

LN1 (LayerNormalization) (None, 20, 32) 64

================================================================= Total params: 1344 (5.25 KB) Trainable params: 1344 (5.25 KB) Non-trainable params: 0 (0.00 Byte)


Embedding LN1 1280 Embedding.weight 32 LN1.bias 32 LN1.weight 1344 Embedding.0 (40, 32) LN1.0 (32,) LN1.1 (32,) Embedding.weight (40, 32) LN1.weight (32,) LN1.bias (32,) 2024-04-28 20:12:54.785211: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled. 1/1 [==============================] - 1s 832ms/step

AssertionError Traceback (most recent call last) Cell In[2], line 93 89 mlx_output = model_mlx(mx.array(x_train)) 90 assert mlx_output.shape == keras_output.shape ---> 93 assert np.max(np.abs(mlx_output-mx.array(keras_output))) < 1e-6

AssertionError:

awni commented 2 weeks ago

It looks like Keras uses a different default (and much higher) epsilon for numerical stability. You can set this in the MLX LayerNorm constructor. The following passes:

from tensorflow.keras.layers import LayerNormalization
import numpy as np
import mlx.nn as nn
import mlx.core as mx

# Define the model
ln = LayerNormalization(name='LN1')
x = np.random.uniform(size=(10, 32))
out_keras = np.array(ln(x))

ln_mlx = nn.LayerNorm(32, eps=1e-3) # note setting epsilon here
out_mlx = np.array(ln_mlx(mx.array(x)))
assert np.abs((out_keras - out_mlx)).max() < 1e-6
thegodone commented 2 weeks ago

thanks @awni very much appreciate your help on that: I still have one question can you explain me this error for very large dataset I have a strange behaviour ? image I use this code :

import math
from typing import Any
import mlx.nn as nn
from mlx.utils import tree_flatten
import numpy as np
import mlx.core as mx

from mlx.nn.layers.base import Module

class AttentionM_(Module):
    def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
        super().__init__()
        self.output_dims = output_dims
        scale = math.sqrt(1.0 / input_dims)
        self.weight = mx.random.uniform(
            low=-scale,
            high=scale,
            shape=(input_dims, 1),
        )
        if bias:
            self.bias = mx.zeros(shape=(output_dims, 1))

    def _extra_repr(self) -> str:
        return f"input_dims={self.weight.shape[0]}, output_dims={self.output_dims}, bias={'bias' in self}"

    def __call__(self, x: mx.array) -> mx.array:
        if "bias" in self:
            x_ = mx.addmm(self["bias"], x, self["weight"])
        else:
            x_ = x @ self["weight"]
        x_ = mx.tanh(x_)
        x_ = mx.expand_dims(mx.softmax(mx.squeeze(x_,axis=-1), axis=-1),axis=-1)
        x = mx.sum(x*x_,axis=1)
        return x

class TimeDistributed_(nn.Module):
    def __init__(
        self,
        func : nn.Module
    ):
        super().__init__()
        self.func = func

    def __call__(self, x):
        b_, t_ = x.shape[:2]
        c_ = self.func(x.flatten(0,1))
        return c_.reshape(b_, t_, *c_.shape[1:])

class Bidirectionnal_(nn.Module):
    def __init__(
        self,
        func1 : nn.Module,
        func2 : nn.Module
    ):
        super().__init__()
        self.func1 = func1
        self.func2 = func2

    def __call__(self, x):

        h_f, h_b = self.func1(x), self.func2(x[:, ::-1, :]) 
        return  mx.stack([h_f[0], h_b[0]], axis=-1).flatten(-2,-1)

class SmilesX(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        inputdim: int,
        embdim: int,
        lstmdim: int,
        densedim1: int,
        densedim2: int,
        checkpoint: bool,
        debug: bool,
    ):
        super().__init__()

        self.Embedding = nn.Embedding(num_embeddings=vocab_size, dims=embdim)

        self.Image = Bidirectionnal_(nn.LSTM(embdim, lstmdim, bias=True),
                                    nn.LSTM(embdim, lstmdim, bias=True))
        self.TimeDistributed = TimeDistributed_(nn.Linear(2*lstmdim,densedim1))
        self.AttentionM  = AttentionM_(densedim1,inputdim, bias=True)
        self.Layernorm1 =  nn.LayerNorm(densedim1,eps=0.001)
        self.Proj = nn.Linear(densedim1,densedim2)
        self.Layernorm2 =  nn.LayerNorm(densedim2,eps=0.001)
        self.Output = nn.Linear(densedim2,1)
        self.lk = nn.LeakyReLU(0.1)
        self.debug = debug

    def __call__(self, x):
        if self.debug:
            print('Input:',x.shape)

        # embedding
        x = self.Embedding(x)
        if self.debug:

            print('Embedding:',x.shape)
        # Bidirectional 
        x = self.Image(x)
        if self.debug:
            print('BiLSTM:',x.shape)

        #  TimeDistributed 
        x = self.TimeDistributed(x)
        if self.debug:
            print('TimeDistributed:',x.shape)
       # self attention
        x = self.AttentionM(x)
        if self.debug:
            print('AttentionM:',x.shape)

        # Layer norm 
        x = self.Layernorm1(x)
        if self.debug:
            print('LayerNorm 1:',x.shape)        
        x = self.Proj(x)        
        if self.debug:
            print('proj:',x.shape)

        x = self.lk(x)
        # Layer norm 
        x = self.Layernorm2(x)
        if self.debug:
            print('LayerNorm 2:',x.shape)

        x = self.Output(x)
        if self.debug:
            print('Output:',x.shape)
        return x

model = SmilesX(vocab_size=42, 
                inputdim = 128,
                embdim = 32,
                lstmdim =  32,
                densedim1 = 64,
                densedim2 = 64,
                checkpoint=False,
                debug=False)

# Initialize model:
nparams = sum(
    x.size for k, x in tree_flatten(model.parameters()))
print(f"Training a SMILES-X Model with {nparams}  parameters")

xt = 0
for k, x in tree_flatten(model.parameters()):
    print(x.size,k)
    xt+=x.size
print(xt)

# test for big dataset:
X = mx.random.randint(0,42,[600000,128])

# evaluate by data size the results
Y1 = model(X[:10,:])
Y3 = model(X[:100,:])
Y2 = model(X[:200,:])
Y4 = model(X[:1000,:])
Y5 = model(X[:10000,:])
Y6 = model(X[:100000,:])
Y7 = model(X[:600000,:])

# validate the results 
assert np.max(np.abs(Y1 - Y3[:10]))<1e-8
assert np.max(np.abs(Y1 - Y2[:10]))<1e-8
assert np.max(np.abs(Y1 - Y4[:10]))<1e-8
assert np.max(np.abs(Y1 - Y5[:10]))<1e-8
assert np.max(np.abs(Y1 - Y6[:10]))<1e-8
assert np.max(np.abs(Y1 - Y7[:10]))<1e-8
awni commented 2 weeks ago

The size is really large, I think some matrices are well over 4B entries. My guess is it's overflowing an integer index somewhere but I'm not sure where. I'll look into where that is to see if we can put a error message or fix it. For now I would stick to smaller sizes.

awni commented 2 weeks ago

I filed a separate issue about this https://github.com/ml-explore/mlx/issues/1051

thegodone commented 2 weeks ago

Does it happens during evaluation too ? Would be nice to add batch size for inference Envoyé de mon iPhoneLe 29 avr. 2024 à 15:45, Awni Hannun @.***> a écrit : I filed a separate issue about this #1051

—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: @.***>

awni commented 2 weeks ago

Does it happens during evaluation too ? Would be nice to add batch size for inference

The problem is the large matmul in the LSTM. So if the batch size is larger than about 131k (for the LSTM / model dimensions you provided) then it will break regardless of inference / training modes.

thegodone commented 2 weeks ago

Thanks for clarification.Envoyé de mon iPhoneLe 29 avr. 2024 à 16:20, Awni Hannun @.***> a écrit :

Does it happens during evaluation too ? Would be nice to add batch size for inference

The problem is the large matmul in the LSTM. So if the batch size is larger than about 131k (for the LSTM / model dimensions you provided) then it will break regardless of inference / training modes.

—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: @.***>