Get rid of the Perceiver. Instead just have a sequence of nn.TransformerEncoderLayers (stacked on top of eachother using same mechanism as in Perceiver: maybe create a separate nn.Module called MultiLayerTransformerEncoder which abstracts that out.)
Stack (byte_array, query).
Then the final output = attn_output[:, len(byte_array):]
Still TODO:
[x] Somehow allow us to use more heads. Need to make the byte_array a shape that can be exactly divided by num_heads
I think it should be pretty simple:
Get rid of the
Perceiver
. Instead just have a sequence ofnn.TransformerEncoderLayers
(stacked on top of eachother using same mechanism as inPerceiver
: maybe create a separatenn.Module
calledMultiLayerTransformerEncoder
which abstracts that out.)Stack (byte_array, query).
Then the final output =
attn_output[:, len(byte_array):]
Still TODO: