pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
81.65k stars 21.9k forks source link

Inductor generates unnecessary allocation + copy operations for custom ops with mutable inputs #127660

Closed HanGuo97 closed 1 month ago

HanGuo97 commented 2 months ago

🐛 Describe the bug

We have a custom kernel with the following schema

custom_func(Tensor input, Tensor weight, Tensor(a!) workspace) -> Tensor

here workspace is a scratch space that we pass into the (CUDA) kernel for its use, and it's very large. As such, we pre-allocate the workspace at the beginning, and simply pass the same workspace to every function call regardless of other arguments. The workspace is zero-initialized, and will be zero'ed/cleaned inside the kernel. Since it's only a scratch space, we do not return it, and simply mark it as a mutable input argument.

When we compile the model that uses this function. The model in this case is just a MLP with 3 separate calls to this function. We observe the following codes are generated. Importantly, we note that the generated code included one extra memory allocation for a temporary buffer the same size as the workspace. Then, the code copies the workspace into that buffer, and pass that buffer to kernel calls, copy that buffer back into the workspace, before deleting workspace. Given that the workspace has very non-trivial size, such unnecessary copy/allocation operations significantly slows down the compiled function.

A workaround I found is simply not marking workspace as mutable at all, but this is a bit hacky. Is there a better way to handle this? Thanks in advance for your help!


triton_poi_fused_0 = async_compile.triton('triton_poi_fused_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[134217728], 
    filename=__file__,
    triton_meta={'signature': {0: '*u8', 1: '*u8', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2,))]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_0', 'mutated_arg_names': [], 'no_x_dim': False, 'backend_hash': '85130a58675c9402113372c3f498d2365c29a03094a20bdb7eac4a0e95af91d8'},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_0(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 113247936
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')

triton_poi_fused_2 = async_compile.triton('triton_poi_fused_2', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor import triton_helpers, triton_heuristics
from torch._inductor.ir import ReductionHint, TileHint
from torch._inductor.triton_helpers import libdevice, math as tl_math
from torch._inductor.triton_heuristics import AutotuneHint
from torch._inductor.utils import instance_descriptor

@triton_heuristics.pointwise(
    size_hints=[134217728], 
    filename=__file__,
    triton_meta={'signature': {0: '*u8', 1: '*u8', 2: 'i32'}, 'device': 0, 'device_type': 'cuda', 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2,))]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_2', 'mutated_arg_names': ['out_ptr0'], 'no_x_dim': False, 'backend_hash': '85130a58675c9402113372c3f498d2365c29a03094a20bdb7eac4a0e95af91d8'},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_2(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 113247936
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tl.store(out_ptr0 + (x0), tmp0, xmask)
''', device_str='cuda')

    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        workspace_buffer = empty_strided_cuda((113247936, ), (1, ), torch.uint8)
        # Source Nodes: [], Original ATen: []
        stream0 = get_raw_stream(0)
        triton_poi_fused_0.run(arg4_1, buf0, 113247936, grid=grid(113247936), stream=stream0)
        # Source Nodes: [], Original ATen: []
        buf1 = torch.ops.custom_ops.custom_kernel.default(..., reinterpret_tensor(workspace_buffer, (113247936, ), (1, ), 0))
        ...
        # Source Nodes: [], Original ATen: []
        buf4 = .ops.custom_ops.custom_kernel.default(..., reinterpret_tensor(workspace_buffer, (113247936, ), (1, ), 0))
        ....
        # Source Nodes: [mul, silu], Original ATen: [aten.mul, aten.silu]
        buf8 = .ops.custom_ops.custom_kernel.default(..., reinterpret_tensor(workspace_buffer, (113247936, ), (1, ), 0))
        ....
        # Source Nodes: [], Original ATen: []
        triton_poi_fused_2.run(workspace_buffer, arg4_1, 113247936, grid=grid(113247936), stream=stream0)
        del arg4_1
    return (buf10, )

Error logs

No response

Minified repro

No response

Versions

PyTorch version: 2.3.0
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.26.4
Libc version: glibc-2.35

Python version: 3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-182-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100 80GB PCIe
Nvidia driver version: 545.23.08
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             80
On-line CPU(s) list:                0-79
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Gold 5218R CPU @ 2.10GHz
CPU family:                         6
Model:                              85
Thread(s) per core:                 2
Core(s) per socket:                 20
Socket(s):                          2
Stepping:                           7
CPU max MHz:                        4000.0000
CPU min MHz:                        800.0000
BogoMIPS:                           4200.00
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm mpx rdt_a avx512f avx512dq rdseed adx smap clflushopt clwb intel_pt avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts hwp hwp_act_window hwp_epp hwp_pkg_req pku ospke avx512_vnni md_clear flush_l1d arch_capabilities
Virtualization:                     VT-x
L1d cache:                          1.3 MiB (40 instances)
L1i cache:                          1.3 MiB (40 instances)
L2 cache:                           40 MiB (40 instances)
L3 cache:                           55 MiB (2 instances)
NUMA node(s):                       2
NUMA node0 CPU(s):                  0-19,40-59
NUMA node1 CPU(s):                  20-39,60-79
Vulnerability Gather data sampling: Mitigation; Microcode
Vulnerability Itlb multihit:        KVM: Mitigation: Split huge pages
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Retbleed:             Mitigation; Enhanced IBRS
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] optree==0.11.0
[pip3] torch==2.3.0
[pip3] torchaudio==2.3.0
[pip3] torchelastic==0.2.2
[pip3] torchvision==0.18.0
[pip3] triton==2.3.0
[conda] blas                      1.0                         mkl  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46344  
[conda] mkl-service               2.4.0           py310h5eee18b_1  
[conda] mkl_fft                   1.3.8           py310h5eee18b_0  
[conda] mkl_random                1.2.4           py310hdb19cb5_0  
[conda] numpy                     1.26.4          py310h5f9d8c6_0  
[conda] numpy-base                1.26.4          py310hb5e798b_0  
[conda] optree                    0.11.0                   pypi_0    pypi
[conda] pytorch                   2.3.0           py3.10_cuda12.1_cudnn8.9.2_0    pytorch
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                2.3.0               py310_cu121    pytorch
[conda] torchelastic              0.2.2                    pypi_0    pypi
[conda] torchtriton               2.3.0                     py310    pytorch
[conda] torchvision               0.18.0              py310_cu121    pytorch

cc @bdhirsh @ezyang @anijain2305 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @zou3519 @msaroufim

ezyang commented 2 months ago

I'd be interested to know if it is a functionalization/reinplacing problem, or an Inductor buffer scheduling problem. Being able to see the post grad graph (TORCH_LOGS=aot_graphs output would help, or if you are willing to run TORCH_TRACE=/tmp/logs and upload that entire directory it would definitely have all the info I'm looking for.)

HanGuo97 commented 2 months ago

Thanks for the help, and happy to! Is this what you are looking for?

trace.tar.gz

HanGuo97 commented 2 months ago

I spent a bit more time looking around the codebase. And it seems like the problem is related to functionalization and reinplacing as you suggested. Notably, it seems like Inductor auto-functionalizes the custom function (by allocating + copying the scratch space workspace). However, these are not reinplaced (?), and hence these expensive copies are left as they are.

ezyang commented 2 months ago

🤔 the directory you uploaded is not what TORCH_TRACE creates LOL. I'll see if it has the info I need anyway though...

ezyang commented 2 months ago

OK, the trace does make it clear. In fx_graph_transformed.py, we see what has happened to the graph after we defunctionalized it:

class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "i16[3584, 4096]", arg1_1: "f16[14336, 32]", arg2_1: "f16[16]", arg3_1: "f32[16, 16, 1]", arg4_1: "u8[113247936]", arg5_1: "i16[3584, 4096]", arg6_1: "f16[14336, 32]", arg7_1: "f16[16]", arg8_1: "f32[16, 16, 1]", arg9_1: "i16[1024, 14336]", arg10_1: "f16[4096, 112]", arg11_1: "f16[16]", arg12_1: "f32[16, 16, 1]", arg13_1: "f16[1, 1, 4096]"):
        # No stacktrace found for following nodes
        as_strided_default: "u8[113247936]" = torch.ops.aten.as_strided.default(arg4_1, [113247936], [1], 0)
        clone_default: "u8[113247936]" = torch.ops.aten.clone.default(as_strided_default);  as_strided_default = None
        as_strided_default_1: "u8[113247936]" = torch.ops.aten.as_strided.default(clone_default, [113247936], [1], 0);  clone_default = None
        qgemm_simple_80_default_2: "f16[1, 14336]" = torch.ops.custom.custom_op.default(arg13_1, arg0_1, arg1_1, arg2_1, arg3_1, as_strided_default_1);  arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
        qgemm_simple_80_default_1: "f16[1, 14336]" = torch.ops.custom.custom_op.default(arg13_1, arg5_1, arg6_1, arg7_1, arg8_1, as_strided_default_1);  arg13_1 = arg5_1 = arg6_1 = arg7_1 = arg8_1 = None

Note that arg4_1 is the workspace (I determined this by inspecting fx_graph_readable.py, which includes kwarg names for the custom op call so I could identify the workspace argument). We see that we clone it once, restride it, and then use as_strided_default_1 for the rest of the custom_op calls.

This is the input mutation problem: it looks like defunctionalization is unwilling to directly mutate the input argument, pushing it out to the end when it copies it back in:

        # File: /workspace/main/custom/integrations.py:242 in forward, code: return cast(Callable, custom.custom_op)(
        copy_: "u8[113247936]" = torch.ops.aten.copy_.default(arg4_1, as_strided_default_1);  arg4_1 = as_strided_default_1 = None
        return (qgemm_simple_80_default,)

This is because, of course, this is what the functionalized graph does:

        getitem_4: "f16[1, 4096]" = auto_functionalized_2[0]
        getitem_5: "u8[113247936]" = auto_functionalized_2[1];  auto_functionalized_2 = None
        copy_: "u8[113247936]" = torch.ops.aten.copy_.default(arg4_1, getitem_5);  arg4_1 = getitem_5 = None
        return (getitem_4,)

since the functionalized graph isn't allowed to mutate the input until the very end.

I think this should fixed by making defunctionalization smarter, since at this point in time we should be able to work out if it's safe to push mutations in.

cc @bdhirsh

ezyang commented 2 months ago

By the way, you should benchmark if workspace actually helps. Because PyTorch has a caching allocator, if you repeatedly alloc/dealloc the workspace, this should actually be very cheap (unless some sort of fragmentation happens, but for a buffer of this size that should be pretty unlikely). If it isn't actually speeding things up, removing it would unblock you and remove the copy operaiton.

HanGuo97 commented 2 months ago

Oops, I actually couldn't get the TORCH_TRACE to output anything, so I used this instead.

torch._inductor.config.debug = True
torch._inductor.config.trace.enabled = True
torch._inductor.config.trace.debug_dir = "..."

I thought they output the same thing, but I was wrong I guess... Either way, glad that you are able to parse useful things out of it!

Regarding performance. Yes, this did have a non-trivial (but not horrible) performance hit. For reference, when I simply remove the mutable annotation in the schema (i.e., removing (a!)), I can see ~4% speed up when measured through gpt-Fast.

I agree a smarter defunctionalization could help, but I can imagine this requires changes to the inductor logic and won't be a quick fix. Is there anything I can do in the meantime as a workaround? Is, say, using a schema that treats the workspace as non-mutable (i.e., Tensor not Tensor (a!)) an "okay-ish" hack in the meanwhile?

Thanks again for taking the time during the weekend!

ezyang commented 2 months ago

TORCH_TRACE doesn't output anything to stderr, it goes to the directory you specified.

The main hazard to not specifying that the input is mutable is if we do some optimization pass where we alias the input to something else, thinking that it doesn't matter, but now you are clobbering something else by secretly mutating. However, it should be pretty difficult for this to happen since you are passing in an input and so we're not going to alias that with anything. Another potential hazard is if we think it is OK to run two kernels in parallel as they have no write conflicts, but in fact they do via this argument (but we don't do this optimization right now).

HanGuo97 commented 2 months ago

Got it! So maybe it's better to wait for the patch.

I think my question is fully addressed (thanks!), but I will leave this open for now if you need a reminder, but feel free to close it if you prefer.

ezyang commented 2 months ago

Another thing you could do is a cublas style internal workspace stored in a global, and just not pass it in as an argument haha

HanGuo97 commented 2 months ago

I was actually thinking about the same thing, but I didn't realize that's what cuBLAS does too.

Would those lead to the same hazards as you mentioned earlier? (After all, it's the same thing, except the global variable is hidden from Python)

zou3519 commented 2 months ago

we might start needing a "module: auto-functionalized" label...

ezyang commented 2 months ago

@HanGuo97 no, because now the tensor doesn't show up in the graph. Might be bad if you want to deallocate the workspace when you're done though.

eellison commented 2 months ago

In this case, is the the tensor even mutated if it ends up in the same exact state (zeros) ? I guess if we were to auto parallelize, but we don't do that today.

HanGuo97 commented 2 months ago

@ezyang @eellison yes, that's the question I have.

From an API point of view, the workspace is not "mutated" in the sense that the custom op will clean up the workspace so it was zero coming in, and zero coming out. In that sense, doing the cuBLAS way of handling on the C++ side makes. But in the imaginary near future in which the custom ops will be auto-parallelized, this might cause a race condition (since it does involve mutations)?

zou3519 commented 2 months ago

The model in this case is just a MLP with 3 separate calls to this function

@HanGuo97 do you have the exact model used to repro this? Your code sample looks like just the inductor output.

zou3519 commented 2 months ago

@Chillee we talked about this for a bit earlier this week, my thoughts right now are: 1) we should still mark the buffer as being mutated (because it wouldn't be multi-thread safe). In general we should mark Tensors as being mutated if we need mutable_data_ptr access 2) I agree with @ezyang's assessment that this is a reinplacing pass problem. We can make it smarter.

HanGuo97 commented 2 months ago

The model is simply the MLP layer from gpt-fast.

zou3519 commented 2 months ago

Repro:

import torch

@torch.library.custom_op("mylib::foo", mutates_args={"out"})
def foo(x: torch.Tensor, out: torch.Tensor) -> None:
    out.copy_(x.sin())

@torch.compile(backend="inductor", fullgraph=True)
def f(x, out): # E: Function is missing a type annotation  [no-untyped-def]
    foo(x, out)
    foo(out, out)
    foo(out, out)

x = torch.randn(3)
out = torch.randn(3)
f(x, out)
assert torch.allclose(out, x.sin().sin().sin())

Using TORCH_LOGS=output_code gives:

cpp_fused_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_rzou/sk/cskh5dx62fglpphcrl6723dnmowdabouerrzy3dmqcngbxwfa7bv.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr0)
{
    {
        #pragma omp simd simdlen(8)
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(3L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = in_ptr0[static_cast<long>(x0)];
            out_ptr0[static_cast<long>(x0)] = tmp0;
        }
    }
}
''')

cpp_fused_1 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_rzou/sk/cskh5dx62fglpphcrl6723dnmowdabouerrzy3dmqcngbxwfa7bv.h"
extern "C" void kernel(const float* in_ptr0,
                       float* out_ptr0)
{
    {
        #pragma omp simd simdlen(8)
        for(long x0=static_cast<long>(0L); x0<static_cast<long>(3L); x0+=static_cast<long>(1L))
        {
            auto tmp0 = in_ptr0[static_cast<long>(x0)];
            out_ptr0[static_cast<long>(x0)] = tmp0;
        }
    }
}
''')

async_compile.wait(globals())
del async_compile

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    buf0 = empty_strided_cpu((3, ), (1, ), torch.float32)
    cpp_fused_0(arg0_1, buf0)
    # Source Nodes: [], Original ATen: []
    buf1 = torch.ops.mylib.foo.default(arg1_1, reinterpret_tensor(buf0, (3, ), (1, ), 0))
    del arg1_1
    # Source Nodes: [], Original ATen: []
    buf3 = torch.ops.mylib.foo.default(reinterpret_tensor(buf0, (3, ), (1, ), 0), reinterpret_tensor(buf0, (3, ), (1, ), 0))
    # Source Nodes: [], Original ATen: []
    buf5 = torch.ops.mylib.foo.default(reinterpret_tensor(buf0, (3, ), (1, ), 0), reinterpret_tensor(buf0, (3, ), (1, ), 0))
    cpp_fused_1(buf0, arg0_1)
    del arg0_1
    return ()

We can see that we're cloning arg0_1 before calling the mutable ops on it, and then copy-ing the result back into arg0_1. This clone + copy is unnecessary.

zou3519 commented 1 month ago

~Can't repro with torch.library.Library. Maybe this is just a needs_input_stride_order problem:~

EDIT: we can repro with torch.library.Library

import torch

m = torch.library.Library("mylib", "FRAGMENT")

m.define("foo(Tensor x, Tensor(a!) out) -> ()")

def foo(x: torch.Tensor, out: torch.Tensor) -> None:
    out.copy_(x.sin())

m.impl("foo", foo, "CPU")

@torch.compile(backend="inductor", fullgraph=True)
def f(x, out): # E: Function is missing a type annotation  [no-untyped-def]
    torch.ops.mylib.foo(x, out)
    torch.ops.mylib.foo(out, out)
    torch.ops.mylib.foo(out, out)

x = torch.randn(3)
out = torch.randn(3)
f(x, out)
assert torch.allclose(out, x.sin().sin().sin())
zou3519 commented 1 month ago

@HanGuo97 do you have a script we could run to reproduce your issue? I'm worried there are more bugs hiding here

laithsakka commented 1 month ago

I think this might be related to memory regression i see when using vllm with torch.compile

what is happening is that the original code is accessing some tensors, (

           .....
           kv_cache: torch.Tensor
           key_cache = kv_cache[0]
           value_cache = kv_cache[1]
            # Reshape the input keys and values and store them in the cache.
            # If kv_cache is not provided, the new key and value tensors are
            # not cached. This happens during the initial memory profiling run.
            ops.reshape_and_cache_flash(
                key,
                value,
                key_cache,
                value_cache,
                attn_metadata.slot_mapping.flatten(),
                self.kv_cache_dtype,
            )

those access are translated into

....
[8/1_1] [__aot_graphs]          # File: /home/lsakka/vllm/vllm/attention/backends/flash_attn.py:289 in forward, code: key_cache = kv_cache[0]
[8/1_1] [__aot_graphs]         select: "f16[128248, 16, 12, 64][12288, 768, 64, 1]cuda:0" = torch.ops.aten.select.int(arg3_1, 0, 0)
[8/1_1] [__aot_graphs]         
[8/1_1] [__aot_graphs]          # File: /home/lsakka/vllm/vllm/attention/backends/flash_attn.py:290 in forward, code: value_cache = kv_cache[1]
[8/1_1] [__aot_graphs]         select_1: "f16[128248, 16, 12, 64][12288, 768, 64, 1]cuda:0" = torch.ops.aten.select.int(arg3_1, 0, 1)
[8/1_1] [__aot_graphs]         auto_functionalized = torch._higher_order_ops.auto_functionalize.auto_functionalized(torch.ops._C_cache_ops.reshape_and_cache_flash.default, key = view_1, value = view_2, key_cache = select, value_cache = select_1, slot_mapping = arg4_1, kv_cache_dtype = 'auto');  view_1 = view_2 = select = select_1 = None

then when inductor runs we actually allocate tensors and calls triton_poi_fused_0 and run out of memory at buf 1 allocations

torch.cuda.set_device(0)
[8/1_1] [__output_code]         buf0 = empty_strided_cuda((1575911424, ), (1, ), torch.float16)
[8/1_1] 
[8/1_1] [__output_code]         stream0 = get_raw_stream(0)
[8/1_1] [__output_code]         triton_poi_fused_0.run(arg3_1, buf0, 1575911424, grid=grid(1575911424), stream=stream0)
[8/1_1] [__output_code]         buf1 = empty_strided_cuda((3151822848, ), (1, ), torch.float16)
[8/1_1] [__output_code]         triton_poi_fused_1.run(arg3_1, buf1, 3151822848, grid=grid(3151822848), stream=stream0)
[8/1_1] [__output_code]         torch.ops._C_cache_ops.reshape_and_cache_flash.default(reinterpret_tensor(arg1_1, (256, 12, 64), (2304, 64, 1), 0), reinterpret_tensor(arg2_1, (256, 12, 64), (2304, 64, 1), 0), reinterpret_tensor(buf0, (128248, 16, 12, 64), (12288, 768, 64, 1), 0), reinterpret_tensor(buf1, (128248, 16, 12, 64), (12288, 768, 64, 1), 1575911424), arg4_1, 'auto')

cc @anijain2305

zou3519 commented 1 month ago

Yeah this issue is pretty load-bearing