tracel-ai / cubecl

Multi-platform high-performance compute language extension for Rust.
https://burn.dev
Apache License 2.0
519 stars 21 forks source link

`gelu` example panics with CUDA backend #75

Open wbrickner opened 3 weeks ago

wbrickner commented 3 weeks ago
➜  RUST_BACKTRACE=1 cargo run --example gelu -F cuda
    Finished `dev` profile [optimized + debuginfo] target(s) in 0.22s
     Running `target/debug/examples/gelu`
thread 'main' panicked at /home/will/trinity_stuff/rust/cubecl/crates/cubecl-cuda/src/compute/server.rs:245:85:
called `Result::unwrap()` on an `Err` value: DriverError(CUDA_ERROR_INVALID_PTX, "a PTX JIT compilation failed")
stack backtrace:
   0: rust_begin_unwind
             at /rustc/feeba198f2c5455b31ca898c859385d5161f0bd8/library/std/src/panicking.rs:662:5
   1: core::panicking::panic_fmt
             at /rustc/feeba198f2c5455b31ca898c859385d5161f0bd8/library/core/src/panicking.rs:74:14
   2: core::result::unwrap_failed
             at /rustc/feeba198f2c5455b31ca898c859385d5161f0bd8/library/core/src/result.rs:1679:5
   3: core::result::Result<T,E>::unwrap
             at /rustc/feeba198f2c5455b31ca898c859385d5161f0bd8/library/core/src/result.rs:1102:23
   4: cubecl_cuda::compute::server::CudaContext<MM>::compile_kernel
             at ./crates/cubecl-cuda/src/compute/server.rs:245:17
   5: <cubecl_cuda::compute::server::CudaServer<MM> as cubecl_runtime::server::ComputeServer>::execute
             at ./crates/cubecl-cuda/src/compute/server.rs:143:13
   6: <cubecl_runtime::channel::mutex::MutexComputeChannel<Server> as cubecl_runtime::channel::base::ComputeChannel<Server>>::execute
             at ./crates/cubecl-runtime/src/channel/mutex.rs:67:9
   7: cubecl_runtime::client::ComputeClient<Server,Channel>::execute_unchecked
             at ./crates/cubecl-runtime/src/client.rs:98:9
   8: cubecl_core::compute::launcher::KernelLauncher<R>::launch_unchecked
             at ./crates/cubecl-core/src/compute/launcher.rs:107:9
   9: gelu::gelu_array::launch_unchecked
             at ./examples/gelu/src/lib.rs:3:1
  10: gelu::launch
             at ./examples/gelu/src/lib.rs:22:9
  11: gelu::main
             at ./examples/gelu/examples/gelu.rs:3:5
  12: core::ops::function::FnOnce::call_once
             at /rustc/feeba198f2c5455b31ca898c859385d5161f0bd8/library/core/src/ops/function.rs:250:5
note: Some details are omitted, run with `RUST_BACKTRACE=full` for a verbose backtrace.
➜  nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Mar_28_02:18:24_PDT_2024
Cuda compilation tools, release 12.4, V12.4.131
Build cuda_12.4.r12.4/compiler.34097967_0
wbrickner commented 3 weeks ago

the source gets generated and looks fine.

typedef unsigned int uint;

extern "C"
__global__ void kernel(
  float input_0[], float output_0[], uint info[]
) {

  int3 absoluteIdx = make_int3(
    blockIdx.x * blockDim.x + threadIdx.x,
    blockIdx.y * blockDim.y + threadIdx.y,
    blockIdx.z * blockDim.z + threadIdx.z
  );

  uint idxGlobal = (absoluteIdx.z * gridDim.x * blockDim.x * gridDim.y * blockDim.y) + (absoluteIdx.y * gridDim.x * blockDim.x) + absoluteIdx.x;
  uint rank = info[0];
  uint rank_2 = rank * 2;
  uint l_0_0;
  bool l_0_1;
  float l_0_2;
  float l_0_3;
  l_0_0 = info[(2 * 2 * info[0]) + 1];
  l_0_1 = idxGlobal < l_0_0;
  if (l_0_1) {
    l_0_2 = input_0[idxGlobal];
    l_0_3 = sqrt(float(2.0));
    l_0_3 = l_0_2 / l_0_3;
    l_0_3 = erf(l_0_3);
    l_0_3 = l_0_3 + float(1.0);
    l_0_2 = l_0_2 * l_0_3;
    l_0_2 = l_0_2 / float(2.0);
    output_0[idxGlobal] = l_0_2;
  }
}

the decoded PTX:

//
// Generated by NVIDIA NVVM Compiler
//
// Compiler Build ID: CL-30672275
// Cuda compilation tools, release 11.5, V11.5.119
// Based on NVVM 7.0.1
//

.version 7.5
.target sm_70
.address_size 64

        // .globl       kernel

.visible .entry kernel(
        .param .u64 kernel_param_0,
        .param .u64 kernel_param_1,
        .param .u64 kernel_param_2
)
{
        .reg .pred      %p<4>;
        .reg .f32       %f<31>;
        .reg .b32       %r<27>;
        .reg .b64       %rd<14>;

        ld.param.u64    %rd2, [kernel_param_0];
        ld.param.u64    %rd3, [kernel_param_1];
        ld.param.u64    %rd4, [kernel_param_2];
        cvta.to.global.u64      %rd5, %rd4;
        mov.u32         %r2, %ntid.x;
        mov.u32         %r3, %ctaid.x;
        mov.u32         %r4, %tid.x;
        mad.lo.s32      %r5, %r3, %r2, %r4;
        mov.u32         %r6, %ntid.y;
        mov.u32         %r7, %ctaid.y;
        mov.u32         %r8, %tid.y;
        mad.lo.s32      %r9, %r7, %r6, %r8;
        mov.u32         %r10, %ntid.z;
        mov.u32         %r11, %ctaid.z;
        mov.u32         %r12, %tid.z;
        mad.lo.s32      %r13, %r11, %r10, %r12;
        mov.u32         %r14, %nctaid.y;
        mul.lo.s32      %r15, %r6, %r14;
        mad.lo.s32      %r16, %r15, %r13, %r9;
        mov.u32         %r17, %nctaid.x;
        mul.lo.s32      %r18, %r17, %r2;
        mad.lo.s32      %r1, %r18, %r16, %r5;
        ld.global.u32   %r19, [%rd5];
        shl.b32         %r20, %r19, 2;
        or.b32          %r21, %r20, 1;
        mul.wide.u32    %rd6, %r21, 4;
        add.s64         %rd7, %rd5, %rd6;
        ld.global.u32   %r22, [%rd7];
        setp.ge.u32     %p1, %r1, %r22;
        @%p1 bra        $L__BB0_4;

        cvta.to.global.u64      %rd8, %rd2;
        cvt.u64.u32     %rd1, %r1;
        mul.wide.u32    %rd9, %r1, 4;
        add.s64         %rd10, %rd8, %rd9;
        ld.global.f32   %f1, [%rd10];
        div.rn.f32      %f2, %f1, 0f3FB504F3;
        abs.f32         %f6, %f2;
        setp.ltu.f32    %p2, %f6, 0f3F8060FE;
        setp.ge.f32     %p3, %f6, 0f3F8060FE;
        mul.f32         %f7, %f2, %f2;
        selp.f32        %f8, %f6, %f7, %p3;
        selp.f32        %f9, 0f3789CA3C, 0f38B1E96A, %p3;
        selp.f32        %f10, 0fB9F560B9, 0fBA574D20, %p3;
        fma.rn.f32      %f11, %f9, %f8, %f10;
        selp.f32        %f12, 0f3BAC840B, 0f3BAAD5EA, %p3;
        fma.rn.f32      %f13, %f11, %f8, %f12;
        selp.f32        %f14, 0fBD0C8162, 0fBCDC1BE7, %p3;
        fma.rn.f32      %f15, %f13, %f8, %f14;
        selp.f32        %f16, 0f3E1CF906, 0f3DE718AF, %p3;
        fma.rn.f32      %f17, %f15, %f8, %f16;
        selp.f32        %f18, 0f3F6A937E, 0fBEC093AC, %p3;
        fma.rn.f32      %f19, %f17, %f8, %f18;
        selp.f32        %f20, 0f3F20D842, 0f3E0375D3, %p3;
        fma.rn.f32      %f21, %f19, %f8, %f20;
        neg.f32         %f22, %f6;
        selp.f32        %f23, %f22, %f2, %p3;
        fma.rn.f32      %f30, %f21, %f23, %f23;
        @%p2 bra        $L__BB0_3;

        ex2.approx.ftz.f32      %f24, %f30;
        mov.f32         %f25, 0f3F800000;
        sub.f32         %f26, %f25, %f24;
        mov.b32         %r23, %f26;
        mov.b32         %r24, %f2;
        and.b32         %r25, %r24, -2147483648;
        or.b32          %r26, %r25, %r23;
        mov.b32         %f30, %r26;

$L__BB0_3:
        add.f32         %f27, %f30, 0f3F800000;
        mul.f32         %f28, %f1, %f27;
        mul.f32         %f29, %f28, 0f3F000000;
        cvta.to.global.u64      %rd11, %rd3;
        shl.b64         %rd12, %rd1, 2;
        add.s64         %rd13, %rd11, %rd12;
        st.global.f32   [%rd13], %f29;

$L__BB0_4:
        ret;
}

we then die at cudarc::driver::result::module::load_data(ptx.as_ptr() as *const _)

wbrickner commented 3 weeks ago

Okay the issue is that the arch argument is incorrect for my GPU (was 70).

I am running an old 1050 Ti. Editing arch to be 50 in debugger gives success.

wbrickner commented 3 weeks ago

arch comes from self.minimum_arch_version which is written in CudaServer::new:

    /// Create a new cuda server.
    pub(crate) fn new(index: usize, init: Box<dyn Fn(usize) -> CudaContext<MM>>) -> Self {
        let archs = unsafe {
            let mut num_supported_arg: core::ffi::c_int = 0;
            cudarc::nvrtc::sys::lib()
                .nvrtcGetNumSupportedArchs(core::ptr::from_mut(&mut num_supported_arg));

            let mut archs: Vec<core::ffi::c_int> = vec![0; num_supported_arg as usize];
            cudarc::nvrtc::sys::lib().nvrtcGetSupportedArchs(core::ptr::from_mut(&mut archs[0]));
            archs
        };

        println!("Supported archs: {:?}", archs);

        let minimum_arch_version = archs[0];

        Self {
            state: CudaServerState::Uninitialized {
                device_index: index,
                init,
            },
            logger: DebugLogger::new(),
            archs,
            minimum_arch_version,
        }
    }

We see

Supported archs: [35, 37, 50, 52, 53, 60, 61, 62, 70, 72, 75, 80, 86, 87]

and indeed at that time minimum_arch_version = 35.

It becomes 70 because of CudaRuntime::client:

if let Some(wmma_minimum_version) = register_wmma_features(&mut features, &server.archs)
  {
      server.minimum_arch_version =
          i32::max(server.minimum_arch_version, wmma_minimum_version);
  }

so if i follow, you are trying to emit WMMA instructions which require a more recent gpu architecture than the real GPU.

wbrickner commented 3 weeks ago

Was working on a PR, but a kindof complicated redesign is necessary I think

Issue is we can't tell the max arch for a device without initializing. The design of CudaServer is lazy init. But we require construction of Features before then. we need to be able to get device compute capability to filter archs.

I think this part of the system needs a rethink. I vote to either eager init or relocate archs / minimum_arch_version and redesign the types.

nathanielsimard commented 2 weeks ago

Thanks for the details. It's kinda weird that compiling using one of the supported archs generates invalid PTX.

wbrickner commented 2 weeks ago

the supported archs are just queried from the nvidia runtime compiler and are not filtered by the actual card capability. this is bc the card capability is unknown until it is initialized which is deferred in the current implementation

wbrickner commented 2 weeks ago

(the compilation is not failing, the loading / execution is failing bc any PTX arch too recent is invalid for the device)

nathanielsimard commented 2 weeks ago

the supported archs are just queried from the nvidia runtime compiler and are not filtered by the actual card capability. this is bc the card capability is unknown until it is initialized which is deferred in the current implementation

So maybe we should initialize the device earlier. I had deferred it due to multi-threading issues initially, but now that I think those issues are resolved, we might not need to defer it anymore.