halide / Halide

a language for fast, portable data-parallel computation
https://halide-lang.org
Other
5.92k stars 1.07k forks source link

CUDA error: CUDA_ERROR_ILLEGAL_ADDRESS cuLaunchKernel failed #8318

Open jansel opened 5 months ago

jansel commented 5 months ago

This 1-element (scalar) kernel works on CPU, but gives a Error: CUDA error: CUDA_ERROR_ILLEGAL_ADDRESS cuLaunchKernel failed on CUDA using both Li2018 and Anderson2021 autoschedulers.

import halide as hl
from math import inf, nan

@hl.generator(name="kernel")
class Kernel:
    in_ptr0 = hl.InputBuffer(hl.Float(32), 1)
    out_ptr0 = hl.OutputBuffer(hl.Float(32), 1)

    def generate(g):
        in_ptr0 = g.in_ptr0
        out_ptr0 = g.out_ptr0
        tmp0 = in_ptr0[0,]
        tmp1 = hl.cast(hl.Float(32), hl.f64(0.12300000339746475))
        tmp2 = tmp1 + tmp0
        tmp3 = hl.sqrt(tmp2)
        out_ptr0[hl.Var(),] = hl.cast(hl.Float(32), tmp3)

        assert g.using_autoscheduler()
        in_ptr0.dim(0).set_min(0)
        in_ptr0.dim(0).set_stride(1)
        in_ptr0.dim(0).set_extent(1)
        in_ptr0.set_estimates([hl.Range(0, 1)])
        out_ptr0.set_estimates([hl.Range(0, 1)])

if __name__ == "__main__":
    hl.main()
mcourteaux commented 5 months ago

To help you figure out what could be going wrong, there are three options:

jansel commented 5 months ago

Run with both debug target and HL_DEBUG_CODEGEN=1:

$ HL_DEBUG_CODEGEN=1 CUDA_LAUNCH_BLOCKING=1 python test/inductor/test_halide.py -k test_pow3_cuda 
Failed to load binary:python
JIT compiling shared runtime for x86-64-linux-avx-avx2-avx512-avx512_cannonlake-avx512_skylake-cuda-f16c-fma-jit-sse41
JIT compiling cuda for x86-64-linux-avx-avx2-avx512-avx512_cannonlake-avx512_skylake-cuda-f16c-fma-jit-sse41
Target triple of initial module: x86_64--linux-gnu
Generating llvm bitcode...
Module.compile(): object /tmp/tmp41a7dnia/halide-runtime-host-cuda-ekwqd6zia46hyjdrjgwbi43mwgyu3lbqsd6vlfc7epqa6gzx3ci/standalone_halide_runtime.a
emit_file.Compiling to native code...
Target machine is Position Independent!
dir_rmdir: /tmp/FascNt
Registering autoscheduler 'Li2018'...
Generator kernel has base_path /tmp/tmp41a7dnia/ni/cni4tskyquk2f2uv7hiz36a44k7g6vyehe3k2p55tasiewlrako7.halide/halide_kernel
compile_multitarget: single target is x86-64-linux-avx-avx2-avx512-avx512_cannonlake-avx512_skylake-cuda-cuda_capability_86-debug-f16c-fma-no_asserts-no_runtime-sse41-strict_float-user_context
Applying autoscheduler Li2018 to Generator kernel ...
[gradient_autoscheduler] Processing function:out_ptr0
[gradient_autoscheduler] Processing function:in_ptr0_im
out_ptr0.compute_root()
    .split(v0,v0,v1,32,GuardWithIf)
    .reorder(v1,v0)
    .gpu_blocks(v0)
    .gpu_threads(v1)
;
in_ptr0_im.compute_root()
    .split(_0,_0,v2,32,GuardWithIf)
    .reorder(v2,_0)
    .gpu_blocks(_0)
    .gpu_threads(v2)
;

Creating initial loop nests...
Injecting realization of { out_ptr0 }
Injecting realization of { in_ptr0_im }
Skipping injecting memoization...
Injecting tracing...
Adding checks for parameters
Computing bounds of each function's value
Clamping unsafe data-dependent accesses
Performing computation bounds inference...
Asserting that all split factors are positive...
Removing extern loops...
Performing sliding window optimization...
Uniquifying variable names...
Simplifying...
Simplifying correlated differences...
Performing allocation bounds inference...
Adding checks for images
Removing code that depends on undef values...
Performing storage folding optimization...
Injecting debug_to_file calls...
Injecting prefetches...
Discarding safe promises...
Dynamically skipping stages...
Forking asynchronous producers...
Destructuring tuple-valued realizations...
Canonicalizing GPU var names...
Bounding small realizations...
Performing storage flattening...
Adding atomic mutex allocation...
Unpacking buffer arguments...
Skipping rewriting memoized allocations...
Selecting a GPU API for GPU loops...
Injecting host <-> dev buffer copies...
Selecting a GPU API for extern stages...
Simplifying...
Reduce prefetch dimension...
Simplifying correlated differences...
Bounding constant extent loops...
Unrolling...
Vectorizing...
Injecting per-block gpu synchronization...
Detecting vector interleavings...
Partitioning loops to simplify boundary conditions...
Staging strided loads...
Trimming loops to the region over which they do something...
Rebasing loops to zero...
Hoisting loop invariant if statements...
Injecting early frees...
Simplifying correlated differences...
Bounding small allocations...
Injecting warp shuffles...
Simplifying...
Lowering unsafe promises...
Flattening nested ramps...
Removing dead allocations and moving loop invariant code...
Finding intrinsics...
Hoisting prefetches...
Stripping asserts...
Lowering after final simplification:
let in_ptr0 = (void *)_halide_buffer_get_host((struct halide_buffer_t *)in_ptr0.buffer)
let out_ptr0 = (void *)_halide_buffer_get_host((struct halide_buffer_t *)out_ptr0.buffer)
let out_ptr0.min.0 = _halide_buffer_get_min((struct halide_buffer_t *)out_ptr0.buffer, 0)
let out_ptr0.extent.0 = _halide_buffer_get_extent((struct halide_buffer_t *)out_ptr0.buffer, 0)
let out_ptr0.extent.0.required = min(min(out_ptr0.extent.0, 32) + (((out_ptr0.extent.0 + -1)/32)*32), out_ptr0.extent.0)
if ((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)in_ptr0.buffer)) {
 (struct halide_buffer_t *)_halide_buffer_init((struct halide_buffer_t *)in_ptr0.buffer, (struct halide_dimension_t *)_halide_buffer_get_shape((struct halide_buffer_t *)in_ptr0.buffer), reinterpret<(void *)>((uint64)0), (uint64)0, reinterpret<(struct halide_device_interface_t *)>((uint64)0), 2, 32, 1, (struct halide_dimension_t *)make_struct(0, 1, 1, 0), (uint64)0)
}
if ((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)out_ptr0.buffer)) {
 (struct halide_buffer_t *)_halide_buffer_init((struct halide_buffer_t *)out_ptr0.buffer, (struct halide_dimension_t *)_halide_buffer_get_shape((struct halide_buffer_t *)out_ptr0.buffer), reinterpret<(void *)>((uint64)0), (uint64)0, reinterpret<(struct halide_device_interface_t *)>((uint64)0), 2, 32, 1, (struct halide_dimension_t *)make_struct(out_ptr0.min.0, out_ptr0.extent.0.required, 1, 0), (uint64)0)
}
if (!((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)in_ptr0.buffer) || (uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)out_ptr0.buffer))) {
 allocate in_ptr0_im[float32 * 1] if (uint1)0
 let in_ptr0_im.buffer = let t50 = (struct halide_dimension_t *)make_struct(0, 1, 1, 0) in (struct halide_buffer_t *)_halide_buffer_init((struct halide_buffer_t *)alloca(size_of_halide_buffer_t()), t50, reinterpret<(void *)>((uint64)0), (uint64)0, reinterpret<(struct halide_device_interface_t *)>((uint64)0), 2, 32, 1, t50, (uint64)0)
 (void *)register_destructor("halide_device_free_as_destructor", in_ptr0_im.buffer)
 produce in_ptr0_im {
  halide_device_malloc(in_ptr0_im.buffer, (struct halide_device_interface_t const *)halide_cuda_device_interface())
  halide_copy_to_device((struct halide_buffer_t *)in_ptr0.buffer, (struct halide_device_interface_t const *)halide_cuda_device_interface())
  gpu_block<CUDA> (in_ptr0_im.s0._0._0.block_id_x, 0, 1) {
   gpu_thread<CUDA> (.thread_id_x, 0, 32) {
    if (.thread_id_x < 1) {
     in_ptr0_im[0] = (float32)strict_float(in_ptr0[0])
    }
   }
  }
  _halide_buffer_set_device_dirty(in_ptr0_im.buffer, (uint1)1)
 }
 produce out_ptr0 {
  consume in_ptr0_im {
   halide_device_malloc(in_ptr0_im.buffer, (struct halide_device_interface_t const *)halide_cuda_device_interface())
   halide_copy_to_device((struct halide_buffer_t *)out_ptr0.buffer, (struct halide_device_interface_t const *)halide_cuda_device_interface())
   let t47 = (out_ptr0.extent.0 + 31)/32
   let t48 = out_ptr0.extent.0/32
   let t49 = out_ptr0.extent.0 + out_ptr0.min.0
   gpu_block<CUDA> (out_ptr0.s0.v0.v0.block_id_x, 0, t47) {
    gpu_thread<CUDA> (.thread_id_x, 0, 32) {
     if (out_ptr0.s0.v0.v0.block_id_x < t48) {
      out_ptr0[(out_ptr0.s0.v0.v0.block_id_x*32) + .thread_id_x] = (float32)strict_float((float32)sqrt_f32((float32)strict_float((float32)strict_float(0.123000f) + (float32)strict_float(in_ptr0_im[0]))))
     } else if (((((out_ptr0.s0.v0.v0.block_id_x*32) + out_ptr0.min.0) + .thread_id_x) + 1) <= t49) {
      out_ptr0[(out_ptr0.s0.v0.v0.block_id_x*32) + .thread_id_x] = (float32)strict_float((float32)sqrt_f32((float32)strict_float((float32)strict_float(0.123000f) + (float32)strict_float(in_ptr0_im[0]))))
     }
    }
   }
   _halide_buffer_set_device_dirty((struct halide_buffer_t *)out_ptr0.buffer, (uint1)1)
   halide_device_free(in_ptr0_im.buffer)
   free in_ptr0_im
  }
 }
}

Skipping Hexagon offload...
Offloading GPU loops...
Generating llvm bitcode for kernel...
Generating llvm bitcode for kernel...
PTX kernel:
//
// Generated by LLVM NVPTX Back-End
//

.version 7.1
.target sm_86
.address_size 64

    // .globl   _kernel_in_ptr0_im_s0__0__0_block_id_x // -- Begin function _kernel_in_ptr0_im_s0__0__0_block_id_x
                                        // @_kernel_in_ptr0_im_s0__0__0_block_id_x
.visible .entry _kernel_in_ptr0_im_s0__0__0_block_id_x(
    .param .u64 _kernel_in_ptr0_im_s0__0__0_block_id_x_param_0,
    .param .u64 _kernel_in_ptr0_im_s0__0__0_block_id_x_param_1
)
{
    .reg .pred  %p<2>;
    .reg .b32   %r<2>;
    .reg .f32   %f<2>;
    .reg .b64   %rd<5>;

// %bb.0:                               // %entry
    mov.u32     %r1, %tid.x;
    setp.lt.s32     %p1, %r1, 1;
    @%p1 bra    $L__BB0_2;
    bra.uni     $L__BB0_1;
$L__BB0_2:                              // %then_bb
    ld.param.u64    %rd3, [_kernel_in_ptr0_im_s0__0__0_block_id_x_param_0];
    ld.param.u64    %rd4, [_kernel_in_ptr0_im_s0__0__0_block_id_x_param_1];
    cvta.to.global.u64  %rd1, %rd4;
    cvta.to.global.u64  %rd2, %rd3;
    ld.global.nc.f32    %f1, [%rd2];
    st.global.f32   [%rd1], %f1;
$L__BB0_1:                              // %after_bb
    ret;
                                        // -- End function
}
    // .globl   _kernel_out_ptr0_s0_v0_v0_block_id_x // -- Begin function _kernel_out_ptr0_s0_v0_v0_block_id_x
.visible .entry _kernel_out_ptr0_s0_v0_v0_block_id_x(
    .param .u64 _kernel_out_ptr0_s0_v0_v0_block_id_x_param_0,
    .param .u64 _kernel_out_ptr0_s0_v0_v0_block_id_x_param_1,
    .param .u32 _kernel_out_ptr0_s0_v0_v0_block_id_x_param_2,
    .param .u32 _kernel_out_ptr0_s0_v0_v0_block_id_x_param_3,
    .param .u32 _kernel_out_ptr0_s0_v0_v0_block_id_x_param_4
)                                       // @_kernel_out_ptr0_s0_v0_v0_block_id_x
{
    .reg .pred  %p<3>;
    .reg .b32   %r<11>;
    .reg .f32   %f<7>;
    .reg .b64   %rd<9>;

// %bb.0:                               // %entry
    ld.param.u64    %rd4, [_kernel_out_ptr0_s0_v0_v0_block_id_x_param_0];
    ld.param.u64    %rd5, [_kernel_out_ptr0_s0_v0_v0_block_id_x_param_1];
    cvta.to.global.u64  %rd1, %rd5;
    cvta.to.global.u64  %rd2, %rd4;
    mov.u32     %r1, %ctaid.x;
    ld.param.u32    %r5, [_kernel_out_ptr0_s0_v0_v0_block_id_x_param_3];
    mov.u32     %r2, %tid.x;
    setp.lt.s32     %p1, %r1, %r5;
    @%p1 bra    $L__BB1_1;
    bra.uni     $L__BB1_2;
$L__BB1_1:                              // %then_bb
    ld.global.nc.f32    %f4, [%rd2];
    add.f32     %f5, %f4, 0f3DFBE76D;
    sqrt.rn.f32     %f6, %f5;
    shl.b32     %r9, %r1, 5;
    add.s32     %r10, %r9, %r2;
    mul.wide.s32    %rd7, %r10, 4;
    add.s64     %rd8, %rd1, %rd7;
    st.global.f32   [%rd8], %f6;
    bra.uni     $L__BB1_4;
$L__BB1_2:                              // %next_bb
    ld.param.u32    %r4, [_kernel_out_ptr0_s0_v0_v0_block_id_x_param_4];
    ld.param.u32    %r3, [_kernel_out_ptr0_s0_v0_v0_block_id_x_param_2];
    shl.b32     %r6, %r1, 5;
    add.s32     %r7, %r6, %r2;
    add.s32     %r8, %r7, %r3;
    setp.ge.s32     %p2, %r8, %r4;
    @%p2 bra    $L__BB1_4;
// %bb.3:                               // %then_bb1
    mul.wide.s32    %rd6, %r7, 4;
    add.s64     %rd3, %rd1, %rd6;
    ld.global.nc.f32    %f1, [%rd2];
    add.f32     %f2, %f1, 0f3DFBE76D;
    sqrt.rn.f32     %f3, %f2;
    st.global.f32   [%rd3], %f3;
$L__BB1_4:                              // %after_bb
    ret;
                                        // -- End function
}

Lowering Parallel Tasks...
Embedding image cuda_buf
Embedding image cuda_gpu_source_kernels
Target triple of initial module: x86_64--linux-gnu
Generating llvm bitcode...
Generating llvm bitcode prolog for function halide_kernel...
Generating llvm bitcode for function halide_kernel...
add_temp_object_file: /tmp/8o9cyq/halide_kernel.a.o
Module.compile(): temporary object /tmp/8o9cyq/halide_kernel.a.o
emit_file.Compiling to native code...
Target machine is Position Independent!
Module.compile(): static_library /tmp/tmp41a7dnia/ni/cni4tskyquk2f2uv7hiz36a44k7g6vyehe3k2p55tasiewlrako7.halide/halide_kernel.a
file_unlink: /tmp/8o9cyq/halide_kernel.a.o
dir_rmdir: /tmp/8o9cyq
Module.compile(): c_header /tmp/tmp41a7dnia/ni/cni4tskyquk2f2uv7hiz36a44k7g6vyehe3k2p55tasiewlrako7.halide/halide_kernel.h
Module.compile(): schedule /tmp/tmp41a7dnia/ni/cni4tskyquk2f2uv7hiz36a44k7g6vyehe3k2p55tasiewlrako7.halide/halide_kernel.schedule.h
dir_rmdir: /tmp/4tV52B
Entering Pipeline halide_kernel
Target: x86-64-linux-avx-avx2-avx512-avx512_cannonlake-avx512_skylake-cuda-cuda_capability_86-debug-f16c-fma-no_asserts-no_runtime-sse41-strict_float-user_context
 Input (void const *) __user_context: 0x0
 Input Buffer in_ptr0: 0x7ffd85510950 -> buffer(145176576, 0x72ef3f4af658, 0x0, 2, float32, {0, 1, 1})
 Output Buffer out_ptr0: 0x7ffd85510990 -> buffer(126369784660992, 0x72ef3f4af658, 0x0, 2, float32, {0, 1, 1})
Error: CUDA error: CUDA_ERROR_ILLEGAL_ADDRESS cuLaunchKernel failed
zsh: IOT instruction (core dumped)  HL_DEBUG_CODEGEN=1 CUDA_LAUNCH_BLOCKING=1 python test/inductor/test_halide.py
jansel commented 5 months ago

conceptual.stmt:

let in_ptr0 = (void *)_halide_buffer_get_host((struct halide_buffer_t *)in_ptr0.buffer)
let out_ptr0 = (void *)_halide_buffer_get_host((struct halide_buffer_t *)out_ptr0.buffer)
let out_ptr0.min.0 = _halide_buffer_get_min((struct halide_buffer_t *)out_ptr0.buffer, 0)
let out_ptr0.extent.0 = _halide_buffer_get_extent((struct halide_buffer_t *)out_ptr0.buffer, 0)
let out_ptr0.extent.0.required = min(min(out_ptr0.extent.0, 32) + (((out_ptr0.extent.0 + -1)/32)*32), out_ptr0.extent.0)
if ((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)in_ptr0.buffer)) {
 (struct halide_buffer_t *)_halide_buffer_init((struct halide_buffer_t *)in_ptr0.buffer, (struct halide_dimension_t *)_halide_buffer_get_shape((struct halide_buffer_t *)in_ptr0.buffer), reinterpret<(void *)>((uint64)0), (uint64)0, reinterpret<(struct halide_device_interface_t *)>((uint64)0), 2, 32, 1, (struct halide_dimension_t *)make_struct(0, 1, 1, 0), (uint64)0)
}
if ((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)out_ptr0.buffer)) {
 (struct halide_buffer_t *)_halide_buffer_init((struct halide_buffer_t *)out_ptr0.buffer, (struct halide_dimension_t *)_halide_buffer_get_shape((struct halide_buffer_t *)out_ptr0.buffer), reinterpret<(void *)>((uint64)0), (uint64)0, reinterpret<(struct halide_device_interface_t *)>((uint64)0), 2, 32, 1, (struct halide_dimension_t *)make_struct(out_ptr0.min.0, out_ptr0.extent.0.required, 1, 0), (uint64)0)
}
if (!((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)in_ptr0.buffer) || (uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)out_ptr0.buffer))) {
 allocate in_ptr0_im[float32 * 1] if (uint1)0
 let in_ptr0_im.buffer = let t50 = (struct halide_dimension_t *)make_struct(0, 1, 1, 0) in (struct halide_buffer_t *)_halide_buffer_init((struct halide_buffer_t *)alloca(size_of_halide_buffer_t()), t50, reinterpret<(void *)>((uint64)0), (uint64)0, reinterpret<(struct halide_device_interface_t *)>((uint64)0), 2, 32, 1, t50, (uint64)0)
 (void *)register_destructor("halide_device_free_as_destructor", in_ptr0_im.buffer)
 produce in_ptr0_im {
  halide_device_malloc(in_ptr0_im.buffer, (struct halide_device_interface_t const *)halide_cuda_device_interface())
  halide_copy_to_device((struct halide_buffer_t *)in_ptr0.buffer, (struct halide_device_interface_t const *)halide_cuda_device_interface())
  gpu_block<CUDA> (in_ptr0_im.s0._0._0.block_id_x, 0, 1) {
   gpu_thread<CUDA> (.thread_id_x, 0, 32) {
    if (.thread_id_x < 1) {
     in_ptr0_im[0] = (float32)strict_float(in_ptr0[0])
    }
   }
  }
  _halide_buffer_set_device_dirty(in_ptr0_im.buffer, (uint1)1)
 }
 produce out_ptr0 {
  consume in_ptr0_im {
   halide_device_malloc(in_ptr0_im.buffer, (struct halide_device_interface_t const *)halide_cuda_device_interface())
   halide_copy_to_device((struct halide_buffer_t *)out_ptr0.buffer, (struct halide_device_interface_t const *)halide_cuda_device_interface())
   let t47 = (out_ptr0.extent.0 + 31)/32
   let t48 = out_ptr0.extent.0/32
   let t49 = out_ptr0.extent.0 + out_ptr0.min.0
   gpu_block<CUDA> (out_ptr0.s0.v0.v0.block_id_x, 0, t47) {
    gpu_thread<CUDA> (.thread_id_x, 0, 32) {
     if (out_ptr0.s0.v0.v0.block_id_x < t48) {
      out_ptr0[(out_ptr0.s0.v0.v0.block_id_x*32) + .thread_id_x] = (float32)strict_float((float32)sqrt_f32((float32)strict_float((float32)strict_float(0.123000f) + (float32)strict_float(in_ptr0_im[0]))))
     } else if (((((out_ptr0.s0.v0.v0.block_id_x*32) + out_ptr0.min.0) + .thread_id_x) + 1) <= t49) {
      out_ptr0[(out_ptr0.s0.v0.v0.block_id_x*32) + .thread_id_x] = (float32)strict_float((float32)sqrt_f32((float32)strict_float((float32)strict_float(0.123000f) + (float32)strict_float(in_ptr0_im[0]))))
     }
    }
   }
   _halide_buffer_set_device_dirty((struct halide_buffer_t *)out_ptr0.buffer, (uint1)1)
   halide_device_free(in_ptr0_im.buffer)
   free in_ptr0_im
  }
 }
}
abadams commented 5 months ago

I don't see anything obviously wrong. Is it weird that the input device pointer (0x8a73800) is so much smaller than the output device pointer (0x72eec2200400)?

btw zero-dimensional Funcs/Buffers are a thing in Halide. It would generate simpler code in these cases.

mcourteaux commented 5 months ago

@abadams I'm mostly concerned with the whole "produce input" block. The generator shouldn't produce the input, but only produce the output, and consume the input? It sort of looks like it's making a global-wrapper for the input in_ptr and naming it in_ptr_im, which is needlessly copying it.

It is funny to see how the autoscheduler tiled with 32 thread size, given there was only one. Bad scheduling, but not wrong.

abadams commented 5 months ago

I think that's just a bad schedule. Input buffers come with wrapper Funcs (ending in "_im") that are normally just inlined. It's to support ImageParam::in(). Looks like this one has been compute_root'd.

jansel commented 5 months ago

I can try the 0D version, what is the syntax for that in Python? (a[] = ... is not valid Python syntax)

I think there might be something wrong with the handling of 1-element tensors on CUDA. I'm seeing a similar error for every test using hl.random_float(seed_ptr[0]), maybe because the tests are loading a 1-element seed. I picked this test to open an issue because it looked like the simplest example of this error I could find.

abadams commented 5 months ago

I think the syntax is a[()] (i.e. index it with an empty tuple of Vars)

We have at a number of tests that output a scalar from cuda, so I don't think it's just straight-up broken. Halide's bounds inference logic is very well tested, and this is a very simple case, so it's not going to be that. There could be a bug in the runtime's handling of device allocations, but I don't think you're using that - you're just wrapping existing device pointers. 99% of the time when there's an illegal address exception it's because the input or output buffer is malformed. That's why I'm suspicious of your input device pointer. All the cuMalloc results I see are 12 hex digits and start with 7.

mcourteaux commented 5 months ago

Is it weird that the input device pointer (0x8a73800) is so much smaller than the output device pointer (0x72eec2200400)?

Where do you see that? I read this in his output:

Target: x86-64-linux-avx-avx2-avx512-avx512_cannonlake-avx512_skylake-cuda-cuda_capability_86-debug-f16c-fma-no_asserts-no_runtime-sse41-strict_float-user_context
 Input (void const *) __user_context: 0x0
 Input Buffer in_ptr0: 0x7ffd85510950 -> buffer(145176576, 0x72ef3f4af658, 0x0, 2, float32, {0, 1, 1})
 Output Buffer out_ptr0: 0x7ffd85510990 -> buffer(126369784660992, 0x72ef3f4af658, 0x0, 2, float32, {0, 1, 1})
Error: CUDA error: CUDA_ERROR_ILLEGAL_ADDRESS cuLaunchKernel failed
mcourteaux commented 5 months ago

Ah, what that first number that is printed decimally instead of hexadecimally. Why isn't this printed clearer? Any reason or is this open for change?

abadams commented 5 months ago

The device pointer comes first. 145176576 is 0x8a73800

It's an opaque 64-bit handle represented as a uint64 so we print it as a uint64. We could print it as if it's a void *, but that's not trivial because it involves printing a 64-bit "pointer" on 32-bit platforms. The code is here: https://github.com/halide/Halide/blob/main/src/runtime/to_string.cpp#L305

Maybe it should be refactored to have a halide_uint64_to_hex_string, and the pointer-printing method can defer to that.

jansel commented 5 months ago

I think the 0x8a73800 pointer is a red herring coming from how the PyTorch CUDA caching allocator works. There is a fast path for scalar constants where PyTorch will stash them in fixed memory region rather than allocating a new buffer for them.

I manually allocated a buffer and now the pointers look normal. It still fails, however the error message changed:

$ TORCHINDUCTOR_COMPILE_THREADS=1 HL_DEBUG_CODEGEN=1 CUDA_LAUNCH_BLOCKING=1 python test/inductor/test_halide.py -k test_pow3_cuda
Failed to load binary:python
JIT compiling shared runtime for x86-64-linux-avx-avx2-avx512-avx512_cannonlake-avx512_skylake-cuda-f16c-fma-jit-sse41
JIT compiling cuda for x86-64-linux-avx-avx2-avx512-avx512_cannonlake-avx512_skylake-cuda-f16c-fma-jit-sse41
Target triple of initial module: x86_64--linux-gnu
Generating llvm bitcode...
Module.compile(): object /tmp/tmpbr7y2zw_/halide-runtime-host-cuda-ekwqd6zia46hyjdrjgwbi43mwgyu3lbqsd6vlfc7epqa6gzx3ci/standalone_halide_runtime.a
emit_file.Compiling to native code...
Target machine is Position Independent!
dir_rmdir: /tmp/GIEBDl
Registering autoscheduler 'Li2018'...
Generator kernel has base_path /tmp/tmpbr7y2zw_/t4/ct4ns2no3iqzuxcxnbma37xodurfcgzqnwjt6lrep5m7pyih7ycw.halide/halide_kernel
compile_multitarget: single target is x86-64-linux-avx-avx2-avx512-avx512_cannonlake-avx512_skylake-cuda-cuda_capability_86-debug-f16c-fma-no_asserts-no_runtime-sse41-strict_float-user_context
Applying autoscheduler Li2018 to Generator kernel ...
[gradient_autoscheduler] Processing function:out_ptr0
[gradient_autoscheduler] Processing function:tmp3
out_ptr0.compute_root()
    .split(v0,v0,v1,32,GuardWithIf)
    .reorder(v1,v0)
    .gpu_blocks(v0)
    .gpu_threads(v1)
;
tmp3.compute_root()
;

Creating initial loop nests...
Injecting realization of { out_ptr0 }
Injecting realization of { tmp3 }
Skipping injecting memoization...
Injecting tracing...
Adding checks for parameters
Computing bounds of each function's value
Clamping unsafe data-dependent accesses
Performing computation bounds inference...
Asserting that all split factors are positive...
Removing extern loops...
Performing sliding window optimization...
Uniquifying variable names...
Simplifying...
Simplifying correlated differences...
Performing allocation bounds inference...
Adding checks for images
Removing code that depends on undef values...
Performing storage folding optimization...
Injecting debug_to_file calls...
Injecting prefetches...
Discarding safe promises...
Dynamically skipping stages...
Forking asynchronous producers...
Destructuring tuple-valued realizations...
Canonicalizing GPU var names...
Bounding small realizations...
Performing storage flattening...
Adding atomic mutex allocation...
Unpacking buffer arguments...
Skipping rewriting memoized allocations...
Selecting a GPU API for GPU loops...
Injecting host <-> dev buffer copies...
Selecting a GPU API for extern stages...
Simplifying...
Reduce prefetch dimension...
Simplifying correlated differences...
Bounding constant extent loops...
Unrolling...
Vectorizing...
Injecting per-block gpu synchronization...
Detecting vector interleavings...
Partitioning loops to simplify boundary conditions...
Staging strided loads...
Trimming loops to the region over which they do something...
Rebasing loops to zero...
Hoisting loop invariant if statements...
Injecting early frees...
Simplifying correlated differences...
Bounding small allocations...
Injecting warp shuffles...
Simplifying...
Lowering unsafe promises...
Flattening nested ramps...
Removing dead allocations and moving loop invariant code...
Finding intrinsics...
Hoisting prefetches...
Stripping asserts...
Lowering after final simplification:
let in_ptr0 = (void *)_halide_buffer_get_host((struct halide_buffer_t *)in_ptr0.buffer)
let out_ptr0 = (void *)_halide_buffer_get_host((struct halide_buffer_t *)out_ptr0.buffer)
let out_ptr0.min.0 = _halide_buffer_get_min((struct halide_buffer_t *)out_ptr0.buffer, 0)
let out_ptr0.extent.0 = _halide_buffer_get_extent((struct halide_buffer_t *)out_ptr0.buffer, 0)
let out_ptr0.extent.0.required = min(min(out_ptr0.extent.0, 32) + (((out_ptr0.extent.0 + -1)/32)*32), out_ptr0.extent.0)
if ((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)in_ptr0.buffer)) {
 (struct halide_buffer_t *)_halide_buffer_init((struct halide_buffer_t *)in_ptr0.buffer, (struct halide_dimension_t *)_halide_buffer_get_shape((struct halide_buffer_t *)in_ptr0.buffer), reinterpret<(void *)>((uint64)0), (uint64)0, reinterpret<(struct halide_device_interface_t *)>((uint64)0), 2, 32, 1, (struct halide_dimension_t *)make_struct(0, 1, 1, 0), (uint64)0)
}
if ((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)out_ptr0.buffer)) {
 (struct halide_buffer_t *)_halide_buffer_init((struct halide_buffer_t *)out_ptr0.buffer, (struct halide_dimension_t *)_halide_buffer_get_shape((struct halide_buffer_t *)out_ptr0.buffer), reinterpret<(void *)>((uint64)0), (uint64)0, reinterpret<(struct halide_device_interface_t *)>((uint64)0), 2, 32, 1, (struct halide_dimension_t *)make_struct(out_ptr0.min.0, out_ptr0.extent.0.required, 1, 0), (uint64)0)
}
if (!((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)in_ptr0.buffer) || (uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)out_ptr0.buffer))) {
 let tmp3.buffer = let t43 = reinterpret<(struct halide_dimension_t *)>((uint64)0) in (struct halide_buffer_t *)_halide_buffer_init((struct halide_buffer_t *)alloca(size_of_halide_buffer_t()), t43, reinterpret<(void *)>((uint64)0), (uint64)0, reinterpret<(struct halide_device_interface_t *)>((uint64)0), 2, 32, 0, t43, (uint64)0)
 halide_device_and_host_malloc(tmp3.buffer, (struct halide_device_interface_t const *)halide_cuda_device_interface())
 (void *)register_destructor("halide_device_and_host_free_as_destructor", tmp3.buffer)
 allocate tmp3[float32] in Heap
  custom_new { (void *)_halide_buffer_get_host(tmp3.buffer) }
  custom_delete { halide_device_host_nop_free(tmp3); }
 produce tmp3 {
  halide_copy_to_host((struct halide_buffer_t *)in_ptr0.buffer)
  tmp3[0] = (float32)strict_float((float32)sqrt_f32((float32)strict_float((float32)strict_float(0.123000f) + (float32)strict_float(in_ptr0[0]))))
  _halide_buffer_set_host_dirty(tmp3.buffer, (uint1)1)
 }
 produce out_ptr0 {
  consume tmp3 {
   halide_copy_to_device(tmp3.buffer, (struct halide_device_interface_t const *)halide_cuda_device_interface())
   halide_copy_to_device((struct halide_buffer_t *)out_ptr0.buffer, (struct halide_device_interface_t const *)halide_cuda_device_interface())
   let t40 = (out_ptr0.extent.0 + 31)/32
   let t41 = out_ptr0.extent.0/32
   let t42 = out_ptr0.extent.0 + out_ptr0.min.0
   gpu_block<CUDA> (out_ptr0.s0.v0.v0.block_id_x, 0, t40) {
    gpu_thread<CUDA> (.thread_id_x, 0, 32) {
     if (out_ptr0.s0.v0.v0.block_id_x < t41) {
      out_ptr0[(out_ptr0.s0.v0.v0.block_id_x*32) + .thread_id_x] = (float32)strict_float(tmp3[0])
     } else if (((((out_ptr0.s0.v0.v0.block_id_x*32) + out_ptr0.min.0) + .thread_id_x) + 1) <= t42) {
      out_ptr0[(out_ptr0.s0.v0.v0.block_id_x*32) + .thread_id_x] = (float32)strict_float(tmp3[0])
     }
    }
   }
   _halide_buffer_set_device_dirty((struct halide_buffer_t *)out_ptr0.buffer, (uint1)1)
   halide_device_and_host_free(tmp3.buffer)
   free tmp3
  }
 }
}

Skipping Hexagon offload...
Offloading GPU loops...
Generating llvm bitcode for kernel...
PTX kernel:
//
// Generated by LLVM NVPTX Back-End
//

.version 7.1
.target sm_86
.address_size 64

    // .globl   _kernel_out_ptr0_s0_v0_v0_block_id_x // -- Begin function _kernel_out_ptr0_s0_v0_v0_block_id_x
                                        // @_kernel_out_ptr0_s0_v0_v0_block_id_x
.visible .entry _kernel_out_ptr0_s0_v0_v0_block_id_x(
    .param .u64 _kernel_out_ptr0_s0_v0_v0_block_id_x_param_0,
    .param .u64 _kernel_out_ptr0_s0_v0_v0_block_id_x_param_1,
    .param .u32 _kernel_out_ptr0_s0_v0_v0_block_id_x_param_2,
    .param .u32 _kernel_out_ptr0_s0_v0_v0_block_id_x_param_3,
    .param .u32 _kernel_out_ptr0_s0_v0_v0_block_id_x_param_4
)
{
    .reg .pred  %p<3>;
    .reg .b32   %r<11>;
    .reg .f32   %f<3>;
    .reg .b64   %rd<9>;

// %bb.0:                               // %entry
    ld.param.u64    %rd4, [_kernel_out_ptr0_s0_v0_v0_block_id_x_param_0];
    ld.param.u64    %rd5, [_kernel_out_ptr0_s0_v0_v0_block_id_x_param_1];
    cvta.to.global.u64  %rd1, %rd5;
    cvta.to.global.u64  %rd2, %rd4;
    mov.u32     %r1, %ctaid.x;
    ld.param.u32    %r5, [_kernel_out_ptr0_s0_v0_v0_block_id_x_param_3];
    mov.u32     %r2, %tid.x;
    setp.lt.s32     %p1, %r1, %r5;
    @%p1 bra    $L__BB0_1;
    bra.uni     $L__BB0_2;
$L__BB0_1:                              // %then_bb
    ld.global.nc.f32    %f2, [%rd1];
    shl.b32     %r9, %r1, 5;
    add.s32     %r10, %r9, %r2;
    mul.wide.s32    %rd7, %r10, 4;
    add.s64     %rd8, %rd2, %rd7;
    st.global.f32   [%rd8], %f2;
    bra.uni     $L__BB0_4;
$L__BB0_2:                              // %next_bb
    ld.param.u32    %r4, [_kernel_out_ptr0_s0_v0_v0_block_id_x_param_4];
    ld.param.u32    %r3, [_kernel_out_ptr0_s0_v0_v0_block_id_x_param_2];
    shl.b32     %r6, %r1, 5;
    add.s32     %r7, %r6, %r2;
    add.s32     %r8, %r7, %r3;
    setp.ge.s32     %p2, %r8, %r4;
    @%p2 bra    $L__BB0_4;
// %bb.3:                               // %then_bb1
    mul.wide.s32    %rd6, %r7, 4;
    add.s64     %rd3, %rd2, %rd6;
    ld.global.nc.f32    %f1, [%rd1];
    st.global.f32   [%rd3], %f1;
$L__BB0_4:                              // %after_bb
    ret;
                                        // -- End function
}

Lowering Parallel Tasks...
Embedding image cuda_buf
Embedding image cuda_gpu_source_kernels
Target triple of initial module: x86_64--linux-gnu
Generating llvm bitcode...
Generating llvm bitcode prolog for function halide_kernel...
Generating llvm bitcode for function halide_kernel...
add_temp_object_file: /tmp/e2FV38/halide_kernel.a.o
Module.compile(): temporary object /tmp/e2FV38/halide_kernel.a.o
emit_file.Compiling to native code...
Target machine is Position Independent!
Module.compile(): static_library /tmp/tmpbr7y2zw_/t4/ct4ns2no3iqzuxcxnbma37xodurfcgzqnwjt6lrep5m7pyih7ycw.halide/halide_kernel.a
file_unlink: /tmp/e2FV38/halide_kernel.a.o
dir_rmdir: /tmp/e2FV38
Module.compile(): conceptual_stmt /tmp/tmpbr7y2zw_/t4/ct4ns2no3iqzuxcxnbma37xodurfcgzqnwjt6lrep5m7pyih7ycw.halide/halide_kernel.conceptual.stmt
Module.compile(): c_header /tmp/tmpbr7y2zw_/t4/ct4ns2no3iqzuxcxnbma37xodurfcgzqnwjt6lrep5m7pyih7ycw.halide/halide_kernel.h
Module.compile(): schedule /tmp/tmpbr7y2zw_/t4/ct4ns2no3iqzuxcxnbma37xodurfcgzqnwjt6lrep5m7pyih7ycw.halide/halide_kernel.schedule.h
dir_rmdir: /tmp/47iHJp
Entering Pipeline halide_kernel
Target: x86-64-linux-avx-avx2-avx512-avx512_cannonlake-avx512_skylake-cuda-cuda_capability_86-debug-f16c-fma-no_asserts-no_runtime-sse41-strict_float-user_context
 Input (void const *) __user_context: 0x0
 Input Buffer in_ptr0: 0x7ffef44e1b60 -> buffer(140496468967424, 0x7fc8620e0658, 0x0, 2, float32, {0, 1, 1})
 Output Buffer out_ptr0: 0x7ffef44e1ba0 -> buffer(140496468968960, 0x7fc8620e0658, 0x0, 2, float32, {0, 1, 1})
Error: CUDA error: CUDA_ERROR_INVALID_VALUE cuMemcpyDtoH failed
zsh: IOT instruction (core dumped)  TORCHINDUCTOR_COMPILE_THREADS=1 HL_DEBUG_CODEGEN=1 CUDA_LAUNCH_BLOCKING=1   -

I'm not sure why Halide is trying to call cuMemcpyDtoH here. Maybe the host pointer being NULL is a problem?

abadams commented 5 months ago

The schedule says:

tmp3.compute_root()

Not sure what tmp3 is, but that schedule says it is to be computed on CPU. To do that presumably the input needs to be copied to CPU, but the host pointer is null.

abadams commented 5 months ago

I think the previous issue, where the device pointer was 32-bit, might hint that it was in a different memory space than global memory like Halide expected. Halide emitted ld.global.nc.f32 to load it, but maybe it's actually in constant memory, so that's not the right instruction?