huggingface / optimum

šŸš€ Accelerate training and inference of šŸ¤— Transformers and šŸ¤— Diffusers with easy to use hardware optimization tools
https://huggingface.co/docs/optimum/main/
Apache License 2.0
2.6k stars 477 forks source link

Gradients greatly change after `BetterTransformer.transform` application #1091

Open lengstrom opened 1 year ago

lengstrom commented 1 year ago

System Info

optimum: 1.8.6
transformers: 4.29.2
ubuntu/CUDA 11.7/pytorch 2.1

Who can help?

@fxmarty @younesbelkada

Information

Tasks

Reproduction

https://gist.github.com/lengstrom/84a5d0d95f8942eb4eca82d9eb5aa272

This code should perfectly reproduce with just pytorch 2.1, optimum, datasets, numpy, and transformers installed. The output is (full log here: https://gist.github.com/lengstrom/84a5d0d95f8942eb4eca82d9eb5aa272?permalink_comment_id=4593821#gistcomment-4593821):

tensor([0.9919, 0.9719, 0.9756, 0.9666, 0.9827, 0.9827, 0.9839, 0.9857, 0.9909,
        0.9564, 0.9752, 0.9904, 0.9798, 0.9764, 0.9827, 0.9842])
cosine sim for gpt_neox.embed_in.weight:  tensor(0.0670) tensor(0.3000)
cosine sim for gpt_neox.layers.0.input_layernorm.weight:  tensor(0.7921) tensor(0.1659)
cosine sim for gpt_neox.layers.0.input_layernorm.bias:  tensor(0.8910) tensor(0.1325)
cosine sim for gpt_neox.layers.0.post_attention_layernorm.weight:  tensor(0.8930) tensor(0.0971)

This log shows that the gradients are often rather different, and that in some parameter groups the compared gradients are nearly uncorrelated (i.e., close to 0)

Expected behavior

The gradients of samples for models with and without BetterTransformer transformation should be nearly identical, but are in practice often very different.

Considering a GPTNeoX model, we fix a sample and record the gradient of the BetterTransformer-transformed model and the gradient of the original model. We then measure the cosine similarity (a vector similarity metric, 0 = uncorrelated, 1 = perfectly correlated) and find that both:

Is there a reason for this behavior? Based on the documentation, the only difference between the transformed and untransformed models is different code paths being used, which should not change the underlying gradients.

lengstrom commented 1 year ago

Intriguingly, the logits (i.e., just the result of the forward pass) are super close to one another:

def get_logits(m):
        fmodel, fweights, fbuffs, names = make_functional_with_buffers(m)
        logits_fn = ch.func.vmap(output_logits, in_dims=(None, None, None, 0))
        logits = logits_fn(fmodel, fweights, fbuffs, input_ids)
        return logits

logits = get_logits(model)
logits_btfm = get_logits(model_bettertfmer)

ch.nn.functional.cosine_similarity(logits.reshape(bs, -1), logits_btfm.reshape(bs, -1), dim=1)

yields

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       device='cuda:0', grad_fn=<SumBackward1>)
fxmarty commented 1 year ago

Thank you - this is rather critical, apologies this should have been tested upfront. I had a try reproducing with just torch.nn.functional.scaled_dot_product_attention and the boiler plate code from PyTorch doc - gradients indeed match there. So the issue is really in our implementation. Let me have a look, debug and add tests for this.

fxmarty commented 1 year ago

I can't reproduce the issue with gpt2 and gpt-neo type models, but can with gpt-neox, where the gradients for attention.query_key_value.bias differ by a large relative margin. We have the same in your logs, things get fuzzy from there:

cosine sim for gpt_neox.layers.5.post_attention_layernorm.weight:  tensor(1.) tensor(8.2333e-08)
cosine sim for gpt_neox.layers.5.post_attention_layernorm.bias:  tensor(1.) tensor(8.3089e-08)
cosine sim for gpt_neox.layers.5.attention.query_key_value.weight:  tensor(0.9972) tensor(0.0091)
cosine sim for gpt_neox.layers.5.attention.query_key_value.bias:  tensor(0.7171) tensor(0.4962)
cosine sim for gpt_neox.layers.5.attention.dense.weight:  tensor(1.0000) tensor(4.8597e-07)
cosine sim for gpt_neox.layers.5.attention.dense.bias:  tensor(1.) tensor(8.8518e-08)
cosine sim for gpt_neox.layers.5.mlp.dense_h_to_4h.weight:  tensor(1.) tensor(9.2877e-08)
cosine sim for gpt_neox.layers.5.mlp.dense_h_to_4h.bias:  tensor(1.) tensor(8.5622e-08)
cosine sim for gpt_neox.layers.5.mlp.dense_4h_to_h.weight:  tensor(1.) tensor(8.9708e-08)
cosine sim for gpt_neox.layers.5.mlp.dense_4h_to_h.bias:  tensor(1.) tensor(8.8518e-08)
cosine sim for gpt_neox.final_layer_norm.weight:  tensor(1.) tensor(8.7570e-08)
cosine sim for gpt_neox.final_layer_norm.bias:  tensor(1.) tensor(8.9339e-08)
cosine sim for embed_out.weight:  tensor(0.9811) tensor(0.1333)```

Here's a simplier script (yours does not work when I change the model):

from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, GPT2LMHeadModel
from optimum.bettertransformer import BetterTransformer

import torch
import copy
import functools

def recurse_getattr(obj, attr: str):
    def _getattr(obj, attr):
        return getattr(obj, attr)

    return functools.reduce(_getattr, [obj] + attr.split("."))

model_id = "EleutherAI/pythia-160m"
model_vanilla = AutoModelForCausalLM.from_pretrained(model_id)
model_sdpa = AutoModelForCausalLM.from_pretrained(model_id)

for name, param_vanilla in model_vanilla.named_parameters():
    param_sdpa = recurse_getattr(model_sdpa, name)

    maxdiff = torch.max(torch.abs(param_vanilla - param_sdpa)).item()
    relativediff = torch.mean(torch.abs(param_vanilla - param_sdpa) / torch.abs(param_vanilla))
    print(f"{name} param match:", torch.allclose(param_vanilla, param_sdpa), f"Maxdiff: {maxdiff}, relativediff: {relativediff}")

tokenizer = AutoTokenizer.from_pretrained(model_id)

model_sdpa = BetterTransformer.transform(model_sdpa)

inp = tokenizer(["This is just to test gradients"] * 16, return_tensors="pt")

model_vanilla = model_vanilla.train()
model_sdpa = model_sdpa.train()
model_vanilla.zero_grad()
model_sdpa.zero_grad()

model_vanilla = model_vanilla.eval()
model_sdpa = model_sdpa.eval()

torch.set_printoptions(threshold=100000000000)

res_vanilla = model_vanilla(**inp).logits

print("\n\n\n\n RUNNING SDPA")
res_sdpa = model_sdpa(**inp).logits

print("Res match:", torch.allclose(res_vanilla, res_sdpa), "Maxdiff:", torch.max(torch.abs(res_vanilla - res_sdpa)).item())

loss_vanilla = res_vanilla.mean()
loss_sdpa = res_sdpa.mean()

loss_vanilla.backward()
loss_sdpa.backward()

print("Loss match:", torch.allclose(loss_vanilla, loss_sdpa), "Maxdiff:", torch.max(torch.abs(loss_vanilla - loss_sdpa)).item())

for name, param_vanilla in model_vanilla.named_parameters():
    param_sdpa = recurse_getattr(model_sdpa, name)

    maxdiff = torch.max(torch.abs(param_vanilla.grad - param_sdpa.grad)).item()
    relativediff = torch.mean(torch.abs(param_vanilla.grad - param_sdpa.grad) / torch.abs(param_vanilla.grad))
    cosine = torch.nn.functional.cosine_similarity(param_vanilla.grad.flatten(), param_sdpa.grad.flatten(), dim=0)
    print(f"{name} grad match:", torch.allclose(param_vanilla.grad, param_sdpa.grad), f"Maxdiff: {maxdiff}, relativediff: {relativediff}, cosine={cosine}")
fxmarty commented 1 year ago

Actually I don't think there is any issue either with gpt-neox. The only fuzzy bit I see are these gradients, but we are 10e-14 here.

image

Otherwise, printing the cosine similarity (updated the code above), I have fine results:

gpt_neox.embed_in.weight grad match: False Maxdiff: 0.0004773736000061035, relativediff: nan, cosine=1.0
gpt_neox.layers.0.input_layernorm.weight grad match: False Maxdiff: 6.480515003204346e-05, relativediff: 0.0008785824757069349, cosine=1.0
gpt_neox.layers.0.input_layernorm.bias grad match: False Maxdiff: 2.306140959262848e-05, relativediff: 0.002862692577764392, cosine=1.0
gpt_neox.layers.0.post_attention_layernorm.weight grad match: False Maxdiff: 2.3115426301956177e-05, relativediff: 0.0007874640286900103, cosine=1.0000001192092896
gpt_neox.layers.0.post_attention_layernorm.bias grad match: False Maxdiff: 1.6219913959503174e-05, relativediff: 0.0012927413918077946, cosine=1.0000001192092896
gpt_neox.layers.0.attention.query_key_value.weight grad match: False Maxdiff: 0.00034546852111816406, relativediff: 0.0016394504345953465, cosine=1.0
gpt_neox.layers.0.attention.query_key_value.bias grad match: False Maxdiff: 3.819167613983154e-05, relativediff: nan, cosine=1.0000001192092896
gpt_neox.layers.0.attention.dense.weight grad match: False Maxdiff: 7.398426532745361e-05, relativediff: 0.001369286677800119, cosine=0.9999998211860657
gpt_neox.layers.0.attention.dense.bias grad match: False Maxdiff: 7.408857345581055e-05, relativediff: 0.0011199495056644082, cosine=1.0000001192092896
gpt_neox.layers.0.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.0016028881072998047, relativediff: 0.001413618098013103, cosine=1.000000238418579
gpt_neox.layers.0.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 4.233419895172119e-05, relativediff: 0.0007718717097304761, cosine=1.0
gpt_neox.layers.0.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 0.00015889108180999756, relativediff: 0.001413600635714829, cosine=0.9999999403953552
gpt_neox.layers.0.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 7.408857345581055e-05, relativediff: 0.0011199495056644082, cosine=1.0000001192092896
gpt_neox.layers.1.input_layernorm.weight grad match: False Maxdiff: 1.6685575246810913e-05, relativediff: 0.0006055051344446838, cosine=0.9999999403953552
gpt_neox.layers.1.input_layernorm.bias grad match: False Maxdiff: 1.432374119758606e-05, relativediff: 0.0005490509211085737, cosine=0.9999999403953552
gpt_neox.layers.1.post_attention_layernorm.weight grad match: False Maxdiff: 3.9480626583099365e-05, relativediff: 0.001803021994419396, cosine=0.9999999403953552
gpt_neox.layers.1.post_attention_layernorm.bias grad match: False Maxdiff: 1.3155164197087288e-05, relativediff: 0.0043691834434866905, cosine=1.0
gpt_neox.layers.1.attention.query_key_value.weight grad match: False Maxdiff: 0.00036919116973876953, relativediff: 0.0015037399716675282, cosine=0.9999999403953552
gpt_neox.layers.1.attention.query_key_value.bias grad match: False Maxdiff: 3.212690353393555e-05, relativediff: nan, cosine=1.0000001192092896
gpt_neox.layers.1.attention.dense.weight grad match: False Maxdiff: 0.00013083219528198242, relativediff: 0.0013935145689174533, cosine=0.9999999403953552
gpt_neox.layers.1.attention.dense.bias grad match: False Maxdiff: 6.993114948272705e-05, relativediff: 0.0010215052170678973, cosine=0.9999999403953552
gpt_neox.layers.1.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.0004469156265258789, relativediff: nan, cosine=1.0
gpt_neox.layers.1.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.7128495275974274e-05, relativediff: nan, cosine=0.9999999403953552
gpt_neox.layers.1.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 7.201731204986572e-05, relativediff: nan, cosine=1.0
gpt_neox.layers.1.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 6.993114948272705e-05, relativediff: 0.0010215052170678973, cosine=0.9999999403953552
gpt_neox.layers.2.input_layernorm.weight grad match: False Maxdiff: 1.7311424016952515e-05, relativediff: 0.001300215837545693, cosine=1.0
gpt_neox.layers.2.input_layernorm.bias grad match: False Maxdiff: 1.7337501049041748e-05, relativediff: 0.001227947068400681, cosine=1.0
gpt_neox.layers.2.post_attention_layernorm.weight grad match: False Maxdiff: 3.092736005783081e-05, relativediff: 0.0005665196222253144, cosine=1.0
gpt_neox.layers.2.post_attention_layernorm.bias grad match: False Maxdiff: 1.5107914805412292e-05, relativediff: 0.003649575635790825, cosine=0.9999998807907104
gpt_neox.layers.2.attention.query_key_value.weight grad match: False Maxdiff: 0.0003069639205932617, relativediff: 0.0019483822397887707, cosine=0.9999999403953552
gpt_neox.layers.2.attention.query_key_value.bias grad match: False Maxdiff: 3.221631050109863e-05, relativediff: nan, cosine=1.0
gpt_neox.layers.2.attention.dense.weight grad match: False Maxdiff: 8.071213960647583e-05, relativediff: 0.001911609317176044, cosine=0.9999999403953552
gpt_neox.layers.2.attention.dense.bias grad match: False Maxdiff: 6.04093074798584e-05, relativediff: 0.0007769144140183926, cosine=1.0000001192092896
gpt_neox.layers.2.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.0002313852310180664, relativediff: nan, cosine=0.9999998807907104
gpt_neox.layers.2.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.1599233150482178e-05, relativediff: nan, cosine=0.9999999403953552
gpt_neox.layers.2.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 0.0001385807991027832, relativediff: nan, cosine=1.0
gpt_neox.layers.2.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 6.04093074798584e-05, relativediff: 0.0007769144140183926, cosine=1.0000001192092896
gpt_neox.layers.3.input_layernorm.weight grad match: False Maxdiff: 3.2275915145874023e-05, relativediff: 0.0011378898052498698, cosine=1.0000001192092896
gpt_neox.layers.3.input_layernorm.bias grad match: False Maxdiff: 1.817941665649414e-05, relativediff: 0.001662743859924376, cosine=0.9999998807907104
gpt_neox.layers.3.post_attention_layernorm.weight grad match: False Maxdiff: 1.9088387489318848e-05, relativediff: 0.0008994439267553389, cosine=1.0
gpt_neox.layers.3.post_attention_layernorm.bias grad match: False Maxdiff: 1.938268542289734e-05, relativediff: 0.0015516605926677585, cosine=0.9999999403953552
gpt_neox.layers.3.attention.query_key_value.weight grad match: False Maxdiff: 0.00041878223419189453, relativediff: 0.0024944001343101263, cosine=1.0
gpt_neox.layers.3.attention.query_key_value.bias grad match: False Maxdiff: 3.2745301723480225e-05, relativediff: nan, cosine=1.0000001192092896
gpt_neox.layers.3.attention.dense.weight grad match: False Maxdiff: 0.00010060705244541168, relativediff: 0.0022460015024989843, cosine=1.0
gpt_neox.layers.3.attention.dense.bias grad match: False Maxdiff: 5.4156407713890076e-05, relativediff: 0.0030462404247373343, cosine=1.0
gpt_neox.layers.3.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.00016620755195617676, relativediff: 0.0018187703099101782, cosine=1.0
gpt_neox.layers.3.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.2606924176216125e-05, relativediff: 0.0009581153281033039, cosine=1.0
gpt_neox.layers.3.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 0.0003941059112548828, relativediff: 0.0019543024245649576, cosine=1.0
gpt_neox.layers.3.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 5.4156407713890076e-05, relativediff: 0.0030462404247373343, cosine=1.0
gpt_neox.layers.4.input_layernorm.weight grad match: False Maxdiff: 9.484589099884033e-06, relativediff: 0.002448969054967165, cosine=0.9999999403953552
gpt_neox.layers.4.input_layernorm.bias grad match: False Maxdiff: 1.4883698895573616e-05, relativediff: 0.0007243411964736879, cosine=0.9999998807907104
gpt_neox.layers.4.post_attention_layernorm.weight grad match: False Maxdiff: 2.366676926612854e-05, relativediff: 0.0014065144350752234, cosine=1.0000001192092896
gpt_neox.layers.4.post_attention_layernorm.bias grad match: False Maxdiff: 1.585017889738083e-05, relativediff: 0.0007562157697975636, cosine=0.9999998807907104
gpt_neox.layers.4.attention.query_key_value.weight grad match: False Maxdiff: 0.0006879568099975586, relativediff: 0.002234308049082756, cosine=1.0
gpt_neox.layers.4.attention.query_key_value.bias grad match: False Maxdiff: 3.507733345031738e-05, relativediff: inf, cosine=1.0000001192092896
gpt_neox.layers.4.attention.dense.weight grad match: False Maxdiff: 6.670504808425903e-05, relativediff: 0.0063442327082157135, cosine=1.0
gpt_neox.layers.4.attention.dense.bias grad match: False Maxdiff: 4.736892879009247e-05, relativediff: 0.0013332836097106338, cosine=1.0
gpt_neox.layers.4.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.00012412108480930328, relativediff: 0.003574290545657277, cosine=1.0
gpt_neox.layers.4.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.1204352378845215e-05, relativediff: 0.001701479428447783, cosine=1.0
gpt_neox.layers.4.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 0.00010700523853302002, relativediff: 0.002418705727905035, cosine=1.0
gpt_neox.layers.4.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 4.736892879009247e-05, relativediff: 0.0013332836097106338, cosine=1.0
gpt_neox.layers.5.input_layernorm.weight grad match: False Maxdiff: 1.722574234008789e-05, relativediff: 0.0006909793592058122, cosine=1.0000001192092896
gpt_neox.layers.5.input_layernorm.bias grad match: False Maxdiff: 2.1550804376602173e-05, relativediff: 0.000694852031301707, cosine=1.0
gpt_neox.layers.5.post_attention_layernorm.weight grad match: False Maxdiff: 1.827254891395569e-05, relativediff: 0.0037720382679253817, cosine=1.0000001192092896
gpt_neox.layers.5.post_attention_layernorm.bias grad match: False Maxdiff: 1.3113021850585938e-05, relativediff: 0.0005836377968080342, cosine=1.0
gpt_neox.layers.5.attention.query_key_value.weight grad match: False Maxdiff: 0.00030303001403808594, relativediff: 0.0020995934028178453, cosine=1.0
gpt_neox.layers.5.attention.query_key_value.bias grad match: False Maxdiff: 2.4404376745224e-05, relativediff: nan, cosine=0.9999998807907104
gpt_neox.layers.5.attention.dense.weight grad match: False Maxdiff: 3.871321678161621e-05, relativediff: 0.002867215545848012, cosine=1.0
gpt_neox.layers.5.attention.dense.bias grad match: False Maxdiff: 3.5896897315979004e-05, relativediff: 0.0045297760516405106, cosine=0.9999999403953552
gpt_neox.layers.5.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.0001646876335144043, relativediff: 0.0037182371597737074, cosine=1.0
gpt_neox.layers.5.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.504885196685791e-05, relativediff: 0.0012974942801520228, cosine=0.9999999403953552
gpt_neox.layers.5.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 9.982287883758545e-05, relativediff: 0.00209479290060699, cosine=0.9999999403953552
gpt_neox.layers.5.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 3.5896897315979004e-05, relativediff: 0.0045297760516405106, cosine=0.9999999403953552
gpt_neox.layers.6.input_layernorm.weight grad match: False Maxdiff: 2.893805503845215e-05, relativediff: 0.0009801711421459913, cosine=1.0
gpt_neox.layers.6.input_layernorm.bias grad match: False Maxdiff: 1.376122236251831e-05, relativediff: 0.0006134507711976767, cosine=1.0
gpt_neox.layers.6.post_attention_layernorm.weight grad match: False Maxdiff: 5.0127506256103516e-05, relativediff: 0.007451085839420557, cosine=1.0
gpt_neox.layers.6.post_attention_layernorm.bias grad match: False Maxdiff: 1.2849457561969757e-05, relativediff: 0.0008298815810121596, cosine=1.0
gpt_neox.layers.6.attention.query_key_value.weight grad match: False Maxdiff: 0.0002187490463256836, relativediff: 0.0022141069639474154, cosine=0.9999999403953552
gpt_neox.layers.6.attention.query_key_value.bias grad match: False Maxdiff: 2.746284008026123e-05, relativediff: nan, cosine=0.9999998807907104
gpt_neox.layers.6.attention.dense.weight grad match: False Maxdiff: 4.579313099384308e-05, relativediff: 0.0017882962711155415, cosine=1.0000001192092896
gpt_neox.layers.6.attention.dense.bias grad match: False Maxdiff: 3.5293400287628174e-05, relativediff: 0.0008495993097312748, cosine=1.0000001192092896
gpt_neox.layers.6.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.00011885911226272583, relativediff: 0.0021896674297749996, cosine=1.0000001192092896
gpt_neox.layers.6.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.1107494831085205e-05, relativediff: 0.0006742423865944147, cosine=1.0000001192092896
gpt_neox.layers.6.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 9.671971201896667e-05, relativediff: 0.00230384455062449, cosine=1.0000001192092896
gpt_neox.layers.6.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 3.5293400287628174e-05, relativediff: 0.0008495993097312748, cosine=1.0000001192092896
gpt_neox.layers.7.input_layernorm.weight grad match: False Maxdiff: 3.8079917430877686e-05, relativediff: 0.002206353936344385, cosine=0.9999999403953552
gpt_neox.layers.7.input_layernorm.bias grad match: False Maxdiff: 1.3178214430809021e-05, relativediff: 0.0006930383387953043, cosine=1.0000001192092896
gpt_neox.layers.7.post_attention_layernorm.weight grad match: False Maxdiff: 1.747533679008484e-05, relativediff: 0.0013376493006944656, cosine=1.0
gpt_neox.layers.7.post_attention_layernorm.bias grad match: False Maxdiff: 1.3608485460281372e-05, relativediff: 0.0008480687974952161, cosine=1.0
gpt_neox.layers.7.attention.query_key_value.weight grad match: False Maxdiff: 0.00023524463176727295, relativediff: 0.002312550786882639, cosine=0.9999999403953552
gpt_neox.layers.7.attention.query_key_value.bias grad match: False Maxdiff: 2.3789703845977783e-05, relativediff: inf, cosine=1.0000001192092896
gpt_neox.layers.7.attention.dense.weight grad match: False Maxdiff: 4.3258070945739746e-05, relativediff: 0.0021618057508021593, cosine=1.0
gpt_neox.layers.7.attention.dense.bias grad match: False Maxdiff: 3.759562969207764e-05, relativediff: 0.0006909780204296112, cosine=0.9999999403953552
gpt_neox.layers.7.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.00015154480934143066, relativediff: 0.0025800946168601513, cosine=0.9999999403953552
gpt_neox.layers.7.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.353265881538391e-05, relativediff: 0.0012763891136273742, cosine=1.0
gpt_neox.layers.7.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 7.880479097366333e-05, relativediff: 0.0030845303554087877, cosine=1.0
gpt_neox.layers.7.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 3.759562969207764e-05, relativediff: 0.0006909780204296112, cosine=0.9999999403953552
gpt_neox.layers.8.input_layernorm.weight grad match: False Maxdiff: 2.9072165489196777e-05, relativediff: 0.002014160854741931, cosine=0.9999998807907104
gpt_neox.layers.8.input_layernorm.bias grad match: False Maxdiff: 5.7248398661613464e-06, relativediff: 0.0009236705373041332, cosine=1.0
gpt_neox.layers.8.post_attention_layernorm.weight grad match: False Maxdiff: 3.810226917266846e-05, relativediff: 0.0014925239374861121, cosine=0.9999998807907104
gpt_neox.layers.8.post_attention_layernorm.bias grad match: False Maxdiff: 1.50240957736969e-05, relativediff: 0.0015558138256892562, cosine=0.9999998807907104
gpt_neox.layers.8.attention.query_key_value.weight grad match: False Maxdiff: 0.0007872581481933594, relativediff: 0.009567000903189182, cosine=0.9999999403953552
gpt_neox.layers.8.attention.query_key_value.bias grad match: False Maxdiff: 9.194482117891312e-06, relativediff: nan, cosine=0.9999998807907104
gpt_neox.layers.8.attention.dense.weight grad match: False Maxdiff: 5.197152495384216e-05, relativediff: 0.0020594012457877398, cosine=1.0
gpt_neox.layers.8.attention.dense.bias grad match: False Maxdiff: 2.5920569896697998e-05, relativediff: 0.0008061685948632658, cosine=0.9999999403953552
gpt_neox.layers.8.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.00012427568435668945, relativediff: 0.004034997895359993, cosine=0.9999999403953552
gpt_neox.layers.8.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.0071864128112793e-05, relativediff: 0.0011649620719254017, cosine=1.0
gpt_neox.layers.8.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 9.659864008426666e-05, relativediff: 0.0026284982450306416, cosine=0.9999998807907104
gpt_neox.layers.8.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 2.5920569896697998e-05, relativediff: 0.0008061685948632658, cosine=0.9999999403953552
gpt_neox.layers.9.input_layernorm.weight grad match: False Maxdiff: 6.195157766342163e-06, relativediff: 0.0019708990585058928, cosine=0.9999998807907104
gpt_neox.layers.9.input_layernorm.bias grad match: False Maxdiff: 5.588633939623833e-06, relativediff: 0.0013916906900703907, cosine=0.9999999403953552
gpt_neox.layers.9.post_attention_layernorm.weight grad match: False Maxdiff: 6.753206253051758e-05, relativediff: 0.002181797521188855, cosine=1.0000001192092896
gpt_neox.layers.9.post_attention_layernorm.bias grad match: False Maxdiff: 1.5349360182881355e-05, relativediff: 0.0013889750698581338, cosine=1.0
gpt_neox.layers.9.attention.query_key_value.weight grad match: False Maxdiff: 0.0005746409296989441, relativediff: 0.014891190454363823, cosine=0.9999997019767761
gpt_neox.layers.9.attention.query_key_value.bias grad match: False Maxdiff: 9.512528777122498e-06, relativediff: inf, cosine=0.9999998807907104
gpt_neox.layers.9.attention.dense.weight grad match: False Maxdiff: 2.2362452000379562e-05, relativediff: 0.001138008083216846, cosine=1.0000001192092896
gpt_neox.layers.9.attention.dense.bias grad match: False Maxdiff: 2.1470244973897934e-05, relativediff: 0.0005539876292459667, cosine=0.9999999403953552
gpt_neox.layers.9.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.0002691000699996948, relativediff: 0.004176828544586897, cosine=1.0
gpt_neox.layers.9.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 5.860254168510437e-05, relativediff: 0.029437750577926636, cosine=0.9999997615814209
gpt_neox.layers.9.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 0.00012803077697753906, relativediff: 0.002299846848472953, cosine=1.0000001192092896
gpt_neox.layers.9.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 2.1470244973897934e-05, relativediff: 0.0005539876292459667, cosine=0.9999999403953552
gpt_neox.layers.10.input_layernorm.weight grad match: False Maxdiff: 0.00041857361793518066, relativediff: 0.013254483230412006, cosine=1.0000001192092896
gpt_neox.layers.10.input_layernorm.bias grad match: False Maxdiff: 3.1903618946671486e-06, relativediff: 0.0007093409076333046, cosine=1.0
gpt_neox.layers.10.post_attention_layernorm.weight grad match: False Maxdiff: 3.400444984436035e-05, relativediff: 0.0071408688090741634, cosine=1.0000001192092896
gpt_neox.layers.10.post_attention_layernorm.bias grad match: False Maxdiff: 1.761317253112793e-05, relativediff: 0.0008480017422698438, cosine=0.9999999403953552
gpt_neox.layers.10.attention.query_key_value.weight grad match: False Maxdiff: 0.004154682159423828, relativediff: 0.009285532869398594, cosine=0.999999463558197
gpt_neox.layers.10.attention.query_key_value.bias grad match: False Maxdiff: 5.601905286312103e-06, relativediff: 0.1024150550365448, cosine=0.9999998807907104
gpt_neox.layers.10.attention.dense.weight grad match: False Maxdiff: 2.562534064054489e-05, relativediff: 0.000457679241662845, cosine=1.0
gpt_neox.layers.10.attention.dense.bias grad match: False Maxdiff: 2.1981075406074524e-05, relativediff: 0.00017686102364677936, cosine=1.0000001192092896
gpt_neox.layers.10.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.00013469159603118896, relativediff: 0.0026225948240607977, cosine=0.9999997615814209
gpt_neox.layers.10.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.5781802833080292e-05, relativediff: 0.0007138229557313025, cosine=0.9999999403953552
gpt_neox.layers.10.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 0.00019641220569610596, relativediff: 0.0010898219188675284, cosine=0.9999999403953552
gpt_neox.layers.10.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 2.1981075406074524e-05, relativediff: 0.00017686102364677936, cosine=1.0000001192092896
gpt_neox.layers.11.input_layernorm.weight grad match: False Maxdiff: 1.2040138244628906e-05, relativediff: 0.0023908887524157763, cosine=1.0000001192092896
gpt_neox.layers.11.input_layernorm.bias grad match: False Maxdiff: 2.6496127247810364e-06, relativediff: 0.0017001176020130515, cosine=1.0
gpt_neox.layers.11.post_attention_layernorm.weight grad match: False Maxdiff: 4.336237907409668e-06, relativediff: 0.00013862353807780892, cosine=1.0
gpt_neox.layers.11.post_attention_layernorm.bias grad match: False Maxdiff: 1.0542571544647217e-06, relativediff: 3.621343421400525e-05, cosine=0.9999998807907104
gpt_neox.layers.11.attention.query_key_value.weight grad match: False Maxdiff: 0.02075481414794922, relativediff: 0.06359207630157471, cosine=0.9999991655349731
gpt_neox.layers.11.attention.query_key_value.bias grad match: False Maxdiff: 2.696644514799118e-06, relativediff: 0.25537019968032837, cosine=1.0000001192092896
gpt_neox.layers.11.attention.dense.weight grad match: False Maxdiff: 2.91336327791214e-05, relativediff: 0.0007787492941133678, cosine=0.9999998807907104
gpt_neox.layers.11.attention.dense.bias grad match: False Maxdiff: 1.8852879293262959e-06, relativediff: 0.00015919502766337246, cosine=0.9999998807907104
gpt_neox.layers.11.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 2.5608460418879986e-05, relativediff: 0.00040599319618195295, cosine=0.9999999403953552
gpt_neox.layers.11.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.766493707895279e-06, relativediff: 0.00014488919987343252, cosine=1.0
gpt_neox.layers.11.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 2.9437243938446045e-05, relativediff: 0.0002886583679355681, cosine=1.0000001192092896
gpt_neox.layers.11.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 1.8852879293262959e-06, relativediff: 0.00015919502766337246, cosine=0.9999998807907104
gpt_neox.final_layer_norm.weight grad match: False Maxdiff: 1.341104507446289e-06, relativediff: 2.6080595034727594e-06, cosine=1.0000001192092896
gpt_neox.final_layer_norm.bias grad match: True Maxdiff: 0.0, relativediff: 0.0, cosine=1.0
embed_out.weight grad match: True Maxdiff: 1.3242242857813835e-09, relativediff: 1.7900250668390072e-06, cosine=0.9999999403953552
lengstrom commented 1 year ago

Thank you for looking into this and for the quick response! Sorry am confused - what did you change to get the gradients to match in NeoX?

fxmarty commented 1 year ago

@lengstrom I changed nothing. I used a simplier script (tried to change the model on yours but got into errors) and can not reproduce the issue. Do you see anything wrong with my script? A difference is the device, I did not try on cuda (which may in turn use the flash attention or memory efficient attention kernel).

lengstrom commented 1 year ago

@fxmarty : Thank you for the help, and for the helpful pointer. Changing three basic things in your script I am able to replicate the behavior I showed.

  1. Put model / inputs on GPU (I have an 80GB A100)
  2. Use the 70m model instead of 160m model
  3. 3x longer input (from ~7 tokens to 22 tokens)

Not doing any of these three things will make you miss the behavior I showed. Furthermore, I found that changing the loss function to cross entropy instead of the mean of the output logits also exacerbates the issue. Changing the mode from .train() to .eval() on the models does not change the behavior either. Here's a diff (and accompanying gist of the problem) showing the behavior: https://gist.github.com/lengstrom/aad52170c69a0f9502b7332ed22c4048/revisions

lengstrom commented 1 year ago

It looks like memory efficient SDP is the issue - when I disable it, the behavior I found goes away:

torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
fxmarty commented 1 year ago

Thank you for investigating it, I can reproduce the issue (even on EleutherAI/pythia-160m, to a lesser extent)! It indeed looks like the mem-efficient attention kernel is at fault. I can't reproduce with a minimal pytorch example yet - so haven't pinpoint yet where the issue is.

import torch
import torch.nn as nn
import math
import copy

class ModelVanilla(nn.Module):
    def __init__(self):
        super().__init__()

        self.q = nn.Parameter(torch.rand(1, 8, 22, 64))
        self.k = nn.Parameter(torch.rand(1, 8, 22, 64))
        self.v = nn.Parameter(torch.rand(1, 8, 22, 64))

    def forward(self, x):
        L = self.q.shape[-2]
        S = self.k.shape[-2]

        scale_factor = 1 / math.sqrt(self.q.size(-1))
        attn_mask = torch.ones(L, S).tril(diagonal=0).to(x.device)
        attn_mask = attn_mask.masked_fill(attn_mask == 0, -float('inf'))

        attn_weight = torch.softmax((self.q @ self.k.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1)

        # attn_weight = torch.dropout(attn_weight, dropout_p)

        res = attn_weight @ self.v
        res = res * x
        return res

class ModelSDPA(nn.Module):
    def __init__(self):
        super().__init__()

        self.q = nn.Parameter(torch.rand(1, 8, 22, 64))
        self.k = nn.Parameter(torch.rand(1, 8, 22, 64))
        self.v = nn.Parameter(torch.rand(1, 8, 22, 64))

    def forward(self, x):
        res = torch.nn.functional.scaled_dot_product_attention(self.q, self.k, self.v, is_causal=True, dropout_p=0.0) * x
        return res

model_vanilla = ModelVanilla().to("cuda")
model_vanilla = model_vanilla.train()

model_sdpa = ModelSDPA().to("cuda")
model_sdpa = model_sdpa.train()

model_sdpa.q.data = copy.deepcopy(model_vanilla.q.data)
model_sdpa.k.data = copy.deepcopy(model_vanilla.k.data)
model_sdpa.v.data = copy.deepcopy(model_vanilla.v.data)

inp = torch.Tensor([3.]).to("cuda", torch.float32)

torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(False)

res_vanilla = model_vanilla(inp)
res_sdpa = model_sdpa(inp)

print("Res match:", torch.allclose(res_vanilla, res_sdpa), "Maxdiff:", torch.max(torch.abs(res_vanilla - res_sdpa)).item())

loss_vanilla = res_vanilla.mean()
loss_sdpa = res_sdpa.mean()

print("Loss vanilla:", loss_vanilla)
print("Loss SDPA:", loss_sdpa)

loss_vanilla.backward()
loss_sdpa.backward()

print("Loss match:", torch.allclose(loss_vanilla, loss_sdpa), "Maxdiff:", torch.max(torch.abs(loss_vanilla - loss_sdpa)).item())

for name, param_vanilla in model_vanilla.named_parameters():
    param_sdpa = getattr(model_sdpa, name)

    maxdiff = torch.max(torch.abs(param_vanilla.grad - param_sdpa.grad)).item()
    relativediff = torch.mean(torch.abs(param_vanilla.grad - param_sdpa.grad) / torch.abs(param_vanilla.grad))
    cosine = torch.nn.functional.cosine_similarity(param_vanilla.grad.flatten(), param_sdpa.grad.flatten(), dim=0)
    print(f"{name} grad match:", torch.allclose(param_vanilla.grad, param_sdpa.grad), f"Maxdiff: {maxdiff}, relativediff: {relativediff}, cosine={cosine}")
lengstrom commented 1 year ago

interesting, thanks for looking into it more! if it is informative: using the Flash SDP I found the same issue with the 70m model (albeit - I had to do a lot of float/half casting to get the NeoX model to even run without errors).

fxmarty commented 1 year ago

I can actually only reproduce with mem-efficient kernel (flash fp16 is fine to me):

edit: did a mistake

lengstrom commented 1 year ago

cool! this is not what I found - but I compared the gradients of vanilla fp32 models with flash attention fp16 models so that could be why. sleeping now but will take a look tomorrow..

lengstrom commented 1 year ago

curiosity got the better of me, its weird that you get consistent half prec performance - when I modify both models to use half precision in the script above (i.e., from https://github.com/huggingface/optimum/issues/1091#issuecomment-1588040170) I get the same kinds of cosine similarity issues: (regardless of whether or not the vanilla model is fp16 or fp32) - https://gist.github.com/lengstrom/39267bbeb9a40ec68f21e71fb60b1b2b

fxmarty commented 1 year ago

My bad - was just an error in copy/paste on the plot: image