ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

[Feature] Add a `clip_grad_norm` to `mlx.optimizers` #1040

Closed awni closed 2 weeks ago

awni commented 3 weeks ago

It's relatively simple to implement, but maybe worth adding since it's also quite common. Here's a possible implementation:

import mlx.core as mx
from mlx.utils import tree_map

def clip_grad_norm(grads, max_norm):
    norm = mx.array(0.0)
    def accumulate(g):
        nonlocal norm
        norm += g.square().sum()

    tree_map(accumulate, grads)
    norm = mx.sqrt(norm)

    clip = lambda g: mx.where(norm < max_norm, g, g * max_norm / (norm + 1e-6))
    grads = tree_map(clip, grads)
    return grads, norm

We could also add a tree_reduce which would simplify it a bit.