facebookresearch / spiritlm

Inference code for the paper "Spirit-LM Interleaved Spoken and Written Language Model".
Other
817 stars 52 forks source link

using SPIRITLM_CHECKPOINTS_DIR for custom checkpoints dir #12

Closed tarekabouzeid closed 1 month ago

tarekabouzeid commented 1 month ago

After downloading checkpoints, usually these resides on a different directory outside the repo, so adding support for an environment variable SPIRITLM_CHECKPOINTS_DIR where it points to a custom location for downloaded checkpoints. While maintaining backward compatibility in case the environment variable is not defined

This should resolves #11

facebook-github-bot commented 1 month ago

Hi @tarekabouzeid!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

facebook-github-bot commented 1 month ago

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

hitchhicker commented 1 month ago

Thanks for the PR! Could you run pytest tests from the root folder of this repo to see nothing is broken? Thanks!

tarekabouzeid commented 1 month ago

Hi @hitchhicker , Thank you, I extracted checkpoints to /tmp/checkpoints/checkpoints/ and verified that I don't have the checkpoints dir in root folder of the repo, exported export SPIRITLM_CHECKPOINTS_DIR=/tmp/checkpoints/checkpoints/ , then ran pytest, below is the result

(/home/jovyan/spirit_llm/spirit_llm_env) jovyan@jupyter-user:~/spirit_llm/meta-spiritlm$ export SPIRITLM_CHECKPOINTS_DIR=/tmp/checkpoints/checkpoints/
(/home/jovyan/spirit_llm/spirit_llm_env) jovyan@jupyter-user:~/spirit_llm/meta-spiritlm$ rm -rf checkpoints
(/home/jovyan/spirit_llm/spirit_llm_env) jovyan@jupyter-user:~/spirit_llm/meta-spiritlm$ ls -ltrh
total 48K
-rw-r--r-- 1 jovyan users 3.5K Oct 23 13:05 CODE_OF_CONDUCT.md
-rw-r--r-- 1 jovyan users 1.3K Oct 23 13:05 CONTRIBUTING.md
-rw-r--r-- 1 jovyan users  12K Oct 23 13:05 LICENSE
-rw-r--r-- 1 jovyan users 5.3K Oct 23 13:05 MODEL_CARD.md
-rw-r--r-- 1 jovyan users 1.8K Oct 23 13:05 README.md
drwxr-xr-x 2 jovyan users   35 Oct 23 13:05 assets
drwxr-xr-x 3 jovyan users   22 Oct 23 13:05 data
-rw-r--r-- 1 jovyan users  319 Oct 23 13:05 env.yml
drwxr-xr-x 6 jovyan users  104 Oct 23 13:05 examples
-rw-r--r-- 1 jovyan users    6 Oct 23 13:05 requirements.dev.txt
-rw-r--r-- 1 jovyan users  137 Oct 23 13:05 requirements.txt
-rw-r--r-- 1 jovyan users 1.7K Oct 23 13:05 setup.py
drwxr-xr-x 6 jovyan users   93 Oct 23 13:13 spiritlm
drwxr-xr-x 2 jovyan users  130 Oct 24 07:19 spiritlm.egg-info
drwxr-xr-x 3 jovyan users   97 Oct 24 07:19 tests
(/home/jovyan/spirit_llm/spirit_llm_env) jovyan@jupyter-user:~/spirit_llm/meta-spiritlm$ pytest
================================================================================================================ test session starts ================================================================================================================
platform linux -- Python 3.9.20, pytest-8.3.3, pluggy-1.5.0
rootdir: /home/jovyan/spirit_llm/meta-spiritlm
collected 33 items                                                                                                                                                                                                                                  

tests/test_spirit_model.py ............................                                                                                                                                                                                       [ 84%]
tests/test_tokenizer.py .....                                                                                                                                                                                                                 [100%]

================================================================================================================= warnings summary ==================================================================================================================
tests/test_tokenizer.py::test_expressive_tokenizer_encode_units
  /home/jovyan/spirit_llm/spirit_llm_env/lib/python3.9/site-packages/torchfcpe/models_infer.py:191: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
    ckpt = torch.load(pt_path, map_location=torch.device(device))

tests/test_tokenizer.py::test_expressive_tokenizer_encode_units
tests/test_tokenizer.py::test_base_tokenizer_encode_units
  /home/jovyan/spirit_llm/spirit_llm_env/lib/python3.9/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.
    WeightNorm.apply(module, name, dim)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================================================== 33 passed, 3 warnings in 26.19s ==========================================================================================================
(/home/jovyan/spirit_llm/spirit_llm_env) jovyan@jupyter-user:~/spirit_llm/meta-spiritlm$ 

Then to double check backward compatibility, I unset the environment variable and created a symlink to the checkpoints dir instead and reran pytest, results:

(/home/jovyan/spirit_llm/spirit_llm_env) jovyan@jupyter-user:~/spirit_llm/meta-spiritlm$ unset SPIRITLM_CHECKPOINTS_DIR
(/home/jovyan/spirit_llm/spirit_llm_env) jovyan@jupyter-user:~/spirit_llm/meta-spiritlm$ ln -s /tmp/checkpoints/checkpoints/ \
      /home/jovyan/spirit_llm/meta-spiritlm
(/home/jovyan/spirit_llm/spirit_llm_env) jovyan@jupyter-user:~/spirit_llm/meta-spiritlm$ ls -ltrh
total 48K
-rw-r--r-- 1 jovyan users 3.5K Oct 23 13:05 CODE_OF_CONDUCT.md
-rw-r--r-- 1 jovyan users 1.3K Oct 23 13:05 CONTRIBUTING.md
-rw-r--r-- 1 jovyan users  12K Oct 23 13:05 LICENSE
-rw-r--r-- 1 jovyan users 5.3K Oct 23 13:05 MODEL_CARD.md
-rw-r--r-- 1 jovyan users 1.8K Oct 23 13:05 README.md
drwxr-xr-x 2 jovyan users   35 Oct 23 13:05 assets
drwxr-xr-x 3 jovyan users   22 Oct 23 13:05 data
-rw-r--r-- 1 jovyan users  319 Oct 23 13:05 env.yml
drwxr-xr-x 6 jovyan users  104 Oct 23 13:05 examples
-rw-r--r-- 1 jovyan users    6 Oct 23 13:05 requirements.dev.txt
-rw-r--r-- 1 jovyan users  137 Oct 23 13:05 requirements.txt
-rw-r--r-- 1 jovyan users 1.7K Oct 23 13:05 setup.py
drwxr-xr-x 6 jovyan users   93 Oct 23 13:13 spiritlm
drwxr-xr-x 2 jovyan users  130 Oct 24 07:19 spiritlm.egg-info
drwxr-xr-x 3 jovyan users   97 Oct 24 07:19 tests
lrwxrwxrwx 1 jovyan users   29 Oct 24 07:42 checkpoints -> /tmp/checkpoints/checkpoints/
(/home/jovyan/spirit_llm/spirit_llm_env) jovyan@jupyter-user:~/spirit_llm/meta-spiritlm$ env |grep -i spirit
XML_CATALOG_FILES=file:///home/jovyan/spirit_llm/spirit_llm_env/etc/xml/catalog file:///etc/xml/catalog
PWD=/home/jovyan/spirit_llm/meta-spiritlm
CONDA_PREFIX=/home/jovyan/spirit_llm/spirit_llm_env
CONDA_PROMPT_MODIFIER=(/home/jovyan/spirit_llm/spirit_llm_env) 
CONDA_DEFAULT_ENV=/home/jovyan/spirit_llm/spirit_llm_env
PATH=/home/jovyan/spirit_llm/spirit_llm_env/bin:/opt/conda/condabin:/opt/conda/bin:/opt/java/openjdk/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/spark/bin
OLDPWD=/home/jovyan/spirit_llm
(/home/jovyan/spirit_llm/spirit_llm_env) jovyan@jupyter-user:~/spirit_llm/meta-spiritlm$ 
(/home/jovyan/spirit_llm/spirit_llm_env) jovyan@jupyter-user:~/spirit_llm/meta-spiritlm$ pytest
================================================================================================================ test session starts ================================================================================================================
platform linux -- Python 3.9.20, pytest-8.3.3, pluggy-1.5.0
rootdir: /home/jovyan/spirit_llm/meta-spiritlm
collected 33 items                                                                                                                                                                                                                                  

tests/test_spirit_model.py ............................                                                                                                                                                                                       [ 84%]
tests/test_tokenizer.py .....                                                                                                                                                                                                                 [100%]

================================================================================================================= warnings summary ==================================================================================================================
tests/test_tokenizer.py::test_expressive_tokenizer_encode_units
  /home/jovyan/spirit_llm/spirit_llm_env/lib/python3.9/site-packages/torchfcpe/models_infer.py:191: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
    ckpt = torch.load(pt_path, map_location=torch.device(device))

tests/test_tokenizer.py::test_expressive_tokenizer_encode_units
tests/test_tokenizer.py::test_base_tokenizer_encode_units
  /home/jovyan/spirit_llm/spirit_llm_env/lib/python3.9/site-packages/torch/nn/utils/weight_norm.py:143: FutureWarning: `torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.
    WeightNorm.apply(module, name, dim)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================================================== 33 passed, 3 warnings in 28.85s ==========================================================================================================
(/home/jovyan/spirit_llm/spirit_llm_env) jovyan@jupyter-user:~/spirit_llm/meta-spiritlm$ 
hitchhicker commented 1 month ago

LGTM! Thanks for providing the detailed test outputs!