tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
452 stars 65 forks source link

[Bug Report] Binary op with scalar is very slow #13644

Open dmakoviichuk-tt opened 2 weeks ago

dmakoviichuk-tt commented 2 weeks ago

Describe the bug Every time we call binary op with scalar we create tensor from this scalar and then also call ttnn::repeat:

template <BinaryOpType binary_op_type>
Tensor BinaryOperation<binary_op_type>::invoke(
    uint8_t queue_id,
    const ttnn::Tensor &input_tensor_a,
    const float scalar,
    const std::optional<const DataType> &dtype,
    const std::optional<ttnn::MemoryConfig> &memory_config,
    const std::optional<Tensor> &optional_output_tensor,
    std::optional<unary::FusedActivations> activations,
    std::optional<unary::UnaryWithParam> input_tensor_a_activation) {
    using namespace tt::constants;
    // Cast Float Scalar to a device tensor
    auto host_buffer = owned_buffer::create<::bfloat16>(static_cast<std::size_t>(TILE_HEIGHT * TILE_WIDTH));
    host_buffer[0] = scalar;
    Tensor scalar_tensor_host = Tensor(
        OwnedStorage{host_buffer},
        ttnn::Shape(std::array<std::uint32_t, 2>{1, 1}, std::array<std::uint32_t, 2>{TILE_HEIGHT, TILE_WIDTH}),
        DataType::BFLOAT16,
        Layout::TILE);
    Tensor scalar_tensor_device = scalar_tensor_host.to(input_tensor_a.device());
    // TODO(arakhmati): #7637 pass in memory_config instead of operation::DEFAULT_OUTPUT_MEMORY_CONFIG
    return BinaryOperation::invoke(
        input_tensor_a,
        scalar_tensor_device,
        dtype,
        memory_config,
        optional_output_tensor,
        activations,
        input_tensor_a_activation);
}

We are using the in the optimizer step for each layer https://github.com/tenstorrent/TT-Tron/blob/main/sources/ttml/optimizers/sgd.cpp.

SGD performance is 10 times slower than pytorch cpu version. To Reproduce Just run any binary op with tensor and scalar.

Expected behavior Scalar parameter should be passed as runtime arg to the program. We should never create a new tensor from cpu from every call.

Additional context @eyonland I assigned this ticket to you as elementwise owner. My expectation that you can drive it to the LLK team and make sure they and your team can add needed changes in both metal and ttnn level. If you cannot do it for some reason please let me know I'll find a new owner. It significantly reduces performance of our training code.

umadevimcw commented 2 weeks ago

@dmakoviichuk-tt Can you provide details on how you collected the performance?

eyonland commented 2 weeks ago

@dmakoviichuk-tt , my assumption is that you swapped out the ttnn function with a direct pytorch function that runs on host and saw the overall perf difference. If during training a small tensor is called multiple times I can see how the CPU branching predictions would be blazingly fast compared to pushing the tensor on and off device. Did you measure it by overall performance of the training time?

eyonland commented 2 weeks ago

@dmakoviichuk-tt , what was the size of the Tensor?

dmakoviichuk-tt commented 1 week ago

@umadevimcw with timer. @eyonland it doesn't matter. As I mentioned in optimizer we need to multiply gradients by scalars. Gradients have shape of the weights so it could be like (1,1, 512, 1024). But we are using this ops not only with gradients, in this case shape could be: (64,1,256,2048) for example.

dmakoviichuk-tt commented 1 week ago

@eyonland it is obviously really bad and slow code for the very simple operation like this. Why ask questions like that? I've already demonstrated and showed two problems why it is so slow.

@dmakoviichuk-tt , my assumption is that you swapped out the ttnn function with a direct pytorch function that runs on host and saw the overall perf difference. You assumption is wrong in all possible ways. How can I swap something with pytorch call if I don't use pytorch?

Please be respectful to your colleagues. Right now it looks like you are trying to avoid fixing that obvious issue!

eyonland commented 1 week ago

Sorry for the misunderstanding here. I was trying to figure out how you measured it originally.

We absolutely should be passing a scalar as a runtime arg and never ever create a tensor. My time has been stretched thin on this issue and as well as rebuilding eltwise ops to properly handle broadcasting given that bcast does not adequately do this and the use of repeat is absolutely terrible given we make multiple calls.