hltcoe / sandle

Run a large language modeling SANDbox in your Local Environment
Other
7 stars 1 forks source link

Handle prompt-too-long error #61

Open ccmaymay opened 2 years ago

ccmaymay commented 2 years ago

From @nweir127:

If I feed the API text that is longer than the acceptable context window (2048 tokens) , the generator post processor currently crashes

A nice feature of the API would be to either give an appropriate error response or to perform front truncation

And a code snippet:

input_token_ids = tokenizer(text, return_tensors='pt')['input_ids']
max_context = model.config.n_positions - max_new_tokens
if hasattr(model.config, "n_positions") and input_token_ids.shape[1] > max_context:
    input_token_ids = input_token_ids[:, -max_context:]
    text = tokenizer.batch_decode(input_token_ids)[0]
    input_token_ids = tokenizer(text, return_tensors='pt')['input_ids']