Open ksasi opened 1 year ago
Not in the very near future.
However we could build bindings around https://www.amd.com/en/graphics/servers-solutions-rocm first in order to start working. I'm not too familiar with the workflow, I do know there are some adapter layers for cuda somewhere.
In any case I think we should aim for the same thing as cuda, which would be a bare minimum kernels, and enabling users to write their own kernels.
If you want to start working on bindings (or know existing up-to-date ones) we can keep an eye on !
llama-cpp now supports rocm https://github.com/ggerganov/llama.cpp/pull/1087 discussion in that PR may provide guidance
my team has less than 25% the cost/performance using ROCm over CUDA but are stuck on the python side we would very much appreciate ROCm support from rust with candle
I am also looking into the possibility of running LLMs on ROCm-compatible AMD hardware (for potential significant savings), and it seems like llama.cpp might be the only viable option. I have done a test integrating candle and would prefer that, but it looks like I actually may be going back to llama.cpp because of ROCm support. The integration with Rust is awkward though and I would rather stick with a Rust solution if possible.
But overall this is such an amazing engineering effort and I really appreciate your work.
I'd love to contribute to the AMD support initiative for Candle. I'm wondering if HIP might not be a reasonable first pass.
Additionally, I propose prioritizing RDNA3 architecture cards due to its advanced features like multi-precision capability and AI Matrix Accelerator, which are crucial for ML. And AMD/ROCm seem to be starting with RDNA3 for serious support for ML/AI.
Anyway, I'm ready to contribute my time and skills, though I'd prefer not to lead the effort, but count me in for support!
FWIW: My background includes decades of low-level systems/embedded programming (C/C++), recent focus on Rust, experience with GPUs from game development, self-taught ML knowledge, and some familiarity with OpenCL and CUDA from my HPC days. While I'm less experienced with ROCm, I understand its significance.
This would be a great first project for my new System76 setup with a Ryzen 9 7950X, 128GB RAM, and a Radeon RX 7900XT. I plan to swap in dual 7900 XTXs for the 48GB GPU RAM.
I'm also happy contribute to the AMD support, But now there are two options to start this support. The first is to compile cuda into hip, and the second is to use hip source language.Which one is better?
A very elegant migration method. We may also need to handle the migration of flash-attention, but it is not laborious. We can copy it directly from the official amd library.
If anyone is interested, i've made a POC for ROCM You also need to use this fork of cudarc
I'm able to run example on an AMD gfx1030 GPU. GPU arch is hardcoded inside candle-hip-kernels/build.rs so if you have another GPU arch you must change that.
Thank you for the work! Which example did you try @vberthet?
I tried on gfx1102 (RX7600)
HSA_OVERRIDE_GFX_VERSION='11.0.2' cargo run --example phi --features=hip --release -- --model phi-hermes --prompt "A skier slides down a frictionless slope of height 40m and length 80m. What's the skier speed at the bottom?"
seeing model loaded
============================================ ROCm System Management Interface ============================================
====================================================== Concise Info ======================================================
Device [Model : Revision] Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU%
Name (20 chars) (Edge) (Avg) (Mem, Compute)
==========================================================================================================================
0 [0x240b : 0xcf] 50.0°C 145.0W N/A, N/A 2868Mhz 1124Mhz 32.94% auto 145.0W 79% 96%
0x7480
But the example never stops without results.
Edit: Now the yolo-v8 example runs on my AMD Radeon RX 7600.
HIP_VISIBLE_DEVICES=0 HSA_OVERRIDE_GFX_VERSION='11.0.2' cargo run --features hip --example yolo-v8 --release -- candle-examples/examples/yolo-v8/assets/bike.jpg
But the example never stops without results.
Same here, I tried the yolov8 one:
cargo run --features hip --example yolo-v8 --release -- candle-examples/examples/yolo-v8/assets/bike.jpg
it downloaded the model and then just got stuck with 100% CPU, I think I aborted after 15 minutes or so, the GPU did not show any signs of usage
I'm able to run example on an AMD gfx1030 GPU.
So, same question from my side, @vberthet you wrote you were able to run an example - which one did you try?
Would be great to get candle to work with ROCm, I see there is quite some interest in doing that, so perhaps we should coordinate in some way and also figure out what kind of PR with ROCm support could get accepted?
The implementation take the first GPU available and some GPU doesn't seems to work as expected. I have the same behavior on my computer, by default HIP select the low power GPU embedded inside the CPU instead of the discrete GPU. And for some reason this default choice never return a result... You can force HIP to use a specific GPU using the HIP_VISIBLE_DEVICES environment variable.
There still works to do in kernel ports to HIP, some half precision operation doesn't works or compile (eg : https://github.com/vberthet/candle/blob/2a0096af8013634479a3be0190286b60eb27205f/candle-hip-kernels/src/reduce.cu#L363)
One better approach is to use orochi to dynamically load CUDA / HIP at runtime. We shall also avoid using separate kernel for cuda and hip, llama.cpp does hip compatibility in kernel using cpp marco (see ggml-cuda.cu, unfortunately i don't have that much experience with c++.
The implementation take the first GPU available and some GPU doesn't seems to work as expected. I have the same behavior on my computer, by default HIP select the low power GPU embedded inside the CPU instead of the discrete GPU. And for some reason this default choice never return a result... You can force HIP to use a specific GPU using the HIP_VISIBLE_DEVICES environment variable.
Unfortunately this did not help me, it clearly reacts to the variable (i.e. specifying a non existing index will panic), but setting it (in my case to 0
) did not help. At the moment I anyway have only one GPU installed.
I'll try to look into what is going on there.
One better approach is to use orochi to dynamically load CUDA / HIP at runtime.
This one looks interesting, I wonder if candle
maintainers would consider this to be an option for a PR?
Overall, unfortunately, I was not yet able to reproduce your success with the POC :(
I have zero experience with GPU programming, so maybe someone could chime in. I attached with gdb to the Yolo example, here's thread apply all bt
:
(gdb) thread apply all bt
Thread 2 (Thread 0x7fe847c006c0 (LWP 2277) "yolo-v8"):
#0 __GI___ioctl (fd=fd@entry=3, request=request@entry=3222817548) at ../sysdeps/unix/sysv/linux/ioctl.c:36
#1 0x00007fe9129ed400 in kmtIoctl (fd=3, request=request@entry=3222817548, arg=arg@entry=0x7fe847bff380) at /usr/src/debug/hsakmt-1.0.6-38.rocm6.0.0.fc40.x86_64/src/libhsakmt.c:13
#2 0x00007fe9129ee7db in hsaKmtWaitOnMultipleEvents_Ext (event_age=0x7fe847bff430, Milliseconds=4294967294, WaitOnAll=<optimized out>, NumEvents=4, Events=0x7fe847bff4e0) at /usr/src/debug/hsakmt-1.0.6-38.rocm6.0.0.fc40.x86_64/src/events.c:409
#3 hsaKmtWaitOnMultipleEvents_Ext (Events=0x7fe847bff4e0, NumEvents=4, WaitOnAll=<optimized out>, Milliseconds=4294967294, event_age=0x7fe847bff430) at /usr/src/debug/hsakmt-1.0.6-38.rocm6.0.0.fc40.x86_64/src/events.c:380
#4 0x00007fe911c59a09 in rocr::core::Signal::WaitAny (satisfying_value=<optimized out>, wait_hint=<optimized out>, timeout=<optimized out>, values=<optimized out>, conds=<optimized out>, hsa_signals=<optimized out>, signal_count=7) at /usr/src/debug/rocm-runtime-6.0.0-3.fc40.x86_64/src/core/runtime/signal.cpp:324
#5 rocr::AMD::hsa_amd_signal_wait_any (signal_count=7, hsa_signals=<optimized out>, conds=<optimized out>, values=<optimized out>, timeout_hint=<optimized out>, wait_hint=<optimized out>, satisfying_value=<optimized out>) at /usr/src/debug/rocm-runtime-6.0.0-3.fc40.x86_64/src/core/runtime/hsa_ext_amd.cpp:587
#6 0x00007fe911c685f9 in rocr::core::Runtime::AsyncEventsLoop () at /usr/src/debug/rocm-runtime-6.0.0-3.fc40.x86_64/src/core/runtime/runtime.cpp:1136
#7 0x00007fe911c2541c in rocr::os::ThreadTrampoline (arg=<optimized out>) at /usr/src/debug/rocm-runtime-6.0.0-3.fc40.x86_64/src/core/util/lnx/os_linux.cpp:80
#8 0x00007fe912aa91f7 in start_thread (arg=<optimized out>) at pthread_create.c:447
#9 0x00007fe912b2b3ac in clone3 () at ../sysdeps/unix/sysv/linux/x86_64/clone3.S:78
Thread 1 (Thread 0x7fe95407ffc0 (LWP 2049) "yolo-v8"):
#0 0x00007fe911c5c942 in rocr::__rdtsc () at /usr/lib/gcc/x86_64-redhat-linux/14/include/ia32intrin.h:114
#1 rocr::timer::fast_clock::raw_now () at /usr/src/debug/rocm-runtime-6.0.0-3.fc40.x86_64/src/core/util/timer.h:149
#2 rocr::timer::fast_clock::now () at /usr/src/debug/rocm-runtime-6.0.0-3.fc40.x86_64/src/core/util/timer.h:140
#3 rocr::core::InterruptSignal::WaitRelaxed (this=0x560934e7c200, condition=HSA_SIGNAL_CONDITION_LT, compare_value=1, timeout=<optimized out>, wait_hint=HSA_WAIT_STATE_ACTIVE) at /usr/src/debug/rocm-runtime-6.0.0-3.fc40.x86_64/src/core/runtime/interrupt_signal.cpp:211
--Type <RET> for more, q to quit, c to continue without paging--
#4 0x00007fe911c5cb4e in rocr::core::InterruptSignal::WaitAcquire (this=<optimized out>, condition=<optimized out>, compare_value=<optimized out>, timeout=<optimized out>, wait_hint=<optimized out>) at /usr/src/debug/rocm-runtime-6.0.0-3.fc40.x86_64/src/core/runtime/interrupt_signal.cpp:251
#5 0x00007fe911c4c93f in rocr::HSA::hsa_signal_wait_scacquire (hsa_signal=..., condition=<optimized out>, compare_value=<optimized out>, timeout_hint=<optimized out>, wait_state_hint=<optimized out>) at /usr/src/debug/rocm-runtime-6.0.0-3.fc40.x86_64/src/core/runtime/hsa.cpp:1220
#6 0x00007fe95270e11b in roc::WaitForSignal<false> (forced_wait=false, active_wait=<optimized out>, signal=...) at /usr/src/debug/rocclr-6.0.0-3.fc40.x86_64/rocclr/device/rocm/rocvirtual.hpp:70
#7 roc::VirtualGPU::HwQueueTracker::CpuWaitForSignal (this=<optimized out>, signal=0x5609352dc8f0) at /usr/src/debug/rocclr-6.0.0-3.fc40.x86_64/rocclr/device/rocm/rocvirtual.cpp:558
#8 0x00007fe95273580e in roc::VirtualGPU::HwQueueTracker::WaitCurrent (this=<optimized out>) at /usr/src/debug/rocclr-6.0.0-3.fc40.x86_64/rocclr/device/rocm/rocvirtual.hpp:240
#9 roc::DmaBlitManager::hsaCopyStaged (this=this@entry=0x560934490380, hostSrc=hostSrc@entry=0x560934dfd5b0 " ", hostDst=0x7fe744401000 "", size=<optimized out>, size@entry=16, staging=0x7fe746500000 " ", hostToDev=hostToDev@entry=true) at /usr/src/debug/rocclr-6.0.0-3.fc40.x86_64/rocclr/device/rocm/rocblit.cpp:808
#10 0x00007fe952736455 in roc::DmaBlitManager::writeMemoryStaged (xferSize=16, totalSize=<synthetic pointer>: <optimized out>, offset=<synthetic pointer>: <optimized out>, origin=<optimized out>, xferBuf=..., dstMemory=..., srcHost=0x560934dfd5b0, this=0x560934490380) at /usr/src/debug/rocclr-6.0.0-3.fc40.x86_64/rocclr/device/rocm/rocblit.cpp:229
#11 roc::DmaBlitManager::writeBuffer (this=0x560934490380, srcHost=<optimized out>, dstMemory=..., origin=..., size=..., entire=<optimized out>, copyMetadata=...) at /usr/src/debug/rocclr-6.0.0-3.fc40.x86_64/rocclr/device/rocm/rocblit.cpp:314
#12 0x00007fe9527389b0 in roc::KernelBlitManager::writeBuffer (this=this@entry=0x560934490380, srcHost=srcHost@entry=0x560934dfd5b0, dstMemory=..., origin=..., size=..., entire=<optimized out>, copyMetadata=...) at /usr/src/debug/rocclr-6.0.0-3.fc40.x86_64/rocclr/device/rocm/rocblit.cpp:1944
#13 0x00007fe95271164b in roc::VirtualGPU::submitWriteMemory (this=0x560934e083e0, cmd=...) at /usr/src/debug/rocclr-6.0.0-3.fc40.x86_64/rocclr/device/rocm/rocvirtual.cpp:1705
#14 0x00007fe9526edcf4 in amd::Command::enqueue (this=0x560934eb46d0) at /usr/src/debug/rocclr-6.0.0-3.fc40.x86_64/rocclr/platform/command.cpp:391
#15 0x00007fe95259f7d6 in ihipMemcpy (dst=0x7fe744401000, src=0x560934dfd5b0, sizeBytes=<optimized out>, kind=hipMemcpyHostToDevice, stream=..., isHostAsync=false, isGPUAsync=true) at /usr/src/debug/rocclr-6.0.0-3.fc40.x86_64/hipamd/src/hip--Type <RET> for more, q to quit, c to continue without paging--
_memory.cpp:522
#16 0x00007fe9525c0706 in hipMemcpyHtoD (dstDevice=<optimized out>, srcHost=0x560934dfd5b0, ByteCount=16) at /usr/src/debug/rocclr-6.0.0-3.fc40.x86_64/hipamd/src/hip_memory.cpp:1435
#17 0x0000560932ca3f22 in cudarc::driver::safe::alloc::<impl cudarc::driver::safe::core::CudaDevice>::htod_copy ()
#18 0x0000560932c53efe in <candle_core::cuda_backend::CudaStorage as candle_core::backend::BackendStorage>::to_dtype ()
#19 0x0000560932c706ce in candle_core::tensor::Tensor::to_dtype ()
#20 0x0000560932c2d85e in <candle_core::safetensors::MmapedSafetensors as candle_nn::var_builder::SimpleBackend>::get ()
#21 0x0000560932c2ab1b in <alloc::boxed::Box<dyn candle_nn::var_builder::SimpleBackend> as candle_nn::var_builder::Backend>::get ()
#22 0x0000560932a299b2 in candle_nn::var_builder::VarBuilderArgs<B>::get_with_hints ()
#23 0x0000560932a4d43e in candle_nn::batch_norm::batch_norm ()
#24 0x0000560932a31335 in yolo_v8::model::ConvBlock::load ()
#25 0x0000560932a34b4a in yolo_v8::model::DarkNet::load ()
#26 0x0000560932a3f448 in <yolo_v8::model::YoloV8 as yolo_v8::Task>::load ()
#27 0x0000560932a50d3d in yolo_v8::main ()
#28 0x0000560932a2d233 in std::sys_common::backtrace::__rust_begin_short_backtrace ()
#29 0x0000560932a5dc8d in std::rt::lang_start::{{closure}} ()
#30 0x0000560932db6437 in std::rt::lang_start_internal ()
#31 0x0000560932a5dc7e in std::rt::lang_start ()
#32 0x00007fe912a3d088 in __libc_start_call_main (main=main@entry=0x560932a580a0 <main>, argc=argc@entry=2, argv=argv@entry=0x7ffdde43bdc8) at ../sysdeps/nptl/libc_start_call_main.h:58
#33 0x00007fe912a3d14b in __libc_start_main_impl (main=0x560932a580a0 <main>, argc=2, argv=0x7ffdde43bdc8, init=<optimized out>, fini=<optimized out>, rtld_fini=<optimized out>, stack_end=0x7ffdde43bdb8) at ../csu/libc-start.c:360
#34 0x00005609329f5535 in _start ()
I think it's pretty much this issue: https://github.com/ROCm/ROCm/issues/2715
For me it hangs with 100% CPU just as described there and the backtrace related to roc::DmaBlitManager::hsaCopyStaged
looks somewhat similar , so would cautiously guess that it may not be the POCs fault?
For those who are not suffering from the ROCm related CPU-hog bug, this project looks like a very interesting alternative: https://github.com/vosen/ZLUDA
If it does what it says it does, we could simply run unmodified candle code on AMD GPUs.
For those who are not suffering from the ROCm related CPU-hog bug, this project looks like a very interesting alternative: https://github.com/vosen/ZLUDA
If it does what it says it does, we could simply run unmodified candle code on AMD GPUs.
I gave it a spin. Compiling Candle requires NV libraries, both cuda, and cudnn. You need to add the env var CUDA_COMPUTE_CAP="80"
to for the project to compile.
Attempted to run with:
LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/zluda CUDA_COMPUTE_CAP="80" cargo run --example wuerstchen --release --features cuda,cudnn -- --prompt "Anthropomorphic cat dressed as a fire fighter"
Once it does, you are greeted with the error
Error: DriverError(CUDA_ERROR_NOT_FOUND, "named symbol not found") when loading is_u32_f32
I've updated the build script for hip kernel, it should generate kernel with better GPU arch compatibility.
I've been able to run :
Phi seems to run into an infinite loop Mistral needs bfloat16 and hip port of this part isn't done.
@cantor-set this errors seems to also exists with cuda see #353
Error: DriverError(CUDA_ERROR_NOT_FOUND, "named symbol not found") when loading is_u32_f32
@vberthet I had to patch build.rs, "-parallel-jobs=15"
did not work for me (Rawhide 40), I had to remove it in order to compile.
clang: error: unknown argument: '-parallel-jobs=15'
thread 'main' panicked at candle-hip-kernels/build.rs:354:13:
nvcc error while compiling "src/affine.cu":
Unfortunately I am still experiencing the ROCm 100% CPU hog from https://github.com/ROCm/ROCm/issues/2715 so nothing works for me at the moment anyway :(
Which rocm version are you running ? I've only tried with the latest version : 6.0.2
I am currently on RawHide which provides ROCm 6.0.0, although judging from the comments in the issue this problem was present in 5.7 as well. I could perhaps try to rebuild 6.0.2 and see if it goes away, although I am almost inclined to downgrade to 5.7 in order to try ZLUDA which afaik does not support the 6.x APIs yet. I think the last ROCm version which worked for me was 5.4, back on Fedora 38.
I finally got past the ROCm hanging memcpy issue, turned out it was enough to export HSA_ENABLE_SDMA=0
, took me a while to find this option though.
So, once that worked I got back to trying candle. @vberthet I am getting a coredump when trying to run the yolo-v8 demo:
model loaded
processing candle-examples/examples/yolo-v8/assets/bike.jpg
generating predictions
Segmentation fault (core dumped)
The printout "generating predictions" was added by me, right before let predictions = model.forward(&image_t)?.squeeze(0)?;
, which is where it seems to crash. I'll see that I get a usable trace.
OK, so... after a longer time I finally got back to this and I am not sure what changed - I guess I fixed my installation without realizing it - it did not crash. I got the yolov8 example to run through! @vberthet - awesome!
Question is - what's next? Will you be maintaining the fork and perhaps coordinate the efforts? How could we translate the POC into a maintained version?
EDIT: I may have been celebrating too early... I do not see any indication that the GPU was actually used, no spikes when looking at nvtop while yolo is processing...
EDIT2: all good, yolo was simply too fast to be noticeable on the GPU in nvtop, with SDXL-Turbo I can see that the GPU is being used
I've got 2 systems each with 8 AMD MI300X's and I'm pissed I can't used Candle with it...Python is yucky.
Somebody help me out?
======================================================= Concise Info =======================================================
Device Node IDs Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU%
(DID, GUID) (Junction) (Socket) (Mem, Compute, ID)
============================================================================================================================
0 26 0x74a1, 8554 37.0°C 131.0W NPS1, SPX, 0 132Mhz 900Mhz 0% manual 750.0W 0% 0%
1 27 0x74a1, 19011 38.0°C 130.0W NPS1, SPX, 0 132Mhz 900Mhz 0% manual 750.0W 0% 0%
2 25 0x74a1, 30036 39.0°C 132.0W NPS1, SPX, 0 132Mhz 900Mhz 0% manual 750.0W 0% 0%
3 24 0x74a1, 23964 36.0°C 132.0W NPS1, SPX, 0 132Mhz 900Mhz 0% manual 750.0W 0% 0%
4 30 0x74a1, 1197 37.0°C 131.0W NPS1, SPX, 0 132Mhz 900Mhz 0% manual 750.0W 0% 0%
5 31 0x74a1, 41351 35.0°C 130.0W NPS1, SPX, 0 131Mhz 900Mhz 0% manual 750.0W 0% 0%
6 29 0x74a1, 26775 40.0°C 134.0W NPS1, SPX, 0 132Mhz 900Mhz 0% manual 750.0W 0% 0%
7 28 0x74a1, 45536 35.0°C 133.0W NPS1, SPX, 0 132Mhz 900Mhz 0% manual 750.0W 0% 0%
============================================================================================================================
=================================================== End of ROCm SMI Log ====================================================
mastersplinter@turtle005:~/candle$ rocm-smi --showproductname
============================ ROCm System Management Interface ============================
====================================== Product Info ======================================
GPU[0] : Card Series: AMD Instinct MI300X OAM
GPU[0] : Card Model: 0x74a1
GPU[0] : Card Vendor: Advanced Micro Devices, Inc. [AMD/ATI]
GPU[0] : Card SKU: MI3SRIOV
GPU[0] : Subsystem ID: 0x74a1
GPU[0] : Device Rev: 0x00
GPU[0] : Node ID: 26
GPU[0] : GUID: 8554
GPU[0] : GFX Version: gfx942
GPU[1] : Card Series: AMD Instinct MI300X OAM
GPU[1] : Card Model: 0x74a1
GPU[1] : Card Vendor: Advanced Micro Devices, Inc. [AMD/ATI]
GPU[1] : Card SKU: MI3SRIOV
GPU[1] : Subsystem ID: 0x74a1
GPU[1] : Device Rev: 0x00
GPU[1] : Node ID: 27
GPU[1] : GUID: 19011
GPU[1] : GFX Version: gfx942
GPU[2] : Card Series: AMD Instinct MI300X OAM
GPU[2] : Card Model: 0x74a1
GPU[2] : Card Vendor: Advanced Micro Devices, Inc. [AMD/ATI]
GPU[2] : Card SKU: MI3SRIOV
GPU[2] : Subsystem ID: 0x74a1
GPU[2] : Device Rev: 0x00
GPU[2] : Node ID: 25
GPU[2] : GUID: 30036
GPU[2] : GFX Version: gfx942
GPU[3] : Card Series: AMD Instinct MI300X OAM
GPU[3] : Card Model: 0x74a1
GPU[3] : Card Vendor: Advanced Micro Devices, Inc. [AMD/ATI]
GPU[3] : Card SKU: MI3SRIOV
GPU[3] : Subsystem ID: 0x74a1
GPU[3] : Device Rev: 0x00
GPU[3] : Node ID: 24
GPU[3] : GUID: 23964
GPU[3] : GFX Version: gfx942
GPU[4] : Card Series: AMD Instinct MI300X OAM
GPU[4] : Card Model: 0x74a1
GPU[4] : Card Vendor: Advanced Micro Devices, Inc. [AMD/ATI]
GPU[4] : Card SKU: MI3SRIOV
GPU[4] : Subsystem ID: 0x74a1
GPU[4] : Device Rev: 0x00
GPU[4] : Node ID: 30
GPU[4] : GUID: 1197
GPU[4] : GFX Version: gfx942
GPU[5] : Card Series: AMD Instinct MI300X OAM
GPU[5] : Card Model: 0x74a1
GPU[5] : Card Vendor: Advanced Micro Devices, Inc. [AMD/ATI]
GPU[5] : Card SKU: MI3SRIOV
GPU[5] : Subsystem ID: 0x74a1
GPU[5] : Device Rev: 0x00
GPU[5] : Node ID: 31
GPU[5] : GUID: 41351
GPU[5] : GFX Version: gfx942
GPU[6] : Card Series: AMD Instint
@kennethdsheridan I'll be honest, I gave up on Candle, because my goal was to learn and use an AI framework,not to spend time HIPifying CUDA code. Afaik they now work on WGPU support and WGPU does support ROCm, so it should work eventually, although I am not sure if WGPU is the most performant backend at the time.
I switched to Burn: https://github.com/tracel-ai/burn
Apart from an already working WGPU backend they also support the tch-rs
backend which is based on libtorch and since PyTorch/libtorch supports ROCm you get AMD support via this backend too (should be as fast as libtorch).
Nice system btw, I wish I had so many GPUs :)
also interested in rocm support in candle for screenpipe
@vberthet where did this end for you? I'd like to try to help out. My Rust skills are very limited, so is my GPU knowledge. But I'd like to help out where I can with AMD support for Candle.\
Hi,
This library is cool. Rust for deep learning is nice and great work from huggingface. I am curious to understand if there are plans for AMD hardware support for training and Inference.
Thanks