microsoft / microxcaling

PyTorch emulation library for Microscaling (MX)-compatible data formats
MIT License
164 stars 21 forks source link

How to set gradient and activation to different formats??? #30

Closed rensushan closed 2 months ago

rensushan commented 2 months ago

I still can not understand which option( w_elem_format_bp, a_elem_format_bp, a_elem_format_bp_ex, a_elem_format_bp_os ) represents gradient? In fact , in the BP process, I wish to set the gradient as E5M2, but keep activation as E4M3. According to your suggestion, it seems that you set all activation/weight/gradient to E5M2, but this is not what I want.

So, how to separately define gradient and activation to different formats ( I mean to set the gradient to E5M2 , activation/weight to E4M3 in the BP phase ) ?? What is the difference between a_elem_format_bp, a_elem_format_bp_ex, and a_elem_format_bp_os?? What is the meaning of postfix bp , ex, os? Which represents gradient and which represents activation?

gakolhe commented 2 months ago

Hi @rensushan As I suggested previously, please refer to the new README. link

If I understood your intentions correctly, then I would set a_elem_format_bp_ex, and a_elem_format_bp_os to FP8 E5M2 a_elem_format_bp can remain in FP8 E4M3

a_elem_format_bp -> this flag is used to quantize stashed activations in b/w pass. a_elem_format_bp_ex -> In case of linear layer, this flag is used to quantize the incoming gradient which is used to calculate weight gradient. a_elem_format_bp_os -> This flag is used to quantize the incoming gradient which is used to calculate the outgoing gradient.

The image in the README visualizes this and should be helpful. Feel free to open the ticket if required.