allenai / OLMo

Modeling, training, eval, and inference code for OLMo
https://allenai.org/olmo
Apache License 2.0
4.48k stars 449 forks source link

Attention and FF should run in parallel (sort of) #18

Closed dirkgr closed 1 year ago

dirkgr commented 1 year ago

The Palm paper says it improves throughput, and doesn't slow learning, at least not at large scales.

dirkgr commented 1 year ago

I would love to ablate this one with a few B of tokens.

epwalsh commented 1 year ago

It's not clear to me that this would be beneficial at our scale of 70B. If we have to resort to model parallelism it would definitely help, otherwise I don't think so. Assuming either of these two operations (MLP(LayerNorm(x)) and Attention(LayerNorm(x))) saturate the GPU, which they probably will, then there's really no way to get them to run in parallel on a single device.

CUDA steams is one way we could try to get these ops to run in parallel on a single GPU, but like I said they won't actually run in parallel if either operation saturates the GPU. There's a good discussion about this in this thread, in particular see this comment.

dirkgr commented 1 year ago

Is it true that we'd have to do the parallelization manually, with streams? I thought lazy execution would sort us out there.

epwalsh commented 1 year ago

I'm not sure about that. The PyTorch docs say this about async execution:

This allows us to execute more computations in parallel, including operations on CPU or other GPUs.

In general, the effect of asynchronous computation is invisible to the caller, because (1) each device executes operations in the order they are queued, and (2) PyTorch automatically performs necessary synchronization when copying data between CPU and GPU or between two GPUs. Hence, computation will proceed as if every operation was executed synchronously. ... A CUDA stream is a linear sequence of execution that belongs to a specific device. You normally do not need to create one explicitly: by default, each device uses its own “default” stream.

It sounds like independent operations on the same stream will always be executed in sequence, albeit asynchronously.

epwalsh commented 1 year ago

After reading through the PaLM source code again I realized there is a way to get these to run in parallel (sort of). The key is to fused together the query, key, value projections and MLP-FF ops into one.

https://github.com/lucidrains/PaLM-pytorch/blob/7164d13d5a831647edb5838544017f387130f987/palm_pytorch/palm_pytorch.py#L130