Closed meijieru closed 1 year ago
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).
View this failed invocation of the CLA check for more information.
For the most up to date status, view the checks section at the bottom of the pull request.
Hi @meijieru,
Thanks for your work, I appreciate that! I ran some quick performance test to see the difference when using the PR's code. On my local machine, I was able to observe a runtime increase of a bit more than 70%. I would address this mainly due to how the confusion matrix is computed.
Given that, we run this code in some time critical environments (like benchmark server), I would suggest to not alter the code directly, but make a weighted STQ subclass that inherits from STQ. This way, we could have both performance for the base case, but also support for the weighted STQ.
What do you think @aquariusjay ?
@markweberdev Thanks for reviewing the code. The suggestion sounds good.
@meijieru Could you please try to update the code to maintain the original computation speed? It would be great if wSTQ could be also efficiently computed. What do you think?
@markweberdev Thanks for reviewing the code. The suggestion sounds good.
@meijieru Could you please try to update the code to maintain the original computation speed? It would be great if wSTQ could be also efficiently computed. What do you think?
Thanks for the suggestions! Would check it soon.
Several updates
Would you please check again? Thx.
Looks great on my end. Thanks for updating the code, @meijieru
Please wait for the input from @markweberdev
Updated, thanks for the suggestions.
This PR adds wSTQ implementation in numpy. It also adds the unit test to guarantee the compatibility of tf & numpy implementation. Fix a dtype error for tf impl.