linkedin / Liger-Kernel

Efficient Triton Kernels for LLM Training
https://arxiv.org/pdf/2410.10989
BSD 2-Clause "Simplified" License
3.5k stars 208 forks source link

mllama patch modifies nn.LayerNorm globally #315

Open tyler-romero opened 1 month ago

tyler-romero commented 1 month ago

🐛 Describe the bug

Instead of only patching the transformers mllama module (transformers.models.mllama.modeling_mllama), apply_liger_kernel_to_mllama modifies torch.nn.LayerNorm globally.

The issue is here.

The fix would be to: (1) Not patch LayerNorm in Liger by assigning to modeling_mllama.nn.LayerNorm (2) Change transformers.models.mllama.modeling_mllama to not use from torch import nn and to instead just import layernorm like from torch.nn import LayerNorm (3) instead patch layernorm in Liger by assigning to modeling_mllama.LayerNorm

Reproduce

pip install transformers==4.45 liger-kernel-nightly
from liger_kernel.transformers import apply_liger_kernel_to_mllama
from torch import nn

apply_liger_kernel_to_mllama()
print(nn.LayerNorm)
<class 'liger_kernel.transformers.layer_norm.LigerLayerNorm'>

Versions

Environment Report:

Operating System: Linux-6.1.85+-x86_64-with-glibc2.35 Python version: 3.10.12 PyTorch version: 2.4.1+cu121 CUDA version: Not available Triton version: 3.1.0 Transformers version: 4.45.0

ByronHsu commented 1 month ago

curious does it require to change transformer source code? i think we can maybe raise a request

tyler-romero commented 1 month ago

Yeah the proposed fix would require a change to transformers unfortunately. How mllama was implemented differs very slightly from the conventions in other transformers modeling files.

ByronHsu commented 1 month ago

sounds good. let us try to send a PR there