albanD / subclass_zoo

121 stars 22 forks source link

Question: apply trainable scale for qdq `linear` and `matmul` #51

Open yiliu30 opened 2 months ago

yiliu30 commented 2 months ago

In a quantization scenario where fake quantization is utilized to assess the accuracy of a new algorithm with trainable scale, we can implement it for an eager model by replacing the Linear module with QDQLinear, as demonstrated below:

class QDQLinear(torch.nn.Module):
    def __init__(self, orign_linear: torch.nn.Module) -> None:
        super().__init__()
        self._orign_linear = original_tensor
        self.trainable_scale =  torch.nn.Parameter(torch.tensor(1), requires_grad=True)

    def qdq_tensor(self, input: torch.Tensor):
        # ... new qdq method that use `self.trainable_scale` to update q-dq tensor.
        # int_input = q(input)
        # qdq_input = dq(int_input)
        # return qdq_input
        pass

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        input = self.qdq_tensor(input)
        return torch.nn.functional.linear(input, self._orign_linear.weight, self._orign_linear.bias)

### replace all `Linear` with `QDQLinear`

However, some models utilize torch.matmul to perform similar thing as torch.nn.Linear. We also want to apply the aforementioned QDQ method to torch.matmul, but this cannot be achieved through module swapping.

We may probably customize a new TorchDispatchMode to replace all aten.mm with qdq - aten.mm to apply qdq to all input tensors of torch.matmul or torch.nn.Linear. However, I'm currently unsure how to handle the trainable_scale. Do you happen to have any suggestions?

Thank you very much!

albanD commented 2 months ago

Hi

If you want to pre-process the input to the Module, I think a module pre forward hook would work?

def qdq_tensor(input):
    pass

your_layer.register_forward_pre_hook(qdq_tensor)

Is that enough for you?

yiliu30 commented 2 months ago

Oh, sorry. There are a few errors in the question. I want to pre-process the module's weight (for weight-only quantization).


class QDQLinear(torch.nn.Module):
    def __init__(self, orign_linear: torch.nn.Module) -> None:
        super().__init__()
        self._orign_linear = original_tensor
        self.trainable_scale =  torch.nn.Parameter(torch.tensor(1), requires_grad=True)

    def qdq_tensor(self, input: torch.Tensor, scale: torch.Tensor):
        # ... new qdq method that use `self.trainable_scale` to update q-dq tensor.
        # int_input = q(input, scale)
        # qdq_input = dq(int_input)
        # return qdq_input
        pass

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        qdq_weight= self.qdq_tensor(self._orign_linear.weight, self.trainable_scale)  # <---------- q-dq weight
        return torch.nn.functional.linear(input, qdq_weight, self._orign_linear.bias)

### replace all `Linear` with `QDQLinear`
albanD commented 2 months ago

If you do want these two params and combine them only when mod.weight is used, I would suggest reparametrization:

from torch.nn.utils.parametrize import register_parametrization

class QDQParam(torch.nn.Module):
    def forward(self, orig_linear_weight, scale):
        return qdq_tensor(orig_linear_weight, scale)

    def right_inverse(self, orig_linear_weight):
        return orig_linear_weight, torch.tensor(1)

m = nn.Linear(2, 2)
register_parametrization(m, "weight", QDQParam())

More details at https://pytorch.org/tutorials/intermediate/parametrizations.html