openxla / stablehlo

Backward compatible ML compute opset inspired by HLO/MHLO
Apache License 2.0
376 stars 103 forks source link

Consider adding support for unknown scales and zero_points #1407

Open sdasgup3 opened 1 year ago

sdasgup3 commented 1 year ago

The goal of the ticket is to track the support of unknown scales and zero-points. This is required to represent the scales and zero-points, in StableHLO graph, calculated on the fly by the training program while quantizing the activations.

Please refer to relevant discussion here.

lgeiger commented 2 months ago

Support for unknown scales would be incredibly useful for quantization aware training (QAT). What is the current status on this?

Maybe a bit of context, our use case is focused on QAT targeting a fully int8 quantized TFLite inference model. Currently we're relying on tf.quantization.fake_quant_with_min_max_vars on the training side. As far as I'm aware this is the only supported way at the moment but it would be great to be able to directly output StableHLO from jax or maybe even PyTorch for greater flexibility and better usability.

@abattery mentioned that he's interested in QAT as well and the odml team seems to have a way to inject stablehlo.uniform_quantize ops but I'm not sure what the latest status on these efforts are.

@sdasgup3 do you know whether there is interest in supporting QAT workflows via StableHLO from frontends like jax or PyTorch? I'd be very interested in getting involved and contributing towards any consolidated effort here since QAT is much easier to deal with from an ML training standpoint compared to post training quantization which always has the potential to introduce accuracy degradations if not done carefully.

sdasgup3 commented 2 months ago

@lgeiger Thanks for bringing this up and providing details about your case. This is on our radar for sometime, but did get a sufficiently motivating use-case (and bandwidth) to initiate work on it.

do you know whether there is interest in supporting QAT workflows via StableHLO from frontends like jax or PyTorch? I'd be very interested in getting involved and contributing towards any consolidated effort here

Much appreciated on your willingness to contribute! I will get back to you on you the question.