pytorch / ao

Create and integrate custom data types, layouts and kernels with up to 2x speedups and 65% less VRAM for inference and training
BSD 3-Clause "New" or "Revised" License
340 stars 52 forks source link

Feedback on `quantize()` API #384

Open gau-nernst opened 2 weeks ago

gau-nernst commented 2 weeks ago

Previously we do this

from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors

model = torch.compile(model, mode="max-autotune", fullgraph=True)
change_linear_weights_to_int8_woqtensors(model)
# compile after quant also works

With the new quantization API, we have to do this

from torchao.quantization.quant_api import quantize, int8wo, unwrap_tensor_subclass

model = quantize(model, int8wo())  # or "int8_weight_only"
model = unwrap_tensor_subclass(model)
model = torch.compile(model, mode='max-autotune', fullgraph=True)  # must compile after unwrap

I think the new API is less user-friendly than the previous one.

  1. int8wo(), int4wo() is a bit unintuitive. I understand it is a mechanism to pass params like group size to the quantization. Alternatives: full-blown class with __call__() method e.g. Int8WeightOnlyConfig (kinda verbose, but intention is clear); just pass quant params as extra args/kwargs e.g. quantize("int4wo", groupsize=128)
  2. It's not clear what unwrap_tensor_subclass() does. Also, why do we need it now to compile the model, but not previously?

@jerryzh168

msaroufim commented 2 weeks ago

So my understanding of unwrap_tensor_subclass() is this is primarily there to deal with some limitation of torch.export() but perhaps @tugsbayasgalan can shed some more light but if that's the case then we should ONLY recommend people do that for export() since indeed the API is strange cause it introduces concepts like unwrapping and subclasses which are implementation details so I see 2 options here

  1. Either remove the call to unwrap() by default and only recommend people do it for export() and link to an issue in export() as to why this is needed so people can follow progress
  2. Call unwrap automatically as part of quantize() function so end users aren't aware of it

Regarding the int8wo() only point the problem is not all quantization algorithms will share the same algorithms and kwargs are difficult for users to figure out what's actually supported granted we did explore some other ideas like

  1. Don't have a top level user API for quantization
  2. Don't try to shorten the names so for example jus say Int8WeightOnlyConfig, I personally strongly dislike our abbreviations like wo and qat since they're familiar to people that work with quantization all the time but no one else

On your point around compilation, it is indeed unclear when a user should vs must compile and we need to communicate the benefits and the necessity of compilation might drive users back to a module swap api

gau-nernst commented 2 weeks ago

Using the new quantize() API, unwrap_tensor_subclass() is a MUST. Without it, I'm getting this error (running the snippet above)

  File "/home/---/code/ao/torchao/dtypes/aqt.py", line 240, in __torch_function__
    return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](*args, **kwargs)
  File "/home/---/code/ao/torchao/dtypes/utils.py", line 25, in wrapper
    return func(*args, **kwargs)
  File "/home/---/code/ao/torchao/dtypes/aqt.py", line 685, in functional_linear
    weight_tensor = weight_tensor.dequantize()
  File "/home/---/code/ao/torchao/dtypes/aqt.py", line 160, in dequantize
    int_data, scale, zero_point = self.layout_tensor.get_plain()
torch._dynamo.exc.TorchRuntimeError: Failed running call_method forward(*(Linear(in_features=4096, out_features=4096, bias=False), FakeTensor(..., device='cuda:0', size=(1, 4096), dtype=torch.float16)), **{}):
'FakeTensor' object has no attribute 'get_plain'

from user code:
   File "/home/---/miniconda3/envs/dev_nightly/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 43, in inner
    return fn(*args, **kwargs)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
tugsbayasgalan commented 2 weeks ago

@msaroufim I am working on the support for unwrapping/wrapping nested tensor subclasses in PT2. In general, we expect we should be able to preserve the tensor subclasses if users are targetting our training IR and they shouldn't have to rely on unwrap_tensor_subclass().

yiliu30 commented 2 weeks ago

Hi, I noticed that the GPTQ-related API was marked to be moved to prototype. Is there any alternative API to use, or are there any plans to support GPTQ formally?

jerryzh168 commented 1 week ago

@gau-nernst thanks for the feedback.

  1. makes sense, Configs sounds reasonable to me, I'll gather a bit more feedback on this one. I think we don't want to pass around kwargs since we can't document them
  2. so the reason we need it for torch.compile now is because we have multiple levels of tensor subclass now, but previous implementation does not. this is a temporary workaround and I hope it get fixed soon, @tugsbayasgalan is working on this one
jerryzh168 commented 1 week ago

Hi, I noticed that the GPTQ-related API was marked to be moved to prototype. Is there any alternative API to use, or are there any plans to support GPTQ formally?

we are thinking of deprecating GPTQ when we make HQQ work. cc @HDCharles to confirm that hqq is better than GPTQ in general.

jerryzh168 commented 1 week ago

Hi, I noticed that the GPTQ-related API was marked to be moved to prototype. Is there any alternative API to use, or are there any plans to support GPTQ formally?

can you also describe your use case for GPTQ as well?

HDCharles commented 1 week ago

Hi, I noticed that the GPTQ-related API was marked to be moved to prototype. Is there any alternative API to use, or are there any plans to support GPTQ formally?

@yiliu30 to add on to what @jerryzh168 is saying, we haven't seen a lot of people interested in this API at the moment so its not something we've invested a ton of effort into, there are some limitations in the existing API/implementation that make it not work on some parts of some models unless they're carefully handled (https://github.com/pytorch/ao/blob/main/torchao/_models/llama/model.py#L89-L96) . We could fix those if we rewrote the whole thing, but until we do that, it hasn't been tested as thoroughly and isn't expected to work as widely as something like int8 weight only quantization. If you have a significant use case for GPTQ that may change what we do with it.

yiliu30 commented 1 week ago

Hi, I noticed that the GPTQ-related API was marked to be moved to prototype. Is there any alternative API to use, or are there any plans to support GPTQ formally?

can you also describe your use case for GPTQ as well?

@jerryzh168 @HDCharles My reason for keeping GPTQ support is that it is quite popular within the community :). For instance, Hugging Face currently includes 3000+ GPTQ models.

gau-nernst commented 6 days ago

@jerryzh168 Just visiting this issue again, particularly about unwrap_tensor_subclass(). When I tested with latest main (96d49cd), unwrap_tensor_subclass() is still needed. Are there any drawbacks if we include it inside quantize() so that the users don't need to care about it? (as suggested by @msaroufim https://github.com/pytorch/ao/issues/384#issuecomment-2171701261).

jerryzh168 commented 5 days ago

@jerryzh168 Just visiting this issue again, particularly about unwrap_tensor_subclass(). When I tested with latest main (96d49cd), unwrap_tensor_subclass() is still needed. Are there any drawbacks if we include it inside quantize() so that the users don't need to care about it? (as suggested by @msaroufim #384 (comment)).

main thing is it makes it a bit harder to debug I think, we'll be removing this soon though, in these two days, stay tuned. we are waiting for https://github.com/pytorch/pytorch/pull/127431 to be landed, and I'll put up a PR to remove it

gau-nernst commented 5 days ago

@jerryzh168 that's good to hear! However, users of previous versions of PyTorch (e.g. v2.3) will still need to unwarp tensor subclass? Might not be that important.