ELS-RD / kernl

Kernl lets you run PyTorch transformer models several times faster on GPU with a single line of code, and is designed to be easily hackable.
http://www.kernl.ai
Apache License 2.0
1.53k stars 95 forks source link

Question regarding larger T5 model support #188

Closed caffeinetoomuch closed 1 year ago

caffeinetoomuch commented 1 year ago

Hey, I was checking out kernl and was folloing through this notebook: https://github.com/ELS-RD/kernl/blob/main/tutorial/t5%20e2e.ipynb

This seems to work well with t5-small, but I was getting memory allocation error(MemoryError: std::bad_alloc) for t5-3b. Do you have any idea what I could be setting wrong? or Is t5 only supported for smaller checkpoints?

Also I would like to ask if longt5 could be applicable, meaning do we need to add replacement pattern to make it work?

jonathlela commented 1 year ago

Sadly I don't have the necessary hardware to test thet5-3b model I can't really answer this point. But for the second one, by looking at the huggingface code I think the actual optimization won't be applicable as it uses not the same primitives to compute attention as the ones we use to match the attention pattern (torch.einsum for example) . We'll have to make a specific pattern for it.

jonathlela commented 1 year ago

Actually for the decoder part there is nothing to do as it's the same pattern as the t5 base implementation for self attention. For the encoder part, we don't cover local self attention and transient global self attention yet.

caffeinetoomuch commented 1 year ago

Sadly I don't have the necessary hardware to test thet5-3b model I can't really answer this point.

@jonathlela Will it help if I post specific stacktrace or error messages? It seems T5 generation notebook example works fine up to t5-large, but starts to give errors from t5-3b Could there be anything I can try to fix it?

pommedeterresautee commented 1 year ago

Thanks @ice-americano for the report. It appears to be a bug with dynamo (PyTorch 2.0) as I have been able to reproduce with the eager compiler (aka just apply dynamo but do not replace any kernel.

In src/kernl/model_optimization.py I used:

    @torchdynamo.optimize('eager')  # <- instead of _compiler which is our kernels
    def run(*args, **kwargs):
        return original_model.forward2(*args, **kwargs)

I got:

...
are/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 235, in register_attr_or_module
    return wrap_name(k)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 209, in wrap_name
    return TensorVariable.create(
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/variables/tensor.py", line 206, in create
    proxy.tracer.real_value_cache[proxy.node] = _clone_input(example_value)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/variables/tensor.py", line 154, in _clone_input
    value = clone_input(value)
  File "/home/geantvert/.local/share/virtualenvs/kernl/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 412, in clone_input
    y = torch.clone(x)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 23.69 GiB total capacity; 21.50 GiB already allocated; 45.94 MiB free; 22.67 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

It's related to a tensor copy done by dynamo (outside our scope).


def clone_input(x):
    """copy while preserving strides"""
    with torch.no_grad():
        needed_size = sum(
            (shape - 1) * stride for shape, stride in zip(x.size(), x.stride())
        )
        if x.is_quantized:
            result = torch.empty_quantized((needed_size + 32,), x)
        else:
            result = torch.empty(needed_size + 32, dtype=x.dtype, device=x.device)
        cache_line_offset = (
            (x.data_ptr() - result.data_ptr()) % 32
        ) // x.element_size()
        result.as_strided_(x.size(), x.stride(), cache_line_offset)
        try:
            result.copy_(x.clone())
            if x.is_leaf:
                result.requires_grad_(x.requires_grad)
            if x.is_leaf and x.grad is not None:
                result.grad = clone_input(x.grad)
        except RuntimeError:
            # RuntimeError: unsupported operation: more than one element of the written-to
            # tensor refers to a single memory location. Please clone() the tensor before
            # performing the operation.
            y = torch.clone(x)  # <-- CRASH HERE
            if x.is_leaf:
                y.requires_grad_(x.requires_grad)
            if x.is_leaf and x.grad is not None:
                y.grad = clone_input(x.grad)
            return y
        return result

Need to dig, but may be related to the issue we have with Whisper large and beam 5, possible @gaetansnl ?

pommedeterresautee commented 1 year ago

@gaetansnl also of interest: https://github.com/pytorch/pytorch/issues/93774 And to conclude, there is no issue with dynamo + our kernels if we just perform replacement on the encoder part.

caffeinetoomuch commented 1 year ago

@pommedeterresautee Oh so, just to understand the context, fix for torchdynamo needs to happen right? And this is coming from optimizing decoder?

pommedeterresautee commented 1 year ago

yep. They offered a dirty fix, not yet tested and too dirty to be implemented by us: https://github.com/pytorch/torchdynamo/issues/1955#issuecomment-1342899757

caffeinetoomuch commented 1 year ago

One quick question! So basically if you optimize with eager, is it just same as applying optimization of just torchdynamo? or is there still extra optimizations that are being done with eager mode? I am also asking as I am little bit confused with where(and how) CUDA graph is actually being utilized to optimize.

pommedeterresautee commented 1 year ago

You are right on 1/, eager == torchdynamo, aka capture Python code, convert to (Fx) graph, replay (with guards, etc.). By itself, torchdynamo may make code slightly faster, by cleaning it mostly, but keep in mind it's a side effect, its purpose is to "just" capture graph and decide when to re-capture or just replay existing graph.

CUDA graph on the other side is a way to save the list of CUDA kernels launched (with their parameters). It means you don't go through Python to replay a graph, you just do a single call, and the GPU takes care of everything (it's even faster than CUDA calls for instance). Bonus, it slightly optimizes the graph.

It is meaningful because on vanilla Pytorch, you exec Pytorch code, which call CUDA code in an async way. Python overhead is basically hidden by the fact that calls are async (it returns almost instantly, not waiting for CUDA kernel to finish). In training it works well BUT:

So, in those situations, you just don't want to go through Python, and use CUDA graph. (FWIW, that's also the reason why compiler is coming to Pytorch, check https://pytorch-dev-podcast.simplecast.com/episodes/pytorch-20 about the "death threat" of faster GPUs)

A second reason to use CUDA graph in this project is because we based our code on Triton. Unfortunately, Triton has lots of CPU overhead, and to get decent performances, we need CUDA graph. There is a new endpoint on very recent version of Triton which has less overhead, @ayoub-louati is working on trying to leverage it (making CUDA graph an optional improvement).

See https://developer.nvidia.com/blog/cuda-graphs/ for more info.

github-actions[bot] commented 1 year ago

This issue is marked as stale because it has been open for 30 days with no activity.

caffeinetoomuch commented 1 year ago

According to the other thread, it seems the memory issue has been fixed with the recent nightly version of torch 2.0. Should I try again?

pommedeterresautee commented 1 year ago

PyTorch 2.0 has not yet fixed the point but new kernl version has done it. So time to retry

github-actions[bot] commented 1 year ago

This issue is marked as stale because it has been open for 30 days with no activity.

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.