lucidrains / x-transformers

A concise but complete full-attention transformer with a set of promising experimental features from various papers
MIT License
4.78k stars 417 forks source link

Support for a mask during autoregressive generation with Key-Value Caching #292

Open Oufattole opened 3 days ago

Oufattole commented 3 days ago

Why isn't a mask supported when key-value caching is enabled here?

lucidrains commented 3 days ago

@Oufattole in what scenario would you need masking when doing autoregressive decoding?

Oufattole commented 3 days ago

I'm trying to do sliding window inference, but the lengths of the initial prompts are different in my transformer, so I think I should mask out the padding as that's what we do during autoregressive pretraining.

I'm applying transformers to medical trajectories as a part of this open source project providing ML tooling for modeling patient time-series data (where you tokenize a patient's irregularly sampled time series observations, such as medications, diagnoses, procedures, etc.). I'm interested in generating future trajectories and evaluating them. Here is the relevant code I am currently using for generating trajectories. I currently am just not caching key value pairs, so that I can apply masks, but that is prohibitively slow.

lucidrains commented 3 days ago

@Oufattole yes I see, so you are off the beaten path

sliding windows isn't supported here yet

lucidrains commented 3 days ago

@Oufattole you can do away with masking by slicing the cached key values before passing it back in

Oufattole commented 3 days ago

Ahhh I see thank you, I'll try that! With medical data, unlike in NLP and CV, many patient trajectories are very small and you don't need a long sequence length at all. For example, with my dataset 80% of patients are below the 512 max sequence length, but a small subset of patients are punching over 30k (this is after extreme reductions in the vocabulary -- i.e. which time-series variables we model, prior to which some of these patients hit over 300k).

I naively am trying to use sliding windows, but if there is a better approach you recommend for handling such extreme sequence length variations, I would be happy to try it.

Oufattole commented 3 days ago

Wait, actually, I think you do support masking the left padded tokens with the seq_start_pos arg here @lucidrains .

lucidrains commented 2 days ago

@Oufattole so that hyperparameter was actually built for variable prompt lengths iirc. i'll have to take a closer look to really know if it can be repurposed for what you are doing

during sliding window, you'll have to slice the cached key values as you decode out of the window length

lucidrains commented 2 days ago

@Oufattole what specialty is this and what exactly are you trending in the EMR that hits 300k in length?

Oufattole commented 2 days ago

Yes, I think you already do this kv-cache slicing during generation here when restricting to the max_seq_length (i.e. in the sliding window setting). Am I correct about this?

I'll send you an email in regard to the broader EHR modeling question, which I realize may be out of scope for this github issue.

lucidrains commented 2 days ago

@Oufattole it has been a while, let me review it tomorrow morning and see if it can be made to work for your issue