Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.07k stars 60 forks source link

cuDNN executor: No valid engine configs for MUL_Reduction_MUL_Matmul_MUL_ADD_SUB_EXP_Reshape_Matmul_Matmul_MUL_SUB_MUL_Reduction_MUL_Reshape_Matmul_Reshape_Matmul_ #625

Closed t-vi closed 1 week ago

t-vi commented 1 week ago

A quantization test that I want to merge does not work with the cuDNN executor:

__________________________________________________________________________________________________________________ test_quantization ___________________________________________________________________________________________________________________

    @requiresCUDA
    def test_quantization():
        from thunder.tests import litgpt_model
        from lightning.fabric.plugins import BitsandbytesPrecision

        config = litgpt_model.Config.from_name("llama2-like")
        with torch.device("cuda"):
            model_fp_reference = litgpt_model.GPT(config).to(torch.bfloat16)

        import lightning as L
        plugins = BitsandbytesPrecision("nf4", torch.bfloat16)
        fabric = L.Fabric(devices=1, precision=None, plugins=plugins)
        with fabric.init_module(empty_init=True):
            model = litgpt_model.GPT(config)

        with fabric.init_tensor():
            # set the max_seq_length to limit the memory usage to what we need
            model.max_seq_length = 20
            # enable the kv cache
            model.set_kv_cache(batch_size=1)
        model.eval()
        model.requires_grad_(False)
        model = fabric.setup_module(model)

        model.load_state_dict(model_fp_reference.state_dict())

        x = torch.randint(1, 255, (1,10), device="cuda")
        input_pos = torch.arange(10, device="cuda")
        logits_expected = model(x, input_pos)

        from thunder.transforms.quantization import BitsAndBytesLinearQuant4bit, get_bitsandbytes_executor
        bitsandbytes_executor = get_bitsandbytes_executor()

        model_fp_reference.set_kv_cache(1, device="cuda", dtype=torch.bfloat16)
        model_fp_reference.max_seq_length = 20
        model_fp_reference.requires_grad_(False)

        jm = thunder.jit(model_fp_reference, executors=(bitsandbytes_executor, *thunder.get_default_executors()), early_transforms=[BitsAndBytesLinearQuant4bit()])

>       logits_thunder = jm(x, input_pos)

thunder/tests/test_networks.py:247: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562: in _call_impl
    return forward_call(*args, **kwargs)
thunder/core/module.py:60: in forward
    res = self._forward_fn(*args, **kwargs)
thunder/__init__.py:658: in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
thunder/__init__.py:217: in cache_info_wrapper
    res = fn(*args, **kwargs)
thunder/__init__.py:594: in get_computation_and_inputs
    extraces = transform_for_execution(
thunder/common.py:640: in transform_for_execution
    extrace = executors.passes.transform_for_execution(dce_trace, executors_list)
thunder/executors/passes.py:147: in transform_for_execution
    extrace = _transform_for_operator_executor_execution(trace, executors_list)
thunder/executors/passes.py:112: in _transform_for_operator_executor_execution
    extrace = transforms.visitor_transform(trace, visit_)
thunder/core/transforms.py:369: in visitor_transform
    visit_type = visit(bsym)
thunder/executors/passes.py:97: in visit_
    result: None | bool = visit_helper_(bsym)
thunder/executors/passes.py:67: in visit_helper_
    if (isinstance(ex, OperatorExecutor) and ex.can_execute(bsym)) or (
thunder/extend/__init__.py:98: in can_execute
    return impl.checker(*bsym.args, **bsym.kwargs)
thunder/executors/cudnnex.py:397: in _cudnn_sdpa_checker
    _make_cudnn_sdpa_backward_graph(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

query = CudnnTensorAttributes(size=(1, 4, 10, 16), stride=(640, 160, 16, 1), dtype=bfloat16, device_index=0), key = CudnnTensorAttributes(size=(1, 4, 4096, 16), stride=(262144, 65536, 16, 1), dtype=bfloat16, device_index=0)
value = CudnnTensorAttributes(size=(1, 4, 4096, 16), stride=(262144, 65536, 16, 1), dtype=bfloat16, device_index=0), attn_mask = CudnnTensorAttributes(size=(1, 1, 10, 4096), stride=(40960, 40960, 4096, 1), dtype=bfloat16, device_index=0)
dropout_p = 0.0, is_causal = False, grad_query_stride = (640, 160, 16, 1), grad_key_stride = (262144, 65536, 16, 1), grad_value_stride = (262144, 65536, 16, 1)

    def _make_cudnn_sdpa_backward_graph(
        query, key, value, attn_mask, dropout_p, is_causal, grad_query_stride, grad_key_stride, grad_value_stride
    ):
        b, h, s_q, _ = query.size
        _, _, _, d_v = value.size

        graph = cudnn.pygraph(
            io_data_type=torch_to_cudnn_dtype(query.dtype),
            intermediate_data_type=cudnn.data_type.FLOAT,
            compute_data_type=cudnn.data_type.FLOAT,
            handle=_get_cudnn_handle(query.device_index),
        )

        Q = graph.tensor(name="Q", dim=query.size, stride=query.stride, data_type=torch_to_cudnn_dtype(query.dtype))
        K = graph.tensor(name="K", dim=key.size, stride=key.stride, data_type=torch_to_cudnn_dtype(key.dtype))
        V = graph.tensor(name="V", dim=value.size, stride=value.stride, data_type=torch_to_cudnn_dtype(value.dtype))

        dim_o = (b, h, s_q, d_v)
        stride_o = (h * s_q * d_v, s_q * d_v, d_v, 1)
        O = graph.tensor(name="O", dim=dim_o, stride=stride_o, data_type=torch_to_cudnn_dtype(query.dtype))
        dO = graph.tensor_like(O)

        dim_stats = (b, h, s_q, 1)
        stride_stats = (h * s_q, s_q, 1, 1)
        Stats = graph.tensor(name="Stats", dim=dim_stats, stride=stride_stats, data_type=cudnn.data_type.FLOAT)

        Bias = None
        dBias = None
        if attn_mask is not None:
            Bias = graph.tensor(
                name="bias", dim=attn_mask.size, stride=attn_mask.stride, data_type=torch_to_cudnn_dtype(attn_mask.dtype)
            )
            dBias = graph.tensor_like(Bias)

        scalar_dim_stride = tuple([1] * len(query.size))
        dropout_tuple = None
        Seed = None
        Offset = None
        if dropout_p != 0.0:
            Seed = graph.tensor(
                name="Seed", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT32
            )
            Offset = graph.tensor(
                name="Offset", dim=scalar_dim_stride, stride=scalar_dim_stride, data_type=cudnn.data_type.INT32
            )
            dropout_tuple = (dropout_p, Seed, Offset)

        Attn_scale = graph.tensor(
            name="Attn_scale",
            dim=scalar_dim_stride,
            stride=scalar_dim_stride,
            data_type=cudnn.data_type.FLOAT,
            is_pass_by_value=True,
        )

        dQ, dK, dV = graph.scaled_dot_product_flash_attention_backward(
            q=Q,
            k=K,
            v=V,
            o=O,
            dO=dO,
            stats=Stats,
            attn_scale=Attn_scale,
            bias=Bias,
            dBias=dBias,
            use_causal_mask=is_causal,
            dropout=dropout_tuple,
        )

        dQ.set_output(True).set_dim(query.size).set_stride(grad_query_stride).set_data_type(
            torch_to_cudnn_dtype(query.dtype)
        )
        dK.set_output(True).set_dim(key.size).set_stride(grad_key_stride).set_data_type(torch_to_cudnn_dtype(key.dtype))
        dV.set_output(True).set_dim(value.size).set_stride(grad_value_stride).set_data_type(
            torch_to_cudnn_dtype(value.dtype)
        )

        cache_key = graph.key()
        # If a built graph does not exist in cache already, make one and place it in
        if cache_key not in _cudnnex_cache:
>           graph.build([cudnn.heur_mode.A])
E           RuntimeError: No valid engine configs for MUL_Reduction_MUL_Matmul_MUL_ADD_SUB_EXP_Reshape_Matmul_Matmul_MUL_SUB_MUL_Reduction_MUL_Reshape_Matmul_Reshape_Matmul_

thunder/executors/cudnnex.py:509: RuntimeError
t-vi commented 1 week ago

This does not seem to happen on the CI, so it could be related to my funny setup...

t-vi commented 1 week ago

was fixed by upgrade to cudnn frontend 1.5.1