fangwei123456 / spikingjelly

SpikingJelly is an open-source deep learning framework for Spiking Neural Network (SNN) based on PyTorch.
https://spikingjelly.readthedocs.io
Other
1.4k stars 246 forks source link

Compatibility with Pytorch's quantization-aware training procedure #417

Open CloudyDory opened 1 year ago

CloudyDory commented 1 year ago

For faster response

You can @ the corresponding developers for your issue. Here is the division:

Features Developers
Neurons and Surrogate Functions fangwei123456
Yanqi-Chen
CUDA Acceleration fangwei123456
Yanqi-Chen
Reinforcement Learning lucifer2859
ANN to SNN Conversion DingJianhao
Lyu6PosHao
Biological Learning (e.g., STDP) AllenYolk
Others Grasshlw
lucifer2859
AllenYolk
Lyu6PosHao
DingJianhao
Yanqi-Chen
fangwei123456

We are glad to add new developers who are volunteering to help solve issues to the above table.

Issue type

SpikingJelly version

0.0.0.0.14

Description

I am hoping to train an SNN with weight quantization in linear and convolution layers by spikingjelly. However, it seems that the linear and convolution modules in spikingjelly.activation_based.layer cannot be recognized by Pytorch's torch.quantization.prepare_qat() function.

Minimal code to reproduce the error/bug

import torch
import torch.nn as nn
from spikingjelly.activation_based import layer

class ANN(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.layers = nn.Linear(8, 8)
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        out = self.quant(x)
        out = self.layers(out)
        out = self.dequant(out)
        return out

class SNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.layers = layer.Linear(8, 8)
        self.dequant = torch.ao.quantization.DeQuantStub()

    def forward(self, x):
        out = self.quant(x)
        out = self.layers(out)
        out = self.dequant(out)
        return out

ann = ANN()
ann.eval()
ann.qconfig = torch.quantization.get_default_qat_qconfig('x86')
ann_fp32_prepared = torch.quantization.prepare_qat(ann.train())
print(ann_fp32_prepared)

snn = SNN()
snn.eval()
snn.qconfig = torch.quantization.get_default_qat_qconfig('x86')
snn_fp32_prepared = torch.quantization.prepare_qat(snn.train())
print(snn_fp32_prepared)

This produce the following output:

ANN(
  (quant): QuantStub(
    (activation_post_process): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (layers): Linear(
    in_features=8, out_features=8, bias=True
    (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_channel_symmetric, reduce_range=False
      (activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([]))
    )
    (activation_post_process): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (dequant): DeQuantStub()
)

SNN(
  (quant): QuantStub(
    (activation_post_process): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (layers): Linear(in_features=8, out_features=8, bias=True)
  (dequant): DeQuantStub()
)
fangwei123456 commented 1 year ago

Hi, I am not familiar to the quantization modules in PyTorch. But you can try to use spikingjelly.activation_based.quantize. Another solution is that you can check the source codes of conv/linear in SpikingJelly. They are almost idential with those in PyTorch except for they support to run in multi-step mode. You can check how to modify them to support the quantization modules in PyTorch.