pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.5k stars 482 forks source link

Model support for `mobilenet_v2_quantized_qat` with Torch_XLA2 #8155

Open ManfeiBai opened 2 months ago

ManfeiBai commented 2 months ago

Fix the model test for mobilenet_v2_quantized_qat.py

  1. setup env according to Run a model under torch_xla2
  2. Run model test under run_torchbench/ with python models/your_target_model_name.py
  3. Fix the failure.

Please refer to this guide as guide to fix:

Also refer to these PRs:

barney-s commented 3 weeks ago

Requires lowering aten::quantize_per_tensor.tensor_qparams

% JAX_ENABLE_X64=true JAX_PLATFORMS=cpu python models/mobilenet_v2_quantized_qat.py 
/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py:337: UserWarning: Device capability of jax unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
  warnings.warn(
/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py:146: FutureWarning: Passing a QConfig dictionary to prepare is deprecated and will not be supported in a future version. Please pass in a QConfigMapping instead.
  prepared = prepare(
/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/ao/quantization/observer.py:229: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  warnings.warn(
/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/ao/quantization/utils.py:408: UserWarning: must run observer before calling calculate_qparams. Returning default values.
  warnings.warn(
Traceback (most recent call last):
  File "/usr/local/google/home/barni/workspace/pytorch-tpu/run_torchbench/models/mobilenet_v2_quantized_qat.py", line 61, in <module>
    sys.exit(main())
  File "/usr/local/google/home/barni/workspace/pytorch-tpu/run_torchbench/models/mobilenet_v2_quantized_qat.py", line 39, in main
    xla2_ans = model(*example)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/fx/graph_module.py", line 822, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/fx/graph_module.py", line 400, in __call__
    raise e
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/fx/graph_module.py", line 387, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/google/home/barni/miniconda3/envs/diffusion-models-2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.5", line 7, in forward
  File "/usr/local/google/home/barni/workspace/pytorch/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 215, in __torch_function__
    return func(*args, **(kwargs or {}))
  File "/usr/local/google/home/barni/workspace/pytorch/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 230, in __torch_dispatch__
    return self.env.dispatch(func, types, args, kwargs)
  File "/usr/local/google/home/barni/workspace/pytorch/xla/experimental/torch_xla2/torch_xla2/tensor.py", line 413, in dispatch
    raise OperatorNotFound(
torch_xla2.tensor.OperatorNotFound: Operator with name aten::quantize_per_tensor.tensor_qparams has no lowering