tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
40 stars 7 forks source link

TTIR and TTNN Eltwise class interfaces #110

Closed nsmithtt closed 1 month ago

nsmithtt commented 1 month ago

Let's make the Eltwise class interfaces for these 2 dialects more robust:

nsmithtt commented 1 month ago

Sync with @sdjordjevicTT who is also interested in the simpler Eltwise interfaces.

sdjordjevicTT commented 1 month ago

Yep, it would be great to extend the builders for Eltwise ops to support easier consumption from PyBuda: For example, instead of using the default builder for TTIR_ElementwiseOp which is built on top of let arguments = (ins Variadic:$inputs, Variadic:$outputs, TT_OperandConstraintArrayAttr:$operand_constraints);

We should create builders for TTIR_ElementwiseBinaryOp that will only accept two input params and one output param: let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs, AnyRankedTensor $output, TT_OperandConstraintArrayAttr:$operand_constraints);

How to attach custom builder to an op reference: https://mlir.llvm.org/docs/Tutorials/Toy/Ch-2/#attaching-build-methods

rpavlovicTT commented 1 month ago

Hey folks, I have a commit referenced above that adds a common interface for eltwise ops and adds 1 more level of hierarchy in eltwise ops (unary & binary). Those have attached additional builders that may be used.

But before opening PR I wanted to check if we need this interface, as there already exists Elementwise trait that is tagged to our types. I've just found that it already implements some basic checks for elementwise ops:

file: mlir/lib/IR/Operation.cpp

LogicalResult OpTrait::impl::verifyElementwise(Operation *op) {
  auto isMappableType = [](Type type) {
    return llvm::isa<VectorType, TensorType>(type);
  };
  auto resultMappableTypes = llvm::to_vector<1>(
      llvm::make_filter_range(op->getResultTypes(), isMappableType));
  auto operandMappableTypes = llvm::to_vector<2>(
      llvm::make_filter_range(op->getOperandTypes(), isMappableType));

  // If the op only has scalar operand/result types, then we have nothing to
  // check.
  if (resultMappableTypes.empty() && operandMappableTypes.empty())
    return success();

  if (!resultMappableTypes.empty() && operandMappableTypes.empty())
    return op->emitOpError("if a result is non-scalar, then at least one "
                           "operand must be non-scalar");

  assert(!operandMappableTypes.empty());

  if (resultMappableTypes.empty())
    return op->emitOpError("if an operand is non-scalar, then there must be at "
                           "least one non-scalar result");

  if (resultMappableTypes.size() != op->getNumResults())
    return op->emitOpError(
        "if an operand is non-scalar, then all results must be non-scalar");

  SmallVector<Type, 4> types = llvm::to_vector<2>(
      llvm::concat<Type>(operandMappableTypes, resultMappableTypes));
  TypeID expectedBaseTy = types.front().getTypeID();
  if (!llvm::all_of(types,
                    [&](Type t) { return t.getTypeID() == expectedBaseTy; }) ||
      failed(verifyCompatibleShapes(types))) {
    return op->emitOpError() << "all non-scalar operands/results must have the "
                                "same shape and base type";
  }

  return success();
}

I did implement something very similar, so I don't know if there is a need to check it in. Or should I just leave empty verify function for future extensions?

nsmithtt commented 1 month ago

@rpavlovicTT, I think we should still have your class hierarchy since you already did the work, at the very least it can provide convenience builder interfaces for unary and binary which is something that @sdjordjevicTT was interested in. It also establishes the boilerplate so that in the future it's easy to add more verification.