Open byjlw opened 1 month ago
You need to add --num-samples 3
. The default value is 1, which means the compile overhead will exceed the compile perf gain.
@Jack-Khuu , we should standardize the perf measurement script and check it in.
I don't recall having this issue with samples = 1 in the past. We should make sure we're breaking out time spent compiling and time spent during generation so users can see how time is being spent and what the real t/s is.
Users are not going to want to have more than one sample unless they're benchmarking.
Because --compile is using torch.compile, which is JIT compilation, there is no easy way to separate compile time and execution time. If the cold start time is a concern for user, they should choose the AOTI path.
🐛 Describe the bug
Eager gives me 31 tokens a second
python3 torchchat.py generate llama3.1 --prompt "what's the most recent historical event" --device cuda
torch.compile gives me 10 tokens a second
python3 torchchat.py generate llama3.1 --prompt "what's the most recent historical event" --device cuda --compile
I would expect compile at worst not to be slower than eager.
Versions
Operating System Information Linux Vikander 6.8.0-45-generic #45-Ubuntu SMP PREEMPT_DYNAMIC Fri Aug 30 12:02:04 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux
PRETTY_NAME="Ubuntu 24.04.1 LTS" NAME="Ubuntu" VERSION_ID="24.04" VERSION="24.04.1 LTS (Noble Numbat)" VERSION_CODENAME=noble ID=ubuntu ID_LIKE=debian HOME_URL="https://www.ubuntu.com/" SUPPORT_URL="https://help.ubuntu.com/" BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/" PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy" UBUNTU_CODENAME=noble LOGO=ubuntu-logo
Python Version Python 3.11.10
PIP Version pip 24.0 from /home/warden/source/torchchat/.venv/lib/python3.11/site-packages/pip (python 3.11)
Installed Packages absl-py==2.1.0 accelerate==1.0.1 aiohappyeyeballs==2.4.3 aiohttp==3.10.10 aiosignal==1.3.1 altair==5.4.1 annotated-types==0.7.0 antlr4-python3-runtime==4.9.3 anyio==4.6.2.post1 attrs==24.2.0 blinker==1.8.2 blobfile==3.0.0 cachetools==5.5.0 certifi==2024.8.30 chardet==5.2.0 charset-normalizer==3.4.0 click==8.1.7 cmake==3.30.4 colorama==0.4.6 DataProperty==1.0.1 datasets==3.0.1 dill==0.3.8 distro==1.9.0 evaluate==0.4.3 filelock==3.16.1 Flask==3.0.3 frozenlist==1.4.1 fsspec==2024.6.1 gguf==0.10.0 gitdb==4.0.11 GitPython==3.1.43 h11==0.14.0 httpcore==1.0.6 httpx==0.27.2 huggingface-hub==0.25.2 idna==3.10 itsdangerous==2.2.0 Jinja2==3.1.4 jiter==0.6.1 joblib==1.4.2 jsonlines==4.0.0 jsonschema==4.23.0 jsonschema-specifications==2024.10.1 lm_eval==0.4.2 lxml==5.3.0 markdown-it-py==3.0.0 MarkupSafe==3.0.1 mbstrdecoder==1.1.3 mdurl==0.1.2 more-itertools==10.5.0 mpmath==1.3.0 multidict==6.1.0 multiprocess==0.70.16 narwhals==1.9.3 networkx==3.4.1 ninja==1.11.1.1 nltk==3.9.1 numexpr==2.10.1 numpy==1.26.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==9.1.0.70 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.21.5 nvidia-nvjitlink-cu12==12.6.77 nvidia-nvtx-cu12==12.1.105 omegaconf==2.3.0 openai==1.51.2 packaging==24.1 pandas==2.2.3 pathvalidate==3.2.1 peft==0.13.2 pillow==10.4.0 portalocker==2.10.1 propcache==0.2.0 protobuf==5.28.2 psutil==6.0.0 pyarrow==17.0.0 pybind11==2.13.6 pycryptodomex==3.21.0 pydantic==2.9.2 pydantic_core==2.23.4 pydeck==0.9.1 Pygments==2.18.0 pytablewriter==1.2.0 python-dateutil==2.9.0.post0 pytorch-triton==3.1.0+cf34004b8a pytz==2024.2 PyYAML==6.0.2 referencing==0.35.1 regex==2024.9.11 requests==2.32.3 rich==13.9.2 rouge-score==0.1.2 rpds-py==0.20.0 sacrebleu==2.4.3 safetensors==0.4.5 scikit-learn==1.5.2 scipy==1.14.1 sentencepiece==0.2.0 six==1.16.0 smmap==5.0.1 snakeviz==2.2.0 sniffio==1.3.1 sqlitedict==2.1.0 streamlit==1.39.0 sympy==1.13.1 tabledata==1.3.3 tabulate==0.9.0 tcolorpy==0.1.6 tenacity==9.0.0 threadpoolctl==3.5.0 tiktoken==0.8.0 tokenizers==0.20.1 toml==0.10.2 torch==2.6.0.dev20241002+cu121 torchao==0.5.0 torchtune==0.3.0.dev20240928+cu121 torchvision==0.20.0.dev20241002+cu121 tornado==6.4.1 tqdm==4.66.5 tqdm-multiprocess==0.0.11 transformers==4.45.2 typepy==1.3.2 typing_extensions==4.12.2 tzdata==2024.2 urllib3==2.2.3 watchdog==5.0.3 Werkzeug==3.0.4 word2number==1.1 xxhash==3.5.0 yarl==1.15.2 zstandard==0.23.0 zstd==1.5.5.1
PyTorch Version 2.6.0.dev20241002+cu121