keyvank / femtoGPT

Pure Rust implementation of a minimal Generative Pretrained Transformer
https://discord.gg/wTJFaDVn45
MIT License
772 stars 44 forks source link

AMD Gpu traning not working #19

Open rickytrevor opened 10 months ago

rickytrevor commented 10 months ago

Hello, I'm trying to run femtoGPT on my rx 6600 under Ubuntu Linux, I've installed the required Rocm OpenCL drivers but, when I run the program using

cargo run --release --features gpu

I get an index out of bounds exception

thread 'main' panicked at 'index out of bounds: the len is 0 but the index is 0', src/graph/gpu/mod.rs:119:22 note: run with RUST_BACKTRACE=1 environment variable to display a backtrace

Edit

This is the log with RUST_BACKTRACE=1

0: rust_begin_unwind at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/std/src/panicking.rs:593:5 1: core::panicking::panic_fmt at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/core/src/panicking.rs:67:14 2: core::panicking::panic_bounds_check at /rustc/5680fa18feaa87f3ff04063800aec256c3d4b4be/library/core/src/panicking.rs:162:5 3: femto_gpt::graph::gpu::GpuGraph::new 4: femto_gpt::main note: Some details are omitted, run with RUST_BACKTRACE=full for a verbose backtrace.

keyvank commented 10 months ago

@rickytrevor Ah yes you're right. I haven't tried this on AMD. Try changing the brand on line 119 (graph/gpu/mod.rs) to Brand::Amd and recompile

rickytrevor commented 10 months ago

I've tried changing it to Brand::Amd and it just panics to "Not supported", I've briefly looked at the Impl for Brand and the implementation for amd's missing in the get_bus_id function

   pub fn get_bus_id(&self, d: ocl::Device) -> ocl::Result<u32> {
        match self {
            Brand::Nvidia => {
                const CL_DEVICE_PCI_BUS_ID_NV: u32 = 0x4008;
                let result = d.info_raw(CL_DEVICE_PCI_BUS_ID_NV)?;
                Ok(u32::from_le_bytes(result[..].try_into().unwrap()))
            }
            Brand::Amd => panic!("Not supported!"),
        }
    }
}
rickytrevor commented 10 months ago

Update, I've managed to implement the case for Amd

    pub fn get_bus_id(&self, d: ocl::Device) -> ocl::Result<u32> {
        match self {
            Brand::Nvidia => {
                const CL_DEVICE_PCI_BUS_ID_NV: u32 = 0x4008;
                let result = d.info_raw(CL_DEVICE_PCI_BUS_ID_NV)?;
                Ok(u32::from_le_bytes(result[..].try_into().unwrap()))
            }
            Brand::Amd => {
                const CL_DEVICE_PCIE_ID_AMD: u32 = 0x4034;
                let result = d.info_raw(CL_DEVICE_PCIE_ID_AMD)?;
                Ok(u32::from_le_bytes(result[..].try_into().unwrap()))
            }
        }
    }
}

The problem is that right after starting it crashes because of this error, any clues?

Error: GpuError(OclError(Kernel(ArgTypeMismatch { idx: 0, arg_name: "out", ty_name: "float*", ty: ArgType { base_type: Float, cardinality: One, is_ptr: true } })))

keyvank commented 10 months ago

@rickytrevor I don't have a AMD card so can't debug :smile: Happy to get a PR for AMD support though :wink: