ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

[BUG] Layernorm provide strange results during inference #1041

Closed thegodone closed 3 weeks ago

thegodone commented 3 weeks ago

Describe the bug When I convert a Keras model into a mlx model and I run model(X) with different X length I got a different results. I provides only the 3 first input results in the following code:

To Reproduce Include code snippet

Y = model(X[0:100,:])
provides:
Input: (100, 128)
LayerNorm 1: (100, 64)
proj: (100, 64)
LayerNorm 2: (100, 64)
Output: (100, 1)
array([[0.637482],
       [0.639297],
       [0.87695],
while 
Y = model(X[0:200,:])
provides:
Input: (200, 128)
LayerNorm 1: (200, 64)
proj: (200, 64)
LayerNorm 2: (200, 64)
Output: (200, 1)
array([[0.645814],
       [0.658127],
       [0.862314],

Expected behavior LayerNorm should not change result during inference ?

Desktop (please complete the following information):

Additional context Add any other context about the problem here.

awni commented 3 weeks ago

Are you able to provide some code to reproduce your finding?

For example this works as expected:

import mlx.core as mx
import mlx.nn as nn

l = nn.LayerNorm(64)

x = mx.random.uniform(shape=(200, 64))

print(l(x[:100, :])[0, :])
print(l(x[:200, :])[0, :])
thegodone commented 3 weeks ago
import math
from typing import Any
import mlx.nn as nn
from mlx.utils import tree_flatten

import mlx.core as mx

from mlx.nn.layers.base import Module

class AttentionM(Module):
    def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
        super().__init__()
        self.output_dims = output_dims
        scale = math.sqrt(1.0 / input_dims)
        self.weight = mx.random.uniform(
            low=-scale,
            high=scale,
            shape=(input_dims, 1),
        )
        if bias:
            self.bias = mx.zeros(shape=(output_dims,1))

    def _extra_repr(self) -> str:
        return f"input_dims={self.weight.shape[0]}, output_dims={self.output_dims}, bias={'bias' in self}"

    def __call__(self, x: mx.array) -> mx.array:
        if "bias" in self:
            x_ = mx.addmm(self["bias"], x, self["weight"])
            x_ = mx.tanh(x_)
            x_ = mx.expand_dims(mx.softmax(mx.squeeze(x_,axis=-1)),axis=-1)
            x = mx.sum(x*x_,axis=1)
        else:
            x_ = x @ self["weight"]
            x_ = mx.tanh(x_)
            x_ = mx.expand_dims(mx.softmax(mx.squeeze(x_,axis=-1)),axis=-1)
            x = mx.sum(x*x_,axis=1)
        return x

class TimeDistributed(nn.Module):
    def __init__(
        self,
        func : nn.Module
    ):
        super().__init__()
        self.func = func

    def __call__(self, x):
        b_, t_ = x.shape[:2]
        c_ = self.func(x.flatten(0,1))
        return c_.reshape(b_, t_, *c_.shape[1:])

class Bidirectionnal(nn.Module):
    def __init__(
        self,
        func1 : nn.Module,
        func2 : nn.Module
    ):
        super().__init__()
        self.func1 = func1
        self.func2 = func2

    def __call__(self, x):

        h_f, h_b = self.func1(x), self.func2(x[:, ::-1, :]) 
        return  mx.stack([h_f[0], h_b[0]], axis=-1).flatten(-2,-1)

class SmilesX(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        inputdim: int,
        embdim: int,
        lstmdim: int,
        densedim1: int,
        densedim2: int,
        checkpoint: bool,
        debug: bool,
    ):
        super().__init__()

        self.Embedding = nn.Embedding(num_embeddings=vocab_size, dims=embdim)

        self.Image = Bidirectionnal(nn.LSTM(embdim, lstmdim, bias=True),
                                    nn.LSTM(embdim, lstmdim, bias=True))
        self.TimeDistributed = TimeDistributed(nn.Linear(2*lstmdim,densedim1))
        self.AttentionM  = AttentionM(densedim1,inputdim, bias=True)
        self.Layernorm1 =  nn.LayerNorm(densedim1)
        self.Proj = nn.Linear(densedim1,densedim2)
        self.Layernorm2 =  nn.LayerNorm(densedim2)
        self.Output = nn.Linear(densedim2,1)
        self.lk = nn.LeakyReLU(0.1)
        self.debug = debug

    def __call__(self, x):
        if self.debug:
            print('Input:',x.shape)

        # embedding
        x = self.Embedding(x)
        if self.debug:

            print('Embedding:',x.shape)
        # Bidirectional 
        x = self.Image(x)
        if self.debug:
            print('BiLSTM:',x.shape)

        #  TimeDistributed 
        x = self.TimeDistributed(x)
        if self.debug:
            print('TimeDistributed:',x.shape)
       # self attention
        x = self.AttentionM(x)
        if self.debug:
            print('AttentionM:',x.shape)

        # Layer norm 
        x = self.Layernorm1(x)
        if self.debug:
            print('LayerNorm 1:',x.shape)        
        x = self.Proj(x)        
        if self.debug:
            print('proj:',x.shape)

        x = self.lk(x)
        # Layer norm 
        x = self.Layernorm2(x)
        if self.debug:
            print('LayerNorm 2:',x.shape)

        x = self.Output(x)
        if self.debug:
            print('Output:',x.shape)
        return x

model = SmilesX(vocab_size=42, 
                inputdim = 128,
                embdim = 32,
                lstmdim =  32,
                densedim1 = 64,
                densedim2 = 64,

                checkpoint=False,
                debug=False)

# Initialize model:
nparams = sum(
    x.size for k, x in tree_flatten(model.parameters()))
print(f"Training a Model with {nparams}  parameters")

for k, x in tree_flatten(model.parameters()):

    print(x.size,k)

x = mx.random.randint(0,42,[200,128])

print(model(x[:100, :])[0, :])
print(model(x[:200, :])[0, :])
awni commented 3 weeks ago

The problem is you are not specifying the axis dimension in the softmax so it is taking the softmax over the full input (treating it as a single vector).

You should change:

x_ = mx.expand_dims(mx.softmax(mx.squeeze(x_,axis=-1)),axis=-1)

To:

x_ = mx.expand_dims(mx.softmax(mx.squeeze(x_,axis=-1), axis=-1),axis=-1)

Docs on softmax here https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.softmax.html#mlx.core.softmax