pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.68k stars 514 forks source link

getting different acceptance prob when using `torch.compile` after making a small change. #184

Open kalradivyanshu opened 5 months ago

kalradivyanshu commented 5 months ago

I cloned the gpt-fast repo, and tried it out with Llama-3, to setup I ran the following code:

pip install huggingface_hub[hf_transfer]
export HF_HUB_ENABLE_HF_TRANSFER=1

python3 -m pip install -r ./requirements.txt

huggingface-cli login --token $HF_TOKEN

huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct --local-dir ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct

python3 scripts/convert_hf_checkpoint.py --checkpoint_dir ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct --model_name Meta-Llama-3-8B-Instruct

python3 quantize.py --checkpoint_path ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth --mode int8

python3 quantize.py --checkpoint_path ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth --mode int4

Now I ran generate with speculative decoding:

python3 generate.py --speculate_k 5 --checkpoint_path ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/model.pth --draft_checkpoint_path ./checkpoints/meta-llama/Meta-Llama-3-8B-Instruct/
model_int4.g32.pth  --compile

I get:

Acceptance probs: [0.06866952789699571, 0.03862660944206009, 0.055793991416309016, 0.06866952789699571, 0.030042918454935622, 0.7381974248927039]
Mean Accepted: 4.167381974248927
Average tokens/sec: 76.72
Memory used: 22.38 GB

Which makes sense. I was playing around with the model and added just one line in Transformer's forward method:

    def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
        assert self.freqs_cis is not None, "Caches must be initialized first"
        mask = self.causal_mask[None, None, input_pos]
        freqs_cis = self.freqs_cis[input_pos]
        x = self.tok_embeddings(idx)

        for i, layer in enumerate(self.layers):
            x = layer(x, input_pos, freqs_cis, mask)
        x = self.norm(x)
        self.inner_state = x #NEW LINE
        logits = self.output(x)
        return logits

Now when I run generate with the same command, I get really low acceptance rate:

Acceptance probs: [0.5620253164556962, 0.3632911392405063, 0.07088607594936709, 0.0037974683544303796, 0.0, 0.0]
Mean Accepted: 0.5164556962025316
Average tokens/sec: 24.87
Memory used: 22.12 GB

But if I don't pass --compile I get the same acceptance rate as before:

Acceptance probs: [0.07142857142857142, 0.05102040816326531, 0.04591836734693878, 0.05102040816326531, 0.07142857142857142, 0.7091836734693877]
Mean Accepted: 4.127551020408164
Average tokens/sec: 24.03
Memory used: 22.10 GB

My question is why is this one line causing that drastic decline in quality when using compile? Here is the commit with the change in my fork: https://github.com/kalradivyanshu/gpt-fast/commit/20bd67360daf1e778f4ca1289cfbf12225c42be7

Any insights will really be appreciated. Thankyou!