HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
46 stars 6 forks source link

Faster QRNN #7

Closed ClashLuke closed 2 years ago

ClashLuke commented 2 years ago

Our long-context QRNN has n*log(n) complexity but a high constant iteration time. For every doubling of the context (~35ms), we could've also added another pointwise convolution block (~50ms). One way to reduce iteration time could be "unrolling". By running seven steps "in parallel" using matrix multiplication, we might reduce the time of a doubling to 5ms.\ This issue is about coming up with an idea and benchmarking it against our current language model.

buttercutter commented 2 years ago

Might I ask why you infer that QRNN has n*log(n) complexity ?

ClashLuke commented 2 years ago

I'm glad you asked! The original paper implements QRNN as a for loop that needs twice as many operations whenever it doubles its sequence length and therefore gets unbearably slow for long-context models. However, after lots of work, I came up with a way of implementing their algorithm in log(n) parallel operations. https://github.com/HomebrewNLP/HomebrewNLP-Jax/blob/57c8a7bc497d540c256e78502c26b0538c52c847/src/model.py#L108-L113 The core idea of the new implementation is that you can compute the interaction of items 0 to 1 and items 1 to 2 simultaneously. Similarly, you can compute 0 to 2 and 1 to 3 simultaneously, meaning that you're essentially building up a tree of operations with a depth of log2(n).

Using the following custom gradient: https://github.com/HomebrewNLP/HomebrewNLP-Jax/blob/57c8a7bc497d540c256e78502c26b0538c52c847/src/model.py#L116-L138 We can also compute QRNN in linear memory so that the resulting model needs the same memory but much less run-time.\ Unfortunately, however, the QRNN block still is too slow to use in a fast model with a context of 2 million items. Of course, it's more feasible than the transformer, but it still is relatively slow at 0.7s runtime for 256 features, 2 million tokens and a batch size of 1. Furthermore, if we wanted to use one QRNN layer in every layer of our model, we'd quadruple the cost of one block from 200ms to 900ms, which makes training infeasible.

Do you want to take up this problem?

buttercutter commented 2 years ago

The core idea of the new implementation is that you can compute the interaction of items 0 to 1 and items 1 to 2 simultaneously. Similarly, you can compute 0 to 2 and 1 to 3 simultaneously, meaning that you're essentially building up a tree of operations with a depth of log2(n).

Are you expecting that during runtime, the for loop will be unrolled ?

Note: I am not familiar with jax, this is my first time looking at jax coding. I do not understand how the variable forget helps to achieve log2(n) complexity.

ClashLuke commented 2 years ago

Yes, the compiler unrolls the loop, but it doesn't help with complexity. You hit the nail on the head; that's precisely the issue. You had a for loop over the entire sequence in the original implementation, which is very slow. However, our implementation changes the size of the loop itself. Instead of going over every element in the sequence sequentially, we parallelise as much as possible within each step. This way, we're using the same optimisations as others do in "prefix sum" (cumsum), but apply them to QRNN.\ The forget gate you mention is the same as in the original paper and implementation. The names might differ slightly, but the algorithm is similar.

buttercutter commented 2 years ago

Is there a specific reason of using hard sigmoid instead of just relu ?

I had not done exact timing profiling coverage on the jax coding yet, but in general division by non-power-of-2 is not encouraged for optimized performance.

By the way, what are the differences between f-pooling, fo-pooling, and ifo-pooling ?

Note: I need to study how the forget gate works in QRNN paper first.

buttercutter commented 2 years ago

Your jax coding implementation is very much different from the pseudo-code in the original QRNN paper.

Do you mind adding few more lines of code comments such that it is easier for other people to understand and check the for loop unrolling process ?

 def qrnn(ctx: Context, forget: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray: 
     dtype = forget.dtype 
     for i in range(int(math.log2(ctx.dims.sizes.sequence))): 
         x += jnp.concatenate([jnp.zeros((x.shape[0], 2 ** i, x.shape[2])), x[:, :-2 ** i] * forget[:, 2 ** i:]], 1) 
         forget *= jnp.concatenate([jnp.ones((x.shape[0], 2 ** i, x.shape[2])), forget[:, :-2 ** i]], 1) 
     return x.astype(dtype) 

image

ClashLuke commented 2 years ago

This issue is partially addressed by #50 and #52. They reduce the runtime by removing the first N iterations of the loop. On average, this will halve the execution time. This massive reduction allows us to square the sequence length at the same total cost.

ClashLuke commented 2 years ago

Technically, the following code would work:

forget = forget.transpose(0, 2, 1)
forget = forget.reshape(batch, features, sequence // 128, 128, 1)
forget = jnp.tile(forget, (1, 1, 1, 1, 128))
forget = jnp.triu(forget)
forget = lax.cumprod(forget, 4)

This way, we'd compute seven steps at a time using matmul. Unfortunately, while this might be faster on TPU, it'd also use 128x as much memory (-> 128 GiB/intermediate/device), which is too much for us, so we'll have to find a different way.

buttercutter commented 2 years ago

I have a feeling that forget = lax.cumprod(forget, 4) might have been contributing to the memory overhead since it is a cumulative product/multiplication ?

May I know why you would need cumprod() ?

ClashLuke commented 2 years ago

The key idea here is that you transform the array of gates from something like [0.5, 0.6, 0.7, 1, 0.5] to a square matrix like

0.5  0.3  0.21  0.21  0.105
0    0.6  0.42  0.42  0.21
0    0    0.7   0.7   0.35
0    0    0     1     0.5
0    0    0     0     0.5

Which we then multiplied into x in one operation.\ However, as this would require O(n^2) compute and memory, we shouldn't "unfold" the entire matrix at once but only create chunks up to the size where it's still a good compute/memory tradeoff. Unfortunately, TPUs require 128 features in the matmul dimension, meaning the minimum chunk size is 128, forcing us to allocate a 128x128 "square" matrix chunk.

ClashLuke commented 2 years ago

Due to its slow execution time and poor converge, I deprecated QRNN in favor of a hierarchical MLP Mixer. See #75 for more information.