Closed dx-dtran closed 1 month ago
@dx-dtran i am currently looking into gradient accumulation and found your issue here. Can you maybe share some insights regarding your implementation and if it works properly? Thanks already for your code, very helpful π
@vfsunny thank you! yes, the implementation for gradient accumulation I wrote back then (shown above in the issue description) worked.
Essentially the way it worked was:
acc + new * (1.0 / num_grad_accumulation_steps)
). mx.eval
on the storage array during the mini-batch step to free up memory
Here's a test file I wrote to verify its correctness. It's been several months since I've modified my code so perhaps MLX has implemented a better way to do gradient accumulation since then
thank you so much for MLX, it rocks!!
I've been implementing GPT-2 training in MLX. Please feel free to check it out here: https://github.com/dx-dtran/gpt2-mlx
I had a few questions regarding gradient accumulation, learning rate decay, and float16 training:
I also tried implementing learning rate decay, but had to define a custom optimizer because I couldn't find whether the AdamW
learning_rate
attribute was exposed with a setter. Is there already alearning_rate
setter?Finally, I tried implementing float16 training but couldn't find a way to set a custom
dtype
for thenn.Linear
layers. But I'm sure I have to patiently wait for this right? π Because I read you guys are focusing on implementing quantizationIs merging something like my training script into the main repo something you guys would be interested in? If so I'd be happy to help!
Here are some results I was able to get in full float32 precision: