pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.74k stars 22.82k forks source link

Pointer argument (at 1) cannot be accessed from Triton (cpu tensor?) #121374

Closed msaroufim closed 3 months ago

msaroufim commented 9 months ago

🐛 Describe the bug

import torch

from diffusers import (
   StableDiffusionXLPipeline,
   AutoencoderKL
)

# Load VAE component
vae = AutoencoderKL.from_pretrained(
   "madebyollin/sdxl-vae-fp16-fix",
   torch_dtype=torch.float16
)

# Configure the pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
   "cagliostrolab/animagine-xl-3.0",
   vae=vae,
   torch_dtype=torch.float16,
   use_safetensors=True,
)
pipe.to('cuda')

# Define prompts and generate image
prompt = "1girl, arima kana, oshi no ko, solo, upper body, v, smile, looking at viewer, outdoors, night"
negative_prompt = "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name"

pipe = torch.compile(pipe)

image = pipe(
   prompt,
   negative_prompt=negative_prompt,
   width=832,
   height=1216,
   guidance_scale=7,
   num_inference_steps=28
).images[0]

Error logs


(hack) ubuntu@ip-172-31-36-68:~$ python anime.py 
/opt/conda/envs/hack/lib/python3.10/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
/opt/conda/envs/hack/lib/python3.10/site-packages/diffusers/utils/outputs.py:63: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.
  torch.utils._pytree._register_pytree_node(
Loading pipeline components...: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.51it/s]
  0%|                                                                                                                                                                            | 0/28 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/home/ubuntu/anime.py", line 30, in <module>
    image = pipe(
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py", line 1235, in __call__
    latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/diffusers/schedulers/scheduling_euler_discrete.py", line 276, in scale_model_input
    self._init_step_index(timestep)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/diffusers/schedulers/scheduling_euler_discrete.py", line 276, in resume_in_scale_model_input
    self._init_step_index(timestep)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 901, in forward
    return compiled_fn(full_args)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 81, in g
    return f(*args)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 94, in runtime_wrapper
    all_outs = call_func_at_runtime_with_args(
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 105, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 118, in rng_functionalization_wrapper
    return compiled_fw(args)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 864, in __call__
    return self.get_current_callable()(inputs)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 611, in run
    return model(new_inputs)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 892, in _run_from_cache
    return compiled_graph.compiled_artifact(inputs)
  File "/tmp/torchinductor_ubuntu/nq/cnq74ie4hdvlznv34mmne5b2z6szwcbswvs7evsobaj5uqhegaed.py", line 85, in call
    triton_poi_fused_add_div_pow_0.run(arg1_1, arg0_1, buf0, 126464, grid=grid(126464), stream=stream0)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 533, in run
    self.autotune_to_one_config(*args, grid=grid, **kwargs)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 437, in autotune_to_one_config
    timings = self.benchmark_all_configs(*args, **kwargs)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 413, in benchmark_all_configs
    timings = {
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 414, in <dictcomp>
    launcher: self.bench(launcher, *args, **kwargs)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 385, in bench
    return do_bench(kernel_call, rep=40, fast_flush=True)
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_inductor/utils.py", line 167, in do_bench
    return triton_do_bench(*args, **kwargs)[0]
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/triton/testing.py", line 102, in do_bench
    fn()
  File "/opt/conda/envs/hack/lib/python3.10/site-packages/torch/_inductor/triton_heuristics.py", line 378, in kernel_call
    launcher(
  File "<string>", line 8, in launcher
ValueError: Pointer argument (at 1) cannot be accessed from Triton (cpu tensor?)

Minified repro

No response

Versions

n/a

cc @ezyang @gchanan @zou3519 @kadeng @anijain2305 @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire @bdhirsh @peterbell10 @aakhundov

nighting0le01 commented 6 months ago

hi, any solutions for this?

ezyang commented 4 months ago

Here is a minimal repro

import torch

@torch.compile()
def f(x, y):
    return x + y.sum()

f(torch.randn(3, device='cuda'), torch.randn(3))

Eager PyTorch accepts CUDA tensor + scalar CPU tensor but we don't seem to codegen this correctly.

FindHao commented 4 months ago

Here is a minimal repro

import torch

@torch.compile()
def f(x, y):
    return x + y.sum()

f(torch.randn(3, device='cuda'), torch.randn(3))

Eager PyTorch accepts CUDA tensor + scalar CPU tensor but we don't seem to codegen this correctly.

Here is a part of generated output code.

def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 3
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (0))
    tmp2 = tl.broadcast_to(tmp1, [XBLOCK])
    tmp3 = tmp0 + tmp2
    tl.store(out_ptr0 + (x0), tmp3, xmask)
''', device_str='cuda')

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    buf0 = empty_strided_cpu((), (), torch.float32)
    cpp_fused_sum_0(arg0_1, buf0)
    del arg0_1
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf1 = empty_strided_cuda((3, ), (1, ), torch.float32)
        # Source Nodes: [add], Original ATen: [aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_1.run(arg1_1, buf0, buf1, 3, grid=grid(3), stream=stream0)
        del arg1_1
        del buf0
    return (buf1, )

@ezyang Do you have any suggestions? I think the easiest way to fix it is adding a to operation on buf0. But it can introduce a copy. A better way is to use buf0.item() to obtain its scalar form and pass a variable to the triton kernel rather than in_ptr1. what do you think? Is the latter solution practical?

ezyang commented 4 months ago

It seems to me that a much more straightforward way to fix this is to adjust the decompositions. Specifically, when we have a GPU + CPU scalar, insert a conversion for the CPU scalar to GPU scalar and then run Inductor as normal.

eellison commented 4 months ago

That's an extra kernel launch

ezyang commented 4 months ago

I mean, if a user manually writes

import torch

@torch.compile()
def f(x, y):
    return x + y.sum().cuda()

f(torch.randn(3, device='cuda'), torch.randn(3))

I would also expect Inductor to be able to eliminate this launch, if it was eliminating it at all...

eellison commented 4 months ago

You can't fuse H2D + arbitrary cuda kernel, to my knowledge

ezyang commented 4 months ago

If it's a scalar cpu tensor you can promote it into an argument. But that's a general optimization that you would have to decide you want to do in general.

eellison commented 4 months ago

To get it to be a scalar cpu, you would need the other pr.

mahao18cm commented 3 months ago
        with torch.cuda.device(grid_2.device):
            kernel.run(grid_0, grid_1, grid_2, kernel.num_warps, kernel.num_ctas,  # number of warps/ctas per instance
                    kernel.cluster_dims[0], kernel.cluster_dims[1], kernel.cluster_dims[2],  # cluster
                    kernel.shared, stream, kernel.function, CompiledKernel.launch_enter_hook,
                    CompiledKernel.launch_exit_hook, kernel,
                    *driver.assemble_tensormap_to_arg(kernel.metadata["tensormaps_info"], args))

Anyone can help me ? Because kernel only accept grid_2(int) File "/home/haoma/.conda/envs/vim/lib/python3.10/site-packages/triton/runtime/jit.py", line 435, in run kernel.run(grid_0, grid_1, torch.round(grid_2), kernel.num_warps, kernel.num_ctas, # number of warps/ctas per instance TypeError: only integer tensors of a single element can be converted to an index. However, when i tried to use int. It told me that File "/home/haoma/.conda/envs/vim/lib/python3.10/site-packages/triton/runtime/jit.py", line 435, in run kernel.run(grid_0, grid_1, int(torch.round(grid_2).item()), kernel.num_warps, kernel.num_ctas, # number of warps/ctas per instance ValueError: Pointer argument (at 2) cannot be accessed from Triton (cpu tensor?) So i have no idea how to deal with it. please give some ideas.

FindHao commented 3 months ago
        with torch.cuda.device(grid_2.device):
            kernel.run(grid_0, grid_1, grid_2, kernel.num_warps, kernel.num_ctas,  # number of warps/ctas per instance
                    kernel.cluster_dims[0], kernel.cluster_dims[1], kernel.cluster_dims[2],  # cluster
                    kernel.shared, stream, kernel.function, CompiledKernel.launch_enter_hook,
                    CompiledKernel.launch_exit_hook, kernel,
                    *driver.assemble_tensormap_to_arg(kernel.metadata["tensormaps_info"], args))

Anyone can help me ? Because kernel only accept grid_2(int) File "/home/haoma/.conda/envs/vim/lib/python3.10/site-packages/triton/runtime/jit.py", line 435, in run kernel.run(grid_0, grid_1, torch.round(grid_2), kernel.num_warps, kernel.num_ctas, # number of warps/ctas per instance TypeError: only integer tensors of a single element can be converted to an index. However, when i tried to use int. It told me that File "/home/haoma/.conda/envs/vim/lib/python3.10/site-packages/triton/runtime/jit.py", line 435, in run kernel.run(grid_0, grid_1, int(torch.round(grid_2).item()), kernel.num_warps, kernel.num_ctas, # number of warps/ctas per instance ValueError: Pointer argument (at 2) cannot be accessed from Triton (cpu tensor?) So i have no idea how to deal with it. please give some ideas.

Can you provide a full reproduce code for your issue?