openxla / stablehlo

Backward compatible ML compute opset inspired by HLO/MHLO
Apache License 2.0
388 stars 102 forks source link

Implement type inference for RealDynamicSliceOp #863

Open burmako opened 1 year ago

burmako commented 1 year ago

This month, I've been thinking about implementing a shape function for RealDynamicSliceOp, as well as shape functions for its friends - DynamicConvOp (#861), DynamicGatherOp (#690) and DynamicPadOp (#862).

Such a shape function would need to compute the result shape from runtime values (start_indices, limit_indices and strides operands), which is an inherently imprecise process. We can make things more precise by using various static analysis techniques, e.g. partial evaluation, but we don't have to do this immediately. Simply starting with two cases: 1) all values are static => infer fully static shape, 2) otherwise => infer fully dynamic shape - would be a good step forward from the state of the art which doesn't support even that.

However, there's a logistical problem with this. Whenever we add a shape function to a StableHLO op, its autogenerated Python bindings change to no longer include the result type in the signature of the FooOp.__init__ function. This is not a problem for precise shape functions, but not so much for imprecise shape functions like this one. If the producer knows a more precise type than what the shape function can infer, then FooOp.__init__ won't work for them, and they will need to use the much less convenient FooOp.build_generic. Here's an example from JAX (click here for full code):

  # TODO(burmako): Fix overly conservative type inference of DynamicGatherOp.
  # For now use the build_generic so that we can specify the result type.
  # return hlo.DynamicGatherOp(
  #     operand, indices, mlir.shape_tensor(slice_sizes),
  #     dnums, indices_are_sorted=ir.BoolAttr.get(indices_are_sorted)).results
  results = [mlir.aval_to_ir_type(aval_out)]
  operands = [operand, indices, mlir.shape_tensor(slice_sizes)]
  attributes = {
      "dimension_numbers": dnums,
      "indices_are_sorted": ir.BoolAttr.get(indices_are_sorted)
  }
  return hlo.DynamicGatherOp.build_generic(
      results=results, operands=operands, attributes=attributes).results

This problem prompted me to re-evaluate the path forward. Originally, I decided to work on this ticket because of #622. My plan was to add shape functions for as many StableHLO ops as possible and then use these shape functions to specialize the ops from dynamic shapes to static shapes where possible.

However, later on, I realized that having inferFooOp functions in TypeInference.h is good enough for my use case, even if these function are not hooked into the upstream type inference machinery (#867 talks more about this). Sure, the logistical challenges described above can be solved (e.g. I imagine that #32 will fix the problem with Python bindings), but I wouldn't want to block #622 on that.

With that in mind, I'll be limiting myself to implementing inferRealDynamicSliceOp in TypeInference.h as part of #622 and will not turn it into a shape function for RealDynamicSliceOp just yet. Later on, once the problem with Python bindings is resolved, we can approach this ticket again, but until then I'll be unassigning it.