Closed hexd0t closed 1 week ago
I don't have an AMD GPU, but most of these are precision errors, where we should simply fix the test. The only exception is maybe the remainder operation in jit_fusion
where I don't really understand the the 3.5
😅.
On the WGPU backend, the Remainder op is implemented like this:
Instruction::Remainder { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = (({lhs} % {rhs}) + {rhs}) % {rhs};\n"))
}
If i change it to:
Instruction::Remainder { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = (({lhs} % {rhs}) + {rhs} + 0.00001) % {rhs};\n"))
}
... the test passes. So it's also due to a cascading precision issue from the sum (if ({lhs} % {rhs}) + {rhs}) is just an epsilon lower than {rhs}, we will get a result ≈ {rhs} instead of 0 (which is strange too but whatever)). However what really puzzles me is why it was implemented this way in the first place. FYI, this is the Modulo op:
Instruction::Modulo { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n"))
}
If i replace the Remainder implementation with this:
Instruction::Remainder { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n"))
}
... then there is no failing test related to this op anymore.
I don't remember why we didn't just use the %
operator @louisfd, but maybe we should.
The complicated form supports float and negatives, it was made to mimic the behaviour in pytorch in #1597. I'm not sure it proved useful to have it yet, though
@louisfd I think we might use the normal version for integers though.
@nathanielsimard In Cube I made UInt implement normal modulo and Numeric implement the more complex remainder
I still don't get why the special version is needed, as both torch documentation and wgpu spec use exactly the same algorithm. For me, torch's remainder() and wgpu's % operator should behave exactly the same for floats. If it's not the case, then shouldn't it mean instead that the implementation is was tested against in the referenced issue was (is?) non-compliant?
@agelas what do you think?
I think the rationale was laid out in #1597, but iirc the torch documentation and wgpu don't use the same algorithm for remainder. There's a difference between % and remainder. In the PR itself:
For burn-jit, I redirected the remainder to the modulo operator because in WGSL, it should be the same as PyTorch's definition. Here, the % for WGSL is defined as
The return value is computed as x − y*floor(x/y). As with min and max, y can be either a vector or a float. The mod function can be used as a substitute for the % operator
In PyTorch, the remainder is defined as
torch.remainder(a, b) == a - a.div(b, rounding_mode="floor") * b
Later there's also this tidbit which speaks more to the second point:
It terms of what % means, in Rust (and in WGSL), the % calculates the remainder, not the modulo. Things get a bit wonky when negative numbers are used. It's different in Python though hence the difference, and the need for PyTorch's remainder operation.
I tried replacing the Remainder op as suggested with this:
Instruction::Remainder { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n"))
}
While tests::jit::remainder::tests::should_be_zero
passes, all of a sudden I start getting errors elsewhere, for example:
---- tests::jit_fusion::remainder::tests::should_support_remainder_basic stdout ----
thread 'tests::jit_fusion::remainder::tests::should_support_remainder_basic' panicked at crates/burn-wgpu/src/lib.rs:75:5:
Tensors are not approx eq:
=> Position 0: 1 != -1 | difference 2 > tolerance 0.0010000000000000002
=> Position 2: 1 != -1 | difference 2 > tolerance 0.0010000000000000002
---- tests::jit_fusion::remainder::tests::should_support_remainder_float stdout ----
thread 'tests::jit_fusion::remainder::tests::should_support_remainder_float' panicked at crates/burn-wgpu/src/lib.rs:75:5:
Tensors are not approx eq:
=> Position 0: -0.5 != 1 | difference 1.5 > tolerance 0.0010000000000000002
=> Position 1: -1 != 0.5 | difference 1.5 > tolerance 0.0010000000000000002
=> Position 3: -0.5 != 1 | difference 1.5 > tolerance 0.0010000000000000002
=> Position 4: -1 != 0.5 | difference 1.5 > tolerance 0.0010000000000000002
In fact, everything but tests::jit::remainder::tests::should_be_zero
seems to fail:
failures:
tests::jit::remainder::tests::should_be_negative
tests::jit::remainder::tests::should_support_fp_dividends
tests::jit::remainder::tests::should_support_large_divisor
tests::jit::remainder::tests::should_support_remainder_basic
tests::jit::remainder::tests::should_support_remainder_float
tests::jit_fusion::remainder::tests::should_be_negative
tests::jit_fusion::remainder::tests::should_support_fp_dividends
tests::jit_fusion::remainder::tests::should_support_large_divisor
tests::jit_fusion::remainder::tests::should_support_remainder_basic
tests::jit_fusion::remainder::tests::should_support_remainder_float
In terms of spec compliance, should_support_remainder_basic
and should_support_remainder_float
are pulled directly from PyTorch's remainder examples, so I think we should be good there.
Okay, i've been misled by the PR.
As you cited from the PR about the % operator for WGSL (which references the GLSL spec as a source (?)):
The return value is computed as x − y*floor(x/y). As with min and max, y can be either a vector or a float. The mod function can be used as a substitute for the % operator
However, from the latest WGSL draft (https://www.w3.org/TR/WGSL/#syntax_sym-modulo):
If T is a floating point type, the result is equal to: e1 - e2 * trunc(e1 / e2)
Which makes a lot more sense why it differs from pytorch's definition, as trunc() and floor() do not perform the same operation at all.
In fact, using either:
Instruction::Remainder { lhs, rhs, out } => f.write_fmt(format_args!(
"{out} = {lhs} - {rhs} * trunc({lhs} / {rhs});\n"
)),
or
Instruction::Remainder { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n"))
}
I get:
tests::jit::remainder::tests::should_be_negative
tests::jit::remainder::tests::should_support_fp_dividends
tests::jit::remainder::tests::should_support_large_divisor
tests::jit::remainder::tests::should_support_remainder_basic
tests::jit::remainder::tests::should_support_remainder_float
tests::jit::remainder::tests::should_support_remainder_op
tests::jit_fusion::remainder::tests::should_be_negative
tests::jit_fusion::remainder::tests::should_support_fp_dividends
tests::jit_fusion::remainder::tests::should_support_large_divisor
tests::jit_fusion::remainder::tests::should_support_remainder_basic
tests::jit_fusion::remainder::tests::should_support_remainder_float
tests::jit_fusion::remainder::tests::should_support_remainder_op
Whereas using:
Instruction::Remainder { lhs, rhs, out } => f.write_fmt(format_args!(
"{out} = {lhs} - {rhs} * floor({lhs} / {rhs});\n"
)),
... which matches pythorch's remainder(), I don't get any test failure for this operator.
As a reminder, with the current burn implementation:
Instruction::Remainder { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = (({lhs} % {rhs}) + {rhs}) % {rhs};\n"))
}
This test still fails:
tests::jit::remainder::tests::should_be_zero
tests::jit_fusion::remainder::tests::should_be_zero
So... what is the correct appproach? Would it be better to switch to pytorch's definition? We would go from 2 float modulo operations to 1 float division, what would be the performance impact?
The current implementation was to guarantee that the output follows the sign of the divisor. If pytorch's definition is subbed in and it passes/is equivalent then we should go with that since it looks more hardware agnostic.
@agelas @booti386 I think using the PyTorch definition of remainder is probably the best option. Performance shouldn't be a huge factor since it should pretty much never create a bottleneck and will have zero impact on real performance. Correctness and being more compatible are more important. I created a PR #1979
Describe the bug Running
run-checks.ps1 all
fails for burn-wgpu with 4 failed tests during the std tests.To Reproduce Checkout the current
main
branch. Open a Powershell Terminal, cd to the repo root, execute. .\run-checks.ps1 all
Expected behavior All tests should pass.
Desktop (please complete the following information):
Additional context