ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.1k stars 988 forks source link

[BUG] Different outputs from PyTorch and MLX on a simple MLP #1424

Closed sachinraja13 closed 1 month ago

sachinraja13 commented 1 month ago

Describe the bug Different outputs from PyTorch and MLX of a simple MLP despite same weight initialisation and same input

To Reproduce


from torch import nn
import torch.nn.functional as F
import torch
import mlx.core as mx
import mlx.nn as mnn
import numpy as np
from mlx.utils import tree_flatten, tree_unflatten

class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x

class MLP_mlx(mnn.Module):
    """ Simple multi-layer perceptron (also called FFN) """

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        h = [hidden_dim] * (num_layers - 1)
        self.layers = [mnn.Linear(idim, odim) for idim, odim in zip([input_dim] + h, h + [output_dim])]

    def __call__(self, x):
        for i, layer in enumerate(self.layers):
            x = mnn.relu(layer(x)) if i < len(self.layers) - 1 else layer(x)
        return x

def initialize_numpy_arrays_and_model(model):
    numpy_arrays = {}
    for name, param in model.named_parameters():
        if 'params.' in name:
            continue
        print(name, param.shape)
        param_shape = param.shape
        np_array = np.random.randn(*param_shape).astype(np.float32)
        numpy_arrays[name] = np_array
        param.data = torch.from_numpy(np_array)
    return model, numpy_arrays

def apply(dst, parameters, param_key = ''):
    if isinstance(parameters, dict):
        for k in parameters:
            if k in dst:
                current_value = dst[k]
                new_value = parameters[k]
                # print("Key: ", k, " FOUND")
                if isinstance(current_value, mx.array):
                    if current_value.shape == new_value.shape:
                        dst[k] = new_value
                        param_key = param_key + k
                    else:
                        print("Not updated : " , param_key)
                        print(current_value.shape, new_value.shape)
                elif isinstance(current_value, mnn.Module):
                    param_key = param_key + k + "."
                    apply(current_value, new_value, param_key)
                elif isinstance(current_value, (dict, list)):
                    param_key = param_key + k + "."
                    apply(current_value, new_value, param_key)
            else:
                pass
                # print("Key: ", k, " NOT FOUND")
    elif isinstance(parameters, list):
        for i in range(len(parameters)):
            current_value = dst[i]
            new_value = parameters[i]
            if isinstance(current_value, mx.array):
                print("Updated : " + param_key)
                dst[i] = new_value
            elif isinstance(current_value, mnn.Module):
                apply(current_value, new_value, param_key)
            elif isinstance(current_value, (dict, list)):
                apply(current_value, new_value, param_key)

def initialize_mlx_model(mlx_model, numpy_arrays):
    new_params = []
    keys_mapped = {}
    for name, param in tree_flatten(mlx_model):
        if 'params.' in name:
            continue
        print(name)
        try:
            new_params.append((name, mx.array(numpy_arrays[name])))
            keys_mapped[name] = True
        except:
            print("Could not initialize : " + name)
    mlx_model = tree_unflatten(new_params)
    for name in numpy_arrays:
        if name not in keys_mapped:
            print("Key not found in mlx model: ", name)
    return mlx_model

def initialize_models(pt_model, mlx_model):
    # Initialize parameters with random NumPy arrays
    pt_model, numpy_arrays = initialize_numpy_arrays_and_model(pt_model)
    mlx_params = initialize_mlx_model(mlx_model, numpy_arrays)
    apply(mlx_model.children(), mlx_params)

def assert_close(pytorch_result, mlx_result, rtol=1e-6, atol=1e-6):
    # print(pytorch_result)
    # print(mlx_result)
    count_close = np.isclose(pytorch_result, mlx_result, rtol=rtol, atol=atol).sum()
    close_flag = np.isclose(pytorch_result, mlx_result, rtol=rtol, atol=atol)
    print("Total elements: " , pytorch_result.size)
    print("Close elements: " , count_close)
    print("Count of Differing elements: " , pytorch_result.size - count_close)
    bins=[0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1, 1.5, 2, 5, 10]
    hist = np.histogram(np.abs(pytorch_result[~close_flag].flatten() - mlx_result[~close_flag].flatten()), bins=bins)[0]
    for i in range(len(hist)):
        print("Bin:\t", bins[i], " <= ", bins[i+1])
        print("\t\t\t\tCount:\t", hist[i])

def test_mlp():
    bs = 2
    input_dim = 256
    num_layers = 5
    hidden_dim = 512
    output_dim = 256
    np.random.seed(42)  # For reproducibility
    x_np = np.random.rand(bs, input_dim).astype(np.float32)
    x = torch.from_numpy(x_np)
    x_mlx = mx.array(x_np)
    mlp = MLP(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, num_layers=num_layers)
    mlp_mlx = MLP_mlx(input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, num_layers=num_layers)
    initialize_models(mlp, mlp_mlx)
    output = mlp(x)
    output_mlx = mlp_mlx(x_mlx)

    assert_close(output.detach().numpy(), np.asarray(output_mlx), rtol=1e-6, atol=1e-6)

test_mlp()

Expected behavior All output values from Pytorch and MLX models within the range of 1e-4.

Actual Output

Total elements:  512
Close elements:  313
Count of Differing elements:  199
Bin:     1e-06  <=  1e-05
                Count:   0
Bin:     1e-05  <=  0.0001
                Count:   0
Bin:     0.0001  <=  0.001
                Count:   0
Bin:     0.001  <=  0.01
                Count:   0
Bin:     0.01  <=  0.1
                Count:   7
Bin:     0.1  <=  1
                Count:   179
Bin:     1  <=  1.5
                Count:   12
Bin:     1.5  <=  2
                Count:   1
Bin:     2  <=  5
                Count:   0
Bin:     5  <=  10
                Count:   0

Additional context MLX 0.17.3 Pytorch 2.1.2

angeloskath commented 1 month ago

There is no bug here it is just the expected numerical difference after 5 matrix multiplications. The matrices are initialized with very big values which makes the final result in the millions which makes matters worse.

For instance changing the initialization to the common fan in ie initialize from N(0, 1/sqrt(input_dims)) instead of of N(0, 1) then the results pass the check and the output is in a much more reasonable range ~[-0.5, 0.5].

I did the above by changing the line in initialize_numpy_arrays_and_model as follows

np_array = np.random.randn(*param_shape).astype(np.float32) / param_shape[-1]**0.5

I will close the issue but feel free to reopen it if you think I didn't cover it.

sachinraja13 commented 1 month ago

That helps @angeloskath . Many thanks!