ml-explore / mlx-examples

Examples in the MLX framework
MIT License
5.84k stars 829 forks source link

GPT-2 training examples: Gradient accumulation, learning rate decay, fp16 training #111

Closed dx-dtran closed 1 month ago

dx-dtran commented 9 months ago

thank you so much for MLX, it rocks!!

I've been implementing GPT-2 training in MLX. Please feel free to check it out here: https://github.com/dx-dtran/gpt2-mlx

I had a few questions regarding gradient accumulation, learning rate decay, and float16 training:

  1. I implemented gradient accumulation in the following way. Is this proper, or is there a better way to do it?
# outside training loop:
# allocate memory to accumulate gradients
# same shape as the gradients array (which is the same as the model parameters)

accumulated_grads = tree_map(
    lambda x: mx.zeros_like(x), model.parameters()
)

        # inside training loop:
        # accumulate the gradients by adding to the pre-allocated accumulator array

        accumulated_grads = tree_map(
            lambda acc, new: acc + new * (1.0 / num_grad_accumulation_steps),
            accumulated_grads,
            grads,
        )

        # evaluate the grads immediately each mini-batch step 
        # so we don't build up too much memory for when we eventually update the model parameters

        tree_map(
            lambda grad: mx.eval(grad),
            accumulated_grads,
        )
  1. I also tried implementing learning rate decay, but had to define a custom optimizer because I couldn't find whether the AdamW learning_rate attribute was exposed with a setter. Is there already a learning_rate setter?

  2. Finally, I tried implementing float16 training but couldn't find a way to set a custom dtype for the nn.Linear layers. But I'm sure I have to patiently wait for this right? πŸ˜„ Because I read you guys are focusing on implementing quantization

Is merging something like my training script into the main repo something you guys would be interested in? If so I'd be happy to help!

Here are some results I was able to get in full float32 precision:

Hardware Model Batch Size Grad Accum Steps Time per full iteration
M1 Max 64GB GPT-2 124M 2 1 0.6 seconds
M1 Max 64GB GPT-2 124M 12 1 4 seconds
M1 Max 64GB GPT-2 124M 12 40 142 seconds
M1 Max 64GB GPT-2 124M 16 1 8.13 seconds
M1 Max 64GB GPT-2 XL 1.5B 2 1 28 seconds
M1 Pro 16GB GPT-2 124M 1 4 2.4 seconds
M1 Pro 16GB GPT-2 124M 3 4 10.5 seconds
M1 Pro 16GB GPT-2 124M 2 240 270 seconds
M1 Pro 16GB GPT-2 124M 2 1 1.35 seconds
vfsunny commented 5 months ago

@dx-dtran i am currently looking into gradient accumulation and found your issue here. Can you maybe share some insights regarding your implementation and if it works properly? Thanks already for your code, very helpful πŸ™

dx-dtran commented 1 month ago

@vfsunny thank you! yes, the implementation for gradient accumulation I wrote back then (shown above in the issue description) worked.

Essentially the way it worked was:

Here's a test file I wrote to verify its correctness. It's been several months since I've modified my code so perhaps MLX has implemented a better way to do gradient accumulation since then