pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.34k stars 484 forks source link

shape fix for gptq #156

Closed HDCharles closed 2 months ago

HDCharles commented 2 months ago

Summary: aligns with previous shape fixes (https://github.com/pytorch-labs/gpt-fast/pull/152)

Test Plan:

export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4-gptq --calibration_tasks wikitext --calibration_limit 10 python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4-gptq.g32.cuda.pth --tasks wikitext

wikitext: {'word_perplexity,none': 12.4647656874071, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.6028703940149458, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.6806577757911142, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 python eval.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.cuda.pth --tasks wikitext

wikitext: {'word_perplexity,none': 12.639992147818221, 'word_perplexity_stderr,none': 'N/A', 'byte_perplexity,none': 1.6070602521912754, 'byte_perplexity_stderr,none': 'N/A', 'bits_per_byte,none': 0.6844240198082908, 'bits_per_byte_stderr,none': 'N/A', 'alias': 'wikitext'}

Reviewers:

Subscribers:

Tasks:

Tags: