tenstorrent / tt-mlir

Tenstorrent MLIR compiler
https://tenstorrent.github.io/tt-mlir/
Apache License 2.0
65 stars 8 forks source link

[LLAMA] Implement conversion of stablehlo.complex to TTIR #748

Open mmanzoorTT opened 4 weeks ago

mmanzoorTT commented 4 weeks ago

stablehlo.complex OP specification can be found here.

ddilbazTT commented 2 weeks ago

@nsmithtt @uazizTT @mmanzoorTT Am I supposed to be creating a new file under runtime/lib/ttnn/operations/eltwise for complex operations? Or should I add a complexEltwiseBinaryOp under binary.cpp? I am assuming this function will use ComplexTensor struct defined in third_party/tt-metal/src/tt-metal/ttnn/cpp/ttnn/operations/eltwise/complex/complex.hpp

Edit: changes so far --> https://github.com/tenstorrent/tt-mlir/compare/ddilbaz/issue-748?expand=1

ddilbazTT commented 2 weeks ago

Hi all, I need help because "Tensor" and "ComplexTensor" structs are not interchangeable with the way the codebase is designed right now. It seems implementing ComplexOp needs an extensive code refactoring.

I created a folder under runtime/lib/ttnn/operations named complex - which includes a header and a cpp file. (Commit Link) The most basic purpose of this commit is to include a ComplexTensor alternative of run. For example, this is the run operation for "concat" (from runtime/lib/ttnn/operations/data_movement/concat.cpp).

void run(const ::tt::target::ttnn::ConcatOp *op, ProgramContext &context) {
  ProgramTensorPool &tensorPool = context.getTensorPool();
  std::vector<::ttnn::Tensor> inputs;
  for (const auto &input : *op->inputs()) {
    inputs.push_back(tensorPool.at(input->global_id()));
  }
  int32_t dim = op->dim();
  ::ttnn::Tensor out = ::ttnn::concat(inputs, dim);
  tensorPool.insert_or_assign(op->out()->global_id(), out);
}

My void run(const ::tt::target::ttnn::ComplexOp *op, ProgramContext &context) is not correct, and this is related to needed codebase changes. I will explain the dilemma below.

If you look at /runtime/lib/ttnn/runtime.cpp, you will see the function

Event submit(Device deviceHandle, Binary executableHandle,
             std::uint32_t programIndex,
             std::vector<Tensor> const &inputHandles,
             std::vector<Tensor> const &outputHandles) 

which invokes

void runProgram(::ttnn::MeshDevice &meshDevice,
                ::tt::target::ttnn::Program const *program,
                std::vector<::ttnn::Tensor *> const &inputs,
                std::vector<::ttnn::Tensor *> const &outputs)

from runtime/lib/ttnn/program.cpp. Simply, runProgram executes the run command for each op as described at the beginning of this comment.

However, Event submit uses Tensor struct, defined in runtime/include/tt/runtime/types.h

struct Tensor : public detail::RuntimeCheckedObjectImpl {
  std::shared_ptr<void> data;
  Tensor(std::shared_ptr<void> handle, std::shared_ptr<void> data,
         DeviceRuntime runtime)
      : detail::RuntimeCheckedObjectImpl(handle, runtime), data(data) {}
};

However, as discussed on https://github.com/tenstorrent/tt-mlir/pull/808 - we had decided to implement ComplexOp to be used for RealOp/ImagOp and use ComplexTensor struct defined in third_party/tt-metal/src/tt-metal/ttnn/cpp/ttnn/operations/eltwise/complex/complex.hpp

struct ComplexTensor {
    std::array<Tensor, 2> m_real_imag;

    const Tensor& operator[](uint32_t index) const;
    const Tensor& real() const;
    const Tensor& imag() const;
    void deallocate();
};

Thus, we need to figure out how to interchange between Tensor/ComplexTensor structs and most likely the following files need to change:

I might be wrong in assuming that Tensor/ComplexTensor structs are interchangeable. Regardless, I need help on how to move forward. Please view the commits for this issue to see where I am at. Thanks!

@nsmithtt @mmanzoorTT @uazizTT @jnie-TT @AleksKnezevic

jnie-TT commented 1 week ago

Hi all, I need help because "Tensor" and "ComplexTensor" structs are not interchangeable with the way the codebase is designed right now. It seems implementing ComplexOp needs an extensive code refactoring.

Hey @ddilbazTT I see the issue here. How/where are we constructing these complex tensors? If it's outside of runtime on the user level, could we call submit with 2 regular tensors representing the real and imaginary tensors? I assume we should have this info already since constructing ComplexTensors would need this regardless. In the TTNN IR we can implement the real op (and any other complex op) such that it takes in 2 regular tensors, one real, one imaginary. If we need to we can construct the ComplexTensor within the run API of the op. Essentially do something like

void run(const ::tt::target::ttnn::ComplexOp *op, ProgramContext &context) {
  ProgramTensorPool &tensorPool = context.getTensorPool();
  ::ttnn::Tensor &real = tensorPool.at(op->real()->global_id());
  ::ttnn::Tensor &img = tensorPool.at(op->img()->global_id());
  ComplexTensor complex(real, img);
  // Do whatever we need with the complex tensor...
}

As long as the global ids of the tensors between ops are valid, we probably won't need to store the ComplexTensor in the program context either, any op that outputs a complex tensor will just have 2 outputs, 1 real 1 imaginary.

nsmithtt commented 1 week ago

An alternative it is to not use complex at all. I guess it depends on how the tensor is used. Per my comment here it seems like we could bypass the use of TTNN complex entirely: https://github.com/tenstorrent/tt-mlir/pull/808#discussion_r1776365343

nsmithtt commented 1 week ago

An alternative it is to not use complex at all. I guess it depends on how the tensor is used. Per my comment here it seems like we could bypass the use of TTNN complex entirely: #808 (comment)

Ah scratch that. This is forming a complex tensor, not taking one apart. I suppose we need to support complex element type in TTIR and TTNN dialects?