tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.56k stars 422 forks source link

Text classification example gives "Shader validation error" when run on multiple GPUs #1745

Open joshhansen opened 5 months ago

joshhansen commented 5 months ago

Describe the bug Running the text classification example's ag news training step on multiple discrete GPUs fails with "Shader validation error":

This error overlaps some with the one in #1088.

To Reproduce On a system with two or more discrete GPUs:

git clone https://github.com/tracel-ai/burn.git
cd burn/examples/text-classification

Edit examples/ag-news-train.rs like so:

-        launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(vec![WgpuDevice::default()]);
+        launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(vec![
+            WgpuDevice::DiscreteGpu(0),
+            WgpuDevice::DiscreteGpu(1),
+        ]);

cargo run --example ag-news-train --features wgpu

Expected behavior The training proceeds, utilizing both GPUs.

Desktop (please complete the following information):

nathanielsimard commented 5 months ago

Looking at the experiment.log the problem seems to come from the validation layer of Vulkan, not from a multi-device error. I tested on my system and I can run the training with multiple devices. Maybe you can try to disable the validation layer of Vulkan (branch wgpu-no-validation).

Also, you could test using the LibTorch backend instead.

joshhansen commented 5 months ago

Training does appear to work with the LibTorch GPU backend, with multiple GPUs specified. That may not be much use to me though - I am specifically migrating away from libtorch due to its lack of thread safety.

Running on the wgpu-no-validation branch surprisingly results in the same validation error: experiment.log

nathanielsimard commented 5 months ago

@joshhansen My intuition would suggest that the problem may come from a precision error, where wgpu can't convert the literal to a float32. If you change that value, does it work?

joshhansen commented 5 months ago

Change 0.00000000023283064365386963f? My apologies, I'm not familiar with Burn's compilation process, where would that value "live" such that I could modify it?

nathanielsimard commented 4 months ago

@joshhansen I guessed it was a constant defined by your code 😅