Closed msaroufim closed 3 months ago
hi, any solutions for this?
Here is a minimal repro
import torch
@torch.compile()
def f(x, y):
return x + y.sum()
f(torch.randn(3, device='cuda'), torch.randn(3))
Eager PyTorch accepts CUDA tensor + scalar CPU tensor but we don't seem to codegen this correctly.
Here is a minimal repro
import torch @torch.compile() def f(x, y): return x + y.sum() f(torch.randn(3, device='cuda'), torch.randn(3))
Eager PyTorch accepts CUDA tensor + scalar CPU tensor but we don't seem to codegen this correctly.
Here is a part of generated output code.
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 3
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tl.load(in_ptr1 + (0))
tmp2 = tl.broadcast_to(tmp1, [XBLOCK])
tmp3 = tmp0 + tmp2
tl.store(out_ptr0 + (x0), tmp3, xmask)
''', device_str='cuda')
def call(args):
arg0_1, arg1_1 = args
args.clear()
assert_size_stride(arg0_1, (3, ), (1, ))
assert_size_stride(arg1_1, (3, ), (1, ))
buf0 = empty_strided_cpu((), (), torch.float32)
cpp_fused_sum_0(arg0_1, buf0)
del arg0_1
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf1 = empty_strided_cuda((3, ), (1, ), torch.float32)
# Source Nodes: [add], Original ATen: [aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_1.run(arg1_1, buf0, buf1, 3, grid=grid(3), stream=stream0)
del arg1_1
del buf0
return (buf1, )
@ezyang Do you have any suggestions? I think the easiest way to fix it is adding a to
operation on buf0. But it can introduce a copy. A better way is to use buf0.item()
to obtain its scalar form and pass a variable to the triton kernel rather than in_ptr1. what do you think? Is the latter solution practical?
It seems to me that a much more straightforward way to fix this is to adjust the decompositions. Specifically, when we have a GPU + CPU scalar, insert a conversion for the CPU scalar to GPU scalar and then run Inductor as normal.
That's an extra kernel launch
I mean, if a user manually writes
import torch
@torch.compile()
def f(x, y):
return x + y.sum().cuda()
f(torch.randn(3, device='cuda'), torch.randn(3))
I would also expect Inductor to be able to eliminate this launch, if it was eliminating it at all...
You can't fuse H2D + arbitrary cuda kernel, to my knowledge
If it's a scalar cpu tensor you can promote it into an argument. But that's a general optimization that you would have to decide you want to do in general.
To get it to be a scalar cpu, you would need the other pr.
with torch.cuda.device(grid_2.device):
kernel.run(grid_0, grid_1, grid_2, kernel.num_warps, kernel.num_ctas, # number of warps/ctas per instance
kernel.cluster_dims[0], kernel.cluster_dims[1], kernel.cluster_dims[2], # cluster
kernel.shared, stream, kernel.function, CompiledKernel.launch_enter_hook,
CompiledKernel.launch_exit_hook, kernel,
*driver.assemble_tensormap_to_arg(kernel.metadata["tensormaps_info"], args))
Anyone can help me ? Because kernel only accept grid_2(int) File "/home/haoma/.conda/envs/vim/lib/python3.10/site-packages/triton/runtime/jit.py", line 435, in run kernel.run(grid_0, grid_1, torch.round(grid_2), kernel.num_warps, kernel.num_ctas, # number of warps/ctas per instance TypeError: only integer tensors of a single element can be converted to an index. However, when i tried to use int. It told me that File "/home/haoma/.conda/envs/vim/lib/python3.10/site-packages/triton/runtime/jit.py", line 435, in run kernel.run(grid_0, grid_1, int(torch.round(grid_2).item()), kernel.num_warps, kernel.num_ctas, # number of warps/ctas per instance ValueError: Pointer argument (at 2) cannot be accessed from Triton (cpu tensor?) So i have no idea how to deal with it. please give some ideas.
with torch.cuda.device(grid_2.device): kernel.run(grid_0, grid_1, grid_2, kernel.num_warps, kernel.num_ctas, # number of warps/ctas per instance kernel.cluster_dims[0], kernel.cluster_dims[1], kernel.cluster_dims[2], # cluster kernel.shared, stream, kernel.function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, kernel, *driver.assemble_tensormap_to_arg(kernel.metadata["tensormaps_info"], args))
Anyone can help me ? Because kernel only accept grid_2(int) File "/home/haoma/.conda/envs/vim/lib/python3.10/site-packages/triton/runtime/jit.py", line 435, in run kernel.run(grid_0, grid_1, torch.round(grid_2), kernel.num_warps, kernel.num_ctas, # number of warps/ctas per instance TypeError: only integer tensors of a single element can be converted to an index. However, when i tried to use int. It told me that File "/home/haoma/.conda/envs/vim/lib/python3.10/site-packages/triton/runtime/jit.py", line 435, in run kernel.run(grid_0, grid_1, int(torch.round(grid_2).item()), kernel.num_warps, kernel.num_ctas, # number of warps/ctas per instance ValueError: Pointer argument (at 2) cannot be accessed from Triton (cpu tensor?) So i have no idea how to deal with it. please give some ideas.
Can you provide a full reproduce code for your issue?
🐛 Describe the bug
Error logs
Minified repro
No response
Versions
n/a
cc @ezyang @gchanan @zou3519 @kadeng @anijain2305 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire @bdhirsh @peterbell10 @aakhundov