pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.36k stars 485 forks source link

torch.compile leads to OOM with different prompts. #81

Open samuelstevens opened 6 months ago

samuelstevens commented 6 months ago

Given a model compiled with:

model = torch.compile(model, mode="reduce-overhead", fullgraph=True, dynamic=True)

where the bulk of the task is computing next token logits for different prompts (MMLU), memory usage grows until reaching OOM.

Without torch.compile, there is no OOM.

I think this happens because different prompts are different lengths, so torch.compile records different paths for each different size, which eventually leads to an OOM.

I call model.setup_caches(max_batch_size=1, max_seq_length=model.config.block_size) so the caches are not being adjusted ever.

Is there a way to set up the prompt to always be model.config.block_size tokens long, then mask out irrelevant tokens, so that there is only one path through model.forward? Or should I avoid torch.compile and setup batching to achieve a speedup?