Open Oufattole opened 3 days ago
@Oufattole in what scenario would you need masking when doing autoregressive decoding?
I'm trying to do sliding window inference, but the lengths of the initial prompts are different in my transformer, so I think I should mask out the padding as that's what we do during autoregressive pretraining.
I'm applying transformers to medical trajectories as a part of this open source project providing ML tooling for modeling patient time-series data (where you tokenize a patient's irregularly sampled time series observations, such as medications, diagnoses, procedures, etc.). I'm interested in generating future trajectories and evaluating them. Here is the relevant code I am currently using for generating trajectories. I currently am just not caching key value pairs, so that I can apply masks, but that is prohibitively slow.
@Oufattole yes I see, so you are off the beaten path
sliding windows isn't supported here yet
@Oufattole you can do away with masking by slicing the cached key values before passing it back in
Ahhh I see thank you, I'll try that! With medical data, unlike in NLP and CV, many patient trajectories are very small and you don't need a long sequence length at all. For example, with my dataset 80% of patients are below the 512 max sequence length, but a small subset of patients are punching over 30k (this is after extreme reductions in the vocabulary -- i.e. which time-series variables we model, prior to which some of these patients hit over 300k).
I naively am trying to use sliding windows, but if there is a better approach you recommend for handling such extreme sequence length variations, I would be happy to try it.
Wait, actually, I think you do support masking the left padded tokens with the seq_start_pos arg here @lucidrains .
@Oufattole so that hyperparameter was actually built for variable prompt lengths iirc. i'll have to take a closer look to really know if it can be repurposed for what you are doing
during sliding window, you'll have to slice the cached key values as you decode out of the window length
@Oufattole what specialty is this and what exactly are you trending in the EMR that hits 300k in length?
Yes, I think you already do this kv-cache slicing during generation here when restricting to the max_seq_length (i.e. in the sliding window setting). Am I correct about this?
I'll send you an email in regard to the broader EHR modeling question, which I realize may be out of scope for this github issue.
@Oufattole it has been a while, let me review it tomorrow morning and see if it can be made to work for your issue
Why isn't a mask supported when key-value caching is enabled here?