I was attempting to run this on Windows (not WSL) and was not able to get it running. After some troubleshooting here were the main takeaways:
To prevent dependency issues it was best to install torch before installing the rest of the requirements
The PyPI distribution of flash-attn is not compatible with windows, to circumvent that I used this fork by jllllll
The fork unfortunately only supports up to v2.4.2 (not sure if the version is meant to match that of the official flask-attn package) which is incompatible with the usage of transformers in this implementation.
Conclusion: need a flash-attn version compatible with Windows and transformers as used in this implementation
Logs from running run_phi3.ps1 (copy of run_phi3.sh for powershell):
(.venv) PS C:\Notes\writing-in-the-margins> .\run_phi3.ps1
Parameters: {'model_id': 'microsoft/Phi-3-medium-128k-instruct', 'attn_implementation':
'flash_attention_2', 'input_file': 'babilong_64k.json', 'user_header': '<|user|>\n',
'generation_header': '<|end|>\n<|assistant|>\n', 'dtype': 'bfloat16',
'min_tokens_segment': 4096, 'max_new_tokens_extractive_summary': 100,
'max_new_tokens_final_answer': 50, 'max_new_tokens_classification': 10, 'do_sample': True,'top_p': 0.9, 'temperature': 1.0, 'early_stopping': True, 'print_step_summary': True}
Loading checkpoint shards: 100%|███████████████████████████| 6/6 [00:00<00:00, 13.64it/s]
Some parameters are on the meta device device because they were offloaded to the disk and
cpu.
Number of segments in the context: 16
Traceback (most recent call last):
File "C:\Notes\writing-in-the-margins\run.py", line 281, in <module>
fire.Fire(main)
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\fire\core.py", line 143, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\fire\core.py", line 477, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\fire\core.py", line 693, i component = fn(*varargs, **kwargs)
File "C:\Notes\writing-in-the-margins\run.py", line 147, in main
File "C:\Notes\writing-in-the-margins\wim.py", line 155, in prefill_text_with_kv_cache
outputs = self._prefill_tokens(input_ids, attention_mask, cache_positions, kv_cache)
File "C:\Notes\writing-in-the-margins\wim.py", line 43, in _prefill_tokens
outputs = self.model(
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\accelerate\hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\transformers\models\phi3\modeling_phi3.py", line 1203, in forward
outputs = self.model(
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\transformers\models\phi3\modeling_phi3.py", line 998, in forward
layer_outputs = decoder_layer(
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\accelerate\hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\transformers\models\phi3\modeling_phi3.py", line 735, in forward
attn_outputs, self_attn_weights, present_key_value = self.self_attn(
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\accelerate\hooks.py", line 169, in new_forward
output = module._old_forward(*args, **kwargs)
File "C:\Notes\writing-in-the-margins\.venv\lib\site-packages\transformers\models\phi3\modeling_phi3.py", line 560, in forward
attn_output = _flash_attention_forward(
NameError: name '_flash_attention_forward' is not defined
I was attempting to run this on Windows (not WSL) and was not able to get it running. After some troubleshooting here were the main takeaways:
torch
before installing the rest of the requirementsflash-attn
is not compatible with windows, to circumvent that I used this fork by jllllllflask-attn
package) which is incompatible with the usage oftransformers
in this implementation.Conclusion: need a
flash-attn
version compatible with Windows andtransformers
as used in this implementationLogs from running
run_phi3.ps1
(copy ofrun_phi3.sh
for powershell):