ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.45k stars 825 forks source link

Add Parameter Group Support to MLX Framework #331

Open m0saan opened 4 months ago

m0saan commented 4 months ago

Issue Description:

Feature Request

Summary: I propose adding support for parameter groups in MLX to enhance the flexibility and customization of model optimization.

Details: The addition of parameter groups would enable users to group and apply different optimization configurations to specific subsets of model parameters. This is a common feature in many deep learning frameworks and can significantly improve the efficiency of training and fine-tuning models.

Expected Behavior:

Motivation:

Example Usage:

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim

model = nn.Sequential([...])

optimizer = optim.SGD(learning_rate=0.01)
parameter_groups = [
    {'params': model.layer1.parameters(), 'lr': 0.001},
    {'params': model.layer2.parameters(), 'lr': 0.005},
    {'params': model.layer3.parameters(), 'lr': 0.01},
]

optimizer.add_param_groups(parameter_groups)
m0saan commented 4 months ago

@awni your thoughts please!

awni commented 4 months ago

Thanks @m0saan. I'm not sure we need parameter groups yet. Let's keep this issue open but I would mark it as low priority until we have reason to observe otherwise.

In MLX it's a lot easier to have multiple optimizers each working on a subset of the model since things are a bit more decoupled than in PyTorch. So this is an instance where the added functionality doesn't make as much sense for us.