I used transformers==4.43.3 with tensor_parallel==2.0.0, and I loaded the Llama-3.1-8B-Instruct model. When I am doing inference, there is an error:
Traceback (most recent call last):
File "/home/bc20/jingwei/topk/./exploration/eval/save_vectors.py", line 151, in <module>
main(args)
File "/home/bc20/jingwei/topk/./exploration/eval/save_vectors.py", line 120, in main
evaluator.evaluate(model)
File "/home/bc20/anaconda3/envs/dejavu/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/bc20/jingwei/topk/./exploration/eval/save_vectors.py", line 57, in evaluate
_ = model(batch).logits
File "/home/bc20/anaconda3/envs/dejavu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/bc20/anaconda3/envs/dejavu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/bc20/anaconda3/envs/dejavu/lib/python3.11/site-packages/tensor_parallel/pretrained_model.py", line 76, in forward
return self.wrapped_model(*args, **kwargs)
File "/home/bc20/anaconda3/envs/dejavu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/bc20/anaconda3/envs/dejavu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/bc20/anaconda3/envs/dejavu/lib/python3.11/site-packages/tensor_parallel/tensor_parallel.py", line 159, in forward
return parallel_apply(self.module_shards, inputs, kwargs_tup, self.devices)[self.output_device_index]
File "/home/bc20/anaconda3/envs/dejavu/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 108, in parallel_apply
output.reraise()
File "/home/bc20/anaconda3/envs/dejavu/lib/python3.11/site-packages/torch/_utils.py", line 722, in reraise
raise exception
AttributeError: Caught AttributeError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/bc20/anaconda3/envs/dejavu/lib/python3.11/site-packages/torch/nn/parallel/parallel_apply.py", line 83, in _worker
output = module(*input, **kwargs)
File "/home/bc20/anaconda3/envs/dejavu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/bc20/anaconda3/envs/dejavu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/bc20/anaconda3/envs/dejavu/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1141, in forward
outputs = self.model(
File "/home/bc20/anaconda3/envs/dejavu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/bc20/anaconda3/envs/dejavu/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/home/bc20/anaconda3/envs/dejavu/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 971, in forward
next_cache = next_cache.to_legacy_cache()
AttributeError: 'tuple' object has no attribute 'to_legacy_cache'
Is there any insights to work this around? I don't want to downgrade the transformers library to 4.35, because I want to use the newest llama-3.1 model.
This is mentioned in the
transformers
library. https://github.com/huggingface/transformers/issues/28003 & https://github.com/huggingface/transformers/issues/28045I used
transformers==4.43.3
withtensor_parallel==2.0.0
, and I loaded theLlama-3.1-8B-Instruct
model. When I am doing inference, there is an error:Is there any insights to work this around? I don't want to downgrade the transformers library to
4.35
, because I want to use the newestllama-3.1
model.