tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
471 stars 73 forks source link

[Bug Report] MorehAdamW loses performance overtime #14546

Open dmakoviichuk-tt opened 1 week ago

dmakoviichuk-tt commented 1 week ago

Describe the bug When you use it for a long time with large step value it become significantly slower.

To Reproduce Steps to reproduce the behavior: Compare perf with step == 1 vs step == 1000.

Expected behavior It should same performance for same shapes. Screenshots If applicable, add screenshots to help explain your problem.

Please complete the following environment information:

Additional context So you call power inside of your kernel:

// bias_correction1 = 1 - pow(beta1, step);
        // cb_tmp2 = pow(beta1, step);
        tile_regs_acquire();
        cb_reserve_back(cb_tmp2, onetile);
        copy_tile_init_with_dt(cb_scalar_args);
        copy_tile(cb_scalar_args, beta1_tile, dst0);
        power_tile_init();
        power_tile(dst0, step);
        tile_regs_commit();

But our power implementation has linear over time complexity:

#pragma GCC unroll 8
    for (int d = 0; d < 8; d++) {
        vFloat in = dst_reg[0];
        vFloat result = 1.0f;
        for (uint i = 0; i < exponent; i++) {
            result *= in;
        }
        dst_reg[0] = result;
        dst_reg++;
    }
}

As result op runs slower and slower over time.

But you don't need to call pow for beta1 and beta2 inside of the kernel at all. you can calculate it only once before: curr_beta1 = std::pow(beta1, step);

My suggestion: remove exponentiation code from the kernel, do not pass step to the kernel, calculate current betas_exponents using step on host and just pass them to the kernel. There will be beta1, beta1_exponent, beta2, beta2_exponents. And don't forget to check that new betas will not be a part of the hash :)

DuongQLee commented 1 week ago

@dmakoviichuk-tt Could you please provide additional context on how the performance test is conducted? Are you running the test sequentially from step 1 through step 1000, or are you iterating over each step multiple times? Additionally, could you confirm if enable_program_cache is included if you are using pytest?

I’m attempting to reproduce the issues you noted but haven’t encountered the same bugs. If possible, could you share a snippet of your test code?

Cc @mrshaw01

dmakoviichuk-tt commented 1 week ago

@DuongQLee don't need to do it. Issue is in the power op. If you just compare perf when you pass step = 1 and step = 10000 you will see that second will be much slower. It is because pow llk does pow in cycle. But you can remove it at all as I suggested to do. Please read the additional context part. Could you confirm please that you were passing different step as input and didn't see performance changes?

DuongQLee commented 5 days ago

I have confirmed the performance loss by profiling with tracy. I will proceed with your suggestion. Thank you.

DuongQLee commented 2 days ago

@dmakoviichuk-tt I have pushed my fix according to your suggestions to branch duong/fix_moreh_adamw. Changes:

I have confirmed the performance by profiling with Tracy and verify that the performance does not change with different step values. Link to PR: https://github.com/tenstorrent/tt-metal/pull/14927

CC: @mrshaw01