HomebrewNLP / Olmax

HomebrewNLP in JAX flavour for maintable TPU-Training
BSD 2-Clause "Simplified" License
45 stars 6 forks source link

Hierarchical Network #50

Closed ClashLuke closed 1 year ago

ClashLuke commented 2 years ago

Our network has a sequential structure through which it passes all messages. This structure implies that we always have a dense computation and assume that all features that are far away are less critical than closer ones. In reality, we might want to pass a "concept" or "general embedding" more quickly through the network than the local information.\ We could achieve exactly that by using a hierarchical network, as proposed in Clockwork RNN. Some parts with a higher "clock rate" would propagate information more locally (1 -> 2 -> 3), while layers with a lower clock rate would work with global contexts (1 -> 3 -> 5). This hierarchy can be put to any depth and adds another hyperparameter. This issue aims to implement such a hierarchical network and benchmark it against the baseline without hierarchy in a long-context setting.

ClashLuke commented 2 years ago

One of the most significant constraints in implementing such an architecture is that it cannot use more memory than our current model.\ To implement this, it seems like the most sensible approach would be to reshape parts of the sequence dimension into the batch dimension like x.reshape(batch, sequence, clock_rate, features).transpose(0, 2, 1, 3).reshape(-1, sequence, features) and do the inverse operation before adding to the residual stream. We'd suddenly compute something akin to dilated convolution and "dilated RNN" without explicitly using the slow primitives.\ Another option would be to introduce pooling and computing zero_pad(layer(avgpool(x, size=clock_rate)), (clock_rate - 1, 0)). This pooled method would reduce memory and time consumption drastically. Still, it might also decrease performance as timestep [4, 5, 6, 7] would all map to the same state after average pooling, even though 7 would be the only one influenced by it. [4, 5, 6] would need to use the previous state generated by [0, 1, 2, 3] to avoid information leakage.

ClashLuke commented 2 years ago

I've started working on this in the hierarchical-network branch.

I propose using the first approach and running different clock rates at every layer to solve the problem above. This would mean that we might have a layer with dilation of 1 followed by one with dilation of 2, 4, 8, etc. One of the biggest problems here is to decide what schedule we want to use.\ Ideally, we would keep at least some form of a locality bias, especially in the last layers, to ensure local coherence. However, the middle layers might want to explore the global context to improve consistency. That's why I'd propose to change the "dilation" after every block in a schedule of 1-2-4-8-4-2-1-1 (for a depth of 8).\ Another approach could be to combine local and global interactions to ensure features can still be propagated flawlessly. This way, we'd arrive at a schedule like 1-2-1-4-1-8-1-1.

However, if we use either of these schedules, we'd exponentially increase clock rates. Such an increase would imply that a 64-block model would have a maximal clock rate of 2^31 in its middle (or second-to-last) block. We also have to put a maximum on its cycle length. This maximum will likely depend on the current context length. Something like 16 would already give us a context of 1600 with a single block. This would be way too much for classical language modeling with a context of 2048, whereas our long-context model from #33 would need to go up to 4096 to utilize the global context fully. By using 13 increases, each super-block would see over 819,100 tokens.

Additionally, considering that each block is already compromised of 3 (or 4, in the case of QRNN-blocks), we could use that to retain local and global connections. One straightforward approach would be to compute only the long-context ("bottleneck") convolution with dilation and the QRNN. However, taking this one step further, we can see that using three dilated convolutions might not be ideal. Instead, we could use two local convolutions and one dilated convolution in the bottleneck block to maximally utilize global and local context. If we wouldn't "split" it, all three convolutions would act on essentially the same inputs, which might reduce the capacity of our model. Indeed, by just going on with this thought, we conclude that the "base dilation" for the intermediate layer of our convolution should be precisely the size of the preceding convolution kernel.\ By adding this dilation to the intermediate layer of our bottleneck convolution, we'd effectively increase the base context size of our model from 119 to 1225 per block.

After adding this inner dilation, we can't add a modeling hierarchy to our convolutions anymore as the context size is already massive. However, the QRNN remains un-dilated, and Clockwork RNN explicitly only added it to its RNNs. That's why I believe we should only add dilation to our RNN blocks. However, the schedule and other problems mentioned above still apply.\ The simplest solution could be a configurable minimal context size (for example, 16) up to which our RNN goes in exponentially increasing fashion with resets. This way, our model would have a clock rate of 1-2-4-1-2-4 for a context of 64.

I'll implement these changes, and we'll see how well it works.

ClashLuke commented 2 years ago

Thanks to our logarithmic QRNN implementation, this is trivial to implement. All we have to do is start the range at a number higher than 0: https://github.com/HomebrewNLP/HomebrewNLP-Jax/blob/1469d690b005028501f61c5b8d6cafd7fdab499d/src/model.py#L102-L107 This alteration means we increase the minimal step size from 2^0 to 2^start_step. Changing the smallest step, the QRNN will take less time to run, which means it's a net positive.

ClashLuke commented 2 years ago

Addressed by #52. For more comments, see the PR.

ClashLuke commented 1 year ago

Addressed by #75