glassroom / heinsen_sequence

Code implementing "Efficient Parallelization of a Ubiquitious Sequential Computation" (Heinsen, 2023)
http://arxiv.org/abs/2311.06281
MIT License
76 stars 3 forks source link

Comparison to existing algorithm #1

Closed sustcsonglin closed 11 months ago

sustcsonglin commented 11 months ago

Hello, thanks for your work. I wonder what is the difference between the proposed algorithm and sect.1.4.1 in https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf.

fheinsen commented 11 months ago

Thank you! Blelloch's formulation is more general, expressed in terms of a binary operator $\oplus$ that is associative and a second binary operator $\otimes$ that either is associative or can be made associative via a transformation (which Blelloch shows as a third binary operator).

Our formulation applies only to the most common case, real numbers, with sum as the first operator and multiplication as the second operator, making each step non-associative. We find a nice, succinct, numerically stable expression implementable with widely available, highly-optimized software tools -- e.g., PyTorch's cumsum and logcumsumexp.

As soon as I get a chance, I will add another citation to Blelloch's work in the preprint.

Thank you again for pointing this out! :-)

sustcsonglin commented 11 months ago

Thank you for your response! I would also like to mention the work of https://arxiv.org/abs/1709.04057 and the authors actually implemented the Blelloch's formulation in CUDA.

fheinsen commented 11 months ago

Thank you. I'll take a look.

fheinsen commented 10 months ago

PS. I took a quick look at https://arxiv.org/abs/1709.04057 and its associated repository. The CUDA code is specific to the RNNs discussed in the preprint and not at all optimized -- the code is littered with comments such as "TODO: parallel scan (tricky because number of blocks isn't necessarily smaller than number of warps that can fit in a single block)".

In contrast, my formulation is implementable in three lines of Python, and consists only of elementwise operations, which execute in parallel in a GPU, and two prefix sums that call Nvidia's highly-optimized, extensively-tested implementation. (To get a sense of how much work Nvidia has invested in optimizing its parallel prefix sum, see here.)

Thank you again for the link!

fheinsen commented 10 months ago

PPS. This just came across my screen: https://arxiv.org/abs/2312.00752 (repo: https://github.com/state-spaces/mamba/). Based on a first pass, the code for executing parallel scans looks well-written, alghough it's again specific to the models proposed in the paper. It remains to be seen if performance is as good as my formulation.

eamartin commented 9 months ago

Hi, I'm the first author of https://arxiv.org/abs/1709.04057 and I just stumbled across this issue. This repo is a nice implementation, especially for its simplicity and that it does not require any new CUDA kernels beyond what are already present in Pytorch. It would be nice if this implementation were able to handle 0's in the "coeffs" sequence.

I would like to respond to comments made about my work:

The CUDA code is specific to the RNNs discussed in the preprint

This is false. https://github.com/eamartin/parallelizing_linear_rnns/blob/master/linear_recurrent_net/linear_recurrence.h computes exactly the same linear recurrence as discussed in this repo: h[t]=a[t]*h[t-1] + x[t]

and not at all optimized -- the code is littered with comments such as "TODO: parallel scan...

I'd say the code is not "fully optimized" but I think "not at all optimized" is overly harsh and inaccurate. The code spends the majority of its time with each of the many warps on the GPU scanning in parallel over their partition of the input data. There are some serial scan parts that could be parallelized (this is what my TODOs refer to), but these serial sections are of small constant size (<64 steps) that is a function of only the GPU architecture and not the data size.

I'd be curious to race these implementations. The implementation here will benefit from (probably) better scan kernels (either from Pytorch or Nvidia) but will be slowed by the copies required to take logs of input tensors, the logs for the recurrence itself, the double width datatypes required for complex numbers, and that it requires 2 scans over the data rather than one.

Regardless, the simplicity of this implementation makes it a great contribution and probably a sweet-spot on easiness vs speed trade-off curve.

fheinsen commented 9 months ago

Eric -- thank you. Looking at my comment again with the benefit of hindsight, I agree with you: My comment was too harsh. I could go back and delete or edit it, but I won't. Instead, let me apologize here, on the record.

The implementation here will benefit from (probably) better scan kernels (either from Pytorch or Nvidia) but will be slowed by the copies required to take logs of input tensors, the logs for the recurrence itself, the double width datatypes required for complex numbers, and that it requires 2 scans over the data rather than one.

Yes. Something else to consider is that, given how fundamental prefix sums are, I wouldn't be surprised if Nvidia has implemented optimizations at the hardware level. The information Nvidia has made public so far indicates they have thrown a lot of resources at optimizing prefix sums for a diverse set of training and inference scenarios.

Also, note that if the $a_i$'s are dynamically computed probabilities and $x_0$ and the $b_i$'s are logits, this formulation can become significantly cheaper, if we compute the $\log a_i$'s directly with a LogSigmoid function and leave $x_0$ and the $b_i$'s unmodified. Doing so makes it unnecessary to compute any logarithms (other than the LogSigmoid's LSE) or use complex numbers.

Regardless, the simplicity of this implementation makes it a great contribution and probably a sweet-spot on easiness vs speed trade-off curve.

Yes. Benchmarking is a hard problem, so for production applications I would suggest always testing all options and using whichever performs better in the target production environment.

maximzubkov commented 9 months ago

Hey, great job, @fheinsen! I work on a similar problem as well, and approached it from slightly different angle. I utilized FFT (Fast Fourier Transformation) to reduce the number of computations required. Might be interrelating to benchmark out approaches, here is my implementation in PyTorch along with some theory in the README: https://github.com/maximzubkov/fft-scan

Might be interesting for @eamartin as well

fheinsen commented 9 months ago

I work on a similar problem as well, and approached it from slightly different angle.

Thank you. I'll take a look. If I'm interested in discussing your work further, I will respond separately, because I would like to keep the discussion here centered on the topic: Blelloch's classic solution and its implementations.

Otherwise, I wish you and everyone reading this today an enjoyable holiday season!

lezcano commented 8 months ago

Once https://github.com/openai/triton/pull/2947 is in and we expose something like torch.associative_scan, it will be possible to compute this sort of sequences (and higher order ones) with torch.compile. It should generate a kernel that's should be comparable in efficiency to using cub, but you'll be able to write it in pure Python :)

You can follow https://github.com/pytorch/pytorch/issues/50688 and the links in https://github.com/pytorch/pytorch/issues/95408#issuecomment-1857665931 to track the progress

fheinsen commented 8 months ago

Mario -- thank you for the heads-up. An API for multi-input parallel associative scan would be a great addition to PyTorch! :-) I'll keep an eye out for it. Once it's released, I will add to the README here a note about it and a link to it.