tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
75 stars 11 forks source link

Add round_nearest_afz and round_nearest_even #1198

Open ddilbazTT opened 1 week ago

ddilbazTT commented 1 week ago

round_nearest_afz round_nearest_even

Use ttnn.round. Could add round support as well since linalg needs round.

Convert from stablehlo to ttnn end-to-end.

ttnn round tt-metal implementation is in third_party/tt-metal/src/tt-metal/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp

Would need to use decimals to choose between round_nearest_even and round_nearest_afz.

Tensor _round(const Tensor& input, int32_t decimals, const std::optional<MemoryConfig>&  output_mem_config) {
    auto arch = input.device()->arch();
    TT_FATAL(arch == tt::ARCH::WORMHOLE_B0, "Op is only supported on Wormhole");
    Tensor floor_res = ttnn::floor(input, output_mem_config);
    if (decimals != 0) {  // TODO: For decimal value!=0
        Tensor power_10 =
            ttnn::power(ttnn::full_like(input, 10.0f), decimals, output_mem_config);
        Tensor rounded_non_half = ttnn::floor(
            ttnn::add(ttnn::multiply(input, power_10, std::nullopt, output_mem_config), 0.5, std::nullopt, output_mem_config),
            output_mem_config);
        rounded_non_half = ttnn::div(rounded_non_half, power_10);
        return rounded_non_half;
    } else {  // Bankers' Rounding
        Tensor rounded_non_half = ttnn::floor(
            ttnn::add(
                input,
                ttnn::where(ttnn::logical_and(ttnn::ge(input, 0.4), ttnn::le(input, 0.5)), 0.4f, 0.5f, output_mem_config.value()),
                std::nullopt,
                output_mem_config),
            output_mem_config.value());
        Tensor fractional_part = ttnn::subtract(input, floor_res, std::nullopt, output_mem_config);
        Tensor is_half = ttnn::eq(fractional_part, 0.5, std::nullopt, output_mem_config);
        Tensor rounded_half =
            ttnn::add(floor_res, is_odd(floor_res, output_mem_config), std::nullopt, output_mem_config);
        return ttnn::where(is_half, rounded_half, rounded_non_half, output_mem_config.value());
    }
}
sdjordjevicTT commented 1 week ago

Please provide a detailed description of the issues and specify what they need to capture. For example, does this issue only address the XLA conversion to TTIR? Additionally, the component you assigned appears to be "tt-mlir," which does not correspond to any existing team owners. This makes it difficult to determine who should take on this work.

ddilbazTT commented 1 week ago

Hi @sdjordjevicTT ! I followed the existing issue formats so I don't know what to say about your comment on tt-mlir. I opened this issue in tt-mlir similar to all stablehlo op bringups. Is that wrong? I would appreciate if you provide some clarity since I am relatively new and might not know fully.

sdjordjevicTT commented 1 week ago

@mrakitaTT FYI

Hi @ddilbazTT I am glad to jump in and help clarify. We should tend to avoid putting issues in tt-mlir component. Components should resemble our team split and issues should go to appropriate component owners. tt-mlir component is just a placeholder that we had previously, hence we should stop adding issues there, and try to remove it. Regarding StableHLO ops, I belive that most of the issues there are in StableHLO Conversion component, hence this one probably needs to go there: https://github.com/tenstorrent/tt-mlir/issues/1033

@staylorTT we should have a run over the issues in tt-mlir component and try to move all the issues out of it, and just remove it so it doesn't make people confused.

mrakitaTT commented 1 week ago

@ddilbazTT Just to add some additional info to what Stefan said. Components are set under the Projects settings on the right side:

For issues related to conversion from StableHLO dialect, we add them under StableHLO Conversion component. If the issue requires changes in TTIR/TTNN dialects, then please also open additional issue for that work and set the component of that issue to MLIR Dialects/Passes. That work is a subset of changes required to support conversion so you will probably cover both issues with the same PR, but it is important to open additional task so we can track changes in dialects and avoid doing double work between teams. You can find more info here: https://tenstorrent.slack.com/archives/C07EELVS5HT/p1730466912671279

As for the issues that Aleks shared today, regarding errors found in existing conversions, for each such error that you start working on, please open issue under StableHLO Conversion component and add label stablehlo conversion bug (you can find labels on the right side below Assignees field). More info here: https://tenstorrent.slack.com/archives/C07EELVS5HT/p1731074231644839

I know there is a bunch of new procedures, but all this is really important for us to be able to work and collaborate in teams without stepping on each other toes :)

ddilbazTT commented 1 week ago

Hey @mrakitaTT can you share an example MLIR Dialects/Passes issue? I will use that as a template and create it for round. hope this issue looks good for capturing StableHLO Conversion!

mrakitaTT commented 1 week ago

@ddilbazTT Looks good now! Here is an example issue for dialects changes: https://github.com/tenstorrent/tt-mlir/issues/639 It's nothing special, you just describe what changes you need to make in TTIR/TTNN dialects/passes, add MLIR Dialects / Passes component under TT-Forge project, and add MLIR Ops label if you are adding a new op.

sdjordjevicTT commented 5 days ago

@mrakitaTT thanks for jumping in and clarifying, really appreciate it!