mistralai / mistral-finetune

Apache License 2.0
2.45k stars 164 forks source link

Mistral-Finetune creates consolidated.safetensors for mixtral 8x7b instruct v0.1 but mistral-chat fails inference for it complains about LoRA weights file being loaded missing an expected key for one of the model layers. #75

Open tensimixt opened 5 days ago

tensimixt commented 5 days ago

Python Version

Python 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]

Pip Freeze

absl-py==2.1.0
annotated-types==0.7.0
anyio==4.0.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
attrs==23.1.0
Babel==2.13.1
beautifulsoup4==4.12.2
bleach==6.1.0
blinker==1.4
certifi==2022.12.7
cffi==1.16.0
charset-normalizer==2.1.1
click==8.1.7
cmake==3.25.0
comm==0.1.4
cryptography==3.4.8
dbus-python==1.2.18
debugpy==1.8.0
decorator==5.1.1
defusedxml==0.7.1
distro==1.7.0
docker-pycreds==0.4.0
docstring_parser==0.16
entrypoints==0.4
exceptiongroup==1.1.3
executing==2.0.1
fastjsonschema==2.18.1
filelock==3.9.0
fire==0.6.0
fqdn==1.5.1
fsspec==2024.6.1
gitdb==4.0.11
GitPython==3.1.43
grpcio==1.64.1
httplib2==0.20.2
idna==3.4
importlib-metadata==4.6.4
ipykernel==6.26.0
ipython==8.17.2
ipython-genutils==0.2.0
ipywidgets==8.1.1
isoduration==20.11.0
jedi==0.19.1
jeepney==0.7.1
Jinja2==3.1.2
json5==0.9.14
jsonpointer==2.4
jsonschema==4.21.1
jsonschema-specifications==2023.7.1
jupyter-archive==3.4.0
jupyter-contrib-core==0.4.2
jupyter-contrib-nbextensions==0.7.0
jupyter-events==0.8.0
jupyter-highlight-selected-word==0.2.0
jupyter-lsp==2.2.0
jupyter-nbextensions-configurator==0.6.3
jupyter_client==7.4.9
jupyter_core==5.5.0
jupyter_server==2.9.1
jupyter_server_terminals==0.4.4
jupyterlab==4.0.8
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.9
jupyterlab_server==2.25.0
keyring==23.5.0
launchpadlib==1.10.16
lazr.restfulclient==0.14.4
lazr.uri==1.0.6
lit==15.0.7
lxml==4.9.3
Markdown==3.6
MarkupSafe==2.1.2
matplotlib-inline==0.1.6
mistral_common==1.2.1
mistral_inference==1.1.0
mistune==3.0.2
more-itertools==8.10.0
mpmath==1.3.0
nbclassic==1.0.0
nbclient==0.8.0
nbconvert==7.10.0
nbformat==5.9.2
nest-asyncio==1.5.8
networkx==3.0
notebook==6.5.5
notebook_shim==0.2.3
numpy==1.24.1
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.5.82
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.0
overrides==7.4.0
packaging==23.2
pandocfilters==1.5.0
parso==0.8.3
pexpect==4.8.0
Pillow==9.3.0
platformdirs==3.11.0
prometheus-client==0.18.0
prompt-toolkit==3.0.39
protobuf==4.25.3
psutil==5.9.6
ptyprocess==0.7.0
pure-eval==0.2.2
pycparser==2.21
pydantic==2.6.1
pydantic_core==2.16.2
Pygments==2.16.1
PyGObject==3.42.1
PyJWT==2.3.0
pyparsing==2.4.7
python-apt==2.4.0+ubuntu3
python-dateutil==2.8.2
python-json-logger==2.0.7
PyYAML==6.0.1
pyzmq==24.0.1
referencing==0.30.2
requests==2.31.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.10.6
safetensors==0.4.3
SecretStorage==3.3.1
Send2Trash==1.8.2
sentencepiece==0.1.99
sentry-sdk==2.7.1
setproctitle==1.3.3
simple_parsing==0.1.5
six==1.16.0
smmap==5.0.1
sniffio==1.3.0
soupsieve==2.5
stack-data==0.6.3
sympy==1.12
tensorboard==2.17.0
tensorboard-data-server==0.7.2
termcolor==2.4.0
terminado==0.17.1
tinycss2==1.2.1
tomli==2.0.1
torch==2.2.0
torchaudio==2.0.2+cu118
torchvision==0.15.2+cu118
tornado==6.3.3
tqdm==4.66.4
traitlets==5.13.0
triton==2.2.0
types-python-dateutil==2.8.19.14
typing_extensions==4.12.2
uri-template==1.3.0
urllib3==1.26.13
wadllib==1.3.6
wandb==0.17.4
wcwidth==0.2.9
webcolors==1.13
webencodings==0.5.1
websocket-client==1.6.4
Werkzeug==3.0.3
widgetsnbextension==4.0.9
xformers==0.0.24
zipp==1.0.0

Reproduction Steps

clone repo download mixtral 8x7b instruct v0.1 ==> put in /mistral_models download v3 tokenizer and put into /mistral_models run util extend which generates /mistral_models_extended put v3 tokenizer into /mistral_models_extended directory put data into /data run data validation train (generates checkpoints. after 300 steps get /workspace/mistral-finetune/experiment5/checkpoints/checkpoint_000300/consolidated/lora.safetensors)

Finally run mistral-chat: torchrun --nproc-per-node 2 --no-python mistral-chat /workspace/mistral_models_extended --max_tokens 256 --temperature 0.7 --instruct --lora_path workspace/mistral-finetune/experiment5/checkpoints/checkpoint_000300/consolidated/lora.safetensors

This generates the following error

[2024-07-04 00:29:06,277] torch.distributed.run: [WARNING] 
[2024-07-04 00:29:06,277] torch.distributed.run: [WARNING] *****************************************
[2024-07-04 00:29:06,277] torch.distributed.run: [WARNING] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
[2024-07-04 00:29:06,277] torch.distributed.run: [WARNING] *****************************************
Traceback (most recent call last):
  File "/usr/local/bin/mistral-chat", line 8, in <module>
    sys.exit(mistral_chat())
  File "/usr/local/lib/python3.10/dist-packages/mistral_inference/main.py", line 179, in mistral_chat
    fire.Fire(interactive)
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/mistral_inference/main.py", line 70, in interactive
    transformer.load_lora(Path(lora_path))
  File "/usr/local/lib/python3.10/dist-packages/mistral_inference/lora.py", line 104, in load_lora
    self._load_lora_state_dict(state_dict, scaling=scaling)
  File "/usr/local/lib/python3.10/dist-packages/mistral_inference/lora.py", line 144, in _load_lora_state_dict
    lora_state_dict[name + ".lora_B.weight"]
KeyError: 'layers.16.feed_forward.gate.lora_B.weight'
Traceback (most recent call last):
  File "/usr/local/bin/mistral-chat", line 8, in <module>
    sys.exit(mistral_chat())
  File "/usr/local/lib/python3.10/dist-packages/mistral_inference/main.py", line 179, in mistral_chat
    fire.Fire(interactive)
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 143, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 477, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/usr/local/lib/python3.10/dist-packages/fire/core.py", line 693, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/mistral_inference/main.py", line 70, in interactive
    transformer.load_lora(Path(lora_path))
  File "/usr/local/lib/python3.10/dist-packages/mistral_inference/lora.py", line 104, in load_lora
    self._load_lora_state_dict(state_dict, scaling=scaling)
  File "/usr/local/lib/python3.10/dist-packages/mistral_inference/lora.py", line 144, in _load_lora_state_dict
    lora_state_dict[name + ".lora_B.weight"]
KeyError: 'layers.0.feed_forward.gate.lora_B.weight'
[2024-07-04 00:29:21,301] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 3264) of binary: mistral-chat
Traceback (most recent call last):
  File "/usr/local/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 812, in main
    run(args)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
mistral-chat FAILED
------------------------------------------------------------
Failures:
[1]:
  time      : 2024-07-04_00:29:21
  host      : finetuning-latest-2-0
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 3265)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-07-04_00:29:21
  host      : finetuning-latest-2-0
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 3264)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================

Expected Behavior

Expect Prompt to appear for interactive chat in terminal, but getting the above error instead.

tensimixt commented 3 days ago

@patrickvonplaten Hi do you know if mistral-inference works for lora+mixtral8x7b instruct v0.1? It does work for lora+mistral-7b v0.3 but getting error about LoRA weights file being loaded missing an expected key for one of the model layers when trying for lora+mixtral8x7b instruct v0.1

Is there something else required to make it work?

Thank you