Closed HanGuo97 closed 1 month ago
I'd be interested to know if it is a functionalization/reinplacing problem, or an Inductor buffer scheduling problem. Being able to see the post grad graph (TORCH_LOGS=aot_graphs output would help, or if you are willing to run TORCH_TRACE=/tmp/logs and upload that entire directory it would definitely have all the info I'm looking for.)
Thanks for the help, and happy to! Is this what you are looking for?
I spent a bit more time looking around the codebase. And it seems like the problem is related to functionalization and reinplacing as you suggested. Notably, it seems like Inductor auto-functionalizes the custom function (by allocating + copying the scratch space workspace
). However, these are not reinplaced (?), and hence these expensive copies are left as they are.
🤔 the directory you uploaded is not what TORCH_TRACE creates LOL. I'll see if it has the info I need anyway though...
OK, the trace does make it clear. In fx_graph_transformed.py
, we see what has happened to the graph after we defunctionalized it:
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "i16[3584, 4096]", arg1_1: "f16[14336, 32]", arg2_1: "f16[16]", arg3_1: "f32[16, 16, 1]", arg4_1: "u8[113247936]", arg5_1: "i16[3584, 4096]", arg6_1: "f16[14336, 32]", arg7_1: "f16[16]", arg8_1: "f32[16, 16, 1]", arg9_1: "i16[1024, 14336]", arg10_1: "f16[4096, 112]", arg11_1: "f16[16]", arg12_1: "f32[16, 16, 1]", arg13_1: "f16[1, 1, 4096]"):
# No stacktrace found for following nodes
as_strided_default: "u8[113247936]" = torch.ops.aten.as_strided.default(arg4_1, [113247936], [1], 0)
clone_default: "u8[113247936]" = torch.ops.aten.clone.default(as_strided_default); as_strided_default = None
as_strided_default_1: "u8[113247936]" = torch.ops.aten.as_strided.default(clone_default, [113247936], [1], 0); clone_default = None
qgemm_simple_80_default_2: "f16[1, 14336]" = torch.ops.custom.custom_op.default(arg13_1, arg0_1, arg1_1, arg2_1, arg3_1, as_strided_default_1); arg0_1 = arg1_1 = arg2_1 = arg3_1 = None
qgemm_simple_80_default_1: "f16[1, 14336]" = torch.ops.custom.custom_op.default(arg13_1, arg5_1, arg6_1, arg7_1, arg8_1, as_strided_default_1); arg13_1 = arg5_1 = arg6_1 = arg7_1 = arg8_1 = None
Note that arg4_1 is the workspace (I determined this by inspecting fx_graph_readable.py
, which includes kwarg names for the custom op call so I could identify the workspace argument). We see that we clone it once, restride it, and then use as_strided_default_1
for the rest of the custom_op calls.
This is the input mutation problem: it looks like defunctionalization is unwilling to directly mutate the input argument, pushing it out to the end when it copies it back in:
# File: /workspace/main/custom/integrations.py:242 in forward, code: return cast(Callable, custom.custom_op)(
copy_: "u8[113247936]" = torch.ops.aten.copy_.default(arg4_1, as_strided_default_1); arg4_1 = as_strided_default_1 = None
return (qgemm_simple_80_default,)
This is because, of course, this is what the functionalized graph does:
getitem_4: "f16[1, 4096]" = auto_functionalized_2[0]
getitem_5: "u8[113247936]" = auto_functionalized_2[1]; auto_functionalized_2 = None
copy_: "u8[113247936]" = torch.ops.aten.copy_.default(arg4_1, getitem_5); arg4_1 = getitem_5 = None
return (getitem_4,)
since the functionalized graph isn't allowed to mutate the input until the very end.
I think this should fixed by making defunctionalization smarter, since at this point in time we should be able to work out if it's safe to push mutations in.
cc @bdhirsh
By the way, you should benchmark if workspace
actually helps. Because PyTorch has a caching allocator, if you repeatedly alloc/dealloc the workspace, this should actually be very cheap (unless some sort of fragmentation happens, but for a buffer of this size that should be pretty unlikely). If it isn't actually speeding things up, removing it would unblock you and remove the copy operaiton.
Oops, I actually couldn't get the TORCH_TRACE
to output anything, so I used this instead.
torch._inductor.config.debug = True
torch._inductor.config.trace.enabled = True
torch._inductor.config.trace.debug_dir = "..."
I thought they output the same thing, but I was wrong I guess... Either way, glad that you are able to parse useful things out of it!
Regarding performance. Yes, this did have a non-trivial (but not horrible) performance hit. For reference, when I simply remove the mutable annotation in the schema (i.e., removing (a!)
), I can see ~4% speed up when measured through gpt-Fast.
I agree a smarter defunctionalization could help, but I can imagine this requires changes to the inductor logic and won't be a quick fix. Is there anything I can do in the meantime as a workaround? Is, say, using a schema that treats the workspace
as non-mutable (i.e., Tensor
not Tensor (a!)
) an "okay-ish" hack in the meanwhile?
Thanks again for taking the time during the weekend!
TORCH_TRACE
doesn't output anything to stderr, it goes to the directory you specified.
The main hazard to not specifying that the input is mutable is if we do some optimization pass where we alias the input to something else, thinking that it doesn't matter, but now you are clobbering something else by secretly mutating. However, it should be pretty difficult for this to happen since you are passing in an input and so we're not going to alias that with anything. Another potential hazard is if we think it is OK to run two kernels in parallel as they have no write conflicts, but in fact they do via this argument (but we don't do this optimization right now).
Got it! So maybe it's better to wait for the patch.
I think my question is fully addressed (thanks!), but I will leave this open for now if you need a reminder, but feel free to close it if you prefer.
Another thing you could do is a cublas style internal workspace stored in a global, and just not pass it in as an argument haha
I was actually thinking about the same thing, but I didn't realize that's what cuBLAS does too.
Would those lead to the same hazards as you mentioned earlier? (After all, it's the same thing, except the global variable is hidden from Python)
we might start needing a "module: auto-functionalized" label...
@HanGuo97 no, because now the tensor doesn't show up in the graph. Might be bad if you want to deallocate the workspace when you're done though.
In this case, is the the tensor even mutated if it ends up in the same exact state (zeros) ? I guess if we were to auto parallelize, but we don't do that today.
@ezyang @eellison yes, that's the question I have.
From an API point of view, the workspace is not "mutated" in the sense that the custom op will clean up the workspace so it was zero coming in, and zero coming out. In that sense, doing the cuBLAS way of handling on the C++ side makes. But in the imaginary near future in which the custom ops will be auto-parallelized, this might cause a race condition (since it does involve mutations)?
The model in this case is just a MLP with 3 separate calls to this function
@HanGuo97 do you have the exact model used to repro this? Your code sample looks like just the inductor output.
@Chillee we talked about this for a bit earlier this week, my thoughts right now are: 1) we should still mark the buffer as being mutated (because it wouldn't be multi-thread safe). In general we should mark Tensors as being mutated if we need mutable_data_ptr access 2) I agree with @ezyang's assessment that this is a reinplacing pass problem. We can make it smarter.
Repro:
import torch
@torch.library.custom_op("mylib::foo", mutates_args={"out"})
def foo(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x.sin())
@torch.compile(backend="inductor", fullgraph=True)
def f(x, out): # E: Function is missing a type annotation [no-untyped-def]
foo(x, out)
foo(out, out)
foo(out, out)
x = torch.randn(3)
out = torch.randn(3)
f(x, out)
assert torch.allclose(out, x.sin().sin().sin())
Using TORCH_LOGS=output_code gives:
cpp_fused_0 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_rzou/sk/cskh5dx62fglpphcrl6723dnmowdabouerrzy3dmqcngbxwfa7bv.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(3L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
out_ptr0[static_cast<long>(x0)] = tmp0;
}
}
}
''')
cpp_fused_1 = async_compile.cpp_pybinding(['const float*', 'float*'], '''
#include "/tmp/torchinductor_rzou/sk/cskh5dx62fglpphcrl6723dnmowdabouerrzy3dmqcngbxwfa7bv.h"
extern "C" void kernel(const float* in_ptr0,
float* out_ptr0)
{
{
#pragma omp simd simdlen(8)
for(long x0=static_cast<long>(0L); x0<static_cast<long>(3L); x0+=static_cast<long>(1L))
{
auto tmp0 = in_ptr0[static_cast<long>(x0)];
out_ptr0[static_cast<long>(x0)] = tmp0;
}
}
}
''')
async_compile.wait(globals())
del async_compile
def call(args):
arg0_1, arg1_1 = args
args.clear()
assert_size_stride(arg0_1, (3, ), (1, ))
assert_size_stride(arg1_1, (3, ), (1, ))
buf0 = empty_strided_cpu((3, ), (1, ), torch.float32)
cpp_fused_0(arg0_1, buf0)
# Source Nodes: [], Original ATen: []
buf1 = torch.ops.mylib.foo.default(arg1_1, reinterpret_tensor(buf0, (3, ), (1, ), 0))
del arg1_1
# Source Nodes: [], Original ATen: []
buf3 = torch.ops.mylib.foo.default(reinterpret_tensor(buf0, (3, ), (1, ), 0), reinterpret_tensor(buf0, (3, ), (1, ), 0))
# Source Nodes: [], Original ATen: []
buf5 = torch.ops.mylib.foo.default(reinterpret_tensor(buf0, (3, ), (1, ), 0), reinterpret_tensor(buf0, (3, ), (1, ), 0))
cpp_fused_1(buf0, arg0_1)
del arg0_1
return ()
We can see that we're cloning arg0_1 before calling the mutable ops on it, and then copy-ing the result back into arg0_1. This clone + copy is unnecessary.
~Can't repro with torch.library.Library. Maybe this is just a needs_input_stride_order problem:~
EDIT: we can repro with torch.library.Library
import torch
m = torch.library.Library("mylib", "FRAGMENT")
m.define("foo(Tensor x, Tensor(a!) out) -> ()")
def foo(x: torch.Tensor, out: torch.Tensor) -> None:
out.copy_(x.sin())
m.impl("foo", foo, "CPU")
@torch.compile(backend="inductor", fullgraph=True)
def f(x, out): # E: Function is missing a type annotation [no-untyped-def]
torch.ops.mylib.foo(x, out)
torch.ops.mylib.foo(out, out)
torch.ops.mylib.foo(out, out)
x = torch.randn(3)
out = torch.randn(3)
f(x, out)
assert torch.allclose(out, x.sin().sin().sin())
@HanGuo97 do you have a script we could run to reproduce your issue? I'm worried there are more bugs hiding here
I think this might be related to memory regression i see when using vllm with torch.compile
what is happening is that the original code is accessing some tensors, (
.....
kv_cache: torch.Tensor
key_cache = kv_cache[0]
value_cache = kv_cache[1]
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping.flatten(),
self.kv_cache_dtype,
)
those access are translated into
....
[8/1_1] [__aot_graphs] # File: /home/lsakka/vllm/vllm/attention/backends/flash_attn.py:289 in forward, code: key_cache = kv_cache[0]
[8/1_1] [__aot_graphs] select: "f16[128248, 16, 12, 64][12288, 768, 64, 1]cuda:0" = torch.ops.aten.select.int(arg3_1, 0, 0)
[8/1_1] [__aot_graphs]
[8/1_1] [__aot_graphs] # File: /home/lsakka/vllm/vllm/attention/backends/flash_attn.py:290 in forward, code: value_cache = kv_cache[1]
[8/1_1] [__aot_graphs] select_1: "f16[128248, 16, 12, 64][12288, 768, 64, 1]cuda:0" = torch.ops.aten.select.int(arg3_1, 0, 1)
[8/1_1] [__aot_graphs] auto_functionalized = torch._higher_order_ops.auto_functionalize.auto_functionalized(torch.ops._C_cache_ops.reshape_and_cache_flash.default, key = view_1, value = view_2, key_cache = select, value_cache = select_1, slot_mapping = arg4_1, kv_cache_dtype = 'auto'); view_1 = view_2 = select = select_1 = None
then when inductor runs we actually allocate tensors and calls triton_poi_fused_0 and run out of memory at buf 1 allocations
torch.cuda.set_device(0)
[8/1_1] [__output_code] buf0 = empty_strided_cuda((1575911424, ), (1, ), torch.float16)
[8/1_1]
[8/1_1] [__output_code] stream0 = get_raw_stream(0)
[8/1_1] [__output_code] triton_poi_fused_0.run(arg3_1, buf0, 1575911424, grid=grid(1575911424), stream=stream0)
[8/1_1] [__output_code] buf1 = empty_strided_cuda((3151822848, ), (1, ), torch.float16)
[8/1_1] [__output_code] triton_poi_fused_1.run(arg3_1, buf1, 3151822848, grid=grid(3151822848), stream=stream0)
[8/1_1] [__output_code] torch.ops._C_cache_ops.reshape_and_cache_flash.default(reinterpret_tensor(arg1_1, (256, 12, 64), (2304, 64, 1), 0), reinterpret_tensor(arg2_1, (256, 12, 64), (2304, 64, 1), 0), reinterpret_tensor(buf0, (128248, 16, 12, 64), (12288, 768, 64, 1), 0), reinterpret_tensor(buf1, (128248, 16, 12, 64), (12288, 768, 64, 1), 1575911424), arg4_1, 'auto')
cc @anijain2305
Yeah this issue is pretty load-bearing
🐛 Describe the bug
We have a custom kernel with the following schema
here
workspace
is a scratch space that we pass into the (CUDA) kernel for its use, and it's very large. As such, we pre-allocate theworkspace
at the beginning, and simply pass the same workspace to every function call regardless of other arguments. The workspace is zero-initialized, and will be zero'ed/cleaned inside the kernel. Since it's only a scratch space, we do not return it, and simply mark it as a mutable input argument.When we compile the model that uses this function. The model in this case is just a MLP with 3 separate calls to this function. We observe the following codes are generated. Importantly, we note that the generated code included one extra memory allocation for a temporary buffer the same size as the
workspace
. Then, the code copies theworkspace
into that buffer, and pass that buffer to kernel calls, copy that buffer back into theworkspace
, before deletingworkspace
. Given that theworkspace
has very non-trivial size, such unnecessary copy/allocation operations significantly slows down the compiled function.A workaround I found is simply not marking
workspace
as mutable at all, but this is a bit hacky. Is there a better way to handle this? Thanks in advance for your help!Error logs
No response
Minified repro
No response
Versions
cc @bdhirsh @ezyang @anijain2305 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @zou3519 @msaroufim