Closed ClashLuke closed 2 years ago
Might I ask why you infer that QRNN has n*log(n) complexity ?
I'm glad you asked! The original paper implements QRNN as a for loop that needs twice as many operations whenever it doubles its sequence length and therefore gets unbearably slow for long-context models. However, after lots of work, I came up with a way of implementing their algorithm in log(n) parallel operations. https://github.com/HomebrewNLP/HomebrewNLP-Jax/blob/57c8a7bc497d540c256e78502c26b0538c52c847/src/model.py#L108-L113 The core idea of the new implementation is that you can compute the interaction of items 0 to 1 and items 1 to 2 simultaneously. Similarly, you can compute 0 to 2 and 1 to 3 simultaneously, meaning that you're essentially building up a tree of operations with a depth of log2(n).
Using the following custom gradient: https://github.com/HomebrewNLP/HomebrewNLP-Jax/blob/57c8a7bc497d540c256e78502c26b0538c52c847/src/model.py#L116-L138 We can also compute QRNN in linear memory so that the resulting model needs the same memory but much less run-time.\ Unfortunately, however, the QRNN block still is too slow to use in a fast model with a context of 2 million items. Of course, it's more feasible than the transformer, but it still is relatively slow at 0.7s runtime for 256 features, 2 million tokens and a batch size of 1. Furthermore, if we wanted to use one QRNN layer in every layer of our model, we'd quadruple the cost of one block from 200ms to 900ms, which makes training infeasible.
Do you want to take up this problem?
The core idea of the new implementation is that you can compute the interaction of items 0 to 1 and items 1 to 2 simultaneously. Similarly, you can compute 0 to 2 and 1 to 3 simultaneously, meaning that you're essentially building up a tree of operations with a depth of log2(n).
Are you expecting that during runtime, the for
loop will be unrolled ?
Note: I am not familiar with jax
, this is my first time looking at jax
coding. I do not understand how the variable forget
helps to achieve log2(n)
complexity.
Yes, the compiler unrolls the loop, but it doesn't help with complexity. You hit the nail on the head; that's precisely the issue. You had a for loop over the entire sequence in the original implementation, which is very slow. However, our implementation changes the size of the loop itself. Instead of going over every element in the sequence sequentially, we parallelise as much as possible within each step. This way, we're using the same optimisations as others do in "prefix sum" (cumsum), but apply them to QRNN.\ The forget gate you mention is the same as in the original paper and implementation. The names might differ slightly, but the algorithm is similar.
Is there a specific reason of using hard sigmoid instead of just relu ?
I had not done exact timing profiling coverage on the jax
coding yet, but in general division by non-power-of-2 is not encouraged for optimized performance.
By the way, what are the differences between f-pooling
, fo-pooling
, and ifo-pooling
?
Note: I need to study how the forget
gate works in QRNN paper first.
Your jax
coding implementation is very much different from the pseudo-code in the original QRNN paper.
Do you mind adding few more lines of code comments such that it is easier for other people to understand and check the for
loop unrolling process ?
def qrnn(ctx: Context, forget: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray:
dtype = forget.dtype
for i in range(int(math.log2(ctx.dims.sizes.sequence))):
x += jnp.concatenate([jnp.zeros((x.shape[0], 2 ** i, x.shape[2])), x[:, :-2 ** i] * forget[:, 2 ** i:]], 1)
forget *= jnp.concatenate([jnp.ones((x.shape[0], 2 ** i, x.shape[2])), forget[:, :-2 ** i]], 1)
return x.astype(dtype)
This issue is partially addressed by #50 and #52. They reduce the runtime by removing the first N iterations of the loop. On average, this will halve the execution time. This massive reduction allows us to square the sequence length at the same total cost.
Technically, the following code would work:
forget = forget.transpose(0, 2, 1)
forget = forget.reshape(batch, features, sequence // 128, 128, 1)
forget = jnp.tile(forget, (1, 1, 1, 1, 128))
forget = jnp.triu(forget)
forget = lax.cumprod(forget, 4)
This way, we'd compute seven steps at a time using matmul. Unfortunately, while this might be faster on TPU, it'd also use 128x as much memory (-> 128 GiB/intermediate/device), which is too much for us, so we'll have to find a different way.
I have a feeling that forget = lax.cumprod(forget, 4)
might have been contributing to the memory overhead since it is a cumulative product/multiplication ?
May I know why you would need cumprod()
?
The key idea here is that you transform the array of gates from something like [0.5, 0.6, 0.7, 1, 0.5]
to a square matrix like
0.5 0.3 0.21 0.21 0.105
0 0.6 0.42 0.42 0.21
0 0 0.7 0.7 0.35
0 0 0 1 0.5
0 0 0 0 0.5
Which we then multiplied into x
in one operation.\
However, as this would require O(n^2) compute and memory, we shouldn't "unfold" the entire matrix at once but only create chunks up to the size where it's still a good compute/memory tradeoff. Unfortunately, TPUs require 128 features in the matmul dimension, meaning the minimum chunk size is 128, forcing us to allocate a 128x128 "square" matrix chunk.
Due to its slow execution time and poor converge, I deprecated QRNN in favor of a hierarchical MLP Mixer. See #75 for more information.
Our long-context QRNN has n*log(n) complexity but a high constant iteration time. For every doubling of the context (~35ms), we could've also added another pointwise convolution block (~50ms). One way to reduce iteration time could be "unrolling". By running seven steps "in parallel" using matrix multiplication, we might reduce the time of a doubling to 5ms.\ This issue is about coming up with an idea and benchmarking it against our current language model.