Xilinx / brevitas

Brevitas: neural network quantization in PyTorch
https://xilinx.github.io/brevitas/
Other
1.2k stars 197 forks source link

Add TopK layer in the BNN-PYNQ Brevitas experiments #836

Open abedbaltaji opened 9 months ago

abedbaltaji commented 9 months ago

Hello,

We are working with FINN using the pre-trained models under BNN-PYNQ Brevitas experiments that classify images from the MNIST dataset. After installing the Brevitas model using FINN APIs, we are exporting it into Brevitas ONNX model. The output of this model is a vector of dim 1x10, representing the probabilities from 0 to 9.

My goal is to implement a TopK layer that will give me the index of the highest probability instead a probabilities vector. We would like to have your suggestion concerning the best method to modify the pre-trained model so at the end the generated FINN stitched IP would have an output of dim 1x1 instead of 1x10.

Screenshot from 2024-02-08 17-01-53

Note that we have tried to insert a TopK node in the ONNX graph, but it has not been integrated well. So we would like to know the best practice of modifying a Brevitas model, either from the ONNX graph or from the PyTorch model. Looking for your suggestions.

Giuseppe5 commented 9 months ago

Would it not be sufficient to do argmax on the output of the network?

If you would like to export that operation, it is sufficient to write a wrapper around the original network where in the forward you add the extra argmax op, and then proceed to export.

Not sure if I fully understood the issue, so let me know if this does not work.

abedbaltaji commented 9 months ago

I am new to Brevitas models. Could you give an example of where I have to add the wrapper?