Open yiliu30 opened 2 months ago
Hi
If you want to pre-process the input to the Module, I think a module pre forward hook would work?
def qdq_tensor(input):
pass
your_layer.register_forward_pre_hook(qdq_tensor)
Is that enough for you?
Oh, sorry. There are a few errors in the question. I want to pre-process the module's weight (for weight-only quantization).
class QDQLinear(torch.nn.Module):
def __init__(self, orign_linear: torch.nn.Module) -> None:
super().__init__()
self._orign_linear = original_tensor
self.trainable_scale = torch.nn.Parameter(torch.tensor(1), requires_grad=True)
def qdq_tensor(self, input: torch.Tensor, scale: torch.Tensor):
# ... new qdq method that use `self.trainable_scale` to update q-dq tensor.
# int_input = q(input, scale)
# qdq_input = dq(int_input)
# return qdq_input
pass
def forward(self, input: torch.Tensor) -> torch.Tensor:
qdq_weight= self.qdq_tensor(self._orign_linear.weight, self.trainable_scale) # <---------- q-dq weight
return torch.nn.functional.linear(input, qdq_weight, self._orign_linear.bias)
### replace all `Linear` with `QDQLinear`
If you do want these two params and combine them only when mod.weight is used, I would suggest reparametrization:
from torch.nn.utils.parametrize import register_parametrization
class QDQParam(torch.nn.Module):
def forward(self, orig_linear_weight, scale):
return qdq_tensor(orig_linear_weight, scale)
def right_inverse(self, orig_linear_weight):
return orig_linear_weight, torch.tensor(1)
m = nn.Linear(2, 2)
register_parametrization(m, "weight", QDQParam())
More details at https://pytorch.org/tutorials/intermediate/parametrizations.html
In a quantization scenario where fake quantization is utilized to assess the accuracy of a new algorithm with trainable scale, we can implement it for an eager model by replacing the Linear module with
QDQLinear
, as demonstrated below:However, some models utilize torch.matmul to perform similar thing as
torch.nn.Linear
. We also want to apply the aforementioned QDQ method totorch.matmul
, but this cannot be achieved through module swapping.We may probably customize a new
TorchDispatchMode
to replace allaten.mm
withqdq - aten.mm
to apply qdq to all input tensors oftorch.matmul
ortorch.nn.Linear
. However, I'm currently unsure how to handle thetrainable_scale
. Do you happen to have any suggestions?Thank you very much!