microsoft / TransformerCompression

For releasing code related to compression methods for transformers, accompanying our publications
MIT License
354 stars 31 forks source link

Layer fusion with Llama #129

Closed kiucho closed 5 months ago

kiucho commented 5 months ago

Thank you for sharing your research.

I have a question about fusing LayerNorm to RMSNorm and would like to clarify.

SliceGPT works with RMSNorm, so to apply SliceGPT to a specific model, we should change the normalizations in modules to RMSNorm.

To change LayerNorm to RMSNorm, we need to fuse the M matrix to the weight before normalization and fuse the weight of LayerNorm to the weight after normalization. Then, we can change the LayerNorm into the RMSN class, which is defined in modules.py.

However, with Llama, which uses RMSNorm by default, I wanna know why we need to make any changes. It appears that SliceGPT also changes the LlamaRMSNorm class to the RMSN class.

I understand fusing the weight of RMSNorm to the after Norm weight because there is also weight in LlamaRMSNorm, but I do not understand why we should fuse the M matrix to the before Norm weight.

If there is any misunderstanding on my part, please let me know.

Thank you for your reply.

kiucho commented 5 months ago

Except Embedding layer weight, seems like M matrix is fused based on should_bake_mean_into_linear which becomes False for Llama, True for OPT.

jameshensman commented 5 months ago

For llama, you’re right there’s no need to mean-subtract. Are we subtracting the mean from embeddings in Llama? If so, that’s a bug.

James

On Mon, 8 Apr 2024 at 12:00, KIUCHO @.***> wrote:

Except Embedding layer weight, seems like M matrix is fused based on should_bake_mean_into_linear which becomes False for Llama, True for OPT.

— Reply to this email directly, view it on GitHub https://github.com/microsoft/TransformerCompression/issues/129#issuecomment-2042454668 or unsubscribe https://github.com/notifications/unsubscribe-auth/AABFWD7LEGLMNRRXIILFXCLY4J2F5BFKMF2HI4TJMJ2XIZLTSOBKK5TBNR2WLJDUOJ2WLJDOMFWWLO3UNBZGKYLEL5YGC4TUNFRWS4DBNZ2F6YLDORUXM2LUPGBKK5TBNR2WLJDUOJ2WLJDOMFWWLLTXMF2GG2C7MFRXI2LWNF2HTAVFOZQWY5LFUVUXG43VMWSG4YLNMWVXI2DSMVQWIX3UPFYGLLDTOVRGUZLDORPXI6LQMWWES43TOVSUG33NNVSW45FGORXXA2LDOOJIFJDUPFYGLKTSMVYG643JORXXE6NFOZQWY5LFVE3TAMBUHAYTIOBUQKSHI6LQMWSWS43TOVS2K5TBNR2WLKRSGIZTANRVG4YDCN5HORZGSZ3HMVZKMY3SMVQXIZI . You are receiving this email because you are subscribed to this thread.

Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub .