ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.59k stars 1.02k forks source link

[BUG] How to improve the training perfomance in MLX compare to pytorch and keras ? #1542

Closed thegodone closed 1 week ago

thegodone commented 1 month ago

Describe the bug I have a major issue that I have seen in lot of the cases on other trial. The MLX training gives rarely a good performance while for torch and keras it is more stable and better. This is really a bottleneck to use MLX, as you need to train 10 to 20 time your model to get a good result while torch and keras are systematically in a good range (rmse : 0.50-0.55). important: models (tf/keras, torch and mlx) have the same number of trainable parameters, and we use the same train, val and test split for the 3 methods).

To Reproduce

run several time the following code the best result is jumping out of the pytorch and tf/keras results https://github.com/thegodone/apple_ai_model/blob/main/AttFP_mlx_faster.ipynb https://github.com/thegodone/apple_ai_model/blob/main/AttFP_torch.ipynb https://github.com/thegodone/apple_ai_model/blob/main/AttFP_tf.ipynb

Expected behavior I don't know if it is weights initialization or optimizer that can cause this huge difference between the 3 packages.

Desktop (please complete the following information): see #1531

thegodone commented 3 weeks ago

Can it be related to the https://github.com/ml-explore/mlx/issues/1153#issuecomment-2128294303 comment ?

import math
from typing import Union, List, Callable
import mlx.core as mx

class Adam(Optimizer):
    def __init__(self, learning_rate: Union[float, Callable], betas: List[float] = [0.9, 0.999], eps: float = 1e-8):
        super().__init__()
        self._maybe_schedule("learning_rate", learning_rate)
        self.betas = betas
        self.eps = eps

    def init_single(self, parameter: mx.array, state: dict):
        state["m"] = mx.zeros_like(parameter)  # Initialize momentum
        state["v"] = mx.zeros_like(parameter)  # Initialize velocity

    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
        lr = self.learning_rate.astype(gradient.dtype)
        beta1, beta2 = self.betas
        eps = self.eps

        # Update biased first moment estimate
        state["m"] = beta1 * state["m"] + (1 - beta1) * gradient
        # Update biased second moment estimate
        state["v"] = beta2 * state["v"] + (1 - beta2) * mx.square(gradient)

        # Bias-corrected estimates (optional depending on application)
        m_hat = state["m"] / (1 - beta1)
        v_hat = state["v"] / (1 - beta2)

        # Parameter update
        return parameter - lr * m_hat / (mx.sqrt(v_hat) + eps)

class AdamW(Adam):
    def __init__(self, learning_rate: Union[float, Callable], betas: List[float] = [0.9, 0.999], eps: float = 1e-8, weight_decay: float = 0.01):
        super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
        self.weight_decay = weight_decay

    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
        lr = self.learning_rate.astype(gradient.dtype)
        # Apply weight decay before applying Adam update
        parameter = parameter * (1 - lr * self.weight_decay)
        # Call the parent Adam's apply_single() for the core Adam update
        return super().apply_single(gradient, parameter, state)
awni commented 2 weeks ago

If you can try either of those and report back, that would be useful to know.

Is it possible to add a parameter for Adam optimizer to be strictly identical to pytorch / tensorflow ?

It's possible..would be good to know if it fixes your issue first though.

thegodone commented 2 weeks ago

here the differences between the 3 trainings I used now the bias corrected AdamW in mlx:

one remark, now we have similar speed between torch and mlx

awni commented 2 weeks ago

The 0 training loss in MLX seems incorrect particularly given the training MSE seems reasonable. I would double check you are averaging the loss in MLX correctly.

Otherwise it mostly looks reasonable.. fine tuning learning rates, warmups, initializations etc could all help.

thegodone commented 2 weeks ago

I will have a synchronize LR scheduler tomorrow to be sure this is not the part that affect the model deviation. Yes I will look at the loss that is strange.

thegodone commented 2 weeks ago

ok I fix the loss and now I have this type of results interestingly pytorch is more efficient than mlx or tensorflow. comparing (1)

thegodone commented 2 weeks ago

I use this LR scheduler now in mlx, one potential issue is that in pytorch/tensorflow the LR scheduler is per epoch while in mlx it is per step, is it possible to have an epoch equivalent ?

def cosineannealingwarmrestartfactor_(initial_lr, restart, decay_steps, warmup_factor, Tmin):
    schedules = []
    boundaries = []  # Boundaries should be one less than schedules
    base_lr = initial_lr
    schedules.append(optim.cosine_decay(initial_lr, decay_steps, Tmin))
    for i in range(restart-1):        
        Tmin *= warmup_factor
        initial_lr*=warmup_factor

        schedules.append(optim.cosine_decay(initial_lr, decay_steps, Tmin))
        boundaries.append(decay_steps*(i+1))

    lr_schedule = optim.join_schedules(schedules, boundaries)
    return lr_schedule
thegodone commented 2 weeks ago

this is the mlx version with "official Adam without bias correction": clearly not performing as the bias correction used in previous posts. comparing (2)

thegodone commented 2 weeks ago

Is there a way to fix a seed for mlx similar to the torch.manual_seed(int) ?

awni commented 2 weeks ago

mx.random.seed

thegodone commented 2 weeks ago

Yes, was expected that was the case.The good accuracy on torch is due to a specific seed… so we can close that issue.Envoyé de mon iPhoneLe 12 nov. 2024 à 15:21, Awni Hannun @.***> a écrit : mx.random.seed

—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you authored the thread.Message ID: @.***>

angeloskath commented 1 week ago

Closing the issue per the OP's last message ie the difference is due to a specific random seed in PT.