tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
7.87k stars 378 forks source link

Burn-WGPU tests fail on Windows with Radeon 6950 #1805

Closed hexd0t closed 1 week ago

hexd0t commented 1 month ago

Describe the bug Running run-checks.ps1 all fails for burn-wgpu with 4 failed tests during the std tests.

To Reproduce Checkout the current main branch. Open a Powershell Terminal, cd to the repo root, execute . .\run-checks.ps1 all

Expected behavior All tests should pass.

Desktop (please complete the following information):

Additional context

failures:

---- tests::jit::module_nearest_interpolate::tests::test_downsample_interpolation stdout ----
thread 'tests::jit::module_nearest_interpolate::tests::test_downsample_interpolation' panicked at crates\burn-wgpu\src\lib.rs:75:5:
Tensors are not approx eq:
  => Position 3: 7 != 6 | difference 1 > tolerance 0.0010000000000000002
  => Position 9: 161 != 160 | difference 1 > tolerance 0.0010000000000000002
  => Position 15: 315 != 314 | difference 1 > tolerance 0.0010000000000000002
  => Position 21: 469 != 468 | difference 1 > tolerance 0.0010000000000000002

---- tests::jit::remainder::tests::should_be_zero stdout ----
thread 'tests::jit::remainder::tests::should_be_zero' panicked at crates\burn-wgpu\src\lib.rs:75:5:
Tensors are not approx eq:
  => Position 0: 0 != 3.5 | difference 3.5 > tolerance 0.0010000000000000002
  => Position 1: 0 != 3.5 | difference 3.5 > tolerance 0.0010000000000000002
  => Position 2: 0 != 3.5 | difference 3.5 > tolerance 0.0010000000000000002

---- tests::jit_fusion::module_nearest_interpolate::tests::test_downsample_interpolation stdout ----
thread 'tests::jit_fusion::module_nearest_interpolate::tests::test_downsample_interpolation' panicked at crates\burn-wgpu\src\lib.rs:75:5:
Tensors are not approx eq:
  => Position 3: 7 != 6 | difference 1 > tolerance 0.0010000000000000002
  => Position 9: 161 != 160 | difference 1 > tolerance 0.0010000000000000002
  => Position 15: 315 != 314 | difference 1 > tolerance 0.0010000000000000002
  => Position 21: 469 != 468 | difference 1 > tolerance 0.0010000000000000002

---- tests::jit_fusion::remainder::tests::should_be_zero stdout ----
thread 'tests::jit_fusion::remainder::tests::should_be_zero' panicked at crates\burn-wgpu\src\lib.rs:75:5:
Tensors are not approx eq:
  => Position 0: 0 != 3.5 | difference 3.5 > tolerance 0.0010000000000000002
  => Position 1: 0 != 3.5 | difference 3.5 > tolerance 0.0010000000000000002
  => Position 2: 0 != 3.5 | difference 3.5 > tolerance 0.0010000000000000002

failures:
    tests::jit::module_nearest_interpolate::tests::test_downsample_interpolation
    tests::jit::remainder::tests::should_be_zero
    tests::jit_fusion::module_nearest_interpolate::tests::test_downsample_interpolation
    tests::jit_fusion::remainder::tests::should_be_zero

test result: FAILED. 1512 passed; 4 failed; 4 ignored; 0 measured; 0 filtered out; finished in 20.25s
nathanielsimard commented 1 month ago

I don't have an AMD GPU, but most of these are precision errors, where we should simply fix the test. The only exception is maybe the remainder operation in jit_fusion where I don't really understand the the 3.5 😅.

booti386 commented 2 weeks ago

On the WGPU backend, the Remainder op is implemented like this:

            Instruction::Remainder { lhs, rhs, out } => {
                f.write_fmt(format_args!("{out} = (({lhs} % {rhs}) + {rhs}) % {rhs};\n"))
            }

If i change it to:

            Instruction::Remainder { lhs, rhs, out } => {
                f.write_fmt(format_args!("{out} = (({lhs} % {rhs}) + {rhs} + 0.00001) % {rhs};\n"))
            }

... the test passes. So it's also due to a cascading precision issue from the sum (if ({lhs} % {rhs}) + {rhs}) is just an epsilon lower than {rhs}, we will get a result ≈ {rhs} instead of 0 (which is strange too but whatever)). However what really puzzles me is why it was implemented this way in the first place. FYI, this is the Modulo op:

            Instruction::Modulo { lhs, rhs, out } => {
                f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n"))
            }

If i replace the Remainder implementation with this:

            Instruction::Remainder { lhs, rhs, out } => {
                 f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n"))
            }

... then there is no failing test related to this op anymore.

nathanielsimard commented 2 weeks ago

I don't remember why we didn't just use the % operator @louisfd, but maybe we should.

louisfd commented 2 weeks ago

The complicated form supports float and negatives, it was made to mimic the behaviour in pytorch in #1597. I'm not sure it proved useful to have it yet, though

nathanielsimard commented 1 week ago

@louisfd I think we might use the normal version for integers though.

louisfd commented 1 week ago

@nathanielsimard In Cube I made UInt implement normal modulo and Numeric implement the more complex remainder

booti386 commented 1 week ago

I still don't get why the special version is needed, as both torch documentation and wgpu spec use exactly the same algorithm. For me, torch's remainder() and wgpu's % operator should behave exactly the same for floats. If it's not the case, then shouldn't it mean instead that the implementation is was tested against in the referenced issue was (is?) non-compliant?

louisfd commented 1 week ago

@agelas what do you think?

agelas commented 1 week ago

I think the rationale was laid out in #1597, but iirc the torch documentation and wgpu don't use the same algorithm for remainder. There's a difference between % and remainder. In the PR itself:

For burn-jit, I redirected the remainder to the modulo operator because in WGSL, it should be the same as PyTorch's definition. Here, the % for WGSL is defined as

The return value is computed as x − y*floor(x/y). As with min and max, y can be either a vector or a float. The mod function can be used as a substitute for the % operator

In PyTorch, the remainder is defined as

torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b

Later there's also this tidbit which speaks more to the second point:

It terms of what % means, in Rust (and in WGSL), the % calculates the remainder, not the modulo. Things get a bit wonky when negative numbers are used. It's different in Python though hence the difference, and the need for PyTorch's remainder operation.

I tried replacing the Remainder op as suggested with this:

Instruction::Remainder { lhs, rhs, out } => {
                 f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n"))
            }

While tests::jit::remainder::tests::should_be_zero passes, all of a sudden I start getting errors elsewhere, for example:

---- tests::jit_fusion::remainder::tests::should_support_remainder_basic stdout ----
thread 'tests::jit_fusion::remainder::tests::should_support_remainder_basic' panicked at crates/burn-wgpu/src/lib.rs:75:5:
Tensors are not approx eq:
  => Position 0: 1 != -1 | difference 2 > tolerance 0.0010000000000000002
  => Position 2: 1 != -1 | difference 2 > tolerance 0.0010000000000000002

---- tests::jit_fusion::remainder::tests::should_support_remainder_float stdout ----
thread 'tests::jit_fusion::remainder::tests::should_support_remainder_float' panicked at crates/burn-wgpu/src/lib.rs:75:5:
Tensors are not approx eq:
  => Position 0: -0.5 != 1 | difference 1.5 > tolerance 0.0010000000000000002
  => Position 1: -1 != 0.5 | difference 1.5 > tolerance 0.0010000000000000002
  => Position 3: -0.5 != 1 | difference 1.5 > tolerance 0.0010000000000000002
  => Position 4: -1 != 0.5 | difference 1.5 > tolerance 0.0010000000000000002

In fact, everything but tests::jit::remainder::tests::should_be_zero seems to fail:

failures:
    tests::jit::remainder::tests::should_be_negative
    tests::jit::remainder::tests::should_support_fp_dividends
    tests::jit::remainder::tests::should_support_large_divisor
    tests::jit::remainder::tests::should_support_remainder_basic
    tests::jit::remainder::tests::should_support_remainder_float
    tests::jit_fusion::remainder::tests::should_be_negative
    tests::jit_fusion::remainder::tests::should_support_fp_dividends
    tests::jit_fusion::remainder::tests::should_support_large_divisor
    tests::jit_fusion::remainder::tests::should_support_remainder_basic
    tests::jit_fusion::remainder::tests::should_support_remainder_float

In terms of spec compliance, should_support_remainder_basic and should_support_remainder_float are pulled directly from PyTorch's remainder examples, so I think we should be good there.

booti386 commented 1 week ago

Okay, i've been misled by the PR.

As you cited from the PR about the % operator for WGSL (which references the GLSL spec as a source (?)):

The return value is computed as x − y*floor(x/y). As with min and max, y can be either a vector or a float. The mod function can be used as a substitute for the % operator

However, from the latest WGSL draft (https://www.w3.org/TR/WGSL/#syntax_sym-modulo):

If T is a floating point type, the result is equal to: e1 - e2 * trunc(e1 / e2)

Which makes a lot more sense why it differs from pytorch's definition, as trunc() and floor() do not perform the same operation at all.

In fact, using either:

            Instruction::Remainder { lhs, rhs, out } => f.write_fmt(format_args!(
                "{out} = {lhs} - {rhs} * trunc({lhs} / {rhs});\n"
            )),

or

            Instruction::Remainder { lhs, rhs, out } => {
                f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n"))
            }

I get:

    tests::jit::remainder::tests::should_be_negative
    tests::jit::remainder::tests::should_support_fp_dividends
    tests::jit::remainder::tests::should_support_large_divisor
    tests::jit::remainder::tests::should_support_remainder_basic
    tests::jit::remainder::tests::should_support_remainder_float
    tests::jit::remainder::tests::should_support_remainder_op
    tests::jit_fusion::remainder::tests::should_be_negative
    tests::jit_fusion::remainder::tests::should_support_fp_dividends
    tests::jit_fusion::remainder::tests::should_support_large_divisor
    tests::jit_fusion::remainder::tests::should_support_remainder_basic
    tests::jit_fusion::remainder::tests::should_support_remainder_float
    tests::jit_fusion::remainder::tests::should_support_remainder_op

Whereas using:

            Instruction::Remainder { lhs, rhs, out } => f.write_fmt(format_args!(
                "{out} = {lhs} - {rhs} * floor({lhs} / {rhs});\n"
            )),

... which matches pythorch's remainder(), I don't get any test failure for this operator.

As a reminder, with the current burn implementation:

            Instruction::Remainder { lhs, rhs, out } => {
                f.write_fmt(format_args!("{out} = (({lhs} % {rhs}) + {rhs}) % {rhs};\n"))
            }

This test still fails:

tests::jit::remainder::tests::should_be_zero
tests::jit_fusion::remainder::tests::should_be_zero

So... what is the correct appproach? Would it be better to switch to pytorch's definition? We would go from 2 float modulo operations to 1 float division, what would be the performance impact?

agelas commented 1 week ago

The current implementation was to guarantee that the output follows the sign of the divisor. If pytorch's definition is subbed in and it passes/is equivalent then we should go with that since it looks more hardware agnostic.

nathanielsimard commented 1 week ago

@agelas @booti386 I think using the PyTorch definition of remainder is probably the best option. Performance shouldn't be a huge factor since it should pretty much never create a bottleneck and will have zero impact on real performance. Correctness and being more compatible are more important. I created a PR #1979