Closed d-kleine closed 3 months ago
Check out this pull request on
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
Thanks for adding the error bars and the other reorg-related improvements! Regarding the change from
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
# Warmup
for _ in range(5):
func(*input)
torch.cuda.synchronize()
start.record()
for _ in range(num_repeats):
func(*input)
torch.cuda.synchronize()
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / num_repeats
to
import numpy as np
def time_pytorch_function(func, *input, num_repeats=1_000):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
# Warmup
for _ in range(5):
func(*input)
torch.cuda.synchronize()
times = []
for _ in range(num_repeats):
start.record()
func(*input)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
return np.mean(times), np.std(times)
I think this looks now incorrect. To compute the error bars, what we would want to do is to have another for-loop over this function call. Maybe "incorrect" is not the right word, but I think it's less ideal because the timing interval is so small now. But based on how small the standard deviation is, it's maybe reasonable. Anyways, let me try it the other way and see what I'll get on the A100
I was thinking of the num iterations and num repetitions of timeit. But in this case there would virtually be no error bars then... Actually, let's keep your version 😊
Thanks for integrating the changes!
Idk if that's useful for other users/readers, but interesting to see that there is quite some variation in the FlexAttn repetitions, at least when not compiled
This is just for internal testing, decline if not needed. Displays error bars in MHA implementations