Open wcshds opened 1 month ago
I tried printing the tensor converted from the input f32 slice in the console. It seems that the tensor's data has been corrupted from the very beginning.
@nathanielsimard @louisfd just be aware of this jit related bug on WASM/WebGPU
I would be great if we had WebGPU tests on CI #810
Is this bug caused by Pool2dEagerKernel
? @nathanielsimard @louisfd
[START_KERNEL_COMPILATION]
name: burn_jit::kernel::pool::pool2d_shader::Pool2dEagerKernel<
burn_jit::kernel::pool::max_pool2d::MaxPool<
f32,
>,
cubecl_wgpu::runtime::WgpuRuntime,
f32,
>
cube_dim: (16, 16, 1)
shared_memory: 0 bytes
source:
@group(0)
@binding(0)
var<storage, read_write> input_0_global: array
@group(0)
@binding(1)
var<storage, read_write> output_0_global: array
@group(0)
@binding(2)
var<storage, read_write> info: array
@group(0) @binding(3) var<storage, read_write> scalars_uint: array<u32, 6>;
const WORKGROUP_SIZE_X = 16u; const WORKGROUP_SIZE_Y = 16u; const WORKGROUP_SIZE_Z = 1u;
@compute
@workgroup_size(16, 16, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3
[END_KERNEL_COMPILATION]
This bug is caused by the limited default precision of Rust's display of f32 values. It should be an easy fix.
I changed the code from https://github.com/tracel-ai/cubecl/blob/cfe0b0204380cbd0931f478194a053a6ac35d1cb/crates/cubecl-wgpu/src/compiler/wgsl/base.rs#L262-L263 :
FloatKind::F32 => f.write_fmt(format_args!("{}f", *val as f32)),
FloatKind::F64 => f.write_fmt(format_args!("{}f", { *val })),
to:
FloatKind::F32 => f.write_fmt(format_args!("{:.9}f", *val as f32)),
FloatKind::F64 => f.write_fmt(format_args!("{:.17}f", { *val })),
It can fix the bug.
I thought maybe it was due to incompatible versions between the main branch and the examples, but I have the same issue after downloading the 0.13.2 release out-of-the-box then cd examples/image-classification-web
then ./build-for-web.sh
. In fact, it would not even compile without further adding the features=["autotune"]
to the burn-wgpu
crate dependency manually before compiling.
Then I tried with 0.13.1 release as well as the 0.13.0 release, and this produces the same issue as described in the first place (also, the Candle backend does not work either, as per issue #1034 ).
I am just not able to reproduce the image-classification-web
with any of 0.13+ version (didn't look at earlier versions), except for the ndarray backend which works seemlessly.
@antimora do you by any chance kept around the original repo you used to make your published version work? Would be very helpful to diff it and check what's wrong.
For instance, how did you make the Candle backend work without the AvgPool2d
op in the first place? Did you switched to another model than squeezenet
for the sake of this example? Also, how did you manage to get the wgpu backend to load? (If I am correct, the above solution suggested by @wcshds only apply to 0.14+ versions that include cubecl
dependency, correct?)
By the way, thanks for sharing this great project :)
In fact, I've tried all the release versions of Burn since 10.0.0 on my device, but the wgpu backend in the image-classification-example
hasn't worked. However, after the small modifications I mentioned above, now I can run this example successfully on the main branch.
Reproduction steps:
# Clone repo
git clone git@github.com:tracel-ai/burn.git
# get into repo
cd burn
# Change cubecl dependency to revision that include the suggested bugfix (https://github.com/tracel-ai/cubecl/commit/32feabc5140170d45d4365a56106db930ed79a33)
# For reproduction purposes, here I use the sd utility: (https://github.com/chmln/sd), but one can just change it manually in the Cargo.toml for both cubecl AND cubecl-common
sd '(cubecl.* rev =).*(\})' '$1 "32feabc5140170d45d4365a56106db930ed79a33" $2' Cargo.toml
# cd into the relevant example
cd examples/image-classification-web
# compile the example
./build-for-web.sh
# run the server
./run-server.sh
RESULTS:
✅ NdArray backend: working (slower than 0.13.2 version by an order of magnitude, but still okay)
❗ Candle backend: backend LOADS correctly but cannot do inference, Candle does not support excluding pad count in pooling
❌ Wgpu backend: cannot load, An home directory should exist
Did I miss something here?
@Jonarod You also need to disable the use of the autotune feature for burn-wgpu.
As for Candle, I don't know how to make it work either.
Hey @wcshds 👋
I was going to update the cubecl dep to the latest to include some fixes, but some wgpu tests are failing with your merged PR.
failures:
tests::jit::gradients::tests::should_update_tensor_when_grad_replace
tests::jit::kernel::bernoulli::tests::number_of_1_proportional_to_prob
tests::jit::kernel::bernoulli::tests::runs_test
tests::jit::kernel::normal::tests::empirical_mean_close_to_expectation
tests::jit::kernel::normal::tests::normal_respects_68_95_99_rule
tests::jit::kernel::uniform::tests::at_least_one_value_per_bin_int_uniform
tests::jit::kernel::uniform::tests::at_least_one_value_per_bin_uniform
tests::jit::kernel::uniform::tests::runs_test
tests::jit_fusion::gradients::tests::should_update_tensor_when_grad_replace
Tried to understand the problem in this issue but I think I'm missing a bit of context.. could you explain why the precision change was required?
/edit: see PRs #2159 #2158 as reference.
Worked like a charm.
Thanks for your help.
I submitted corresponding PR that should solve this issue.
If you read this while PR is not merged to main, basically the solution, as suggested by @wcshds is to:
change cubecl
and cubecl-common
revisions to 32feabc5140170d45d4365a56106db930ed79a33
in the burn's root's Cargo.toml
remove burn-wgpu
's features = [ "autotune" ]
from examples/image-classification-web/Cargo.toml
I think this can be closed now.
@laggui This is because I saw in the browser console that -3.40282347E+38f32 was rounded to -340282350000000000000000000000000000000f, which caused a WGSL compilation error, so I believe this is an issue with Rust's default display precision. It's strange that the test failed, but changing the precision to 13 decimal places solved the problem.
FloatKind::F32 => f.write_fmt(format_args!("{:.13}f", *val as f32)),
I originally thought that a precision of 9 decimal places would be sufficient for f32.
Weird 😅
Not sure if the fix is the proper way to address this or if it's just a patch for a more specific issue. It doesn't seem to happen anywhere else 🤔
/edit: fyi, we have decided to revert the changes applied to the precision for now. The current workaround is at least documented for users to try while we investigate why this happens in this specific example.
When I choose the wgpu backend, I get errors in the console.
After I disable the
autotune
feature of burn-wgpu, the wgpu backend still cannot work. The live demo works fine, so I think it's not my device's problem.