ml-explore / mlx-examples

Examples in the MLX framework
MIT License
5.91k stars 837 forks source link

feature : implementing BitNet #512

Closed thegodone closed 7 months ago

thegodone commented 7 months ago

Can you add an example of BitNet from Microsoft : https://github.com/kyegomez/BitNet ?

mzbac commented 7 months ago

I would love to see the mlx example for BitNet as well, but I would be very cautious about using references from unofficial implementations, especially those from keygomez. Just a heads up: https://www.reddit.com/r/LocalLLaMA/comments/15spxn3/potential_scammer_on_github/

awni commented 7 months ago

I don't think we are going to add this to MLX examples. It's still a bit niche. I would love to see a community contribution though!

RonanKMcGovern commented 6 months ago

@awni would you be able to point me in the right direction for how I would think about doing this? Basically the key is being able to support the linear layer instead being either binary or ternary.

awni commented 6 months ago

@RonanKMcGovern you have two options:

Maybe you could say more about your goals here though.. in either case training will probably be a lot slower (but the first case would be way slower).

RonanKMcGovern commented 6 months ago

Thanks Awni, probably my goals were ill-conceived.

Seeing the BitNet and 1.58 papers, I had thought there could be merit - both for a) reducing VRAM and b) reducing FLOPS - in using smaller 1-2 bit kernels.

However, it appears that:

  1. Hardware doesn't support less than fp8 (although Blackwell will support fp4). So you're always up/down casting during inference, which costs time.
  2. Training actually isn't stable in any of these smaller formats. Even Nvidia TransformerEngine upcasts a lot. So there isn't so much gain on training speeds (maybe a tiny bit in forward pass).
  3. So one is left with two options: a. Quantize a pre-trained model to 1/1.58 bits to save on VRAM - but quality will be bad doing just that, and further training will be required. b. Train a 1/1.58 bit model from scratch. But who is going to do that at the 7B scale - because it's possibly slower than training in bf16... (perhaps even if the results of BitNet do scale and quality is good in 1/1.58 format).

On Wed, Mar 20, 2024 at 5:03 AM Awni Hannun @.***> wrote:

@RonanKMcGovern https://github.com/RonanKMcGovern you have two options:

  • Simulate the bitnet ops with casting and quantization / dequantization before matmuls.
  • Implement the quantized kernels themselves with custom gradients

Maybe you could say more about your goals here though.. in either case training will probably be a lot slower (but the first case would be way slower).

— Reply to this email directly, view it on GitHub https://github.com/ml-explore/mlx-examples/issues/512#issuecomment-2008664828, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASVG6CWKFAAFSAYIUW4D4CTYZEKBJAVCNFSM6AAAAABEBQMXQGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMBYGY3DIOBSHA . You are receiving this because you were mentioned.Message ID: @.***>

awni commented 6 months ago

Yea I agree with your assessment