Closed punithsekar closed 2 days ago
fyi @saichandax
Hi @esmalTT , @punithsekar , @eyonland,
On debugging, we found that the Mish holds the following formula.
**x∗Tanh(Softplus(x))**
On further looking into profiler_result.csv, Softplus took the most time to compute (208524 ns).
Workaround's:
Softplus Formula:
@esmalTT, We would like get your opinion on this? Thank you!
@eyonland @esmalTT As @ruthreshx explained above, the existing piecewise approximation approach seems optimised one for softplus compare to other implementation we tried.
@esmalTT Kindly share your inputs as well
It’s probably feasible to improve the softplus implementation up to 40-50% by tweaking the current approximation (less segments without reducing accuracy) and reducing the number of immediates used.
If that won’t be fast enough, you’ll probably have to come up with a Mish approximation function and then implement a Mish LLK.
Hi @esmalTT , Can we use the below formula, since ELU (Exponential Linear Unit) is another smoother approximation to the Softplus function.
Formula:
Mish(x)≈x⋅tanh(ELU(x))
I could see the performance improvement up to 40-50%. perf_results_mish_unit_test.csv
your opinion on this?
Hi @esmalTT , Can we use the below formula, since ELU (Exponential Linear Unit) is another smoother approximation to the Softplus function.
Formula:
Mish(x)≈x⋅tanh(ELU(x))
I could see the performance improvement up to 40-50%. perf_results_mish_unit_test.csv
your opinion on this?
@ruthreshx I don’t see how that would work - they are different functions. You can always plot the error between this and softplus in PyTorch to see what happens.
Hi @esmalTT , Though it is different function, I don't see much of different in the torch result. I have tried with the different ranges as well.
import matplotlib.pyplot as plt
import torch
import numpy as np
x = np.linspace(-1000, 1000, 10)
# Torch numpy
x = torch.from_numpy(x)
print(x)
# Torch ELU
t_elu = torch.nn.functional.elu(x, alpha = 1.0)
# Torch SOFTPLUS
t_softplus = torch.nn.functional.softplus(x, beta=1.0, threshold=20.0)
print("Torch ELU ==> ", t_elu)
print("Torch SOFTPLUS ==> ", t_softplus)
# Plot
plt.plot(x.numpy(), t_elu.numpy(), "ob", label="torch elu")
plt.plot(x.numpy(), t_softplus.numpy(), "+r", label="torch softplus")
plt.xlabel('Input')
plt.ylabel('Output')
plt.title('Comparison of torch elu and torch softplus function')
plt.legend()
plt.grid(True)
plt.show()
Plot result:
Hi @esmalTT , Though it is different function, I don't see much of different in the torch result. I have tried with the different ranges as well.
import matplotlib.pyplot as plt import torch import numpy as np x = np.linspace(-1000, 1000, 10) # Torch numpy x = torch.from_numpy(x) print(x) # Torch ELU t_elu = torch.nn.functional.elu(x, alpha = 1.0) # Torch SOFTPLUS t_softplus = torch.nn.functional.softplus(x, beta=1.0, threshold=20.0) print("Torch ELU ==> ", t_elu) print("Torch SOFTPLUS ==> ", t_softplus) # Plot plt.plot(x.numpy(), t_elu.numpy(), "ob", label="torch elu") plt.plot(x.numpy(), t_softplus.numpy(), "+r", label="torch softplus") plt.xlabel('Input') plt.ylabel('Output') plt.title('Comparison of torch elu and torch softplus function') plt.legend() plt.grid(True) plt.show()
Plot result:
You really need to look at the error instead of the absolute values. I would also check more than 10 data points (probably at least 100K for that range).
You can also check the PCC of these two plots. I would expect it to be at least 0.9999.
@esmalTT Verified TTNN Softplus vs TTNN ELU and Torch Softplus vs Torch ELU with the PCC of 0.9999. I don't see any PCC drop over here.
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
# SPDX-License-Identifier: Apache-2.0
from loguru import logger
import random
import pytest
import torch
import ttnn
from tests.ttnn.utils_for_testing import assert_with_pcc
from tests.ttnn.python_api_testing.sweep_tests import ttnn_ops
def run_eltwise_mish_tests(input_shape, dtype, dlayout, in_mem_config, output_mem_config, data_seed, device):
torch.manual_seed(data_seed)
x = torch.Tensor(size=input_shape[0]).uniform_(-1e6, 1e6).to(torch.bfloat16)
# TT PCC CHECK, PASSED
try:
x = ttnn_ops.setup_ttnn_tensor(x, device, dlayout[0], in_mem_config, dtype[0])
tt_result_softplus = ttnn.softplus(x, beta=1.0, threshold=20.0)
tt_result_elu = ttnn.elu(x, alpha=1.0)
tt_result_softplus = ttnn_ops.ttnn_tensor_to_torch(tt_result_softplus, output_mem_config)
tt_result_elu = ttnn_ops.ttnn_tensor_to_torch(tt_result_elu, output_mem_config)
except Exception as e:
logger.warning(f"Operation execution crashed")
raise e
assert len(tt_result_softplus.shape) == len(tt_result_elu.shape)
assert tt_result_softplus.shape == tt_result_elu.shape
assert_with_pcc(tt_result_softplus, tt_result_elu, 0.9999)
# TORCH REF PCC CHECK
# try:
# # Torch SOFTPLUS
# t_softplus = torch.nn.functional.softplus(x, beta=1.0, threshold=20.0)
# # Torch ELU
# t_elu = torch.nn.functional.elu(x, alpha = 1.0)
# except Exception as e:
# logger.warning(f"Operation execution crashed")
# raise e
# assert len(t_softplus.shape) == len(t_elu.shape)
# assert t_softplus.shape == t_elu.shape
# assert_with_pcc(t_softplus, t_elu, 0.9999)
test_sweep_args = [
(
[(1, 1, 102400, 32)],
[ttnn.bfloat16],
[ttnn.TILE_LAYOUT],
(ttnn.DRAM_MEMORY_CONFIG),
(ttnn.DRAM_MEMORY_CONFIG),
17799073,
),
]
def test_eltwise_mish(device):
for i in range(1):
for input_shape, dtype, dlayout, in_mem_config, output_mem_config, data_seed in test_sweep_args:
run_eltwise_mish_tests(input_shape, dtype, dlayout, in_mem_config, output_mem_config, data_seed, device)
@ruthreshx Even though the PCC is close and there is a noticeable difference in performance, the two activation functions remain fundamentally distinct. Rather than focusing solely on performance outcomes, it’s crucial to evaluate their influence during both the training and inference phases. Random inputs may not capture all the nuances of these differences.
Softplus outputs only positive values, while ELU allows for negative values
@ruthreshx Even though the PCC is close and there is a noticeable difference in performance, the two activation functions remain fundamentally distinct. Rather than focusing solely on performance outcomes, it’s crucial to evaluate their influence during both the training and inference phases. Random inputs may not capture all the nuances of these differences.
Softplus outputs only positive values, while ELU allows for negative values
That's right @umadevimcw. As @esmalTT suggested we need to tweek the softplus implementation up to 40-50% (less segments without reducing accuracy) and reducing the number of immediates used else implement a Mish LLK.
@umadevimcw @ruthreshx , is this still being tracked? is there an ETA for this issue to be resolved? cc: @mbahnasTT
@dvartaniansTT The current softplus code (which is part of mish) optimised one. We tried the different implementations for mish where didn't find any performance improvement Hence we are trying the approximation approach for mish as well. Will update here once the approach is approved and the results & performance are consistent
thanks for the update @umadevimcw ! sounds good
@dvartaniansTT Please find the PR here and analysis as well #14270
@eyonland @umadevimcw @punithsekar is this resolved?
@punithsekar please test the shared PR by @umadevimcw.
@dvartaniansTT , I see an improvement in performance for mish operation in the attached PR as mentioned in the PR description. The FPS for the unit_test file was 3857.221 previously; Now, with the changes in the PR, it is 9707.795.
Current, perf sheet for the unit_test, mish_recent_perf.csv
Also, We can see a significant improvement in the yolov4 whole model other device ops performance(FPS) from, 110.958 to 156.503.
I’ve listed the performance metrics and limitations of the approach in the PDF attached below. Mish Analysis.pdf
Hi @jvasilje The issue has been resolved and merged into the main branch.
Describe the bug ttnn.mish operation takes too much time(Kernel duration) to compute.
To Reproduce Steps to reproduce the behavior:
./tt_metal/tools/profiler/profile_this.py -n mish_unit_test -c "pytest tests/ttnn/integration_tests/yolov4/unit_test_mish.py"
Please complete the following environment information:
Additional context Perf sheet for the same, mish_perf_sheet.csv