xdslproject / xdsl

A Python Compiler Design Toolkit
Other
267 stars 71 forks source link

Precision limit in xDSL output #2875

Open gabrielrodcanal opened 3 months ago

gabrielrodcanal commented 3 months ago

Hi there, I have noticed that xDSL truncates the precision of floating point numbers in its output. This doesn't affect any pass in particular, since just running the MLIR code through xdsl-opt reduces the precision. Here's a minimal example to reproduce it:

"builtin.module"() ({
  "memref.global"() <{constant, initial_value = dense<[[[[-0.155728638, 0.402353376]]], [[[0.224845588, -0.389391869]]], [[[-0.705623507, 0.498874038]]], [[[-0.173432678, 0.568215191]]]]> : tensor<4x1x1x2xf32>, sym_name = "__constant_4x1x1x2xf32", sym_visibility = "private", type = memref<4x1x1x2xf32>}> : () -> ()
}) {torch.debug_module_name = "ResnetBB"} : () -> ()

After the code is run through xdsl-opt it will produce:

builtin.module attributes  {"torch.debug_module_name" = "ResnetBB"} {
  "memref.global"() <{"constant", "initial_value" = dense<[[[[-1.557286e-01, 4.023534e-01]]], [[[2.248456e-01, -3.893919e-01]]], [[[-7.056235e-01, 4.988740e-01]]], [[[-1.734327e-01, 5.682152e-01]]]]> : tensor<4x1x1x2xf32>, "sym_name" = "__constant_4x1x1x2xf32", "sym_visibility" = "private", "type" = memref<4x1x1x2xf32>}> : () -> ()
}

NOTE: After a Zulip discussion, @AntonLydike pointed out the problem is at https://github.com/xdslproject/xdsl/blob/987abb56d45a710d16985c64b194cae342c1a191/xdsl/printer.py#L524

NOTE 2: since in this example we are dealing with a dense attribute, the problem is at https://github.com/xdslproject/xdsl/blob/987abb56d45a710d16985c64b194cae342c1a191/xdsl/printer.py#L582 where there is also truncation.

superlopuh commented 3 months ago

I initially added this to match the MLIR behaviour. It would be great to gather some more examples and match the printing for various float formats. A short-term solution would be to just print with the default Python printing.