facebookresearch / encodec

State-of-the-art deep learning based audio codec supporting both mono 24 kHz audio and stereo 48 kHz audio.
MIT License
3.51k stars 304 forks source link

Motivation behind `layer_state` in `StreamingTransformerEncoder` #46

Open mxkrn opened 1 year ago

mxkrn commented 1 year ago

❓ Questions

I would like to hear some more about your motivation behind your usage of layer_state in the StreamingTransformerEncoder.

My understanding so far is that, for each layer, the previous input x_past is concatenated with x for the keys and values, not for the queries. This effectively means that the matmul between queries and keys is not just attending to itself but also to part of the inputs of the previous input x_past.

I'm not entirely sure how to interpret this and this maybe due to me not being able to introspect your training strategy. To my understanding, x and x_past should be independent token sequences, in this case it seems strange to allow the transformer to attend to a concatenation of these sequences. Alternatively, x and x_past originate from the same audio clip, in this case I don't understand why you wouldn't just increase the context length explicitly.

I tried to find other transformer implementations that do something similar and the only thing that came close to this is Transformer XL. There is a major difference however since they propagate the output of the transformer layer stack to the next step, your implementation propagates the input.

I may be missing something entirely so please excuse my ignorance in that case, nonetheless I would really appreciate it if you can shed some light on this 😇

adefossez commented 1 year ago

When we are decoding with entropy coding, we have to decode the current token before we can make sense of the rest of the bitstream, that means we cannot juste pass all tokens at one, but we have to decode them one by one.

This is the same that happens when doing generation (except for different reasons), e.g. see https://github.com/facebookresearch/llama/blob/main/llama/model.py#L132