Open KalaivaniMCW opened 2 days ago
Hi @rdjogoTT , I would like your thoughts on this kernel implementation of i1 (corresponding torch op torch.special.i1 ) and another approach in here. In both cases k = 0 to 10, same as in the i0 kernel.
In the former approach (in this PR), i1 is computed based on the current i0 kernel and gives a pcc upto 0.996 in the latter approach it sticks to the general formula, although it is not successful in giving a good pcc
Would like to know if there is any other approach we can try ?
We need this op to implement backward op for i0 efficiently.
@KalaivaniMCW I think the approach of implementing i1
as a polynomial is correct. I took a quick look at https://github.com/tenstorrent/tt-metal/pull/15325 and found that if I change line 17 in tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_i1.h
to be:
((coef0 + (coef1 + (coef2 + (coef3 + (coef4 + (coef5 + (coef6 + (coef7 + (coef8 + (coef9 + coef10 * t2) * t2) * t2) * t2) * t2) * t2) * t2) * t2 ) * t2) * t2) * t2)
then the i1 test gets PCC ~= 0.999998 and the i0_bw test gets PCC ~= 0.9998. The change is that you might be missing a * t2
term in your POLYVAL10_I1
. Can you double check that this is it? I haven't taken too close of a look so I might be wrong, but I compared the polynomial you have in i1
to the one you have in i0
and the i0
one has this extra term present.
Similarly, I see that in this PR, the POLYVAL10_DERIVATIVE
you defined does not make use of the coef0
term. Not sure if this was intentional or not, so please check.
@rdjogoTT
In POLYVAL10_DERIVATIVE
I pass the coef0
term just to compare with current implementation, what happens is the coef0
is with 0 (i.e. k coef) and so it gets ignored.
Let me test POLYVAL10_I1
once again. Thank you for the feedback!
Ticket
Link to Github Issue #13676
Problem description
Current implementation of i0_bw uses reciprocal op which has an ongoing issue #14672
What's changed
op,count,python min dispatch time (ms),python mean dispatch time(ms),python mean dispatch + sync time (ms),C++ mean dispatch time (ms) ttnn.i0_bw,800,0.97,0.997,3.717,0.364 (main ) ttnn.i0_bw,800,0.241,0.244,0.884,0.086 (branch)
Checklist