linkedin / Liger-Kernel

Efficient Triton Kernels for LLM Training
https://arxiv.org/pdf/2410.10989
BSD 2-Clause "Simplified" License
3.38k stars 190 forks source link

Restore monkey patched modules #232

Closed austin362667 closed 1 month ago

austin362667 commented 1 month ago

Summary

Fixes https://github.com/linkedin/Liger-Kernel/issues/176

There are several ways to restore a monkey-patched library in Python, including using context managers, decorators, pytest fixtures, or reloading the entire module.

This PR focuses on reverting monkey-patched modules when with_liger is disabled in convergence tests.

import target.module
importlib.reload(target.module)

These changes simplify the process of resetting the affected patched library and help prevent unintended side effects. And it's easier than manually reassigning functions anyway.

Follow-up

If this PR resolves the https://github.com/linkedin/Liger-Kernel/issues/176, it might introduce other value mismatch problems. We may need to adjust the convergence tolerance accordingly. For instance,

______________________ test_mini_model[mini_mixtral-32-0.0001-dtype11-1e-08-1e-05-0.1-1e-05-0.01-1e-05] _______________________

model_name = 'mini_mixtral', num_steps = 32, lr = 0.0001, dtype = torch.bfloat16, loss_atol = 1e-08, loss_rtol = 1e-05
logits_atol = 0.1, logits_rtol = 1e-05, param_atol = 0.01, param_rtol = 1e-05

    @pytest.mark.parametrize(
        "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol",
        [
            ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 2e-5, 1e-4, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_llama3",
                32,
                1e-4,
                torch.bfloat16,
                5e-3,
                1e-5,
                1e-2,
                1e-5,
                1e-2,
                1e-5,
                marks=pytest.mark.skipif(
                    not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                ),
            ),
            ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_qwen2",
                32,
                1e-4,
                torch.bfloat16,
                1e-8,
                1e-5,
                1e-2,
                1e-5,
                1e-2,
                1e-5,
                marks=pytest.mark.skipif(
                    not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                ),
            ),
            pytest.param(
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.float32,
                1e-8,
                1e-5,
                5e-3,
                1e-5,
                5e-3,
                1e-5,
                marks=pytest.mark.skipif(
                    not QWEN2_VL_AVAILABLE,
                    reason="Qwen2-VL not available in this version of transformers",
                ),
            ),
            pytest.param(
                "mini_qwen2_vl",
                32,
                1e-4,
                torch.bfloat16,
                1e-8,
                1e-5,
                1e-2,
                1e-5,
                1e-2,
                1e-5,
                marks=[
                    pytest.mark.skipif(
                        not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                    ),
                    pytest.mark.skipif(
                        not QWEN2_VL_AVAILABLE,
                        reason="Qwen2-VL not available in this version of transformers",
                    ),
                ],
            ),
            ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_phi3",
                32,
                1e-4,
                torch.bfloat16,
                1e-8,
                1e-5,
                1e-2,
                1e-5,
                1e-2,
                1e-5,
                marks=pytest.mark.skipif(
                    not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                ),
            ),
            ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_mistral",
                32,
                1e-4,
                torch.bfloat16,
                1e-8,
                1e-5,
                1e-2,
                1e-5,
                1e-2,
                1e-5,
                marks=pytest.mark.skipif(
                    not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                ),
            ),
            ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
            pytest.param(
                "mini_mixtral",
                32,
                1e-4,
                torch.bfloat16,
                1e-8,
                1e-5,
                1e-1,
                1e-5,
                1e-2,
                1e-5,
                marks=pytest.mark.skipif(
                    not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                ),
            ),
            # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
            ("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_gemma1",
                32,
                1e-4,
                torch.bfloat16,
                1e-2,
                1e-4,
                2e-1,
                1e-5,
                1e-2,
                1e-5,
                marks=pytest.mark.skipif(
                    not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                ),
            ),
            ("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_gemma1.1",
                32,
                1e-4,
                torch.bfloat16,
                1e-2,
                1e-4,
                2e-1,
                1e-5,
                1e-2,
                1e-5,
                marks=pytest.mark.skipif(
                    not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                ),
            ),
            ("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
            pytest.param(
                "mini_gemma2",
                32,
                1e-4,
                torch.bfloat16,
                1e-2,
                1e-4,
                2e-1,
                1e-5,
                1e-2,
                1e-5,
                marks=pytest.mark.skipif(
                    not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
                ),
            ),
        ],
    )
    def test_mini_model(
        model_name,
        num_steps,
        lr,
        dtype,
        loss_atol,
        loss_rtol,
        logits_atol,
        logits_rtol,
        param_atol,
        param_rtol,
    ):
        # Non-liger models should be initialized and tested first to avoid the module being overridden

        expected_output = run_mini_model(
            model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr
        )

        actual_output = run_mini_model(
            model_name=model_name, num_steps=num_steps, dtype=dtype, lr=lr, with_liger=True
        )

        # Compare every step of the loss
>       assert_verbose_allclose(
            torch.tensor([expected_output["loss"]]),
            torch.tensor([actual_output["loss"]]),
            atol=loss_atol,
            rtol=loss_rtol,
        )

test/convergence/test_mini_models_no_logits.py:594:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

tensor1 = tensor([[10.9374,  7.0162,  4.8162,  3.2886,  2.4254,  1.9993,  1.6753,  1.7743,
          1.4267,  1.4742,  1.4458,  ...  1.0867,  0.8353,  0.9219,  0.8796,
          0.8610,  0.8183,  0.7559,  0.8734,  0.9647,  0.7261,  1.0963,  0.8136]])
tensor2 = tensor([[10.9383,  7.0052,  4.8145,  3.3515,  2.3853,  2.0174,  1.6758,  1.7778,
          1.4256,  1.4737,  1.4442,  ...  1.0870,  0.8346,  0.9222,  0.8817,
          0.8610,  0.8181,  0.7554,  0.8736,  0.9671,  0.7263,  1.0966,  0.8171]])
rtol = 1e-05, atol = 1e-08, max_print = 5

    def assert_verbose_allclose(tensor1, tensor2, rtol=1e-05, atol=1e-08, max_print=5):
        """
        Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.

        Parameters:
        tensor1 (torch.Tensor): First tensor to compare.
        tensor2 (torch.Tensor): Second tensor to compare.
        rtol (float): Relative tolerance.
        atol (float): Absolute tolerance.
        max_print (int): Maximum number of mismatched elements to print.

        Raises:
        AssertionError: If the tensors are not all close within the given tolerance.
        """
        # Check if the shapes of the tensors match
        if tensor1.shape != tensor2.shape:
            raise AssertionError("Input tensors must have the same shape.")

        # Calculate the difference between the tensors
        diff = torch.abs(tensor1 - tensor2)

        # Determine the tolerance
        tolerance = atol + rtol * torch.abs(tensor2)

        # Find mismatched elements
        mismatched = diff > tolerance

        # Get the indices of mismatched elements
        mismatched_indices = torch.nonzero(mismatched)

        # Count the number of mismatched elements
        num_mismatched = mismatched.sum().item()

        # Check if all elements are close
        all_close = num_mismatched == 0

        # Raise AssertionError with detailed information if there are mismatches
        if not all_close and num_mismatched > 1:
            mismatch_details = [f"Number of mismatched elements: {num_mismatched}"]
            print_count = min(max_print, num_mismatched)
            for index in mismatched_indices[:print_count]:
                i = tuple(index.tolist())
                mismatch_details.append(
                    f"Mismatch at index {i}: tensor1[{i}] = {tensor1[i]}, tensor2[{i}] = {tensor2[i]}"
                )
            if num_mismatched > max_print:
                mismatch_details.append(
                    f"... and {num_mismatched - max_print} more mismatched elements."
                )

>           raise AssertionError("\n".join(mismatch_details))
E           AssertionError: Number of mismatched elements: 32
E           Mismatch at index (0, 0): tensor1[(0, 0)] = 10.937411308288574, tensor2[(0, 0)] = 10.938319206237793
E           Mismatch at index (0, 1): tensor1[(0, 1)] = 7.016175270080566, tensor2[(0, 1)] = 7.0052409172058105
E           Mismatch at index (0, 2): tensor1[(0, 2)] = 4.8161821365356445, tensor2[(0, 2)] = 4.814478397369385
E           Mismatch at index (0, 3): tensor1[(0, 3)] = 3.288573980331421, tensor2[(0, 3)] = 3.3514533042907715
E           Mismatch at index (0, 4): tensor1[(0, 4)] = 2.425377368927002, tensor2[(0, 4)] = 2.3853368759155273
E           ... and 27 more mismatched elements.

test/utils.py:83: AssertionError
---------------------------------------------------- Captured stdout call -----------------------------------------------------
Liger kernel patches have been reverted.
Step 0, Loss: 10.937411308288574
Step 1, Loss: 7.016175270080566
Step 2, Loss: 4.8161821365356445
Step 3, Loss: 3.288573980331421
Step 4, Loss: 2.425377368927002
Step 5, Loss: 1.999261736869812
Step 6, Loss: 1.675323486328125
Step 7, Loss: 1.7742501497268677
Step 8, Loss: 1.4266773462295532
Step 9, Loss: 1.474155068397522
Step 10, Loss: 1.4458246231079102
Step 11, Loss: 1.1540931463241577
Step 12, Loss: 1.3520232439041138
Step 13, Loss: 1.311019778251648
Step 14, Loss: 1.219789981842041
Step 15, Loss: 1.3071205615997314
Step 16, Loss: 1.2621395587921143
Step 17, Loss: 1.3119654655456543
Step 18, Loss: 1.1880946159362793
Step 19, Loss: 1.2357648611068726
Step 20, Loss: 1.0867037773132324
Step 21, Loss: 0.8352738618850708
Step 22, Loss: 0.9218576550483704
Step 23, Loss: 0.879619836807251
Step 24, Loss: 0.8610480427742004
Step 25, Loss: 0.8182975053787231
Step 26, Loss: 0.7558884620666504
Step 27, Loss: 0.8734312057495117
Step 28, Loss: 0.9646832942962646
Step 29, Loss: 0.7261283993721008
Step 30, Loss: 1.0963469743728638
Step 31, Loss: 0.8136419057846069
Step 0, Loss: 10.938319206237793
Step 1, Loss: 7.0052409172058105
Step 2, Loss: 4.814478397369385
Step 3, Loss: 3.3514533042907715
Step 4, Loss: 2.3853368759155273
Step 5, Loss: 2.0173795223236084
Step 6, Loss: 1.6758073568344116
Step 7, Loss: 1.777788519859314
Step 8, Loss: 1.4255633354187012
Step 9, Loss: 1.4737187623977661
Step 10, Loss: 1.4441752433776855
Step 11, Loss: 1.1313129663467407
Step 12, Loss: 1.3452619314193726
Step 13, Loss: 1.299330234527588
Step 14, Loss: 1.2130300998687744
Step 15, Loss: 1.3027563095092773
Step 16, Loss: 1.2582926750183105
Step 17, Loss: 1.3112103939056396
Step 18, Loss: 1.1886006593704224
Step 19, Loss: 1.235780954360962
Step 20, Loss: 1.0869864225387573
Step 21, Loss: 0.8346381187438965
Step 22, Loss: 0.9222478866577148
Step 23, Loss: 0.8816985487937927
Step 24, Loss: 0.8609745502471924
Step 25, Loss: 0.81810462474823
Step 26, Loss: 0.7554237246513367
Step 27, Loss: 0.8736312389373779
Step 28, Loss: 0.967080295085907
Step 29, Loss: 0.7262533903121948
Step 30, Loss: 1.0965538024902344
Step 31, Loss: 0.8171141147613525
=================================================== short test summary info ===================================================
FAILED test/convergence/test_mini_models.py::test_mini_model[mini_gemma1-32-0.0001-dtype1-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27
FAILED test/convergence/test_mini_models.py::test_mini_model[mini_gemma1.1-32-0.0001-dtype3-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27
FAILED test/convergence/test_mini_models.py::test_mini_model[mini_llama3-32-0.0001-dtype7-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 29
FAILED test/convergence/test_mini_models.py::test_mini_model[mini_mistral-32-0.0001-dtype9-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27
FAILED test/convergence/test_mini_models.py::test_mini_model[mini_qwen2-32-0.0001-dtype11-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 31
FAILED test/convergence/test_mini_models.py::test_mini_model[mini_phi3-32-0.0001-dtype13-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 31
FAILED test/convergence/test_mini_models_no_logits.py::test_mini_model[mini_qwen2-32-0.0001-dtype3-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27
FAILED test/convergence/test_mini_models_no_logits.py::test_mini_model[mini_phi3-32-0.0001-dtype7-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 31
FAILED test/convergence/test_mini_models_no_logits.py::test_mini_model[mini_mistral-32-0.0001-dtype9-1e-08-1e-05-0.01-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 27
FAILED test/convergence/test_mini_models_no_logits.py::test_mini_model[mini_mixtral-32-0.0001-dtype11-1e-08-1e-05-0.1-1e-05-0.01-1e-05] - AssertionError: Number of mismatched elements: 32
============================== 10 failed, 20 passed, 4 skipped, 4 warnings in 226.37s (0:03:46) ===============================
make: *** [Makefile:23: test-convergence] Error 1

Testing Done

austin362667 commented 1 month ago

Hi Byron, I’m considering what tests we can add to prevent similar mistakes, such as testing against itself, from occurring in the future. I’d appreciate your feedback on this. Thanks in advance!

tyler-romero commented 1 month ago

Thanks! Can you also update to support this new convergence test as well? https://github.com/linkedin/Liger-Kernel/blob/main/test/convergence/test_mini_models_multimodal.py

austin362667 commented 1 month ago

@tyler-romero Sure, done!

ByronHsu commented 1 month ago

LGTM. I have opened a PR in your fork to relax the tol. https://github.com/austin362667/Liger-Kernel/pull/1. Please consolidate the changes