replicate / cog-triton

A cog implementation of Nvidia's Triton server
Apache License 2.0
11 stars 0 forks source link

ensure that max/min new tokens doesn't exceed max seq len #30

Closed joehoover closed 4 months ago

joehoover commented 4 months ago

This PR adds validation for min/max new tokens so that users cannot accidentally try to generate a number of tokens that will exceed a model's maximum sequence length. If users do do that, the trt-llm engine crashes.

I attempted to solve this by simply tokenizing in predict.py and then passing the prompt tokens to Triton. However, I could not figure out how to serialize/format a list of ints so that Triton would accept them.

Accordingly, we tokenize in predict.py and preprocessing/1/model.py. This is inefficient and should be fixed. But, it also only takes 0.04 seconds to tokenize a string that consists of 50k tokens, so I think it is acceptable for now.