nod-ai / SHARK

SHARK - High Performance Machine Learning Distribution
Apache License 2.0
1.4k stars 169 forks source link

bf16 result mismatch for Conv2D op #2090

Open Shukla-Gaurav opened 4 months ago

Shukla-Gaurav commented 4 months ago
  1. Following is the Conv2d pytorch module.
    
    import torch
    import torch.nn as nn

class op_conv2d(nn.Module): def init(self): super().init() self.layers = nn.Sequential( nn.Conv2d(8, 10, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) )

def forward(self, x):
    return self.layers(x)

model = op_conv2d() model_bf16 = model.to(torch.bfloat16) test_input_bf16 = torch.randn(2, 8, 12, 16).to(torch.bfloat16) test_output_bf16 = model_bf16(test_input_bf16) print("Input:", test_input_bf16) print("Output:", test_output_bf16)


2. This is the linalg IR of the above pytorch module:

map = affine_map<(d0, d1, d2, d3) -> (d1)>

map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>

module { ml_program.global private mutable @global_seed(dense<0> : tensor) : tensor func.func @main_graph(%arg0: tensor<2x8x12x16xbf16>) -> tensor<2x10x7x16xbf16> { %cst = arith.constant dense<"0x863DAA3D53BDADBBB93DF13C8BBD91BD94BDB03D66BD9DBD9C3D923B8ABD883BA43DADBDBABC953C253D6CBD98BD8CBD5D3D21BCB4BDA53D743D15BD1EBD953D5CBC6F3DA8BD3C3C61BD24BBCC3AD8BCB13D44BDE43C73BD303C1CBC663C5EBDA63DB03DD7BC523D82BDB93D563D77BD35BD523D78BD46BBAD3DA03DA5BDF93C68BCE5BC563D04BDB8BDB73D6D3D05BD9CBDB13CAEBC15BD89BC47BD4A3D75BDE6BC51BDA7BCAFBC8BBD4B3D5B3D2DBD513D88BCA93DC7BC18BC49BB27BDCABCB5BD253DADBD7E3D94BCBDBC343D4BBDA33DB8BB143D2CBCD23C213D16BD8BBC80BCD83C8C3D44BC3C3D37BD38BD8C3D373D46BC1CBDCC3C41BB743D7FBD15BB7BBD983B8CBD9E3D73BD033D8B3D533CAABDA63C853D2E3DB4BB83BD9B3CF5BC08BD49BD773D5FBD8D3C703D7FBD9D3D133D0E3D8C3DB0BD9E3D87BC74BDB5BD283CC5BB843D863D84BDDDBC9FBDABBC633D8A3D20BDD53C13BC453C5CBD4B3D94BDB83C463D6EBBEBBC89BDAD3D273D05BDAB3CD83B823DE4BBB73D69BCB1BC81BD9B3D573A5BBD32BD6F3DE43CEBBB95BD22BDB33B8E3DF7BC863C133D893C9ABC6BBD8C3D6B3D92BD983D22BB893C173D803D7C3C30BD2BBD07BDFFBCA8BD68BDB0BD653C91BDA0BCB93D9BBD973D56BDCC3CB8BB523B8D3DC4BB6FBDEB3C48BDA5BD753D443D79BCC83CAFBD273C423CD83C8B3D7FBD3ABDAA3D293D89BD2CBDF7BB1BBCEE3C2B3D67BD09BD50BDA73CB43D44BD95BDCABBA93AE03C913D8E3D503D97BC81BC45BB863C0C3D88BD9E3D333DC0BB8CBD9B3DF43CB93D2B3D0A3DF1BC32BC1ABD8BBD71BDA2BDB63D933DF8BCA7BD993D6E3D92BCA43D6BBD8DBDA0BD75BD86BB29BDAB3D8EBDA13DF5BC9B3C98BD143D93BDE23CB1BC753DAA3D693D2C3D1BBCB33D64BDC4BC27BD13BDA13CA4BD8FBC6B3C3ABD2A3DA23D323DD8BC3C3C70BDA23B673D9DBD84BD553D11BD1ABDA73D99BC5DBC9B3DBFBBF9BC6C3D9ABD8A3D45BD72BC0D3CD4BB9DBD1C3D933DAABD68BBF23C623DA83C26BD6EBC33BD943DBABD393C8C3D773D5CBD9A3DABBD82BDB13DB4BC91BD1A3D58BD233C053DD1BC963DA0BD4DBC45BD663DA83D4D3CA2BD5ABDCC3C933DB6BD4DBD23BC44BC7D3D45BC953D9ABD8CBD9F3BBA3D033D1D3D3DBD70BD7E3DB53C82BA153BF93C31BDAB3D543D843DA7BD743D24BDB8BC32BA903DAFBD83BD343C4CBC2F3CC9BCB63B29BD53BD3A3D23BD44BC2BBD893D87BC8BBDB13D643D7B3CA43CC73CFA3B16BC173C6DBDAC3D383C45BD7D3D8E3D49BC07BD903D65BD7C3D653DA7BDE83CB9BDB8BD7CBC96BC83BD8B3C5CBD813DA53D51BD94BD7E3CFC3CB03D95BD9B3D6E3B553D223D7DBC2B3D923DB83D8C3DB83D7B3CB0BD13BC7FBD7DBDB23C89BD8F3D26BD0D3C073D33BD193C01BD96BC213D2C3DA13D61BA56BDB23DBC3CA1BDB4BD643CBD3CF1BCA23B273D97BDA5BCAE3CB6BD543C943D97BD5DBC803D2DBC44BDEABBB13DB13A6CBD72BD7B3C0D3D4D3D7D3DA2BC883D433D48BD8D3C773D4DBD143D98BD77BDCBBC8F3D90BDA5BDE83BAF3D7F3D71BC01BD5DBD9F3D5FBDCA3CEC3CD73CA53D9C3D363D9B3C4F3DB73CF8BCAD3D97BD56BCBFBCAB3D73BD8B3D1ABDB93CAF3D2B3DD3BC9B3D2A3DB63D963DA1BD9CBD20BC2EBDED3C3CBDAC3D1B3DB83D1CBD043D073D78BD96BD84BD8E3D9CBC503DD43CFA3BA63D4EBDB73CA5BDAD3D81BD3D3D213D83BD11BD863A453C97BDC5BBDCBC103DA6BDB1BD14BDA83D7BBD57BC79BD273D8DBD253D863DB93B9ABD8C3DF63C48BD80BD7A3D953BB83DE9BCA53C2F3DA1BD513D04BD5A3DB0BC11BD343D513D9DBCBA3D233D6ABDA03C8CBD6C3DA63D803CCFBC4ABD3B3D8FBD3BBDB4BD983C9E3D823D303D49BD313C60BD653D4E3C8E3D44BDC7BCA5BD0FBD023D5A3C903B8BBDCABB713D6D3CA3BD06BD71BDD73C763DA1BCD7BC813D9FBDFBBC84BD3F3C803C55BD8CBDB13C8D3D8A3C8F3D45BD22BCA23C50BC423D933D9E3DAABC893D44BD4FBCC33CB63D7C3D153BA8BD91BD8B3DCB3C3BBDECBC95BDB03D51BC53BC913D623DA8BBABBDAA3B343DB2BD293BD6BCF93B8B3DA0BD5CBD073D9BBDB0BDD4BCBDBC0CBD8B3DA83D643D873DA9BDAF3DA4BD703C92BD1EBD213DBBBBACBDAD3D013DF23C933D71BDBF3CA63DC4BC183DAE3D17BD7D3DCDBC343D1F3D43BCABBD66BDB7BD5C3C3ABDA3BC193D8EBDA1BC983DA8BD8D3DC139363C88BD88BD97BD35BD833CD53C6BBD9FBDB83D28BD88BD5CBD92BD283DB7BD96BDB83D3F3D20BD683DA4BD313D02BDE13AB03DE73C47BD65BD3C3D523C5B3D853D29BC81BD45BCAE3D22BAB43DB9BD34BDB6BD9B3D36BDA73D5DBCFB3C42BD913D0EBC98BD8C3D76BC6DBC3ABD493D963D4EBD253DADBD88BDA83D69BD9B3CA7BBBABDD13C453D073D70BDDDBC623B80BD173DA23D48BCBC3C88BDD7BC803DF8BCB3BD0F3DB0BD16BD653D963D313DFEBA443C0EBD22BCA8BDDB3C1E3C9CBC3F3D7B3C8D3B153D503D973D8DBD683D28BD2BBD163D35BD3CBCEC3B483D30BD353DCB3C75BC7CBC5E3D5A3D633DA53DB23D90BD243DB33D643DA93B66BD623D5DBDAB3D85BDB73D133DE93C20BD8A3D343DD23C07BD403D9A3D663D8BBC8DBD973D673DFC3B873D523D24BC19BC95BBF53C22BD9CBDABBDC83CA63D9D3C25BC13BC193D52BD61BD403D94BD0EBC763C513BB9BD9ABC793D29BC4DBDB7BD2A3D653CA1BDAA3DA83DA03D22BD953D36BD44BD7F3D7C3DB8BD1C3D8FBD63BDA0BDB73DB3BD86BD9D3D703D22BCD4BC6CBD6B3DA7BD023CAABC59BD953C043DC03BACBDF6BCBA3DA93D9A3C83BDB53DA53D99BD813D203B8F3D41BDAD3D32BDAF3C0FBCD73C2EBD1A3D85BDB0BCE83C6D3D0C3DA73D13BD30BDA6BC9F3D0BBDB93DDB3CBE3CF43C323D0BBCAB3DB53CB83D023B52BD433D1C3D013DEFBC8C3DC3BC54BDA2BD65BD393D7ABD933D073D15BD91BD813D743CB5BB763D9A3D5D3C733DB23D9F3DA73D9D3D753D72BD953C1B3D9D3D953D603D7ABC47BD75BDB9BD99BD393DDA3CA33D5EBDB4BD783D88BDA1BC0DBC8E3C84BD89BC353C25BD0A3DFC3C9D3DC9BC633D93BB6D3C7C3DA7BD1EBDAF3DC4BBBA3D513D903D9B3DCB3C9BBDA83D5BBCAEBC323CABBA533C29BDB83D573D153B94BA8B3CF63B783D84BD07BB723D463C4D3C12BD8D3D04BDB8BAB3BD68BDA43D56BC623DB63D903D023D32BDB9BD8E3D4E3CB8BDB93D6C3DFC3C64BDA6BC35BD3EBD813D6ABD4F3D543DA33D1B3D4D3D8D3C873D813ADE3C4ABDFE3B193C5C3D1C3A95BDB13D3CBA503D8D3DE4BB893D453D18BD1ABD9F3CDEBBA0BD843D81BCAC3D473C4B3D90BD65BD2C3C94BCBE3C833D"> : tensor<10x8x3x5xbf16> %cst_0 = arith.constant dense<[8.056640e-02, 6.738280e-02, -3.637700e-02, -2.111820e-02, 7.568360e-02, 7.519530e-02, -3.112790e-02, 4.663090e-02, -4.589840e-02, 5.908200e-02]> : tensor<10xbf16> %cst_1 = arith.constant 0.000000e+00 : bf16 %padded = tensor.pad %arg0 low[0, 0, 4, 2] high[0, 0, 4, 2] { ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: index): tensor.yield %cst_1 : bf16 } : tensor<2x8x12x16xbf16> to tensor<2x8x20x20xbf16> %0 = tensor.empty() : tensor<2x10x7x16xbf16> %1 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%cst_0 : tensor<10xbf16>) outs(%0 : tensor<2x10x7x16xbf16>) { ^bb0(%in: bf16, %out: bf16): linalg.yield %in : bf16 } -> tensor<2x10x7x16xbf16> %2 = linalg.conv_2d_nchw_fchw {dilations = dense<[3, 1]> : vector<2xi64>, strides = dense<[2, 1]> : vector<2xi64>} ins(%padded, %cst : tensor<2x8x20x20xbf16>, tensor<10x8x3x5xbf16>) outs(%1 : tensor<2x10x7x16xbf16>) -> tensor<2x10x7x16xbf16> return %2 : tensor<2x10x7x16xbf16> } }


Running the above module through the IREE cpu backend generates incorrect results wrt the pytorch output.
MaheshRavishankar commented 4 months ago

Can you give me the run module command that shows the error... That will help repro the error for me. For example the %1 above is strange (correct still and shouldnt affect correctness) but will help triage.

Shukla-Gaurav commented 4 months ago

@MaheshRavishankar

  1. Following are the commands:
    ~/iree-build/tools/iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu  conv2d.linalg.mlir > conv2d.bf16.vmfb 2>iree-compile.log
~/iree-build/tools/iree-run-module --module=conv2d.bf16.vmfb --input="2x8x12x16xbf16=@inference_input.0.bin.txt"  
 > inference.log

I am attaching the inference_input.0.bin, which is all 1's. inference_input.0.bin.txt

  1. The iree-run-module is running fine but the result of iree-run-module mismatches with the result of conv2d pytorch module. (The input is all 1's in both the cases).
    
    import torch
    import torch.nn as nn

class op_conv2d(nn.Module): def init(self): super().init() self.layers = nn.Sequential( nn.Conv2d(8, 10, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1)) ) def forward(self, x): return self.layers(x)

model = op_conv2d() model_bf16 = model.to(torch.bfloat16) test_input_bf16 = torch.ones(2, 8, 12, 16).to(torch.bfloat16) test_output_bf16 = model_bf16(test_input_bf16) print("Input:", test_input_bf16) print("Output:", test_output_bf16)

3. You can also use the repo https://github.com/nod-ai/SHARK-TestSuite to test/cross-check conv2d results with your fix:

git clone https://github.com/nod-ai/SHARK-TestSuite activate your iree_venv or torch_mlir_venv cd e2eshark python ./run.py --runupto inference --mode onnx -c ~/torch-mlir/build -i ~/iree-build --tests pytorch/operators/conv2d/ --hfhome ~/hf_home/ --verbose -d bf16 -r test-conv2d-bf16

Shukla-Gaurav commented 4 months ago

conv2d.bf16.linalg.mlir.txt conv2d.fp32.linalg.mlir.txt iree-compile-conv2d-bf16.log iree-compile-conv2d-fp32.log

Shukla-Gaurav commented 4 months ago

Running conv2d with different precisions, keeping all the constants(weight and bias) same. conv2d.bf16.compile.log conv2d.fp32.compile.log

MaheshRavishankar commented 4 months ago

I think this is not really a codegen issue. This is really a bf16 issue . Are we comfortable closing this, or do we need to do more here.

Shukla-Gaurav commented 4 months ago
  1. Linear Module for reference output: (weight, bias and input has been fixed to simplify comparison in IRs and outputs. Also all these values fits in bf16, so x.to(torch.bfloat16) won't change values).

    
    import torch
    import torch.nn as nn
    class op_linear(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_layer = nn.Linear(3, 4)
        self.linear_layer.weight = nn.Parameter(
            torch.tensor([[-0.4199,  0.4180, -0.0293],
                        [ 0.4297, -0.4434,  0.0162],
                        [-0.2061, -0.4004,  0.2773],
                        [-0.5469,  0.0449,  0.3242]], 
                        dtype=torch.float32), 
            requires_grad=False)
        self.linear_layer.bias = nn.Parameter(
            torch.tensor([ 0.2236,  0.3184, -0.1709, -0.4883], 
                        dtype=torch.float32),
            requires_grad=False)
        self.layers = nn.Sequential(self.linear_layer)
    
    def forward(self, x):
        return self.layers(x)

model = op_linear()

fp32 computation.

test_input = torch.tensor( [[-0.4062, -0.6953, 1.8516], [ 0.4961, 0.9609, -0.7500], [ 0.4766, -1.4531, 0.6172], [-0.4785, 1.0859, -0.9922], [ 2.1094, -0.0107, 0.3496], [-0.6562, -0.0116, 1.7812], [ 0.0114, -0.1279, 1.7266], [-0.1289, 0.6250, 1.3516]], dtype=torch.float32) output = model(test_input)

bf16 computation.

test_input.to(torch.bfloat16) model.to(torch.bfloat16) bf16_golden_output = model(test_input)

2. Linalg IR:
[linear.bf16.linalg.mlir.txt](https://github.com/nod-ai/SHARK/files/14386289/linear.bf16.linalg.mlir.txt)
[linear.fp32.linalg.mlir.txt](https://github.com/nod-ai/SHARK/files/14386290/linear.fp32.linalg.mlir.txt)

3. Applying `--iree-global-opt-enable-demote-contraction-inputs-to-bf16` in fp32 IR and `--iree-llvmcpu-enable-ukernels=all` in bf16 IR to compare output from different paths.

/home/gaurav/MLIRepos/iree-build/tools/iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu --iree-global-opt-enable-demote-contraction-inputs-to-bf16 --iree-input-type=tm_tensor --mlir-print-ir-after-all --mlir-disable-threading linear.fp32.linalg.mlir > linear.contraction.vmfb 2>contraction-flag-linear-fp32-iree_compile.log

/home/gaurav/MLIRepos/iree-build/tools/iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu --iree-llvmcpu-enable-ukernels=all --mlir-print-ir-after-all --mlir-disable-threading --iree-input-type=tm_tensor linear.bf16.linalg.mlir > linear.ukernel.vmfb 2>ukernel-flag-linear-bf16-iree_compile.log

[contraction-flag-linear-fp32-iree_compile.log](https://github.com/nod-ai/SHARK/files/14386286/contraction-flag-linear-fp32-iree_compile.log)
[ukernel-flag-linear-bf16-iree_compile.log](https://github.com/nod-ai/SHARK/files/14386688/ukernel-flag-linear-bf16-iree_compile.log)

4.

/home/gaurav/MLIRepos/iree-build/tools/iree-run-module --module=linear.contraction.vmfb --input="8x3xf32=@inference_input.0.bin.txt" /home/gaurav/MLIRepos/iree-build/tools/iree-run-module --module=linear.ukernel.vmfb --input="8x3xbf16=@inference_input_bf16.0.bin.txt"

[inference_input.0.bin.txt](https://github.com/nod-ai/SHARK/files/14386287/inference_input.0.bin.txt)
[inference_input_bf16.0.bin.txt](https://github.com/nod-ai/SHARK/files/14386288/inference_input_bf16.0.bin.txt)

5. Comparing different outputs

f32_golden_output= (=model(input).to(bf16)) tensor([[ 0.0493, 0.4824, 0.7031, 0.3027], [ 0.4395, 0.0933, -0.8672, -0.9609], [-0.6016, 1.1797, 0.4844, -0.6133], [ 0.9062, -0.3848, -0.7812, -0.5000], [-0.6758, 1.2344, -0.5039, -1.5312], [ 0.4414, 0.0703, 0.4629, 0.4473], [ 0.1147, 0.4082, 0.3574, 0.0596], [ 0.5000, 0.0078, -0.0198, 0.0483]], dtype=torch.bfloat16)

bf16_golden_output= (=model.to(bf16); input.to(bf16); model(input)) tensor([[ 0.0493, 0.4824, 0.7031, 0.3027], [ 0.4395, 0.0933, -0.8672, -0.9609], [-0.6016, 1.1797, 0.4844, -0.6133], [ 0.9062, -0.3848, -0.7812, -0.5000], [-0.6758, 1.2344, -0.5039, -1.5312], [ 0.4414, 0.0703, 0.4629, 0.4473], [ 0.1147, 0.4082, 0.3574, 0.0596], [ 0.5000, 0.0078, -0.0198, 0.0486]], dtype=torch.bfloat16)

contraction_flag_output= (this is quite close to golden output but we trace f32 model(f32 linalg IR) and only demote certain ops to bf16.) tensor([[ 0.0493, 0.4824, 0.7031, 0.3027], [ 0.4395, 0.0933, -0.8672, -0.9609], [-0.6016, 1.1797, 0.4844, -0.6133], [ 0.9062, -0.3848, -0.7812, -0.5000], [-0.6758, 1.2344, -0.5039, -1.5312], [ 0.4414, 0.0703, 0.4629, 0.4473], [ 0.1147, 0.4082, 0.3574, 0.0596], [ 0.5000, 0.0079, -0.0198, 0.0486]], dtype=torch.bfloat16)

ukernel_flag_output= (this is same as iree bf16 inference output without using this flag) (the output mismatches with golden output.) tensor([[ 0.0498, 0.4824, 0.7031, 0.3047], [ 0.4395, 0.0938, -0.8672, -0.9570], [-0.6055, 1.1797, 0.4863, -0.6133], [ 0.9062, -0.3848, -0.7812, -0.7500], [-0.6797, 1.2344, -0.5039, -1.5234], [ 0.4434, 0.0723, 0.4629, 0.4492], [ 0.1147, 0.4082, 0.3574, 0.0586], [ 0.5000, 0.0078, -0.0195, 0.0469]], dtype=torch.bfloat16)

print(torch.allclose(bf16_golden_output, contraction_flag_output, atol=1e-04, rtol=1e-03)) // True print(torch.allclose(bf16_golden_output, ukernel_flag_output, atol=1e-04, rtol=1e-03)) //False



Not sure how pytorch handles the bf16 computation but it's close to  =>(f32 computation then result demoted to bf16) 
(NOTE: all the inputs are small enough to fit in bf16 type).
For the same inputs, iree bf16 inference result mismatches with the above mentioned tolerances(the ukernel_flag_output).
(Although the contraction_flag_output is close but we want to start with a bf16 model not f32 model) 

This problem requires attention at how iree handles bf16 loads in cpu backends? Thoughts?
@stellaraccident @kumardeepakamd @MaheshRavishankar @bjacob  

I can create an onnx linear module with the same inputs and run it on onnx runtime to have one more reference output if it helps. Thanks!
Shukla-Gaurav commented 4 months ago

I think this is not really a codegen issue. This is really a bf16 issue . Are we comfortable closing this, or do we need to do more here.

I think we should not close this unless we can conclude on handling bf16 in cpu. I mean how to verify the model is producing the correct outputs through onnx pipeline.

MaheshRavishankar commented 4 months ago

Really I dont know what the compiler itself can decide here.. this always going to mismatch cause the reference is doing different things (and different references will do different things). What IREE is doing is basically "do what it is told to do". If the linalg op says the input type is bf16 and output type is bf16, it is actually doing the accumulation in f32. IMO that is actually being too smart. It should be doing the accumulation in bf16 as well (cause thats what it was told to do). So only AI I can think of is to actually make IREE less smart.

stellaraccident commented 4 months ago

Usually when I've seen these kinds of issues resolved before, it requires a much more careful drill down vs a high level this vs that. There is no abstract answer to these things at that precision: a rounding/truncation mode difference or 1ULP difference at any stage is enough to result in a 1% error for a datatype like this. If trying to get complete correspondence, then none of that can be ignored.

You'll have to dig deeper, and likely if you are still looking at results as textual floating point, you'll miss the difference.

stellaraccident commented 4 months ago

Usually when I've seen these kinds of issues resolved before, it requires a much more careful drill down vs a high level this vs that. There is no abstract answer to these things at that precision: a rounding/truncation mode difference or 1ULP difference at any stage is enough to result in a 1% error for a datatype like this. If trying to get complete correspondence, then none of that can be ignored.

You'll have to dig deeper, and likely if you are still looking at results as textual floating point, you'll miss the difference.

One advantage to having an ONNX reference is that it is much easier to hack on at that level than PyTorch (i.e. you can build it, set a breakpoint or print specific values in a kernel, etc).

Shukla-Gaurav commented 4 months ago

It seems the mismatch is due to different rounding mechanisms used by pytorch and IREE. I ran few simple add/mul tests, and it's mostly 1-bit difference in the outputs. The pytorch simply truncates the last 16 bits after computing the result while IREE seems to be rounding it. Following example can illustrate the above behavior:

  1. Multiplication (1.3667e+30 * 5) ~ 6.8335e+30 Attaching linalg IR and iree-compile log. mul.linalg.mlir.txt mul-bf16-iree_compile.log /home/gaurav/MLIRepos/iree-build/tools/iree-run-module --module=mul.vmfb --input="1xbf16=[5]"
    pytorch-cpu output: 6.8136e+30
    IREE-cpu output: 6.8532e+30 

    Their binary representation differs exactly in the last bit. If we do f32 multiplication and remove the last 16-bits, it will give exact same output as pytorch. I used following c++ code to check 16-bits truncation result:

    
    #include <iostream>
    struct bfloat16{
    unsigned short int data;
    public:
    bfloat16(){
      data = 0;
    }
    //cast to float
    operator float(){
      unsigned int proc = data<<16;
      return *reinterpret_cast<float*>(&proc);
    }
    //cast to bfloat16
    bfloat16& operator =(float float_val){
      data = (*reinterpret_cast<unsigned int *>(&float_val))>>16;
      return *this;
    }
    };

//an example that enumerates all the possible values between 1.0f and 300.0f using namespace std;

int main(){ bfloat16 x; x = 6.8335e+30; // for(x=1.0f;x<300.0f;x.data++){ cout<<x.data<<" "<<x<<endl; // } return 0; }

bjacob commented 4 months ago

Great analysis, thanks! Indeed, the f32 value 6.8335e+30 has the binary encoding 0x72ac8075. Truncating this to bf16 means replacing this by a value that is a multiple of 0x10000, so the two conceivable candidates are 0x72ac000 or 0x72ad000. These are respectively 6.81362e+30 and 6.85324e+30. However, as the next digit after 0x72ac... is a 8 (the 8 in 0x72ac8075), the nearest value really is the latter, 0x72ad000 = 6.85324e+30. And this isn't even a tie (it would be a tie if the value were 0x72ac8000), so there are no questions of tie-breaks here ("tie to nearest-even"). So there is no question that the nearest value is 0x72ad000 = 6.85324e+30. So, IREE is being correct here, and PyTorch is incorrect. What PyTorch does here has a name, "rounding towards zero", it can be sometimes useful for some really specialized uses, but doesn't make sense as the default way to round all values in a workload.

bjacob commented 4 months ago

I don't know what code path is used in the code that you ran, but I checked this PyTorch f32 -> bf16 rounding helper, https://github.com/pytorch/pytorch/blob/main/c10/util/BFloat16.h#L76

And it does return the correct result,

int main() {
    float f = 6.8335e+30;
    printf("round_to_nearest_even(%g) = 0x%x\n", f, round_to_nearest_even(f));
}

prints

round_to_nearest_even(6.8335e+30) = 0x72ad

And not 0x72ac as in your above PyTorch result. So, it seems like different code paths within PyTorch don't agree with each other, and at least this one agrees with us.

bjacob commented 4 months ago

I used following c++ code to check 16-bits truncation result:

   //cast to bfloat16
   bfloat16& operator =(float float_val){
      data = (*reinterpret_cast<unsigned int *>(&float_val))>>16;
      return *this;
   }

This implements the same incorrect rounding-towards-zero as we discussed above. Just dropping the bottom 16 bits like this fails to account for the possibiliy that their value might be >= 0x8000 requiring rounding upwards to the next representable value.

(Side note: this also has undefined behavior in C++, as an unsigned int object coexists with a float object at the same memory location. The only way to implement a bitcast like this in C++ prior to C++20 is to copy data with something like memcpy, or go down to one of the few POD types that support aliasing, such as char, unsigned char or std::byte. Or if using C++20, use the new std::bitcast for that).

Shukla-Gaurav commented 4 months ago

@bjacob Thanks for the explanation! For multiplication example, iree output is more closer as pytorch is simply truncating the result. Can we add a functionality to explicitly mention the rounding mechanism we want in IREE, as we need pytorch results as our reference for the model outputs?

And I did following for pytorch bf16 multiplication:

>>> x = torch.tensor([1.3667e30], dtype=torch.bfloat16)
>>> y = torch.tensor([5], dtype=torch.bfloat16)
>>> x*y
tensor([6.8136e+30], dtype=torch.bfloat16)
Shukla-Gaurav commented 4 months ago

I also got a weird example, @bjacob

  1. Addition (-0.0112 - 0.4882) ~ -0.4994 Attaching linalg IR and iree-compile log. add.linalg.mlir.txt add-bf16-iree_compile.log /home/gaurav/MLIRepos/iree-build/tools/iree-run-module --module=add.vmfb --input="1xbf16=[-0.4882]"
    pytorch-cpu output: -0.5
    IREE-cpu output: -0.75 
bjacob commented 4 months ago

Wow, funny bug that you found here! It appears to be a parsing bug, in how iree-run-module parses the --input flag. Indeed, it produces expected results when the specified array element has no more than two digits after the decimal point, and it reproduces whenever it has 3 or more digits after the decimal point.

~/iree-build tools/iree-run-module --module=/tmp/add.vmfb --input="1xbf16=[-0.48]"
EXEC @main_graph
result[0]: hal.buffer_view
1xbf16=-0.492188
~/iree-build tools/iree-run-module --module=/tmp/add.vmfb --input="1xbf16=[-0.49]"
EXEC @main_graph
result[0]: hal.buffer_view
1xbf16=-0.5
~/iree-build tools/iree-run-module --module=/tmp/add.vmfb --input="1xbf16=[-0.488]"
EXEC @main_graph
result[0]: hal.buffer_view
1xbf16=-0.75

FYI @benvanik

bjacob commented 4 months ago

The parsing itself is correct, though - iree_hal_parse_element_unsafe does parse the correct value and its caller iree_hal_parse_buffer_elements does store it in the destination buffer.

And yet, something is producing incorrect results only when the --input parameter specified more than two decimals...

bjacob commented 4 months ago

The bug reproduces whenever the specified --input element rounds to -0.488281 as a bfloat16 (encoding 0xbefa). It does not reproduce whenever it rounds to the previous bfloat16 value -0.486328 (encoding 0xbef9). In both cases, our f32 <-> bfloat16 conversion helpers produce correct results, and the parsing is correct too as noted above. Just for some reason, the bfloat16 value -0.488281 (encoding 0xbefa) runs into some arithmetic bug elsewhere.

bjacob commented 4 months ago

And the other operand, which is hardcoded as a constant in the above testcase, also matters. Here is a testcase taking both operands as arguments:

#map = affine_map<(d0) -> (0)>
#map1 = affine_map<(d0) -> (0)>
#map2 = affine_map<(d0) -> (d0)>
module {
  func.func @main_graph(%arg0: tensor<1xbf16>, %arg1: tensor<1xbf16>) -> tensor<1xbf16> {
    %0 = tensor.empty() : tensor<1xbf16>
    %1 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<1xbf16>, tensor<1xbf16>) outs(%0 : tensor<1xbf16>) {
    ^bb0(%in: bf16, %in_0: bf16, %out: bf16):
      %2 = arith.addf %in, %in_0 : bf16
      linalg.yield %2 : bf16
    } -> tensor<1xbf16>
    return %1 : tensor<1xbf16>
  }
}

With that, I find that for this to reproduce, that other operand needs to be bfloat16 0xbc31 or greater-negative (decimal value -0.010832).

bjacob commented 4 months ago

This actually minimizes down to a testcase that performs no bfloat16 arithmetic and only a f32->bfloat16 truncf:

#map = affine_map<(d0) -> (d0)>
module {
  func.func @main_graph(%arg0: tensor<1xf32>) -> tensor<1xbf16> {
    %0 = tensor.empty() : tensor<1xbf16>
    %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<1xf32>) outs(%0 : tensor<1xbf16>) {
    ^bb0(%in0: f32, %out: bf16):
      %3 = arith.truncf %in0 : f32 to bf16
      linalg.yield %3 : bf16
    } -> tensor<1xbf16>
    return %1 : tensor<1xbf16>
  }
}
~/iree-build tools/iree-run-module --module=/tmp/repro2.vmfb --input="1xf32=[-0.499081]" --device=local-task
EXEC @main_graph
result[0]: hal.buffer_view
1xbf16=-0.75
bjacob commented 4 months ago

@rsuderman this might be for you :-)

What --mlir-print-ir-after-all shows for the testcase in the previous comment:

// -----// IR Dump After CSE (cse) //----- //
module {
  func.func @main_graph_dispatch_0_generic() {
    %c0 = arith.constant 0 : index
    %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<f32>
    memref.assume_alignment %0, 64 : memref<f32>
    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) : memref<i16>
    memref.assume_alignment %1, 64 : memref<i16>
    %2 = memref.load %0[] : memref<f32>
    %3 = arith.truncf %2 : f32 to bf16
    %4 = arith.bitcast %3 : bf16 to i16
    memref.store %4, %1[] : memref<i16>
    return
  }
}

// -----// IR Dump After ConvertToLLVM (iree-convert-to-llvm) //----- //
module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-unknown-unknown-eabi-elf"} {
  llvm.func @main_graph_dispatch_0_generic(%arg0: !llvm.ptr {llvm.align = 16 : i64, llvm.noalias}, %arg1: !llvm.ptr {llvm.align = 16 : i64, llvm.noalias}, %arg2: !llvm.ptr {llvm.align = 16 : i64, llvm.noalias}) -> i32 {
    %0 = llvm.mlir.constant(0 : i32) : i32
    %1 = llvm.mlir.constant(16 : i32) : i32
    %2 = llvm.mlir.constant(32768 : i32) : i32
    %3 = llvm.mlir.constant(2130706432 : i32) : i32
    %4 = llvm.mlir.constant(2139095040 : i32) : i32
    %5 = llvm.mlir.constant(8388607 : i32) : i32
    %6 = llvm.mlir.constant(31 : i32) : i32
    %7 = llvm.mlir.constant(23 : i32) : i32
    %8 = llvm.mlir.constant(63 : index) : i64
    %9 = llvm.mlir.constant(0 : index) : i64
    %10 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>
    %11 = llvm.extractvalue %10[10] : !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)> 
    %12 = llvm.load %11 : !llvm.ptr -> !llvm.ptr
    %13 = llvm.ptrtoint %12 : !llvm.ptr to i64
    %14 = llvm.and %13, %8  : i64
    %15 = llvm.icmp "eq" %14, %9 : i64
    "llvm.intr.assume"(%15) : (i1) -> ()
    %16 = llvm.load %arg1 : !llvm.ptr -> !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)>
    %17 = llvm.extractvalue %16[10] : !llvm.struct<"iree_hal_executable_dispatch_state_v0_t", (i32, i32, i16, i16, i32, i32, i16, i8, i8, ptr, ptr, ptr)> 
    %18 = llvm.getelementptr %17[1] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
    %19 = llvm.load %18 : !llvm.ptr -> !llvm.ptr
    %20 = llvm.ptrtoint %19 : !llvm.ptr to i64
    %21 = llvm.and %20, %8  : i64
    %22 = llvm.icmp "eq" %21, %9 : i64
    "llvm.intr.assume"(%22) : (i1) -> ()
    %23 = llvm.load %12 : !llvm.ptr -> f32
    %24 = llvm.bitcast %23 : f32 to i32
    %25 = llvm.lshr %24, %6  : i32
    %26 = llvm.sub %2, %25  : i32
    %27 = llvm.and %24, %5  : i32
    %28 = llvm.add %27, %26  : i32
    %29 = llvm.lshr %28, %7  : i32
    %30 = llvm.lshr %28, %29  : i32
    %31 = llvm.and %24, %4  : i32
    %32 = llvm.add %31, %28  : i32
    %33 = llvm.and %32, %4  : i32
    %34 = llvm.icmp "uge" %31, %3 : i32
    %35 = llvm.select %34, %31, %33 : i1, i32
    %36 = llvm.trunc %29 : i32 to i1
    %37 = llvm.and %34, %36  : i1
    %38 = llvm.select %37, %27, %30 : i1, i32
    %39 = llvm.shl %25, %6  : i32
    %40 = llvm.or %39, %35  : i32
    %41 = llvm.or %40, %38  : i32
    %42 = llvm.lshr %41, %1  : i32
    %43 = llvm.trunc %42 : i32 to i16
    llvm.store %43, %19 : i16, !llvm.ptr
    llvm.return %0 : i32
  }
}

In the first part of the above log, our arith.truncf op is still as in the original source. In the second part, IR Dump After ConvertToLLVM (iree-convert-to-llvm), it has been lowered to llvm ops implementing the truncation. The testcase shows that this lowering produces incorrect results when the value being truncated, here -0.499081, is just below an exponent-threshold and its rounding to the nearest bfloat16 value makes it cross an exponent-threshold (-0.499081 becomes -0.5, bumping the exponent).

bjacob commented 4 months ago

@rsuderman , here is what the equivalent f32->bf16 truncation code does in the runtime (actually it is generic in bit-widths, but it in particular does f32->bf16) specifically to fix-up in this specific case:

https://github.com/openxla/iree/blob/01c4c57/runtime/src/iree/base/internal/math.h#L389-L390

bjacob commented 4 months ago

@rsuderman , here is the much more concise and optimized way that the PyTorch runtime does it (I think that part was written by Marat and carried over from XNNPACK or some predecessor of it): https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L76

In the above-linked runtime code, I didn't bother to implement this magic trick because I wanted genericity and didn't need to chase performance. But in the compiler lowering, it would make sense to do the concise and efficient thing.

The link to IREE math.h in the previous comment has a comment explaining the magic trick here.

        // Note: software implementations that try to be fast tend to get this
        // conditional increment of exp and zeroing of mantissa for free by
        // simplying incrementing the whole uint32 encoding of the float value,
        // so that the mantissa overflows into the exponent bits.
rsuderman commented 4 months ago

@rsuderman , here is the much more concise and optimized way that the PyTorch runtime does it (I think that part was written by Marat and carried over from XNNPACK or some predecessor of it): pytorch/pytorch@e1502c0/c10/util/BFloat16.h#L76

In the above-linked runtime code, I didn't bother to implement this magic trick because I wanted genericity and didn't need to chase performance. But in the compiler lowering, it would make sense to do the concise and efficient thing.

The link to IREE math.h in the previous comment has a comment explaining the magic trick here.

        // Note: software implementations that try to be fast tend to get this
        // conditional increment of exp and zeroing of mantissa for free by
        // simplying incrementing the whole uint32 encoding of the float value,
        // so that the mantissa overflows into the exponent bits.

Great :/, I was pretty sure I had managed to implement the rounding behavior correctly but I did not have an aggressive test case to evaluate with. I assume this means there is an error in our bf16 truncf implementation? I am not sure I have time in the near feauture to debug the exact bit errors, is it possible you could take a look?

bjacob commented 4 months ago

It's ok, I think I have the patch ready soon.

bjacob commented 4 months ago

@Shukla-Gaurav , this seems to work. I'll fix up any unit test that fails and send that for review @rsuderman . https://github.com/llvm/llvm-project/pull/83180

Shukla-Gaurav commented 4 months ago

Thanks a lot @bjacob for actively working on this. Will try the patch with other test cases/models as well.

bjacob commented 4 months ago

https://github.com/llvm/llvm-project/pull/83180 is merged, so you'll get it in the next integrate or can cherry-pick it locally until then to verify it fixed your issue.

kumardeepakamd commented 4 months ago

[AMD Official Use Only - General]

Thank you!

From: Benoit Jacob @.> Sent: Wednesday, February 28, 2024 10:57 AM To: nod-ai/SHARK @.> Cc: Deepak, Kumar @.>; Mention @.> Subject: Re: [nod-ai/SHARK] bf16 result mismatch for Conv2D op (Issue #2090)

Caution: This message originated from an External Source. Use proper caution when opening attachments, clicking links, or responding.

llvm/llvm-project#83180https://github.com/llvm/llvm-project/pull/83180 is merged, so you'll get it in the next integrate or can cherry-pick it locally until then to verify it fixed your issue.

— Reply to this email directly, view it on GitHubhttps://github.com/nod-ai/SHARK/issues/2090#issuecomment-1969642468, or unsubscribehttps://github.com/notifications/unsubscribe-auth/A5OMX36M7WGVDGVTXACBMDDYV54YRAVCNFSM6AAAAABDMNYHESVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMYTSNRZGY2DENBWHA. You are receiving this because you were mentioned.Message ID: @.**@.>>