tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
436 stars 63 forks source link

Optimize `ttnn.round` with a direct implementation #13385

Open jdh8 opened 2 weeks ago

jdh8 commented 2 weeks ago

Rounding is only supported by Wormhole, and Wormhole, and Wormhole has the exact function float_to_int16 if the value is in range. https://github.com/tenstorrent/tt-metal/blob/main/docs/source/tt-metalium/tt_metal/apis/kernel_apis/sfpu/llk.rst#wormhole-only

However, ttnn.round is implemented as a combination of ttnn.floor, ttnn.add, etc. https://github.com/tenstorrent/tt-metal/blob/679b8d59834358e34a080cfa743b324930c8c364/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp#L640-L653

Then in turn, ttnn.floor calls functions that effectively computes ttnn.round. https://github.com/tenstorrent/tt-metal/blob/679b8d59834358e34a080cfa743b324930c8c364/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_floor.h#L25-L26

Rounding to a nearest integer is extremely useful for argument reduction. We can reuse a direct implementation in other mathematical functions (mostly elementary functions) such as:

mouliraj-mcw commented 1 day ago

Hi @jdh8 , I examined your approach and found that it doesn't address rounding to a specific number of decimal places (i.e., 2 or 3 decimal places).
Could you please share your thoughts on how this could be managed?

jdh8 commented 1 day ago

Thanks for pointing it out! I missed the parameter decimals.

It can be managed with multiplication by 10n. To be specific,

round(x, n) = 10**-n * round(10**n * x)
jdh8 commented 1 day ago

I have two proposals:

  1. Implement a native roundeven(x) as conceptually round(x, 0), and then make round(x, n) on top of roundeven. (Named after C23 roundeven)
  2. Make a direct, native round(x, n).

Which approach looks better?

mouliraj-mcw commented 19 hours ago

I think approach two would be more suitable, as it has a straightforward structure.