microsoft / microxcaling

PyTorch emulation library for Microscaling (MX)-compatible data formats
MIT License
164 stars 21 forks source link

example or docs for getting started and converting an existing model to MX dtypes ala FP6? #5

Closed lessw2020 closed 1 year ago

lessw2020 commented 1 year ago

Very interested in testing out MX FP6 on a model per the paper noting similar results as FP32.

However, I was hoping to see an example of how to convert a given PyTorch model to use FP6, or at least some documentation on how to integrate with a model...is there such an example or docs somewhere to provide some getting started info for people new to MX?

I just see the ops and testing for the ops, but not clear to me how to modify a given model to start using these. Thanks for any assistance!

rizhao-msft commented 1 year ago

Good point. We'll be updating this week with a short guide on how to integrate MX into model code. The short version is that in init.py you see imports of things like Linear, LayerNorm, Conv2d, matmul, etc. These are drop-in replacements for torch.Linear, torch.LayerNorm, and etc. So you just need to swap your torch modules and functions with the MX equivalent and supply an msfp_spec to each module/function.

lessw2020 commented 1 year ago

@rizhao-msft - awesome, thanks for the initial input above.

I was speculating that was the case (set spec, use various drop in replacements) but wasn't sure if for example I only upgraded the MLP and not attn, then would that still work for training, or do I need to move everything to mx fp6 to get the near fp32 results.

Thanks for the input above and will look forward to the guide and will start running later this week with some initial MX FP6 runs.
Also, congrats on these new datatypes - overall a very exciting area and really looking forward to leveraging these new MX dtypes.

rizhao-msft commented 1 year ago

We moved everything to MX for our experiments. The most important pieces are the Linear layers and matmuls in the attention mechanism. The LayerNorm and Softmax are done using individual bfloat16 ops. Lastly the residual adds are also implemented with the simd_add function.

Our goal is to emulate how the hardware would use MX, so all matmul operations, including in attention, needed to be in MX.

rizhao-msft commented 1 year ago

Added integration guide: de7dc148af6ed2fa0c1e6fae1c2c8ce6fa832c5d