llvm / torch-mlir

The Torch-MLIR project aims to provide first class support from the PyTorch ecosystem to the MLIR ecosystem.
Other
1.32k stars 490 forks source link

Find a better solution for backends that don't support i64/f64 #1615

Open silvasean opened 1 year ago

silvasean commented 1 year ago

Some backends, such as IREE and TOSA (by today's spec) don't support i64 or f64 (or it is very expensive). This bug is to start collecting use cases and motivation to better support reducing bit width for these targets in Torch-MLIR.

The general theme is that PyTorch's semantics are the 64-bit semantics, but sometimes it is "common sense" to want to drop those down to smaller bit widths. E.g. for a 100kB model, a user might be quite comfortable to use a 32-bit integer to hold the tensor sizes, even if technically it is supposed to be a 64-bit integer.

As we think about this, there are 4 basic cases that probably need separate treatment:

  1. Scalar floats (!torch.float). This represents the PyTorch/TorchScript float type which is 64-bit.
    • Thoughts: I don't have specific insight here, but there are likely many scenarios where 64-bit is overkill (e.g. when deploying to a microcontroller). Some scientific codes might care about some factor being very precise though.
  2. Scalar ints (!torch.int). This represents the PyTorch/TorchScript int type which is specced as 64-bit in TorchScript (though the Python type is arbitrary precision)
    • Thoughts: Often these are used to calculate view sizes and such. With large language models in the 100's of GB these days, we cannot arbitrarily use 32-bit indexing though (though perhaps individual tensor dimensions remain in 32-bit range?).
  3. Tensors with 64-bit floating point numbers (!torch.tensor<[10,12],f64>). This represents tensor computations on f64.
    • Thoughts: PyTorch defaults to f32, so if a user asks for f64 they probably actually want the extra precision (?).
  4. Tensors with 64-bit integers (e.g. !torch.tensor<[10,12],si64>). This is probably most common for embedding indices.
    • Thoughts: Most embeddings are likely OK to index with 32-bit indices, but they seem to be getting larger and larger, and it is not out of the question to need 64-bit indices there.

We need to discuss with the PyTorch devs and see their thinking on this and align on a solution.

silvasean commented 1 year ago

I've created a PyTorch dev discussion: https://dev-discuss.pytorch.org/t/how-to-approach-targets-that-dont-support-i64-f64/867

AmosLewis commented 1 year ago

https://github.com/llvm/torch-mlir/pull/1802