model = torch.compile(model, mode="reduce-overhead", fullgraph=True, dynamic=True)
where the bulk of the task is computing next token logits for different prompts (MMLU), memory usage grows until reaching OOM.
Without torch.compile, there is no OOM.
I think this happens because different prompts are different lengths, so torch.compile records different paths for each different size, which eventually leads to an OOM.
I call model.setup_caches(max_batch_size=1, max_seq_length=model.config.block_size) so the caches are not being adjusted ever.
Is there a way to set up the prompt to always be model.config.block_size tokens long, then mask out irrelevant tokens, so that there is only one path through model.forward? Or should I avoid torch.compile and setup batching to achieve a speedup?
Given a model compiled with:
where the bulk of the task is computing next token logits for different prompts (MMLU), memory usage grows until reaching OOM.
Without
torch.compile
, there is no OOM.I think this happens because different prompts are different lengths, so
torch.compile
records different paths for each different size, which eventually leads to an OOM.I call
model.setup_caches(max_batch_size=1, max_seq_length=model.config.block_size)
so the caches are not being adjusted ever.Is there a way to set up the prompt to always be
model.config.block_size
tokens long, then mask out irrelevant tokens, so that there is only one path through model.forward? Or should I avoid torch.compile and setup batching to achieve a speedup?