Thanks for the great work. When I pull the current version and test it, I get the following error. If I fix only the part where the error occurs, the error continues to occur in other parts. I would be very grateful if you could push a version without the error.
MistralForCausalLM using self extend using MistralModel attention implementation: self_extend Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.23it/s] 61%|████████████████████████████████████████████████████████████████▌ | 112/184 [00:04<00:02, 27.09it/s] Traceback (most recent call last): File "/home/deephigh/git/mistral-7b/test_solar.py", line 39, in <module> generated_tokens = model.generate(**inputs, max_length=19264, pad_token_id=tokenizer.eos_token_id) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/transformers/generation/utils.py", line 1718, in generate return self.greedy_search( File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/transformers/generation/utils.py", line 2579, in greedy_search outputs = self( File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1222, in forward outputs = self.model( File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1106, in forward layer_outputs = decoder_layer( File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 819, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 356, in forward g_q = self.apply_pos_emcode(q, s_g_pos, seq_len, self.dim, device) TypeError: SelfExtendMistralAttention.apply_pos_emcode() takes 4 positional arguments but 6 were given
Thanks for the great work. When I pull the current version and test it, I get the following error. If I fix only the part where the error occurs, the error continues to occur in other parts. I would be very grateful if you could push a version without the error.
MistralForCausalLM using self extend using MistralModel attention implementation: self_extend Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00, 1.23it/s] 61%|████████████████████████████████████████████████████████████████▌ | 112/184 [00:04<00:02, 27.09it/s] Traceback (most recent call last): File "/home/deephigh/git/mistral-7b/test_solar.py", line 39, in <module> generated_tokens = model.generate(**inputs, max_length=19264, pad_token_id=tokenizer.eos_token_id) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/transformers/generation/utils.py", line 1718, in generate return self.greedy_search( File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/transformers/generation/utils.py", line 2579, in greedy_search outputs = self( File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1222, in forward outputs = self.model( File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1106, in forward layer_outputs = decoder_layer( File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 819, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) File "/home/deephigh/miniconda3/envs/mistral_16k/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 356, in forward g_q = self.apply_pos_emcode(q, s_g_pos, seq_len, self.dim, device) TypeError: SelfExtendMistralAttention.apply_pos_emcode() takes 4 positional arguments but 6 were given