TransformerLensOrg / TransformerLens

A library for mechanistic interpretability of GPT-style language models
https://transformerlensorg.github.io/TransformerLens/
MIT License
1.6k stars 307 forks source link

Add mixed precision inference incl loading #104

Open neelnanda-io opened 1 year ago

neelnanda-io commented 1 year ago

Add the option to load models in bfloat16 and float16. Esp important for large models like GPT-J and GPT-NeoX.

Ideally, load from HuggingFace in this low precision, do weight processing on the CPU, and then move the processed model weights to the GPU. Might be easiest to do the weight processing once and caching to HF (see #103 )

neelnanda-io commented 1 year ago

Maybe covered by #125

glerzing commented 1 year ago

Solved with #298

Edit : actually not solved yet, there are still problems with HookedTransformer.generate, and perhaps optimizations to do. I'm preparing a commit.

tbenthompson commented 1 year ago

Before I found this issue, I didn't @glerzing was working on #317 so I was planning to report separately.

Anyway, despite progress, I thought I'd share a demo where I get nans running in float16:

import torch
from transformer_lens import HookedTransformer

torch.set_grad_enabled(False)

# Issue #1: there's no way to use float16 on initialization so we're forced to
# convert to float16.
model32 = HookedTransformer.from_pretrained(f"EleutherAI/pythia-70m-deduped")
print(repr(model32.to_string(model32(" Unable")[0, -1].argmax())))

# Issue #2:
model16 = HookedTransformer.from_pretrained(f"EleutherAI/pythia-70m-deduped").to(
    torch.float16
)
print(model16(" Unable"))

Outputs:

Using pad_token, but it is not set yet.
Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer
' to'
Using pad_token, but it is not set yet.
Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer
Changing model dtype to torch.float16
tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], device='cuda:0',
       dtype=torch.float16)

Thanks for the progress on #317 !!! Psyched to see it merged.

neelnanda-io commented 1 year ago

I believe that Pythia 70m can have attention scores as low as -100,000, which will get you nans in float16 because those can do max -65,536. Honestly, my take is that this is not our problem, and you should use bfloat16 instead, so long as HuggingFace also gives you nans. I have no clue why Pythia is this high lol.

On Tue, 13 Jun 2023 at 01:55, Ben Thompson @.***> wrote:

Before I found this issue, I didn't @glerzing https://github.com/glerzing was working on #317 https://github.com/neelnanda-io/TransformerLens/pull/317 so I was planning to report separately.

Anyway, despite progress, I thought I'd share a demo where I get nans running in float16:

import torch from transformer_lens import HookedTransformer

torch.set_grad_enabled(False)

Issue #1: there's no way to use float16 on initialization so we're forced to

convert to float16.

model32 = HookedTransformer.from_pretrained(f"EleutherAI/pythia-70m-deduped") print(repr(model32.to_string(model32(" Unable")[0, -1].argmax())))

Issue #2:

model16 = HookedTransformer.from_pretrained(f"EleutherAI/pythia-70m-deduped").to( torch.float16 ) print(model16(" Unable"))

Thanks for the progress on #317 https://github.com/neelnanda-io/TransformerLens/pull/317 !!! Psyched to see it merged.

— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/104#issuecomment-1588341142, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKIWFH77YISXJNTSBR3XK627RANCNFSM6AAAAAATDIORBA . You are receiving this because you authored the thread.Message ID: @.***>

tbenthompson commented 1 year ago

Wow, that's fascinating about the giant attention scores!!

I'm seeing big differences in both bfloat16 and float16 between Huggingface and TL on Pythia 410M. I was suspicious that the TL processing (fold LN, center unembed, etc) was causing the differences so I tried from_pretrained_no_processing but the differences persist.

I'm gradually learning more about the internals of TL so if I have time soon, I'll dig in on this and try to figure out what's going on.

Source ``` import torch from transformer_lens import HookedTransformer import transformers torch.set_grad_enabled(False) model_name = f"EleutherAI/pythia-410m-deduped" model32 = HookedTransformer.from_pretrained_no_processing(model_name) logits32 = model32(" Unable", prepend_bos=False)[0, -1] p32 = torch.softmax(logits32, dim=-1) del model32 model16 = HookedTransformer.from_pretrained_no_processing(model_name).to(torch.float16) logits16 = model16(" Unable", prepend_bos=False)[0, -1] p16 = torch.softmax(logits16, dim=-1) del model16 modelB16 = HookedTransformer.from_pretrained_no_processing(model_name).to(torch.bfloat16) logitsB16 = modelB16(" Unable", prepend_bos=False)[0, -1] pB16 = torch.softmax(logitsB16, dim=-1) del modelB16 tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") hf_model32 = transformers.GPTNeoXForCausalLM.from_pretrained( model_name, low_cpu_mem_usage=True, torch_dtype=torch.float32 ).cuda() hf_logits32 = hf_model32(tokenizer(" Unable", return_tensors="pt")["input_ids"].cuda()).logits[0, 0, :] hf_p32 = torch.softmax(hf_logits32, dim=-1) del hf_model32 hf_modelB16 = transformers.GPTNeoXForCausalLM.from_pretrained( model_name, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 ).cuda() hf_logitsB16 = hf_modelB16(tokenizer(" Unable", return_tensors="pt")["input_ids"].cuda()).logits[0, 0, :] hf_pB16 = torch.softmax(hf_logitsB16, dim=-1) del hf_modelB16 hf_model16 = transformers.GPTNeoXForCausalLM.from_pretrained( model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16 ).cuda() hf_logits16 = hf_model16(tokenizer(" Unable", return_tensors="pt")["input_ids"].cuda()).logits[0, 0, :] hf_p16 = torch.softmax(hf_logits16, dim=-1) del hf_model16 print(f'TL, float32 top_token={repr(tokenizer.decode(logits32.argmax())):<16} p={p32.max().item():.3f}') print(f'TL, bfloat16 top_token={repr(tokenizer.decode(logitsB16.argmax())):<16} p={pB16.max().item():.3f}') print(f'TL, float16 top_token={repr(tokenizer.decode(logits16.argmax())):<16} p={p16.max().item():.3f}') print(f'HF, float32 top_token={repr(tokenizer.decode(logits32.argmax())):<16} p={hf_p32.max().item():.3f}') print(f'HF, bfloat16 top_token={repr(tokenizer.decode(logitsB16.argmax())):<16} p={hf_pB16.max().item():.3f}') print(f'HF, float16 top_token={repr(tokenizer.decode(logits16.argmax())):<16} p={hf_p16.max().item():.3f}') ```

Output

TL, float32    top_token=' to'            p=0.745
TL, bfloat16   top_token=' to'            p=0.641
TL, float16    top_token='\n'             p=0.000
HF, float32    top_token=' to'            p=0.745
HF, bfloat16   top_token=' to'            p=0.746
HF, float16    top_token='\n'             p=0.750

For HF: the float32 and bfloat16 results are "good". The float16 results are bad but I'm assuming that's just that the some internal activations are out of range for fp16! For TL: the float32 results are good but something is going wrong beyond the expected numerical issues with both bfloat16 and float16.

neelnanda-io commented 1 year ago

I've also observed this - my weak guess is that it's due to implementation details like the use of einsum and reshaping of attention matrices? I'm not super sure otherwise what would be implemented differently.

It would be interesting to me to carefully go through each activation and compare between TL and HF, and I'd be curious what you find!

On Tue, 13 Jun 2023 at 13:26, Ben Thompson @.***> wrote:

Wow, that's fascinating about the giant attention scores!!

I'm seeing big differences in both bfloat16 and float16 between Huggingface and TL on Pythia 410M. I was suspicious that the TL processing (fold LN, center unembed, etc) was causing the differences so I tried from_pretrained_no_processing but the differences persist.

I'm gradually learning more about the internals of TL so if I have time soon, I'll dig in on this and try to figure out what's going on. Source

import torch from transformer_lens import HookedTransformer import transformers

torch.set_grad_enabled(False)

model_name = f"EleutherAI/pythia-410m-deduped" model32 = HookedTransformer.from_pretrained_no_processing(model_name) logits32 = model32(" Unable", prepend_bos=False)[0, -1] p32 = torch.softmax(logits32, dim=-1) del model32

model16 = HookedTransformer.from_pretrained_no_processing(model_name).to(torch.float16) logits16 = model16(" Unable", prepend_bos=False)[0, -1] p16 = torch.softmax(logits16, dim=-1) del model16

modelB16 = HookedTransformer.from_pretrained_no_processing(model_name).to(torch.bfloat16) logitsB16 = modelB16(" Unable", prepend_bos=False)[0, -1] pB16 = torch.softmax(logitsB16, dim=-1) del modelB16

tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

hf_model32 = transformers.GPTNeoXForCausalLM.from_pretrained( model_name, low_cpu_mem_usage=True, torch_dtype=torch.float32 ).cuda() hf_logits32 = hf_model32(tokenizer(" Unable", return_tensors="pt")["input_ids"].cuda()).logits[0, 0, :] hf_p32 = torch.softmax(hf_logits32, dim=-1) del hf_model32

hf_modelB16 = transformers.GPTNeoXForCausalLM.from_pretrained( model_name, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 ).cuda() hf_logitsB16 = hf_modelB16(tokenizer(" Unable", return_tensors="pt")["input_ids"].cuda()).logits[0, 0, :] hf_pB16 = torch.softmax(hf_logitsB16, dim=-1) del hf_modelB16

hf_model16 = transformers.GPTNeoXForCausalLM.from_pretrained( model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16 ).cuda() hf_logits16 = hf_model16(tokenizer(" Unable", return_tensors="pt")["input_ids"].cuda()).logits[0, 0, :] hf_p16 = torch.softmax(hf_logits16, dim=-1) del hf_model16

print(f'TL, float32 top_token={repr(tokenizer.decode(logits32.argmax())):<16} p={p32.max().item():.3f}') print(f'TL, bfloat16 top_token={repr(tokenizer.decode(logitsB16.argmax())):<16} p={pB16.max().item():.3f}') print(f'TL, float16 top_token={repr(tokenizer.decode(logits16.argmax())):<16} p={p16.max().item():.3f}') print(f'HF, float32 top_token={repr(tokenizer.decode(logits32.argmax())):<16} p={hf_p32.max().item():.3f}') print(f'HF, bfloat16 top_token={repr(tokenizer.decode(logitsB16.argmax())):<16} p={hf_pB16.max().item():.3f}') print(f'HF, float16 top_token={repr(tokenizer.decode(logits16.argmax())):<16} p={hf_p16.max().item():.3f}')

Output

TL, float32 top_token=' to' p=0.745 TL, bfloat16 top_token=' to' p=0.641 TL, float16 top_token='\n' p=0.000 HF, float32 top_token=' to' p=0.745 HF, bfloat16 top_token=' to' p=0.746 HF, float16 top_token='\n' p=0.750

— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/104#issuecomment-1589205548, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKKGLYRIHCWJ27I63ODXLBMAZANCNFSM6AAAAAATDIORBA . You are receiving this because you authored the thread.Message ID: @.***>

slavachalnev commented 1 year ago

I think these two changes fix the float16 issue:

  1. Keep LayerNorm in float32
  2. Apply attention scale before computing attention scores. So instead of dividing by attention_scale, divide both q and k by sqrt(attention_scale)
class LayerNorm(nn.Module):
    ...

    def forward():
        x_type = x.dtype
        x = x.to(torch.float32)

        x = x - x.mean(axis=-1, keepdim=True)  # [batch, pos, length]
        scale: Float[torch.Tensor, "batch pos 1"] = self.hook_scale(
            (x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
        )
        x = x / scale  # [batch, pos, length]

        return self.hook_normalized(x * self.w + self.b).to(x_type)

class Attention(nn.Module):
    def __init__():
        ...
        self.attn_scale = np.sqrt(np.sqrt(self.cfg.d_head))
        ...

    def forward():
        ...
        q = q / self.attn_scale
        k = k / self.attn_scale 

        attn_scores = (
            einsum(
                "batch query_pos head_index d_head, \
                    batch key_pos head_index d_head \
                    -> batch head_index query_pos key_pos",
                q,
                k,
            )
            #/ self.attn_scale # REMOVE THIS LINE
        )
        ...

Now running the above script gives

TL, float32    top_token=' to'            p=0.745
TL, bfloat16   top_token=' to'            p=0.641
TL, float16    top_token=' to'            p=0.740
Details

Looks like LayerNorm should stay in float32 https://github.com/pytorch/pytorch/issues/66707

When running the above test script, I saw that attention scores are reasonable for most of the forward pass but get very large (negative) in the last two blocks. Results in -inf when running in float16 without the attention fix. HuggingFace implementation uses `torch.baddbmm` which does both matmul and scaling in one operation.

wesg52 commented 1 year ago

I believe that Pythia 70m can have attention scores as low as -100,000, which will get you nans in float16 because those can do max -65,536. Honestly, my take is that this is not our problem, and you should use bfloat16 instead, so long as HuggingFace also gives you nans. I have no clue why Pythia is this high lol.

So I think Theo Horsley might have discovered why: Pythia models have large bias vectors on the K and Q values (He said the K bias vector for one head was like norm 300 which is especially silly given its just a constant offset). At least in normal supervised ML you don't apply L2 regularization to bias terms, so I assume similarly there is no weight decay on the attn biases and so they end up large and blow up the attn scores.

neelnanda-io commented 1 year ago

Interesting! Note that Pythia uses rotary attention, where b_K does matter (the key gets rotated by the difference in positions, so it doesn't cancel out between different source tokens)

On Thu, 27 Jul 2023, 6:31 pm Wes Gurnee, @.***> wrote:

I believe that Pythia 70m can have attention scores as low as -100,000, which will get you nans in float16 because those can do max -65,536. Honestly, my take is that this is not our problem, and you should use bfloat16 instead, so long as HuggingFace also gives you nans. I have no clue why Pythia is this high lol.

So I think Theo Horsley might have discovered why: Pythia models have large bias vectors on the K and Q values (He said the K bias vector for one head was like norm 300 which is especially silly given its just a constant offset). At least in normal supervised ML you don't apply L2 regularization to bias terms, so I assume similarly there is no weight decay on the attn biases and so they end up large and blow up the attn scores.

— Reply to this email directly, view it on GitHub https://github.com/neelnanda-io/TransformerLens/issues/104#issuecomment-1654080123, or unsubscribe https://github.com/notifications/unsubscribe-auth/ASRPNKJYMCQ2N3VWD64LKA3XSKQY7ANCNFSM6AAAAAATDIORBA . You are receiving this because you authored the thread.Message ID: @.***>

tbenthompson commented 1 year ago

I just re-ran the test above with TL 1.5.0 and I'm getting much better results but there are still noticeable discrepancies from the HF implementation:

TL, float32    top_token=' to'            p=0.745
TL, bfloat16   top_token=' to'            p=0.730
TL, float16    top_token=' to'            p=0.737
HF, float32    top_token=' to'            p=0.745
HF, bfloat16   top_token=' to'            p=0.746
HF, float16    top_token=' to'            p=0.750

Since Pythia was trained in float16, we should probably ignore the bfloat16 comparison, but the discrepancy in float16 is still noticeable.

Thanks glerzing and slavachalnev for the improvements!

glerzing commented 1 year ago

I think the main explanation for the differences is the fact that TransformerLens uses einsum instead of Linear layers or Conv1D.

Lorenayannnnn commented 3 weeks ago

Just a follow up on this: I'm observing inconsistency when doing greedy decoding (using .generate) with a hooked transformer. One example is:

The third answer is changed from "Stan" to "Without Me". Note that input1 is a prefix of input 2, meaning if I do greedy decoding I should expect output1 to have the same three answers as output2. However, it is giving me different results. In fact, if I just use "List only the name of three songs performed by Eminem: 1. \"" as the input (adding just one extra token \" to input1), the third answer will be changed from Stan to Without Me.

The example above is when I'm loading Mistral with dtype=float16. However, this type of inconsistency is still happening with float32 for both Llama3 and Mistral. Any insights would be super helpful thank you! :)