Xilinx / brevitas

Brevitas: neural network quantization in PyTorch
https://xilinx.github.io/brevitas/
Other
1.16k stars 192 forks source link

Using external activation functions #981

Closed Maya7991 closed 3 weeks ago

Maya7991 commented 3 months ago

I have a Spiking convolutional neural network. It uses the Leaky(Leaky Integrate and Fire) neuron from SNNTorch library as activation function. Is it possible to use activation functions like from SNNTorch along with Brevitas. Given below is an example architecture.

import snntorch as snn
from snntorch.functional import quant

class SpikingCNN(nn.Module):
    def __init__(self):
        super(SpikingCNN, self).__init__()

        qlif = quant.state_quant(num_bits=8, uniform=False, thr_centered=True)

        self.conv1 = qnn.QuantConv2d(in_channels=1, out_channels=8, kernel_size=5, padding=0, bias=False, weight_bit_width=8)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 =qnn.QuantConv2d(in_channels=8, out_channels=16, kernel_size=3, padding=0, bias=False, weight_bit_width=8)
        self.conv3 = qnn.QuantConv2d(in_channels=16, out_channels=16, kernel_size=1, padding=0, bias=False, weight_bit_width=8)
        self.fc1 = qnn.QuantLinear(3*3*16, 256, bias=True, weight_bit_width=8)
        self.fc2 = qnn.QuantLinear(256, 10, bias=True, weight_bit_width=8)
        self.lif= snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, state_quant=qlif)

    def forward(self, x):
        self.activation.reset_mem()

        x = self.pool(F.relu(self.conv1(x)))  # First conv + ReLU + max pool
        for step in range(self.num_steps):
            x =self.lif(self.conv2(x))  # Second conv + LIF
        x = torch.sum(x, dim=0)
        x = F.relu(self.conv3(x))  # Third conv + ReLU
        x = x.view(-1, 32 * 7 * 7) 
        x = F.relu(self.fc1(x))     
        x = self.fc2(x)            
        return x

spikingcnn_model= SpikingCNN()
spikingcnn_loss=trainNet(spikingcnn_model,0.0003)   # not including training loop here

quantized_weights = {}
    for name, module in spikingcnn_model.named_modules():
        if isinstance(module, qnn.QuantConv2d) and 'conv2' in name:
            quantized_weights[f'{name}.weight'] = module.quant_weight().int().detach().cpu().numpy()

Is it possible to use Brevitas along with such custom activation functions?

The purpose of quantizing my model is to extract the INT8 weights and use it for simulation of VHDL design I have written. I have recorded the INT8 weights of the quantized spiking convolutional layer(conv2). However, I observe that there is some difference in the observed values after the activation function and expected values. I would like to know if Brevitas support using custom activation functions. If yes, does it need any additional configurations?

Giuseppe5 commented 3 months ago

Hi,

Thanks for opening this issue. Brevitas's layers are made to work as drop-in replacement of the corresponding PyTorch ones, and based on your example, I believe there should be no issue in combining it with third party libraries even though we have no experience with SNNTorch in particular.

You are mentioning observed vs expected values. Where are the expected values coming from?

Maya7991 commented 2 months ago

Hi @Giuseppe5 ,

I apologize for the delay. I had to look into the basics of quantizing a model in order to be able to explain my doubts here.

Use case: Train a spiking CNN with Leaky or LIF activation function in PyTorch & SNNTorch and use the INT8 weights of this trained model in my VHDL design.

  1. Performed QAT of a vanilla Spiking CNN model using Brevitas and extracted INT8 weights from the trained model.
  2. Manually calculated conv operation for a single channel and activation function output.
  3. Compare it with the output of model.

From your previous reply, I understand that there is no problem in using SNNTorch along with Brevitas. However, I have been trying to calculate some channel output manually and compare it with the output of Quantized model. This is where I am seeing a difference in observed vs expected values.

I have a few assumptions on why this is happening. The quant and dequant stubs between each layer in a fake quantized model would not allow such a comparison. As this is a fake quantized model, I have to generate a True INT8 model to be able to compare it with the manual calculations I am doing.

If the above question is the problem, Can I generate a true INT8 model in which I can run a inference pass which uses only INT8 values and no FP32 values? I did not post my thoughts here for so long because I was not able to decide how much of this come under the scope of Brevitas.

note: the input to Spiking conv model comprises of 0 and 1(spike or not spike) which makes manual calculation easy. MAC reduces to just accumulation operation.

Thank you!

Giuseppe5 commented 2 months ago

I still have a few questions about the set-up. If you could share a reproducible script where you show how to compute the real vs expected results, it could be easier for us to help.

Giuseppe5 commented 3 weeks ago

If this is still an issue, please feel free to re-open and we'd be more than happy to help!