Open Shukla-Gaurav opened 9 months ago
Can you give me the run module command that shows the error... That will help repro the error for me. For example the %1
above is strange (correct still and shouldnt affect correctness) but will help triage.
@MaheshRavishankar
~/iree-build/tools/iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu conv2d.linalg.mlir > conv2d.bf16.vmfb 2>iree-compile.log
~/iree-build/tools/iree-run-module --module=conv2d.bf16.vmfb --input="2x8x12x16xbf16=@inference_input.0.bin.txt"
> inference.log
I am attaching the inference_input.0.bin
, which is all 1's.
inference_input.0.bin.txt
iree-run-module
is running fine but the result of iree-run-module
mismatches with the result of conv2d pytorch module.
(The input is all 1's in both the cases).
import torch
import torch.nn as nn
class op_conv2d(nn.Module): def init(self): super().init() self.layers = nn.Sequential( nn.Conv2d(8, 10, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) ) def forward(self, x): return self.layers(x)
model = op_conv2d() model_bf16 = model.to(torch.bfloat16) test_input_bf16 = torch.ones(2, 8, 12, 16).to(torch.bfloat16) test_output_bf16 = model_bf16(test_input_bf16) print("Input:", test_input_bf16) print("Output:", test_output_bf16)
3. You can also use the repo https://github.com/nod-ai/SHARK-TestSuite to test/cross-check conv2d results with your fix:
git clone https://github.com/nod-ai/SHARK-TestSuite activate your iree_venv or torch_mlir_venv cd e2eshark python ./run.py --runupto inference --mode onnx -c ~/torch-mlir/build -i ~/iree-build --tests pytorch/operators/conv2d/ --hfhome ~/hf_home/ --verbose -d bf16 -r test-conv2d-bf16
Running conv2d with different precisions, keeping all the constants(weight and bias) same. conv2d.bf16.compile.log conv2d.fp32.compile.log
I think this is not really a codegen issue. This is really a bf16 issue . Are we comfortable closing this, or do we need to do more here.
Linear Module for reference output: (weight, bias and input has been fixed to simplify comparison in IRs and outputs. Also all these values fits in bf16, so x.to(torch.bfloat16) won't change values).
import torch
import torch.nn as nn
class op_linear(nn.Module):
def __init__(self):
super().__init__()
self.linear_layer = nn.Linear(3, 4)
self.linear_layer.weight = nn.Parameter(
torch.tensor([[-0.4199, 0.4180, -0.0293],
[ 0.4297, -0.4434, 0.0162],
[-0.2061, -0.4004, 0.2773],
[-0.5469, 0.0449, 0.3242]],
dtype=torch.float32),
requires_grad=False)
self.linear_layer.bias = nn.Parameter(
torch.tensor([ 0.2236, 0.3184, -0.1709, -0.4883],
dtype=torch.float32),
requires_grad=False)
self.layers = nn.Sequential(self.linear_layer)
def forward(self, x):
return self.layers(x)
model = op_linear()
test_input = torch.tensor( [[-0.4062, -0.6953, 1.8516], [ 0.4961, 0.9609, -0.7500], [ 0.4766, -1.4531, 0.6172], [-0.4785, 1.0859, -0.9922], [ 2.1094, -0.0107, 0.3496], [-0.6562, -0.0116, 1.7812], [ 0.0114, -0.1279, 1.7266], [-0.1289, 0.6250, 1.3516]], dtype=torch.float32) output = model(test_input)
test_input.to(torch.bfloat16) model.to(torch.bfloat16) bf16_golden_output = model(test_input)
2. Linalg IR:
[linear.bf16.linalg.mlir.txt](https://github.com/nod-ai/SHARK/files/14386289/linear.bf16.linalg.mlir.txt)
[linear.fp32.linalg.mlir.txt](https://github.com/nod-ai/SHARK/files/14386290/linear.fp32.linalg.mlir.txt)
3. Applying `--iree-global-opt-enable-demote-contraction-inputs-to-bf16` in fp32 IR and `--iree-llvmcpu-enable-ukernels=all` in bf16 IR to compare output from different paths.
/home/gaurav/MLIRepos/iree-build/tools/iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu --iree-global-opt-enable-demote-contraction-inputs-to-bf16 --iree-input-type=tm_tensor --mlir-print-ir-after-all --mlir-disable-threading linear.fp32.linalg.mlir > linear.contraction.vmfb 2>contraction-flag-linear-fp32-iree_compile.log
/home/gaurav/MLIRepos/iree-build/tools/iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-enable-ukernels=all --mlir-print-ir-after-all --mlir-disable-threading --iree-input-type=tm_tensor linear.bf16.linalg.mlir > linear.ukernel.vmfb 2>ukernel-flag-linear-bf16-iree_compile.log
[contraction-flag-linear-fp32-iree_compile.log](https://github.com/nod-ai/SHARK/files/14386286/contraction-flag-linear-fp32-iree_compile.log)
[ukernel-flag-linear-bf16-iree_compile.log](https://github.com/nod-ai/SHARK/files/14386688/ukernel-flag-linear-bf16-iree_compile.log)
4.
/home/gaurav/MLIRepos/iree-build/tools/iree-run-module --module=linear.contraction.vmfb --input="8x3xf32=@inference_input.0.bin.txt" /home/gaurav/MLIRepos/iree-build/tools/iree-run-module --module=linear.ukernel.vmfb --input="8x3xbf16=@inference_input_bf16.0.bin.txt"
[inference_input.0.bin.txt](https://github.com/nod-ai/SHARK/files/14386287/inference_input.0.bin.txt)
[inference_input_bf16.0.bin.txt](https://github.com/nod-ai/SHARK/files/14386288/inference_input_bf16.0.bin.txt)
5. Comparing different outputs
f32_golden_output= (=model(input).to(bf16)) tensor([[ 0.0493, 0.4824, 0.7031, 0.3027], [ 0.4395, 0.0933, -0.8672, -0.9609], [-0.6016, 1.1797, 0.4844, -0.6133], [ 0.9062, -0.3848, -0.7812, -0.5000], [-0.6758, 1.2344, -0.5039, -1.5312], [ 0.4414, 0.0703, 0.4629, 0.4473], [ 0.1147, 0.4082, 0.3574, 0.0596], [ 0.5000, 0.0078, -0.0198, 0.0483]], dtype=torch.bfloat16)
bf16_golden_output= (=model.to(bf16); input.to(bf16); model(input)) tensor([[ 0.0493, 0.4824, 0.7031, 0.3027], [ 0.4395, 0.0933, -0.8672, -0.9609], [-0.6016, 1.1797, 0.4844, -0.6133], [ 0.9062, -0.3848, -0.7812, -0.5000], [-0.6758, 1.2344, -0.5039, -1.5312], [ 0.4414, 0.0703, 0.4629, 0.4473], [ 0.1147, 0.4082, 0.3574, 0.0596], [ 0.5000, 0.0078, -0.0198, 0.0486]], dtype=torch.bfloat16)
contraction_flag_output= (this is quite close to golden output but we trace f32 model(f32 linalg IR) and only demote certain ops to bf16.) tensor([[ 0.0493, 0.4824, 0.7031, 0.3027], [ 0.4395, 0.0933, -0.8672, -0.9609], [-0.6016, 1.1797, 0.4844, -0.6133], [ 0.9062, -0.3848, -0.7812, -0.5000], [-0.6758, 1.2344, -0.5039, -1.5312], [ 0.4414, 0.0703, 0.4629, 0.4473], [ 0.1147, 0.4082, 0.3574, 0.0596], [ 0.5000, 0.0079, -0.0198, 0.0486]], dtype=torch.bfloat16)
ukernel_flag_output= (this is same as iree bf16 inference output without using this flag) (the output mismatches with golden output.) tensor([[ 0.0498, 0.4824, 0.7031, 0.3047], [ 0.4395, 0.0938, -0.8672, -0.9570], [-0.6055, 1.1797, 0.4863, -0.6133], [ 0.9062, -0.3848, -0.7812, -0.7500], [-0.6797, 1.2344, -0.5039, -1.5234], [ 0.4434, 0.0723, 0.4629, 0.4492], [ 0.1147, 0.4082, 0.3574, 0.0586], [ 0.5000, 0.0078, -0.0195, 0.0469]], dtype=torch.bfloat16)
print(torch.allclose(bf16_golden_output, contraction_flag_output, atol=1e-04, rtol=1e-03)) // True print(torch.allclose(bf16_golden_output, ukernel_flag_output, atol=1e-04, rtol=1e-03)) //False
Not sure how pytorch handles the bf16 computation but it's close to =>(f32 computation then result demoted to bf16)
(NOTE: all the inputs are small enough to fit in bf16 type).
For the same inputs, iree bf16 inference result mismatches with the above mentioned tolerances(the ukernel_flag_output).
(Although the contraction_flag_output is close but we want to start with a bf16 model not f32 model)
This problem requires attention at how iree handles bf16 loads in cpu backends? Thoughts?
@stellaraccident @kumardeepakamd @MaheshRavishankar @bjacob
I can create an onnx linear module with the same inputs and run it on onnx runtime to have one more reference output if it helps. Thanks!
I think this is not really a codegen issue. This is really a bf16 issue . Are we comfortable closing this, or do we need to do more here.
I think we should not close this unless we can conclude on handling bf16 in cpu. I mean how to verify the model is producing the correct outputs through onnx pipeline.
Really I dont know what the compiler itself can decide here.. this always going to mismatch cause the reference is doing different things (and different references will do different things). What IREE is doing is basically "do what it is told to do". If the linalg op says the input type is bf16 and output type is bf16, it is actually doing the accumulation in f32. IMO that is actually being too smart. It should be doing the accumulation in bf16 as well (cause thats what it was told to do). So only AI I can think of is to actually make IREE less smart.
Usually when I've seen these kinds of issues resolved before, it requires a much more careful drill down vs a high level this vs that. There is no abstract answer to these things at that precision: a rounding/truncation mode difference or 1ULP difference at any stage is enough to result in a 1% error for a datatype like this. If trying to get complete correspondence, then none of that can be ignored.
You'll have to dig deeper, and likely if you are still looking at results as textual floating point, you'll miss the difference.
Usually when I've seen these kinds of issues resolved before, it requires a much more careful drill down vs a high level this vs that. There is no abstract answer to these things at that precision: a rounding/truncation mode difference or 1ULP difference at any stage is enough to result in a 1% error for a datatype like this. If trying to get complete correspondence, then none of that can be ignored.
You'll have to dig deeper, and likely if you are still looking at results as textual floating point, you'll miss the difference.
One advantage to having an ONNX reference is that it is much easier to hack on at that level than PyTorch (i.e. you can build it, set a breakpoint or print specific values in a kernel, etc).
It seems the mismatch is due to different rounding mechanisms used by pytorch and IREE. I ran few simple add/mul tests, and it's mostly 1-bit difference in the outputs. The pytorch simply truncates the last 16 bits after computing the result while IREE seems to be rounding it. Following example can illustrate the above behavior:
/home/gaurav/MLIRepos/iree-build/tools/iree-run-module --module=mul.vmfb --input="1xbf16=[5]"
pytorch-cpu output: 6.8136e+30
IREE-cpu output: 6.8532e+30
Their binary representation differs exactly in the last bit. If we do f32 multiplication and remove the last 16-bits, it will give exact same output as pytorch. I used following c++ code to check 16-bits truncation result:
#include <iostream>
struct bfloat16{
unsigned short int data;
public:
bfloat16(){
data = 0;
}
//cast to float
operator float(){
unsigned int proc = data<<16;
return *reinterpret_cast<float*>(&proc);
}
//cast to bfloat16
bfloat16& operator =(float float_val){
data = (*reinterpret_cast<unsigned int *>(&float_val))>>16;
return *this;
}
};
//an example that enumerates all the possible values between 1.0f and 300.0f using namespace std;
int main(){ bfloat16 x; x = 6.8335e+30; // for(x=1.0f;x<300.0f;x.data++){ cout<<x.data<<" "<<x<<endl; // } return 0; }
Great analysis, thanks! Indeed, the f32
value 6.8335e+30 has the binary encoding 0x72ac8075.
Truncating this to bf16
means replacing this by a value that is a multiple of 0x10000, so the two conceivable candidates are 0x72ac000 or 0x72ad000. These are respectively 6.81362e+30 and 6.85324e+30. However, as the next digit after 0x72ac...
is a 8 (the 8 in 0x72ac8075), the nearest value really is the latter, 0x72ad000 = 6.85324e+30. And this isn't even a tie (it would be a tie if the value were 0x72ac8000), so there are no questions of tie-breaks here ("tie to nearest-even"). So there is no question that the nearest value is 0x72ad000 = 6.85324e+30. So, IREE is being correct here, and PyTorch is incorrect. What PyTorch does here has a name, "rounding towards zero", it can be sometimes useful for some really specialized uses, but doesn't make sense as the default way to round all values in a workload.
I don't know what code path is used in the code that you ran, but I checked this PyTorch f32 -> bf16
rounding helper,
https://github.com/pytorch/pytorch/blob/main/c10/util/BFloat16.h#L76
And it does return the correct result,
int main() {
float f = 6.8335e+30;
printf("round_to_nearest_even(%g) = 0x%x\n", f, round_to_nearest_even(f));
}
prints
round_to_nearest_even(6.8335e+30) = 0x72ad
And not 0x72ac
as in your above PyTorch result. So, it seems like different code paths within PyTorch don't agree with each other, and at least this one agrees with us.
I used following c++ code to check 16-bits truncation result:
//cast to bfloat16
bfloat16& operator =(float float_val){
data = (*reinterpret_cast<unsigned int *>(&float_val))>>16;
return *this;
}
This implements the same incorrect rounding-towards-zero as we discussed above. Just dropping the bottom 16 bits like this fails to account for the possibiliy that their value might be >= 0x8000 requiring rounding upwards to the next representable value.
(Side note: this also has undefined behavior in C++, as an unsigned int
object coexists with a float
object at the same memory location. The only way to implement a bitcast like this in C++ prior to C++20 is to copy data with something like memcpy
, or go down to one of the few POD types that support aliasing, such as char
, unsigned char
or std::byte
. Or if using C++20, use the new std::bitcast
for that).
@bjacob Thanks for the explanation! For multiplication example, iree output is more closer as pytorch is simply truncating the result. Can we add a functionality to explicitly mention the rounding mechanism we want in IREE, as we need pytorch results as our reference for the model outputs?
And I did following for pytorch bf16 multiplication:
>>> x = torch.tensor([1.3667e30], dtype=torch.bfloat16)
>>> y = torch.tensor([5], dtype=torch.bfloat16)
>>> x*y
tensor([6.8136e+30], dtype=torch.bfloat16)
I also got a weird example, @bjacob
/home/gaurav/MLIRepos/iree-build/tools/iree-run-module --module=add.vmfb --input="1xbf16=[-0.4882]"
pytorch-cpu output: -0.5
IREE-cpu output: -0.75
Wow, funny bug that you found here! It appears to be a parsing bug, in how iree-run-module
parses the --input
flag. Indeed, it produces expected results when the specified array element has no more than two digits after the decimal point, and it reproduces whenever it has 3 or more digits after the decimal point.
~/iree-build tools/iree-run-module --module=/tmp/add.vmfb --input="1xbf16=[-0.48]"
EXEC @main_graph
result[0]: hal.buffer_view
1xbf16=-0.492188
~/iree-build tools/iree-run-module --module=/tmp/add.vmfb --input="1xbf16=[-0.49]"
EXEC @main_graph
result[0]: hal.buffer_view
1xbf16=-0.5
~/iree-build tools/iree-run-module --module=/tmp/add.vmfb --input="1xbf16=[-0.488]"
EXEC @main_graph
result[0]: hal.buffer_view
1xbf16=-0.75
FYI @benvanik
The parsing itself is correct, though - iree_hal_parse_element_unsafe
does parse the correct value and its caller iree_hal_parse_buffer_elements
does store it in the destination buffer.
And yet, something is producing incorrect results only when the --input
parameter specified more than two decimals...
The bug reproduces whenever the specified --input
element rounds to -0.488281
as a bfloat16
(encoding 0xbefa
). It does not reproduce whenever it rounds to the previous bfloat16
value -0.486328
(encoding 0xbef9
). In both cases, our f32 <-> bfloat16
conversion helpers produce correct results, and the parsing is correct too as noted above. Just for some reason, the bfloat16
value -0.488281
(encoding 0xbefa
) runs into some arithmetic bug elsewhere.
And the other operand, which is hardcoded as a constant in the above testcase, also matters. Here is a testcase taking both operands as arguments:
#map = affine_map<(d0) -> (0)>
#map1 = affine_map<(d0) -> (0)>
#map2 = affine_map<(d0) -> (d0)>
module {
func.func @main_graph(%arg0: tensor<1xbf16>, %arg1: tensor<1xbf16>) -> tensor<1xbf16> {
%0 = tensor.empty() : tensor<1xbf16>
%1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xbf16>, tensor<1xbf16>) outs(%0 : tensor<1xbf16>) {
^bb0(%in: bf16, %in_0: bf16, %out: bf16):
%2 = arith.addf %in, %in_0 : bf16
linalg.yield %2 : bf16
} -> tensor<1xbf16>
return %1 : tensor<1xbf16>
}
}
With that, I find that for this to reproduce, that other operand needs to be bfloat16 0xbc31
or greater-negative (decimal value -0.010832).
This actually minimizes down to a testcase that performs no bfloat16 arithmetic and only a f32->bfloat16 truncf:
#map = affine_map<(d0) -> (d0)>
module {
func.func @main_graph(%arg0: tensor<1xf32>) -> tensor<1xbf16> {
%0 = tensor.empty() : tensor<1xbf16>
%1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xf32>) outs(%0 : tensor<1xbf16>) {
^bb0(%in0: f32, %out: bf16):
%3 = arith.truncf %in0 : f32 to bf16
linalg.yield %3 : bf16
} -> tensor<1xbf16>
return %1 : tensor<1xbf16>
}
}
~/iree-build tools/iree-run-module --module=/tmp/repro2.vmfb --input="1xf32=[-0.499081]" --device=local-task
EXEC @main_graph
result[0]: hal.buffer_view
1xbf16=-0.75
@rsuderman this might be for you :-)
What --mlir-print-ir-after-all
shows for the testcase in the previous comment:
// -----// IR Dump After CSE (cse) //----- //
module {
func.func @main_graph_dispatch_0_generic() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<f32>
memref.assume_alignment %0, 64 : memref<f32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<i16>
memref.assume_alignment %1, 64 : memref<i16>
%2 = memref.load %0[] : memref<f32>
%3 = arith.truncf %2 : f32 to bf16
%4 = arith.bitcast %3 : bf16 to i16
memref.store %4, %1[] : memref<i16>
return
}
}
// -----// IR Dump After ConvertToLLVM (iree-convert-to-llvm) //----- //
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-unknown-eabi-elf"} {
llvm.func @main_graph_dispatch_0_generic(%arg0: !llvm.ptr {llvm.align = 16 : i64, llvm.noalias}, %arg1: !llvm.ptr {llvm.align = 16 : i64, llvm.noalias}, %arg2: !llvm.ptr {llvm.align = 16 : i64, llvm.noalias}) -> i32 {
%0 = llvm.mlir.constant(0 : i32) : i32
%1 = llvm.mlir.constant(16 : i32) : i32
%2 = llvm.mlir.constant(32768 : i32) : i32
%3 = llvm.mlir.constant(2130706432 : i32) : i32
%4 = llvm.mlir.constant(2139095040 : i32) : i32
%5 = llvm.mlir.constant(8388607 : i32) : i32
%6 = llvm.mlir.constant(31 : i32) : i32
%7 = llvm.mlir.constant(23 : i32) : i32
%8 = llvm.mlir.constant(63 : index) : i64
%9 = llvm.mlir.constant(0 : index) : i64
%10 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>
%11 = llvm.extractvalue %10[10] : !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>
%12 = llvm.load %11 : !llvm.ptr -> !llvm.ptr
%13 = llvm.ptrtoint %12 : !llvm.ptr to i64
%14 = llvm.and %13, %8 : i64
%15 = llvm.icmp "eq" %14, %9 : i64
"llvm.intr.assume"(%15) : (i1) -> ()
%16 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>
%17 = llvm.extractvalue %16[10] : !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>
%18 = llvm.getelementptr %17[1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
%19 = llvm.load %18 : !llvm.ptr -> !llvm.ptr
%20 = llvm.ptrtoint %19 : !llvm.ptr to i64
%21 = llvm.and %20, %8 : i64
%22 = llvm.icmp "eq" %21, %9 : i64
"llvm.intr.assume"(%22) : (i1) -> ()
%23 = llvm.load %12 : !llvm.ptr -> f32
%24 = llvm.bitcast %23 : f32 to i32
%25 = llvm.lshr %24, %6 : i32
%26 = llvm.sub %2, %25 : i32
%27 = llvm.and %24, %5 : i32
%28 = llvm.add %27, %26 : i32
%29 = llvm.lshr %28, %7 : i32
%30 = llvm.lshr %28, %29 : i32
%31 = llvm.and %24, %4 : i32
%32 = llvm.add %31, %28 : i32
%33 = llvm.and %32, %4 : i32
%34 = llvm.icmp "uge" %31, %3 : i32
%35 = llvm.select %34, %31, %33 : i1, i32
%36 = llvm.trunc %29 : i32 to i1
%37 = llvm.and %34, %36 : i1
%38 = llvm.select %37, %27, %30 : i1, i32
%39 = llvm.shl %25, %6 : i32
%40 = llvm.or %39, %35 : i32
%41 = llvm.or %40, %38 : i32
%42 = llvm.lshr %41, %1 : i32
%43 = llvm.trunc %42 : i32 to i16
llvm.store %43, %19 : i16, !llvm.ptr
llvm.return %0 : i32
}
}
In the first part of the above log, our arith.truncf
op is still as in the original source. In the second part, IR Dump After ConvertToLLVM (iree-convert-to-llvm)
, it has been lowered to llvm
ops implementing the truncation. The testcase shows that this lowering produces incorrect results when the value being truncated, here -0.499081
, is just below an exponent-threshold and its rounding to the nearest bfloat16 value makes it cross an exponent-threshold (-0.499081 becomes -0.5, bumping the exponent).
@rsuderman , here is what the equivalent f32->bf16 truncation code does in the runtime (actually it is generic in bit-widths, but it in particular does f32->bf16) specifically to fix-up in this specific case:
https://github.com/openxla/iree/blob/01c4c57/runtime/src/iree/base/internal/math.h#L389-L390
@rsuderman , here is the much more concise and optimized way that the PyTorch runtime does it (I think that part was written by Marat and carried over from XNNPACK or some predecessor of it): https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L76
In the above-linked runtime code, I didn't bother to implement this magic trick because I wanted genericity and didn't need to chase performance. But in the compiler lowering, it would make sense to do the concise and efficient thing.
The link to IREE math.h in the previous comment has a comment explaining the magic trick here.
// Note: software implementations that try to be fast tend to get this
// conditional increment of exp and zeroing of mantissa for free by
// simplying incrementing the whole uint32 encoding of the float value,
// so that the mantissa overflows into the exponent bits.
@rsuderman , here is the much more concise and optimized way that the PyTorch runtime does it (I think that part was written by Marat and carried over from XNNPACK or some predecessor of it): pytorch/pytorch@
e1502c0
/c10/util/BFloat16.h#L76In the above-linked runtime code, I didn't bother to implement this magic trick because I wanted genericity and didn't need to chase performance. But in the compiler lowering, it would make sense to do the concise and efficient thing.
The link to IREE math.h in the previous comment has a comment explaining the magic trick here.
// Note: software implementations that try to be fast tend to get this // conditional increment of exp and zeroing of mantissa for free by // simplying incrementing the whole uint32 encoding of the float value, // so that the mantissa overflows into the exponent bits.
Great :/, I was pretty sure I had managed to implement the rounding behavior correctly but I did not have an aggressive test case to evaluate with. I assume this means there is an error in our bf16
truncf implementation? I am not sure I have time in the near feauture to debug the exact bit errors, is it possible you could take a look?
It's ok, I think I have the patch ready soon.
@Shukla-Gaurav , this seems to work. I'll fix up any unit test that fails and send that for review @rsuderman . https://github.com/llvm/llvm-project/pull/83180
Thanks a lot @bjacob for actively working on this. Will try the patch with other test cases/models as well.
https://github.com/llvm/llvm-project/pull/83180 is merged, so you'll get it in the next integrate or can cherry-pick it locally until then to verify it fixed your issue.
[AMD Official Use Only - General]
Thank you!
From: Benoit Jacob @.> Sent: Wednesday, February 28, 2024 10:57 AM To: nod-ai/SHARK @.> Cc: Deepak, Kumar @.>; Mention @.> Subject: Re: [nod-ai/SHARK] bf16 result mismatch for Conv2D op (Issue #2090)
Caution: This message originated from an External Source. Use proper caution when opening attachments, clicking links, or responding.
llvm/llvm-project#83180https://github.com/llvm/llvm-project/pull/83180 is merged, so you'll get it in the next integrate or can cherry-pick it locally until then to verify it fixed your issue.
— Reply to this email directly, view it on GitHubhttps://github.com/nod-ai/SHARK/issues/2090#issuecomment-1969642468, or unsubscribehttps://github.com/notifications/unsubscribe-auth/A5OMX36M7WGVDGVTXACBMDDYV54YRAVCNFSM6AAAAABDMNYHESVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSNRZGY2DENBWHA. You are receiving this because you were mentioned.Message ID: @.**@.>>
class op_conv2d(nn.Module): def init(self): super().init() self.layers = nn.Sequential( nn.Conv2d(8, 10, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) )
model = op_conv2d() model_bf16 = model.to(torch.bfloat16) test_input_bf16 = torch.randn(2, 8, 12, 16).to(torch.bfloat16) test_output_bf16 = model_bf16(test_input_bf16) print("Input:", test_input_bf16) print("Output:", test_output_bf16)
map = affine_map<(d0, d1, d2, d3) -> (d1)>
map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module { ml_program.global private mutable @global_seed(dense<0> : tensor) : tensor
func.func @main_graph(%arg0: tensor<2x8x12x16xbf16>) -> tensor<2x10x7x16xbf16> {
%cst = arith.constant dense<"0x863DAA3D53BDADBBB93DF13C8BBD91BD94BDB03D66BD9DBD9C3D923B8ABD883BA43DADBDBABC953C253D6CBD98BD8CBD5D3D21BCB4BDA53D743D15BD1EBD953D5CBC6F3DA8BD3C3C61BD24BBCC3AD8BCB13D44BDE43C73BD303C1CBC663C5EBDA63DB03DD7BC523D82BDB93D563D77BD35BD523D78BD46BBAD3DA03DA5BDF93C68BCE5BC563D04BDB8BDB73D6D3D05BD9CBDB13CAEBC15BD89BC47BD4A3D75BDE6BC51BDA7BCAFBC8BBD4B3D5B3D2DBD513D88BCA93DC7BC18BC49BB27BDCABCB5BD253DADBD7E3D94BCBDBC343D4BBDA33DB8BB143D2CBCD23C213D16BD8BBC80BCD83C8C3D44BC3C3D37BD38BD8C3D373D46BC1CBDCC3C41BB743D7FBD15BB7BBD983B8CBD9E3D73BD033D8B3D533CAABDA63C853D2E3DB4BB83BD9B3CF5BC08BD49BD773D5FBD8D3C703D7FBD9D3D133D0E3D8C3DB0BD9E3D87BC74BDB5BD283CC5BB843D863D84BDDDBC9FBDABBC633D8A3D20BDD53C13BC453C5CBD4B3D94BDB83C463D6EBBEBBC89BDAD3D273D05BDAB3CD83B823DE4BBB73D69BCB1BC81BD9B3D573A5BBD32BD6F3DE43CEBBB95BD22BDB33B8E3DF7BC863C133D893C9ABC6BBD8C3D6B3D92BD983D22BB893C173D803D7C3C30BD2BBD07BDFFBCA8BD68BDB0BD653C91BDA0BCB93D9BBD973D56BDCC3CB8BB523B8D3DC4BB6FBDEB3C48BDA5BD753D443D79BCC83CAFBD273C423CD83C8B3D7FBD3ABDAA3D293D89BD2CBDF7BB1BBCEE3C2B3D67BD09BD50BDA73CB43D44BD95BDCABBA93AE03C913D8E3D503D97BC81BC45BB863C0C3D88BD9E3D333DC0BB8CBD9B3DF43CB93D2B3D0A3DF1BC32BC1ABD8BBD71BDA2BDB63D933DF8BCA7BD993D6E3D92BCA43D6BBD8DBDA0BD75BD86BB29BDAB3D8EBDA13DF5BC9B3C98BD143D93BDE23CB1BC753DAA3D693D2C3D1BBCB33D64BDC4BC27BD13BDA13CA4BD8FBC6B3C3ABD2A3DA23D323DD8BC3C3C70BDA23B673D9DBD84BD553D11BD1ABDA73D99BC5DBC9B3DBFBBF9BC6C3D9ABD8A3D45BD72BC0D3CD4BB9DBD1C3D933DAABD68BBF23C623DA83C26BD6EBC33BD943DBABD393C8C3D773D5CBD9A3DABBD82BDB13DB4BC91BD1A3D58BD233C053DD1BC963DA0BD4DBC45BD663DA83D4D3CA2BD5ABDCC3C933DB6BD4DBD23BC44BC7D3D45BC953D9ABD8CBD9F3BBA3D033D1D3D3DBD70BD7E3DB53C82BA153BF93C31BDAB3D543D843DA7BD743D24BDB8BC32BA903DAFBD83BD343C4CBC2F3CC9BCB63B29BD53BD3A3D23BD44BC2BBD893D87BC8BBDB13D643D7B3CA43CC73CFA3B16BC173C6DBDAC3D383C45BD7D3D8E3D49BC07BD903D65BD7C3D653DA7BDE83CB9BDB8BD7CBC96BC83BD8B3C5CBD813DA53D51BD94BD7E3CFC3CB03D95BD9B3D6E3B553D223D7DBC2B3D923DB83D8C3DB83D7B3CB0BD13BC7FBD7DBDB23C89BD8F3D26BD0D3C073D33BD193C01BD96BC213D2C3DA13D61BA56BDB23DBC3CA1BDB4BD643CBD3CF1BCA23B273D97BDA5BCAE3CB6BD543C943D97BD5DBC803D2DBC44BDEABBB13DB13A6CBD72BD7B3C0D3D4D3D7D3DA2BC883D433D48BD8D3C773D4DBD143D98BD77BDCBBC8F3D90BDA5BDE83BAF3D7F3D71BC01BD5DBD9F3D5FBDCA3CEC3CD73CA53D9C3D363D9B3C4F3DB73CF8BCAD3D97BD56BCBFBCAB3D73BD8B3D1ABDB93CAF3D2B3DD3BC9B3D2A3DB63D963DA1BD9CBD20BC2EBDED3C3CBDAC3D1B3DB83D1CBD043D073D78BD96BD84BD8E3D9CBC503DD43CFA3BA63D4EBDB73CA5BDAD3D81BD3D3D213D83BD11BD863A453C97BDC5BBDCBC103DA6BDB1BD14BDA83D7BBD57BC79BD273D8DBD253D863DB93B9ABD8C3DF63C48BD80BD7A3D953BB83DE9BCA53C2F3DA1BD513D04BD5A3DB0BC11BD343D513D9DBCBA3D233D6ABDA03C8CBD6C3DA63D803CCFBC4ABD3B3D8FBD3BBDB4BD983C9E3D823D303D49BD313C60BD653D4E3C8E3D44BDC7BCA5BD0FBD023D5A3C903B8BBDCABB713D6D3CA3BD06BD71BDD73C763DA1BCD7BC813D9FBDFBBC84BD3F3C803C55BD8CBDB13C8D3D8A3C8F3D45BD22BCA23C50BC423D933D9E3DAABC893D44BD4FBCC33CB63D7C3D153BA8BD91BD8B3DCB3C3BBDECBC95BDB03D51BC53BC913D623DA8BBABBDAA3B343DB2BD293BD6BCF93B8B3DA0BD5CBD073D9BBDB0BDD4BCBDBC0CBD8B3DA83D643D873DA9BDAF3DA4BD703C92BD1EBD213DBBBBACBDAD3D013DF23C933D71BDBF3CA63DC4BC183DAE3D17BD7D3DCDBC343D1F3D43BCABBD66BDB7BD5C3C3ABDA3BC193D8EBDA1BC983DA8BD8D3DC139363C88BD88BD97BD35BD833CD53C6BBD9FBDB83D28BD88BD5CBD92BD283DB7BD96BDB83D3F3D20BD683DA4BD313D02BDE13AB03DE73C47BD65BD3C3D523C5B3D853D29BC81BD45BCAE3D22BAB43DB9BD34BDB6BD9B3D36BDA73D5DBCFB3C42BD913D0EBC98BD8C3D76BC6DBC3ABD493D963D4EBD253DADBD88BDA83D69BD9B3CA7BBBABDD13C453D073D70BDDDBC623B80BD173DA23D48BCBC3C88BDD7BC803DF8BCB3BD0F3DB0BD16BD653D963D313DFEBA443C0EBD22BCA8BDDB3C1E3C9CBC3F3D7B3C8D3B153D503D973D8DBD683D28BD2BBD163D35BD3CBCEC3B483D30BD353DCB3C75BC7CBC5E3D5A3D633DA53DB23D90BD243DB33D643DA93B66BD623D5DBDAB3D85BDB73D133DE93C20BD8A3D343DD23C07BD403D9A3D663D8BBC8DBD973D673DFC3B873D523D24BC19BC95BBF53C22BD9CBDABBDC83CA63D9D3C25BC13BC193D52BD61BD403D94BD0EBC763C513BB9BD9ABC793D29BC4DBDB7BD2A3D653CA1BDAA3DA83DA03D22BD953D36BD44BD7F3D7C3DB8BD1C3D8FBD63BDA0BDB73DB3BD86BD9D3D703D22BCD4BC6CBD6B3DA7BD023CAABC59BD953C043DC03BACBDF6BCBA3DA93D9A3C83BDB53DA53D99BD813D203B8F3D41BDAD3D32BDAF3C0FBCD73C2EBD1A3D85BDB0BCE83C6D3D0C3DA73D13BD30BDA6BC9F3D0BBDB93DDB3CBE3CF43C323D0BBCAB3DB53CB83D023B52BD433D1C3D013DEFBC8C3DC3BC54BDA2BD65BD393D7ABD933D073D15BD91BD813D743CB5BB763D9A3D5D3C733DB23D9F3DA73D9D3D753D72BD953C1B3D9D3D953D603D7ABC47BD75BDB9BD99BD393DDA3CA33D5EBDB4BD783D88BDA1BC0DBC8E3C84BD89BC353C25BD0A3DFC3C9D3DC9BC633D93BB6D3C7C3DA7BD1EBDAF3DC4BBBA3D513D903D9B3DCB3C9BBDA83D5BBCAEBC323CABBA533C29BDB83D573D153B94BA8B3CF63B783D84BD07BB723D463C4D3C12BD8D3D04BDB8BAB3BD68BDA43D56BC623DB63D903D023D32BDB9BD8E3D4E3CB8BDB93D6C3DFC3C64BDA6BC35BD3EBD813D6ABD4F3D543DA33D1B3D4D3D8D3C873D813ADE3C4ABDFE3B193C5C3D1C3A95BDB13D3CBA503D8D3DE4BB893D453D18BD1ABD9F3CDEBBA0BD843D81BCAC3D473C4B3D90BD65BD2C3C94BCBE3C833D"> : tensor<10x8x3x5xbf16>
%cst_0 = arith.constant dense<[8.056640e-02, 6.738280e-02, -3.637700e-02, -2.111820e-02, 7.568360e-02, 7.519530e-02, -3.112790e-02, 4.663090e-02, -4.589840e-02, 5.908200e-02]> : tensor<10xbf16>
%cst_1 = arith.constant 0.000000e+00 : bf16
%padded = tensor.pad %arg0 low[0, 0, 4, 2] high[0, 0, 4, 2] {
^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index):
tensor.yield %cst_1 : bf16
} : tensor<2x8x12x16xbf16> to tensor<2x8x20x20xbf16>
%0 = tensor.empty() : tensor<2x10x7x16xbf16>
%1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_0 : tensor<10xbf16>) outs(%0 : tensor<2x10x7x16xbf16>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
} -> tensor<2x10x7x16xbf16>
%2 = linalg.conv_2d_nchw_fchw {dilations = dense<[3, 1]> : vector<2xi64>, strides = dense<[2, 1]> : vector<2xi64>} ins(%padded, %cst : tensor<2x8x20x20xbf16>, tensor<10x8x3x5xbf16>) outs(%1 : tensor<2x10x7x16xbf16>) -> tensor<2x10x7x16xbf16>
return %2 : tensor<2x10x7x16xbf16>
}
}