Closed awni closed 2 weeks ago
It's relatively simple to implement, but maybe worth adding since it's also quite common. Here's a possible implementation:
import mlx.core as mx from mlx.utils import tree_map def clip_grad_norm(grads, max_norm): norm = mx.array(0.0) def accumulate(g): nonlocal norm norm += g.square().sum() tree_map(accumulate, grads) norm = mx.sqrt(norm) clip = lambda g: mx.where(norm < max_norm, g, g * max_norm / (norm + 1e-6)) grads = tree_map(clip, grads) return grads, norm
We could also add a tree_reduce which would simplify it a bit.
tree_reduce
It's relatively simple to implement, but maybe worth adding since it's also quite common. Here's a possible implementation:
We could also add a
tree_reduce
which would simplify it a bit.