Open jcrangel opened 11 months ago
I was because I have some out of index tokens, I have to remove them:
def truncate_batch(batch, max_length=None):
"""
To remove tokens before the padding '32000' which cause
Assertion `srcIndex < srcSelectDimSize` failed.
"""
lengths = batch['attention_mask'].sum(dim=1)
# If max_length is not provided, take the minimum of the lengths in the batch
if not max_length:
max_length = lengths.min().item()
# Slice the tensors
batch['input_ids'] = batch['input_ids'][:, :max_length]
batch['attention_mask'] = batch['attention_mask'][:, :max_length]
return batch
` ``
I'm attempting to evaluate an OpenLlama model on a test dataset. When I use single element inference, it's considerably slow, so I'm trying to utilize batching for efficiency. However, during batch inference, I'm encountering a CUDA error. Error Message
Code for Batch Inference
Additional Information
It's original a "openlm-research/open_llama_7b_v2" but I finetune it using peft. So I load the model using :
Any assistance on this issue would be greatly appreciated. Thank you in advance!