tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
458 stars 68 forks source link

ttnn.mish operation kernel duration is high #12815

Closed punithsekar closed 2 days ago

punithsekar commented 1 month ago

Describe the bug ttnn.mish operation takes too much time(Kernel duration) to compute.

To Reproduce Steps to reproduce the behavior:

  1. Checkout to branch punith/mish_unittest
  2. Run command,./tt_metal/tools/profiler/profile_this.py -n mish_unit_test -c "pytest tests/ttnn/integration_tests/yolov4/unit_test_mish.py"

Please complete the following environment information:

Additional context Perf sheet for the same, mish_perf_sheet.csv

punithsekar commented 1 month ago

fyi @saichandax

ruthreshx commented 1 month ago

Hi @esmalTT , @punithsekar , @eyonland,

On debugging, we found that the Mish holds the following formula.

    **x∗Tanh(Softplus(x))**

On further looking into profiler_result.csv, Softplus took the most time to compute (208524 ns).

Workaround's:

Softplus Formula:

Screenshot 2024-09-25 at 12 57 38 PM

@esmalTT, We would like get your opinion on this? Thank you!

umadevimcw commented 1 month ago

@eyonland @esmalTT As @ruthreshx explained above, the existing piecewise approximation approach seems optimised one for softplus compare to other implementation we tried.

@esmalTT Kindly share your inputs as well

esmalTT commented 1 month ago

It’s probably feasible to improve the softplus implementation up to 40-50% by tweaking the current approximation (less segments without reducing accuracy) and reducing the number of immediates used.

If that won’t be fast enough, you’ll probably have to come up with a Mish approximation function and then implement a Mish LLK.

ruthreshx commented 1 month ago

Hi @esmalTT , Can we use the below formula, since ELU (Exponential Linear Unit) is another smoother approximation to the Softplus function.

Formula: Mish(x)≈x⋅tanh(ELU(x))

I could see the performance improvement up to 40-50%. perf_results_mish_unit_test.csv

your opinion on this?

esmalTT commented 1 month ago

Hi @esmalTT , Can we use the below formula, since ELU (Exponential Linear Unit) is another smoother approximation to the Softplus function.

Formula: Mish(x)≈x⋅tanh(ELU(x))

I could see the performance improvement up to 40-50%. perf_results_mish_unit_test.csv

your opinion on this?

@ruthreshx I don’t see how that would work - they are different functions. You can always plot the error between this and softplus in PyTorch to see what happens.

ruthreshx commented 1 month ago

Hi @esmalTT , Though it is different function, I don't see much of different in the torch result. I have tried with the different ranges as well.

import matplotlib.pyplot as plt
import torch
import numpy as np

x = np.linspace(-1000, 1000, 10)

# Torch numpy
x = torch.from_numpy(x)
print(x)

# Torch ELU
t_elu = torch.nn.functional.elu(x, alpha = 1.0)
# Torch SOFTPLUS
t_softplus = torch.nn.functional.softplus(x, beta=1.0, threshold=20.0)

print("Torch ELU      ==> ", t_elu)
print("Torch SOFTPLUS ==> ", t_softplus)

# Plot
plt.plot(x.numpy(), t_elu.numpy(), "ob", label="torch elu")
plt.plot(x.numpy(), t_softplus.numpy(), "+r", label="torch softplus")
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Comparison of torch elu and torch softplus function')
plt.legend()
plt.grid(True)
plt.show()

Plot result: download

esmalTT commented 1 month ago

Hi @esmalTT , Though it is different function, I don't see much of different in the torch result. I have tried with the different ranges as well.

import matplotlib.pyplot as plt
import torch
import numpy as np

x = np.linspace(-1000, 1000, 10)

# Torch numpy
x = torch.from_numpy(x)
print(x)

# Torch ELU
t_elu = torch.nn.functional.elu(x, alpha = 1.0)
# Torch SOFTPLUS
t_softplus = torch.nn.functional.softplus(x, beta=1.0, threshold=20.0)

print("Torch ELU      ==> ", t_elu)
print("Torch SOFTPLUS ==> ", t_softplus)

# Plot
plt.plot(x.numpy(), t_elu.numpy(), "ob", label="torch elu")
plt.plot(x.numpy(), t_softplus.numpy(), "+r", label="torch softplus")
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Comparison of torch elu and torch softplus function')
plt.legend()
plt.grid(True)
plt.show()

Plot result: download

You really need to look at the error instead of the absolute values. I would also check more than 10 data points (probably at least 100K for that range).

You can also check the PCC of these two plots. I would expect it to be at least 0.9999.

ruthreshx commented 1 month ago

@esmalTT Verified TTNN Softplus vs TTNN ELU and Torch Softplus vs Torch ELU with the PCC of 0.9999. I don't see any PCC drop over here.

# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

from loguru import logger
import random
import pytest
import torch
import ttnn

from tests.ttnn.utils_for_testing import assert_with_pcc
from tests.ttnn.python_api_testing.sweep_tests import ttnn_ops

def run_eltwise_mish_tests(input_shape, dtype, dlayout, in_mem_config, output_mem_config, data_seed, device):
    torch.manual_seed(data_seed)

    x = torch.Tensor(size=input_shape[0]).uniform_(-1e6, 1e6).to(torch.bfloat16)

    # TT PCC CHECK, PASSED 
    try:

        x = ttnn_ops.setup_ttnn_tensor(x, device, dlayout[0], in_mem_config, dtype[0])

        tt_result_softplus = ttnn.softplus(x, beta=1.0, threshold=20.0)
        tt_result_elu = ttnn.elu(x, alpha=1.0)
        tt_result_softplus = ttnn_ops.ttnn_tensor_to_torch(tt_result_softplus, output_mem_config)
        tt_result_elu = ttnn_ops.ttnn_tensor_to_torch(tt_result_elu, output_mem_config)

    except Exception as e:
        logger.warning(f"Operation execution crashed")
        raise e

    assert len(tt_result_softplus.shape) == len(tt_result_elu.shape)
    assert tt_result_softplus.shape == tt_result_elu.shape
    assert_with_pcc(tt_result_softplus, tt_result_elu, 0.9999)

    # TORCH REF PCC CHECK
    # try:

    #     # Torch SOFTPLUS
    #     t_softplus = torch.nn.functional.softplus(x, beta=1.0, threshold=20.0)
    #     # Torch ELU
    #     t_elu = torch.nn.functional.elu(x, alpha = 1.0)

    # except Exception as e:
    #     logger.warning(f"Operation execution crashed")
    #     raise e

    # assert len(t_softplus.shape) == len(t_elu.shape)
    # assert t_softplus.shape == t_elu.shape
    # assert_with_pcc(t_softplus, t_elu, 0.9999)

test_sweep_args = [
    (
        [(1, 1, 102400, 32)],
        [ttnn.bfloat16],
        [ttnn.TILE_LAYOUT],
        (ttnn.DRAM_MEMORY_CONFIG),
        (ttnn.DRAM_MEMORY_CONFIG),
        17799073,
    ),
]

def test_eltwise_mish(device):
    for i in range(1):
        for input_shape, dtype, dlayout, in_mem_config, output_mem_config, data_seed in test_sweep_args:
            run_eltwise_mish_tests(input_shape, dtype, dlayout, in_mem_config, output_mem_config, data_seed, device)
umadevimcw commented 1 month ago

@ruthreshx Even though the PCC is close and there is a noticeable difference in performance, the two activation functions remain fundamentally distinct. Rather than focusing solely on performance outcomes, it’s crucial to evaluate their influence during both the training and inference phases. Random inputs may not capture all the nuances of these differences.

Softplus outputs only positive values, while ELU allows for negative values

ruthreshx commented 1 month ago

@ruthreshx Even though the PCC is close and there is a noticeable difference in performance, the two activation functions remain fundamentally distinct. Rather than focusing solely on performance outcomes, it’s crucial to evaluate their influence during both the training and inference phases. Random inputs may not capture all the nuances of these differences.

Softplus outputs only positive values, while ELU allows for negative values

That's right @umadevimcw. As @esmalTT suggested we need to tweek the softplus implementation up to 40-50% (less segments without reducing accuracy) and reducing the number of immediates used else implement a Mish LLK.

dvartaniansTT commented 1 week ago

@umadevimcw @ruthreshx , is this still being tracked? is there an ETA for this issue to be resolved? cc: @mbahnasTT

umadevimcw commented 1 week ago

@dvartaniansTT The current softplus code (which is part of mish) optimised one. We tried the different implementations for mish where didn't find any performance improvement Hence we are trying the approximation approach for mish as well. Will update here once the approach is approved and the results & performance are consistent

dvartaniansTT commented 1 week ago

thanks for the update @umadevimcw ! sounds good

umadevimcw commented 1 week ago

@dvartaniansTT Please find the PR here and analysis as well #14270

jvasilje commented 3 days ago

@eyonland @umadevimcw @punithsekar is this resolved?

dvartaniansTT commented 3 days ago

@punithsekar please test the shared PR by @umadevimcw.

punithsekar commented 3 days ago

@dvartaniansTT , I see an improvement in performance for mish operation in the attached PR as mentioned in the PR description. The FPS for the unit_test file was 3857.221 previously; Now, with the changes in the PR, it is 9707.795.

Current, perf sheet for the unit_test, mish_recent_perf.csv

Also, We can see a significant improvement in the yolov4 whole model other device ops performance(FPS) from, 110.958 to 156.503.

mouliraj-mcw commented 3 days ago

I’ve listed the performance metrics and limitations of the approach in the PDF attached below. Mish Analysis.pdf

mouliraj-mcw commented 2 days ago

Hi @jvasilje The issue has been resolved and merged into the main branch.