We can currently use max_len parameter in decoders to avoid OOM exceptions during inference time.
However, this is not enough, e.g. when specifying batch_size in number of tokens.
For simplicity, imagine a decoder-only scenario. For example, let the token-level batch_size be 9, which barely fits into memory, and max_len=3. We can get a batch of [4, 2] (batch_size, seq_len). During inference we can easily generate a result of size [4, 3] which will cause OOM.
This PR suggests one possible solution and is open for discussion.
We can currently use max_len parameter in decoders to avoid OOM exceptions during inference time. However, this is not enough, e.g. when specifying batch_size in number of tokens.
For simplicity, imagine a decoder-only scenario. For example, let the token-level batch_size be 9, which barely fits into memory, and max_len=3. We can get a batch of [4, 2] (batch_size, seq_len). During inference we can easily generate a result of size [4, 3] which will cause OOM.
This PR suggests one possible solution and is open for discussion.