Closed AleHD closed 1 year ago
Tested with llama2, tp=4, pp=1 on two 8x 80GB A100 nodes (dp=4)
Good question. It actually seems to be like 8% faster under normal circumstances (tp=4, pp=1, dp=4, 2 nodes w/ 8x 80GB A100; micro=5, global=100). Previous build stabilizes around 12.8 sec/iter, removing torchscript increases performance to 11.8 sec/iter :)
Turns out that for some reason, using torch.jit on the glu activation was the culprit. Removing this seems to totally fix the problem during my tests.