tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.53k stars 419 forks source link

Wgpu backend of image-classification-web example cannot work #2118

Open wcshds opened 1 month ago

wcshds commented 1 month ago

When I choose the wgpu backend, I get errors in the console.

image

After I disable the autotune feature of burn-wgpu, the wgpu backend still cannot work. The live demo works fine, so I think it's not my device's problem.

image

wcshds commented 1 month ago

I tried printing the tensor converted from the input f32 slice in the console. It seems that the tensor's data has been corrupted from the very beginning. image

antimora commented 1 month ago

@nathanielsimard @louisfd just be aware of this jit related bug on WASM/WebGPU

antimora commented 1 month ago

I would be great if we had WebGPU tests on CI #810

wcshds commented 1 month ago

Is this bug caused by Pool2dEagerKernel? @nathanielsimard @louisfd

image

[START_KERNEL_COMPILATION]
name: burn_jit::kernel::pool::pool2d_shader::Pool2dEagerKernel<
    burn_jit::kernel::pool::max_pool2d::MaxPool<
        f32,
    >,
    cubecl_wgpu::runtime::WgpuRuntime,
    f32,
>
cube_dim: (16, 16, 1)
shared_memory: 0 bytes
source:

@group(0) @binding(0) var<storage, read_write> input_0_global: array;

@group(0) @binding(1) var<storage, read_write> output_0_global: array;

@group(0) @binding(2) var<storage, read_write> info: array;

@group(0) @binding(3) var<storage, read_write> scalars_uint: array<u32, 6>;

const WORKGROUP_SIZE_X = 16u; const WORKGROUP_SIZE_Y = 16u; const WORKGROUP_SIZE_Z = 1u;

@compute @workgroup_size(16, 16, 1) fn main(
@builtin(global_invocation_id) global_id: vec3, @builtin(num_workgroups) num_workgroups: vec3, ) {let id = (global_id.z num_workgroups.x WORKGROUP_SIZE_X num_workgroups.y WORKGROUP_SIZE_Y) + (global_id.y num_workgroups.x WORKGROUP_SIZE_X) + global_id.x; let rank: u32 = info[0]; let rank_2: u32 = rank 2u; var l_0_0: u32; var l_0_1: u32; var l_0_2: u32; var l_0_3: u32; var l_0_4: u32; var l_0_5: u32; var l_0_6: u32; var l_0_7: u32; var l_0_8: u32; var l_0_9: u32; var l_0_10: u32; var l_0_11: u32; var l_0_12: u32; var l_0_13: u32; var l_0_14: u32; var l_0_15: u32; var l_0_16: u32; var l_0_17: u32; var l_0_18: u32; var l_0_19: u32; var l_0_20: u32; var l_0_21: u32; var l_0_22: u32; var l_0_23: u32; var l_0_24: u32; var l_0_25: f32; var l_0_26: u32; var l_0_27: u32; var l_0_28: u32; var l_0_29: u32; var l_0_30: u32; var l_0_31: u32; var l_0_32: bool; var l_0_33: bool; var l_0_34: bool; var l_0_35: u32; var l_0_36: u32; var l_0_37: f32; l_0_0 = info[(0u rank_2) + 0u + 1u]; l_0_1 = info[(0u rank_2) + 1u + 1u]; l_0_2 = info[(0u rank_2) + 2u + 1u]; l_0_3 = info[(0u rank_2) + 3u + 1u]; l_0_4 = info[(0u rank_2) + rank + 2u + 1u]; l_0_5 = info[(0u rank_2) + rank + 3u + 1u]; l_0_6 = info[(0u rank_2) + rank + 2u + 1u]; l_0_7 = info[(0u rank_2) + rank + 3u + 1u]; l_0_8 = info[(1u rank_2) + 0u + 1u]; l_0_9 = info[(1u rank_2) + 1u + 1u]; l_0_10 = info[(1u rank_2) + 2u + 1u]; l_0_11 = info[(1u rank_2) + 3u + 1u]; l_0_12 = info[(1u rank_2) + rank + 0u + 1u]; l_0_13 = info[(1u rank_2) + rank + 1u + 1u]; l_0_14 = info[(1u rank_2) + rank + 2u + 1u]; l_0_15 = info[(1u rank_2) + rank + 3u + 1u]; l_0_16 = id / l_0_8; l_0_16 = l_0_16 % l_0_12; l_0_17 = id / l_0_9; l_0_17 = l_0_17 % l_0_13; l_0_18 = id / l_0_10; l_0_18 = l_0_18 % l_0_14; l_0_19 = id / l_0_11; l_0_19 = l_0_19 % l_0_15; l_0_35 = l_0_6 + scalars_uint[4]; l_0_36 = l_0_7 + scalars_uint[5]; l_0_27 = l_0_16 l_0_0; l_0_28 = l_0_17 l_0_1; l_0_37 = f32(-340282350000000000000000000000000000000f); l_0_20 = l_0_18 scalars_uint[0]; l_0_22 = 0u scalars_uint[2]; l_0_20 = l_0_20 + l_0_22; l_0_32 = l_0_20 >= scalars_uint[4]; l_0_34 = l_0_20 < l_0_35; l_0_32 = l_0_32 && l_0_34; if l_0_32 { l_0_21 = l_0_19 scalars_uint[1]; l_0_22 = 0u scalars_uint[3]; l_0_21 = l_0_21 + l_0_22; l_0_33 = l_0_21 >= scalars_uint[5]; l_0_34 = l_0_21 < l_0_36; l_0_33 = l_0_33 && l_0_34; if l_0_33 { var l_2_0: bool; l_0_23 = l_0_20 - scalars_uint[4]; l_0_24 = l_0_21 - scalars_uint[5]; l_0_29 = l_0_23 l_0_2; l_0_31 = u32(l_0_29); l_0_31 = l_0_31 + l_0_24; l_0_30 = l_0_24 l_0_3; l_0_26 = u32(l_0_27); l_0_26 = l_0_26 + l_0_28; l_0_26 = l_0_26 + l_0_29; l_0_26 = l_0_26 + l_0_30; l_0_25 = f32(input_0_global[l_0_26]); l_2_0 = l_0_25 > l_0_37; if l_2_0 { l_0_37 = f32(l_0_25); } } l_0_21 = l_0_19 scalars_uint[1]; l_0_22 = 1u scalars_uint[3]; l_0_21 = l_0_21 + l_0_22; l_0_33 = l_0_21 >= scalars_uint[5]; l_0_34 = l_0_21 < l_0_36; l_0_33 = l_0_33 && l_0_34; if l_0_33 { var l_2_0: bool; l_0_23 = l_0_20 - scalars_uint[4]; l_0_24 = l_0_21 - scalars_uint[5]; l_0_29 = l_0_23 l_0_2; l_0_31 = u32(l_0_29); l_0_31 = l_0_31 + l_0_24; l_0_30 = l_0_24 l_0_3; l_0_26 = u32(l_0_27); l_0_26 = l_0_26 + l_0_28; l_0_26 = l_0_26 + l_0_29; l_0_26 = l_0_26 + l_0_30; l_0_25 = f32(input_0_global[l_0_26]); l_2_0 = l_0_25 > l_0_37; if l_2_0 { l_0_37 = f32(l_0_25); } } l_0_21 = l_0_19 scalars_uint[1]; l_0_22 = 2u scalars_uint[3]; l_0_21 = l_0_21 + l_0_22; l_0_33 = l_0_21 >= scalars_uint[5]; l_0_34 = l_0_21 < l_0_36; l_0_33 = l_0_33 && l_0_34; if l_0_33 { var l_2_0: bool; l_0_23 = l_0_20 - scalars_uint[4]; l_0_24 = l_0_21 - scalars_uint[5]; l_0_29 = l_0_23 l_0_2; l_0_31 = u32(l_0_29); l_0_31 = l_0_31 + l_0_24; l_0_30 = l_0_24 l_0_3; l_0_26 = u32(l_0_27); l_0_26 = l_0_26 + l_0_28; l_0_26 = l_0_26 + l_0_29; l_0_26 = l_0_26 + l_0_30; l_0_25 = f32(input_0_global[l_0_26]); l_2_0 = l_0_25 > l_0_37; if l_2_0 { l_0_37 = f32(l_0_25); } } } l_0_20 = l_0_18 scalars_uint[0]; l_0_22 = 1u scalars_uint[2]; l_0_20 = l_0_20 + l_0_22; l_0_32 = l_0_20 >= scalars_uint[4]; l_0_34 = l_0_20 < l_0_35; l_0_32 = l_0_32 && l_0_34; if l_0_32 { l_0_21 = l_0_19 scalars_uint[1]; l_0_22 = 0u scalars_uint[3]; l_0_21 = l_0_21 + l_0_22; l_0_33 = l_0_21 >= scalars_uint[5]; l_0_34 = l_0_21 < l_0_36; l_0_33 = l_0_33 && l_0_34; if l_0_33 { var l_2_0: bool; l_0_23 = l_0_20 - scalars_uint[4]; l_0_24 = l_0_21 - scalars_uint[5]; l_0_29 = l_0_23 l_0_2; l_0_31 = u32(l_0_29); l_0_31 = l_0_31 + l_0_24; l_0_30 = l_0_24 l_0_3; l_0_26 = u32(l_0_27); l_0_26 = l_0_26 + l_0_28; l_0_26 = l_0_26 + l_0_29; l_0_26 = l_0_26 + l_0_30; l_0_25 = f32(input_0_global[l_0_26]); l_2_0 = l_0_25 > l_0_37; if l_2_0 { l_0_37 = f32(l_0_25); } } l_0_21 = l_0_19 scalars_uint[1]; l_0_22 = 1u scalars_uint[3]; l_0_21 = l_0_21 + l_0_22; l_0_33 = l_0_21 >= scalars_uint[5]; l_0_34 = l_0_21 < l_0_36; l_0_33 = l_0_33 && l_0_34; if l_0_33 { var l_2_0: bool; l_0_23 = l_0_20 - scalars_uint[4]; l_0_24 = l_0_21 - scalars_uint[5]; l_0_29 = l_0_23 l_0_2; l_0_31 = u32(l_0_29); l_0_31 = l_0_31 + l_0_24; l_0_30 = l_0_24 l_0_3; l_0_26 = u32(l_0_27); l_0_26 = l_0_26 + l_0_28; l_0_26 = l_0_26 + l_0_29; l_0_26 = l_0_26 + l_0_30; l_0_25 = f32(input_0_global[l_0_26]); l_2_0 = l_0_25 > l_0_37; if l_2_0 { l_0_37 = f32(l_0_25); } } l_0_21 = l_0_19 scalars_uint[1]; l_0_22 = 2u scalars_uint[3]; l_0_21 = l_0_21 + l_0_22; l_0_33 = l_0_21 >= scalars_uint[5]; l_0_34 = l_0_21 < l_0_36; l_0_33 = l_0_33 && l_0_34; if l_0_33 { var l_2_0: bool; l_0_23 = l_0_20 - scalars_uint[4]; l_0_24 = l_0_21 - scalars_uint[5]; l_0_29 = l_0_23 l_0_2; l_0_31 = u32(l_0_29); l_0_31 = l_0_31 + l_0_24; l_0_30 = l_0_24 l_0_3; l_0_26 = u32(l_0_27); l_0_26 = l_0_26 + l_0_28; l_0_26 = l_0_26 + l_0_29; l_0_26 = l_0_26 + l_0_30; l_0_25 = f32(input_0_global[l_0_26]); l_2_0 = l_0_25 > l_0_37; if l_2_0 { l_0_37 = f32(l_0_25); } } } l_0_20 = l_0_18 scalars_uint[0]; l_0_22 = 2u scalars_uint[2]; l_0_20 = l_0_20 + l_0_22; l_0_32 = l_0_20 >= scalars_uint[4]; l_0_34 = l_0_20 < l_0_35; l_0_32 = l_0_32 && l_0_34; if l_0_32 { l_0_21 = l_0_19 scalars_uint[1]; l_0_22 = 0u scalars_uint[3]; l_0_21 = l_0_21 + l_0_22; l_0_33 = l_0_21 >= scalars_uint[5]; l_0_34 = l_0_21 < l_0_36; l_0_33 = l_0_33 && l_0_34; if l_0_33 { var l_2_0: bool; l_0_23 = l_0_20 - scalars_uint[4]; l_0_24 = l_0_21 - scalars_uint[5]; l_0_29 = l_0_23 l_0_2; l_0_31 = u32(l_0_29); l_0_31 = l_0_31 + l_0_24; l_0_30 = l_0_24 l_0_3; l_0_26 = u32(l_0_27); l_0_26 = l_0_26 + l_0_28; l_0_26 = l_0_26 + l_0_29; l_0_26 = l_0_26 + l_0_30; l_0_25 = f32(input_0_global[l_0_26]); l_2_0 = l_0_25 > l_0_37; if l_2_0 { l_0_37 = f32(l_0_25); } } l_0_21 = l_0_19 scalars_uint[1]; l_0_22 = 1u scalars_uint[3]; l_0_21 = l_0_21 + l_0_22; l_0_33 = l_0_21 >= scalars_uint[5]; l_0_34 = l_0_21 < l_0_36; l_0_33 = l_0_33 && l_0_34; if l_0_33 { var l_2_0: bool; l_0_23 = l_0_20 - scalars_uint[4]; l_0_24 = l_0_21 - scalars_uint[5]; l_0_29 = l_0_23 l_0_2; l_0_31 = u32(l_0_29); l_0_31 = l_0_31 + l_0_24; l_0_30 = l_0_24 l_0_3; l_0_26 = u32(l_0_27); l_0_26 = l_0_26 + l_0_28; l_0_26 = l_0_26 + l_0_29; l_0_26 = l_0_26 + l_0_30; l_0_25 = f32(input_0_global[l_0_26]); l_2_0 = l_0_25 > l_0_37; if l_2_0 { l_0_37 = f32(l_0_25); } } l_0_21 = l_0_19 scalars_uint[1]; l_0_22 = 2u scalars_uint[3]; l_0_21 = l_0_21 + l_0_22; l_0_33 = l_0_21 >= scalars_uint[5]; l_0_34 = l_0_21 < l_0_36; l_0_33 = l_0_33 && l_0_34; if l_0_33 { var l_2_0: bool; l_0_23 = l_0_20 - scalars_uint[4]; l_0_24 = l_0_21 - scalars_uint[5]; l_0_29 = l_0_23 l_0_2; l_0_31 = u32(l_0_29); l_0_31 = l_0_31 + l_0_24; l_0_30 = l_0_24 * l_0_3; l_0_26 = u32(l_0_27); l_0_26 = l_0_26 + l_0_28; l_0_26 = l_0_26 + l_0_29; l_0_26 = l_0_26 + l_0_30; l_0_25 = f32(input_0_global[l_0_26]); l_2_0 = l_0_25 > l_0_37; if l_2_0 { l_0_37 = f32(l_0_25); } } } output_0_global[id] = f32(l_0_37); }

[END_KERNEL_COMPILATION]
wcshds commented 1 month ago

This bug is caused by the limited default precision of Rust's display of f32 values. It should be an easy fix.

I changed the code from https://github.com/tracel-ai/cubecl/blob/cfe0b0204380cbd0931f478194a053a6ac35d1cb/crates/cubecl-wgpu/src/compiler/wgsl/base.rs#L262-L263 :

FloatKind::F32 => f.write_fmt(format_args!("{}f", *val as f32)),
FloatKind::F64 => f.write_fmt(format_args!("{}f", { *val })),

to:

FloatKind::F32 => f.write_fmt(format_args!("{:.9}f", *val as f32)),
FloatKind::F64 => f.write_fmt(format_args!("{:.17}f", { *val })),

It can fix the bug.

Jonarod commented 1 month ago

I thought maybe it was due to incompatible versions between the main branch and the examples, but I have the same issue after downloading the 0.13.2 release out-of-the-box then cd examples/image-classification-web then ./build-for-web.sh. In fact, it would not even compile without further adding the features=["autotune"] to the burn-wgpu crate dependency manually before compiling.

Then I tried with 0.13.1 release as well as the 0.13.0 release, and this produces the same issue as described in the first place (also, the Candle backend does not work either, as per issue #1034 ).

I am just not able to reproduce the image-classification-web with any of 0.13+ version (didn't look at earlier versions), except for the ndarray backend which works seemlessly.

@antimora do you by any chance kept around the original repo you used to make your published version work? Would be very helpful to diff it and check what's wrong.

For instance, how did you make the Candle backend work without the AvgPool2d op in the first place? Did you switched to another model than squeezenet for the sake of this example? Also, how did you manage to get the wgpu backend to load? (If I am correct, the above solution suggested by @wcshds only apply to 0.14+ versions that include cubecl dependency, correct?)

By the way, thanks for sharing this great project :)

wcshds commented 1 month ago

In fact, I've tried all the release versions of Burn since 10.0.0 on my device, but the wgpu backend in the image-classification-example hasn't worked. However, after the small modifications I mentioned above, now I can run this example successfully on the main branch.

Jonarod commented 1 month ago

Reproduction steps:

# Clone repo
git clone git@github.com:tracel-ai/burn.git

# get into repo
cd burn

# Change cubecl dependency to revision that include the suggested bugfix (https://github.com/tracel-ai/cubecl/commit/32feabc5140170d45d4365a56106db930ed79a33)
# For reproduction purposes, here I use the sd utility: (https://github.com/chmln/sd), but one can just change it manually in the Cargo.toml for both cubecl AND cubecl-common
sd '(cubecl.* rev =).*(\})' '$1 "32feabc5140170d45d4365a56106db930ed79a33" $2' Cargo.toml

# cd into the relevant example
cd examples/image-classification-web

# compile the example
./build-for-web.sh

# run the server
./run-server.sh

RESULTS:

NdArray backend: working (slower than 0.13.2 version by an order of magnitude, but still okay)

Candle backend: backend LOADS correctly but cannot do inference, Candle does not support excluding pad count in pooling

Wgpu backend: cannot load, An home directory should exist

Did I miss something here?

wcshds commented 1 month ago

@Jonarod You also need to disable the use of the autotune feature for burn-wgpu.

As for Candle, I don't know how to make it work either.

laggui commented 1 month ago

Hey @wcshds 👋

I was going to update the cubecl dep to the latest to include some fixes, but some wgpu tests are failing with your merged PR.

failures:
    tests::jit::gradients::tests::should_update_tensor_when_grad_replace
    tests::jit::kernel::bernoulli::tests::number_of_1_proportional_to_prob
    tests::jit::kernel::bernoulli::tests::runs_test
    tests::jit::kernel::normal::tests::empirical_mean_close_to_expectation
    tests::jit::kernel::normal::tests::normal_respects_68_95_99_rule
    tests::jit::kernel::uniform::tests::at_least_one_value_per_bin_int_uniform
    tests::jit::kernel::uniform::tests::at_least_one_value_per_bin_uniform
    tests::jit::kernel::uniform::tests::runs_test
    tests::jit_fusion::gradients::tests::should_update_tensor_when_grad_replace

Tried to understand the problem in this issue but I think I'm missing a bit of context.. could you explain why the precision change was required?

/edit: see PRs #2159 #2158 as reference.

Jonarod commented 1 month ago

Worked like a charm.

Thanks for your help.

I submitted corresponding PR that should solve this issue.

If you read this while PR is not merged to main, basically the solution, as suggested by @wcshds is to:

  1. change cubecl and cubecl-common revisions to 32feabc5140170d45d4365a56106db930ed79a33 in the burn's root's Cargo.toml

  2. remove burn-wgpu's features = [ "autotune" ] from examples/image-classification-web/Cargo.toml

I think this can be closed now.

wcshds commented 1 month ago

@laggui This is because I saw in the browser console that -3.40282347E+38f32 was rounded to -340282350000000000000000000000000000000f, which caused a WGSL compilation error, so I believe this is an issue with Rust's default display precision. It's strange that the test failed, but changing the precision to 13 decimal places solved the problem.

FloatKind::F32 => f.write_fmt(format_args!("{:.13}f", *val as f32)),

I originally thought that a precision of 9 decimal places would be sufficient for f32.

laggui commented 1 month ago

Weird 😅

Not sure if the fix is the proper way to address this or if it's just a patch for a more specific issue. It doesn't seem to happen anywhere else 🤔

/edit: fyi, we have decided to revert the changes applied to the precision for now. The current workaround is at least documented for users to try while we investigate why this happens in this specific example.