rasbt / LLMs-from-scratch

Implement a ChatGPT-like LLM in PyTorch from scratch, step by step
https://www.amazon.com/Build-Large-Language-Model-Scratch/dp/1633437167
Other
32.85k stars 3.95k forks source link

added std error bars #320

Closed d-kleine closed 3 months ago

d-kleine commented 3 months ago

This is just for internal testing, decline if not needed. Displays error bars in MHA implementations

review-notebook-app[bot] commented 3 months ago

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

rasbt commented 3 months ago

Thanks for adding the error bars and the other reorg-related improvements! Regarding the change from

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    # Warmup
    for _ in range(5):
        func(*input)
    torch.cuda.synchronize()

    start.record()
    for _ in range(num_repeats):

        func(*input)

        torch.cuda.synchronize()
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end) / num_repeats

to

import numpy as np

def time_pytorch_function(func, *input, num_repeats=1_000):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    # Warmup
    for _ in range(5):
        func(*input)
    torch.cuda.synchronize()

    times = []
    for _ in range(num_repeats):
        start.record()
        func(*input)
        end.record()
        torch.cuda.synchronize()
        times.append(start.elapsed_time(end))

    return np.mean(times), np.std(times)

I think this looks now incorrect. To compute the error bars, what we would want to do is to have another for-loop over this function call. Maybe "incorrect" is not the right word, but I think it's less ideal because the timing interval is so small now. But based on how small the standard deviation is, it's maybe reasonable. Anyways, let me try it the other way and see what I'll get on the A100

rasbt commented 3 months ago

I was thinking of the num iterations and num repetitions of timeit. But in this case there would virtually be no error bars then... Actually, let's keep your version 😊

d-kleine commented 3 months ago

Thanks for integrating the changes!

Idk if that's useful for other users/readers, but interesting to see that there is quite some variation in the FlexAttn repetitions, at least when not compiled grafik