Open lengstrom opened 1 year ago
Intriguingly, the logits (i.e., just the result of the forward pass) are super close to one another:
def get_logits(m):
fmodel, fweights, fbuffs, names = make_functional_with_buffers(m)
logits_fn = ch.func.vmap(output_logits, in_dims=(None, None, None, 0))
logits = logits_fn(fmodel, fweights, fbuffs, input_ids)
return logits
logits = get_logits(model)
logits_btfm = get_logits(model_bettertfmer)
ch.nn.functional.cosine_similarity(logits.reshape(bs, -1), logits_btfm.reshape(bs, -1), dim=1)
yields
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
device='cuda:0', grad_fn=<SumBackward1>)
Thank you - this is rather critical, apologies this should have been tested upfront. I had a try reproducing with just torch.nn.functional.scaled_dot_product_attention
and the boiler plate code from PyTorch doc - gradients indeed match there. So the issue is really in our implementation. Let me have a look, debug and add tests for this.
I can't reproduce the issue with gpt2 and gpt-neo type models, but can with gpt-neox, where the gradients for attention.query_key_value.bias
differ by a large relative margin. We have the same in your logs, things get fuzzy from there:
cosine sim for gpt_neox.layers.5.post_attention_layernorm.weight: tensor(1.) tensor(8.2333e-08)
cosine sim for gpt_neox.layers.5.post_attention_layernorm.bias: tensor(1.) tensor(8.3089e-08)
cosine sim for gpt_neox.layers.5.attention.query_key_value.weight: tensor(0.9972) tensor(0.0091)
cosine sim for gpt_neox.layers.5.attention.query_key_value.bias: tensor(0.7171) tensor(0.4962)
cosine sim for gpt_neox.layers.5.attention.dense.weight: tensor(1.0000) tensor(4.8597e-07)
cosine sim for gpt_neox.layers.5.attention.dense.bias: tensor(1.) tensor(8.8518e-08)
cosine sim for gpt_neox.layers.5.mlp.dense_h_to_4h.weight: tensor(1.) tensor(9.2877e-08)
cosine sim for gpt_neox.layers.5.mlp.dense_h_to_4h.bias: tensor(1.) tensor(8.5622e-08)
cosine sim for gpt_neox.layers.5.mlp.dense_4h_to_h.weight: tensor(1.) tensor(8.9708e-08)
cosine sim for gpt_neox.layers.5.mlp.dense_4h_to_h.bias: tensor(1.) tensor(8.8518e-08)
cosine sim for gpt_neox.final_layer_norm.weight: tensor(1.) tensor(8.7570e-08)
cosine sim for gpt_neox.final_layer_norm.bias: tensor(1.) tensor(8.9339e-08)
cosine sim for embed_out.weight: tensor(0.9811) tensor(0.1333)```
Here's a simplier script (yours does not work when I change the model):
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, GPT2LMHeadModel
from optimum.bettertransformer import BetterTransformer
import torch
import copy
import functools
def recurse_getattr(obj, attr: str):
def _getattr(obj, attr):
return getattr(obj, attr)
return functools.reduce(_getattr, [obj] + attr.split("."))
model_id = "EleutherAI/pythia-160m"
model_vanilla = AutoModelForCausalLM.from_pretrained(model_id)
model_sdpa = AutoModelForCausalLM.from_pretrained(model_id)
for name, param_vanilla in model_vanilla.named_parameters():
param_sdpa = recurse_getattr(model_sdpa, name)
maxdiff = torch.max(torch.abs(param_vanilla - param_sdpa)).item()
relativediff = torch.mean(torch.abs(param_vanilla - param_sdpa) / torch.abs(param_vanilla))
print(f"{name} param match:", torch.allclose(param_vanilla, param_sdpa), f"Maxdiff: {maxdiff}, relativediff: {relativediff}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
model_sdpa = BetterTransformer.transform(model_sdpa)
inp = tokenizer(["This is just to test gradients"] * 16, return_tensors="pt")
model_vanilla = model_vanilla.train()
model_sdpa = model_sdpa.train()
model_vanilla.zero_grad()
model_sdpa.zero_grad()
model_vanilla = model_vanilla.eval()
model_sdpa = model_sdpa.eval()
torch.set_printoptions(threshold=100000000000)
res_vanilla = model_vanilla(**inp).logits
print("\n\n\n\n RUNNING SDPA")
res_sdpa = model_sdpa(**inp).logits
print("Res match:", torch.allclose(res_vanilla, res_sdpa), "Maxdiff:", torch.max(torch.abs(res_vanilla - res_sdpa)).item())
loss_vanilla = res_vanilla.mean()
loss_sdpa = res_sdpa.mean()
loss_vanilla.backward()
loss_sdpa.backward()
print("Loss match:", torch.allclose(loss_vanilla, loss_sdpa), "Maxdiff:", torch.max(torch.abs(loss_vanilla - loss_sdpa)).item())
for name, param_vanilla in model_vanilla.named_parameters():
param_sdpa = recurse_getattr(model_sdpa, name)
maxdiff = torch.max(torch.abs(param_vanilla.grad - param_sdpa.grad)).item()
relativediff = torch.mean(torch.abs(param_vanilla.grad - param_sdpa.grad) / torch.abs(param_vanilla.grad))
cosine = torch.nn.functional.cosine_similarity(param_vanilla.grad.flatten(), param_sdpa.grad.flatten(), dim=0)
print(f"{name} grad match:", torch.allclose(param_vanilla.grad, param_sdpa.grad), f"Maxdiff: {maxdiff}, relativediff: {relativediff}, cosine={cosine}")
Actually I don't think there is any issue either with gpt-neox. The only fuzzy bit I see are these gradients, but we are 10e-14
here.
Otherwise, printing the cosine similarity (updated the code above), I have fine results:
gpt_neox.embed_in.weight grad match: False Maxdiff: 0.0004773736000061035, relativediff: nan, cosine=1.0
gpt_neox.layers.0.input_layernorm.weight grad match: False Maxdiff: 6.480515003204346e-05, relativediff: 0.0008785824757069349, cosine=1.0
gpt_neox.layers.0.input_layernorm.bias grad match: False Maxdiff: 2.306140959262848e-05, relativediff: 0.002862692577764392, cosine=1.0
gpt_neox.layers.0.post_attention_layernorm.weight grad match: False Maxdiff: 2.3115426301956177e-05, relativediff: 0.0007874640286900103, cosine=1.0000001192092896
gpt_neox.layers.0.post_attention_layernorm.bias grad match: False Maxdiff: 1.6219913959503174e-05, relativediff: 0.0012927413918077946, cosine=1.0000001192092896
gpt_neox.layers.0.attention.query_key_value.weight grad match: False Maxdiff: 0.00034546852111816406, relativediff: 0.0016394504345953465, cosine=1.0
gpt_neox.layers.0.attention.query_key_value.bias grad match: False Maxdiff: 3.819167613983154e-05, relativediff: nan, cosine=1.0000001192092896
gpt_neox.layers.0.attention.dense.weight grad match: False Maxdiff: 7.398426532745361e-05, relativediff: 0.001369286677800119, cosine=0.9999998211860657
gpt_neox.layers.0.attention.dense.bias grad match: False Maxdiff: 7.408857345581055e-05, relativediff: 0.0011199495056644082, cosine=1.0000001192092896
gpt_neox.layers.0.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.0016028881072998047, relativediff: 0.001413618098013103, cosine=1.000000238418579
gpt_neox.layers.0.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 4.233419895172119e-05, relativediff: 0.0007718717097304761, cosine=1.0
gpt_neox.layers.0.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 0.00015889108180999756, relativediff: 0.001413600635714829, cosine=0.9999999403953552
gpt_neox.layers.0.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 7.408857345581055e-05, relativediff: 0.0011199495056644082, cosine=1.0000001192092896
gpt_neox.layers.1.input_layernorm.weight grad match: False Maxdiff: 1.6685575246810913e-05, relativediff: 0.0006055051344446838, cosine=0.9999999403953552
gpt_neox.layers.1.input_layernorm.bias grad match: False Maxdiff: 1.432374119758606e-05, relativediff: 0.0005490509211085737, cosine=0.9999999403953552
gpt_neox.layers.1.post_attention_layernorm.weight grad match: False Maxdiff: 3.9480626583099365e-05, relativediff: 0.001803021994419396, cosine=0.9999999403953552
gpt_neox.layers.1.post_attention_layernorm.bias grad match: False Maxdiff: 1.3155164197087288e-05, relativediff: 0.0043691834434866905, cosine=1.0
gpt_neox.layers.1.attention.query_key_value.weight grad match: False Maxdiff: 0.00036919116973876953, relativediff: 0.0015037399716675282, cosine=0.9999999403953552
gpt_neox.layers.1.attention.query_key_value.bias grad match: False Maxdiff: 3.212690353393555e-05, relativediff: nan, cosine=1.0000001192092896
gpt_neox.layers.1.attention.dense.weight grad match: False Maxdiff: 0.00013083219528198242, relativediff: 0.0013935145689174533, cosine=0.9999999403953552
gpt_neox.layers.1.attention.dense.bias grad match: False Maxdiff: 6.993114948272705e-05, relativediff: 0.0010215052170678973, cosine=0.9999999403953552
gpt_neox.layers.1.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.0004469156265258789, relativediff: nan, cosine=1.0
gpt_neox.layers.1.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.7128495275974274e-05, relativediff: nan, cosine=0.9999999403953552
gpt_neox.layers.1.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 7.201731204986572e-05, relativediff: nan, cosine=1.0
gpt_neox.layers.1.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 6.993114948272705e-05, relativediff: 0.0010215052170678973, cosine=0.9999999403953552
gpt_neox.layers.2.input_layernorm.weight grad match: False Maxdiff: 1.7311424016952515e-05, relativediff: 0.001300215837545693, cosine=1.0
gpt_neox.layers.2.input_layernorm.bias grad match: False Maxdiff: 1.7337501049041748e-05, relativediff: 0.001227947068400681, cosine=1.0
gpt_neox.layers.2.post_attention_layernorm.weight grad match: False Maxdiff: 3.092736005783081e-05, relativediff: 0.0005665196222253144, cosine=1.0
gpt_neox.layers.2.post_attention_layernorm.bias grad match: False Maxdiff: 1.5107914805412292e-05, relativediff: 0.003649575635790825, cosine=0.9999998807907104
gpt_neox.layers.2.attention.query_key_value.weight grad match: False Maxdiff: 0.0003069639205932617, relativediff: 0.0019483822397887707, cosine=0.9999999403953552
gpt_neox.layers.2.attention.query_key_value.bias grad match: False Maxdiff: 3.221631050109863e-05, relativediff: nan, cosine=1.0
gpt_neox.layers.2.attention.dense.weight grad match: False Maxdiff: 8.071213960647583e-05, relativediff: 0.001911609317176044, cosine=0.9999999403953552
gpt_neox.layers.2.attention.dense.bias grad match: False Maxdiff: 6.04093074798584e-05, relativediff: 0.0007769144140183926, cosine=1.0000001192092896
gpt_neox.layers.2.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.0002313852310180664, relativediff: nan, cosine=0.9999998807907104
gpt_neox.layers.2.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.1599233150482178e-05, relativediff: nan, cosine=0.9999999403953552
gpt_neox.layers.2.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 0.0001385807991027832, relativediff: nan, cosine=1.0
gpt_neox.layers.2.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 6.04093074798584e-05, relativediff: 0.0007769144140183926, cosine=1.0000001192092896
gpt_neox.layers.3.input_layernorm.weight grad match: False Maxdiff: 3.2275915145874023e-05, relativediff: 0.0011378898052498698, cosine=1.0000001192092896
gpt_neox.layers.3.input_layernorm.bias grad match: False Maxdiff: 1.817941665649414e-05, relativediff: 0.001662743859924376, cosine=0.9999998807907104
gpt_neox.layers.3.post_attention_layernorm.weight grad match: False Maxdiff: 1.9088387489318848e-05, relativediff: 0.0008994439267553389, cosine=1.0
gpt_neox.layers.3.post_attention_layernorm.bias grad match: False Maxdiff: 1.938268542289734e-05, relativediff: 0.0015516605926677585, cosine=0.9999999403953552
gpt_neox.layers.3.attention.query_key_value.weight grad match: False Maxdiff: 0.00041878223419189453, relativediff: 0.0024944001343101263, cosine=1.0
gpt_neox.layers.3.attention.query_key_value.bias grad match: False Maxdiff: 3.2745301723480225e-05, relativediff: nan, cosine=1.0000001192092896
gpt_neox.layers.3.attention.dense.weight grad match: False Maxdiff: 0.00010060705244541168, relativediff: 0.0022460015024989843, cosine=1.0
gpt_neox.layers.3.attention.dense.bias grad match: False Maxdiff: 5.4156407713890076e-05, relativediff: 0.0030462404247373343, cosine=1.0
gpt_neox.layers.3.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.00016620755195617676, relativediff: 0.0018187703099101782, cosine=1.0
gpt_neox.layers.3.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.2606924176216125e-05, relativediff: 0.0009581153281033039, cosine=1.0
gpt_neox.layers.3.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 0.0003941059112548828, relativediff: 0.0019543024245649576, cosine=1.0
gpt_neox.layers.3.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 5.4156407713890076e-05, relativediff: 0.0030462404247373343, cosine=1.0
gpt_neox.layers.4.input_layernorm.weight grad match: False Maxdiff: 9.484589099884033e-06, relativediff: 0.002448969054967165, cosine=0.9999999403953552
gpt_neox.layers.4.input_layernorm.bias grad match: False Maxdiff: 1.4883698895573616e-05, relativediff: 0.0007243411964736879, cosine=0.9999998807907104
gpt_neox.layers.4.post_attention_layernorm.weight grad match: False Maxdiff: 2.366676926612854e-05, relativediff: 0.0014065144350752234, cosine=1.0000001192092896
gpt_neox.layers.4.post_attention_layernorm.bias grad match: False Maxdiff: 1.585017889738083e-05, relativediff: 0.0007562157697975636, cosine=0.9999998807907104
gpt_neox.layers.4.attention.query_key_value.weight grad match: False Maxdiff: 0.0006879568099975586, relativediff: 0.002234308049082756, cosine=1.0
gpt_neox.layers.4.attention.query_key_value.bias grad match: False Maxdiff: 3.507733345031738e-05, relativediff: inf, cosine=1.0000001192092896
gpt_neox.layers.4.attention.dense.weight grad match: False Maxdiff: 6.670504808425903e-05, relativediff: 0.0063442327082157135, cosine=1.0
gpt_neox.layers.4.attention.dense.bias grad match: False Maxdiff: 4.736892879009247e-05, relativediff: 0.0013332836097106338, cosine=1.0
gpt_neox.layers.4.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.00012412108480930328, relativediff: 0.003574290545657277, cosine=1.0
gpt_neox.layers.4.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.1204352378845215e-05, relativediff: 0.001701479428447783, cosine=1.0
gpt_neox.layers.4.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 0.00010700523853302002, relativediff: 0.002418705727905035, cosine=1.0
gpt_neox.layers.4.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 4.736892879009247e-05, relativediff: 0.0013332836097106338, cosine=1.0
gpt_neox.layers.5.input_layernorm.weight grad match: False Maxdiff: 1.722574234008789e-05, relativediff: 0.0006909793592058122, cosine=1.0000001192092896
gpt_neox.layers.5.input_layernorm.bias grad match: False Maxdiff: 2.1550804376602173e-05, relativediff: 0.000694852031301707, cosine=1.0
gpt_neox.layers.5.post_attention_layernorm.weight grad match: False Maxdiff: 1.827254891395569e-05, relativediff: 0.0037720382679253817, cosine=1.0000001192092896
gpt_neox.layers.5.post_attention_layernorm.bias grad match: False Maxdiff: 1.3113021850585938e-05, relativediff: 0.0005836377968080342, cosine=1.0
gpt_neox.layers.5.attention.query_key_value.weight grad match: False Maxdiff: 0.00030303001403808594, relativediff: 0.0020995934028178453, cosine=1.0
gpt_neox.layers.5.attention.query_key_value.bias grad match: False Maxdiff: 2.4404376745224e-05, relativediff: nan, cosine=0.9999998807907104
gpt_neox.layers.5.attention.dense.weight grad match: False Maxdiff: 3.871321678161621e-05, relativediff: 0.002867215545848012, cosine=1.0
gpt_neox.layers.5.attention.dense.bias grad match: False Maxdiff: 3.5896897315979004e-05, relativediff: 0.0045297760516405106, cosine=0.9999999403953552
gpt_neox.layers.5.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.0001646876335144043, relativediff: 0.0037182371597737074, cosine=1.0
gpt_neox.layers.5.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.504885196685791e-05, relativediff: 0.0012974942801520228, cosine=0.9999999403953552
gpt_neox.layers.5.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 9.982287883758545e-05, relativediff: 0.00209479290060699, cosine=0.9999999403953552
gpt_neox.layers.5.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 3.5896897315979004e-05, relativediff: 0.0045297760516405106, cosine=0.9999999403953552
gpt_neox.layers.6.input_layernorm.weight grad match: False Maxdiff: 2.893805503845215e-05, relativediff: 0.0009801711421459913, cosine=1.0
gpt_neox.layers.6.input_layernorm.bias grad match: False Maxdiff: 1.376122236251831e-05, relativediff: 0.0006134507711976767, cosine=1.0
gpt_neox.layers.6.post_attention_layernorm.weight grad match: False Maxdiff: 5.0127506256103516e-05, relativediff: 0.007451085839420557, cosine=1.0
gpt_neox.layers.6.post_attention_layernorm.bias grad match: False Maxdiff: 1.2849457561969757e-05, relativediff: 0.0008298815810121596, cosine=1.0
gpt_neox.layers.6.attention.query_key_value.weight grad match: False Maxdiff: 0.0002187490463256836, relativediff: 0.0022141069639474154, cosine=0.9999999403953552
gpt_neox.layers.6.attention.query_key_value.bias grad match: False Maxdiff: 2.746284008026123e-05, relativediff: nan, cosine=0.9999998807907104
gpt_neox.layers.6.attention.dense.weight grad match: False Maxdiff: 4.579313099384308e-05, relativediff: 0.0017882962711155415, cosine=1.0000001192092896
gpt_neox.layers.6.attention.dense.bias grad match: False Maxdiff: 3.5293400287628174e-05, relativediff: 0.0008495993097312748, cosine=1.0000001192092896
gpt_neox.layers.6.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.00011885911226272583, relativediff: 0.0021896674297749996, cosine=1.0000001192092896
gpt_neox.layers.6.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.1107494831085205e-05, relativediff: 0.0006742423865944147, cosine=1.0000001192092896
gpt_neox.layers.6.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 9.671971201896667e-05, relativediff: 0.00230384455062449, cosine=1.0000001192092896
gpt_neox.layers.6.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 3.5293400287628174e-05, relativediff: 0.0008495993097312748, cosine=1.0000001192092896
gpt_neox.layers.7.input_layernorm.weight grad match: False Maxdiff: 3.8079917430877686e-05, relativediff: 0.002206353936344385, cosine=0.9999999403953552
gpt_neox.layers.7.input_layernorm.bias grad match: False Maxdiff: 1.3178214430809021e-05, relativediff: 0.0006930383387953043, cosine=1.0000001192092896
gpt_neox.layers.7.post_attention_layernorm.weight grad match: False Maxdiff: 1.747533679008484e-05, relativediff: 0.0013376493006944656, cosine=1.0
gpt_neox.layers.7.post_attention_layernorm.bias grad match: False Maxdiff: 1.3608485460281372e-05, relativediff: 0.0008480687974952161, cosine=1.0
gpt_neox.layers.7.attention.query_key_value.weight grad match: False Maxdiff: 0.00023524463176727295, relativediff: 0.002312550786882639, cosine=0.9999999403953552
gpt_neox.layers.7.attention.query_key_value.bias grad match: False Maxdiff: 2.3789703845977783e-05, relativediff: inf, cosine=1.0000001192092896
gpt_neox.layers.7.attention.dense.weight grad match: False Maxdiff: 4.3258070945739746e-05, relativediff: 0.0021618057508021593, cosine=1.0
gpt_neox.layers.7.attention.dense.bias grad match: False Maxdiff: 3.759562969207764e-05, relativediff: 0.0006909780204296112, cosine=0.9999999403953552
gpt_neox.layers.7.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.00015154480934143066, relativediff: 0.0025800946168601513, cosine=0.9999999403953552
gpt_neox.layers.7.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.353265881538391e-05, relativediff: 0.0012763891136273742, cosine=1.0
gpt_neox.layers.7.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 7.880479097366333e-05, relativediff: 0.0030845303554087877, cosine=1.0
gpt_neox.layers.7.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 3.759562969207764e-05, relativediff: 0.0006909780204296112, cosine=0.9999999403953552
gpt_neox.layers.8.input_layernorm.weight grad match: False Maxdiff: 2.9072165489196777e-05, relativediff: 0.002014160854741931, cosine=0.9999998807907104
gpt_neox.layers.8.input_layernorm.bias grad match: False Maxdiff: 5.7248398661613464e-06, relativediff: 0.0009236705373041332, cosine=1.0
gpt_neox.layers.8.post_attention_layernorm.weight grad match: False Maxdiff: 3.810226917266846e-05, relativediff: 0.0014925239374861121, cosine=0.9999998807907104
gpt_neox.layers.8.post_attention_layernorm.bias grad match: False Maxdiff: 1.50240957736969e-05, relativediff: 0.0015558138256892562, cosine=0.9999998807907104
gpt_neox.layers.8.attention.query_key_value.weight grad match: False Maxdiff: 0.0007872581481933594, relativediff: 0.009567000903189182, cosine=0.9999999403953552
gpt_neox.layers.8.attention.query_key_value.bias grad match: False Maxdiff: 9.194482117891312e-06, relativediff: nan, cosine=0.9999998807907104
gpt_neox.layers.8.attention.dense.weight grad match: False Maxdiff: 5.197152495384216e-05, relativediff: 0.0020594012457877398, cosine=1.0
gpt_neox.layers.8.attention.dense.bias grad match: False Maxdiff: 2.5920569896697998e-05, relativediff: 0.0008061685948632658, cosine=0.9999999403953552
gpt_neox.layers.8.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.00012427568435668945, relativediff: 0.004034997895359993, cosine=0.9999999403953552
gpt_neox.layers.8.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.0071864128112793e-05, relativediff: 0.0011649620719254017, cosine=1.0
gpt_neox.layers.8.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 9.659864008426666e-05, relativediff: 0.0026284982450306416, cosine=0.9999998807907104
gpt_neox.layers.8.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 2.5920569896697998e-05, relativediff: 0.0008061685948632658, cosine=0.9999999403953552
gpt_neox.layers.9.input_layernorm.weight grad match: False Maxdiff: 6.195157766342163e-06, relativediff: 0.0019708990585058928, cosine=0.9999998807907104
gpt_neox.layers.9.input_layernorm.bias grad match: False Maxdiff: 5.588633939623833e-06, relativediff: 0.0013916906900703907, cosine=0.9999999403953552
gpt_neox.layers.9.post_attention_layernorm.weight grad match: False Maxdiff: 6.753206253051758e-05, relativediff: 0.002181797521188855, cosine=1.0000001192092896
gpt_neox.layers.9.post_attention_layernorm.bias grad match: False Maxdiff: 1.5349360182881355e-05, relativediff: 0.0013889750698581338, cosine=1.0
gpt_neox.layers.9.attention.query_key_value.weight grad match: False Maxdiff: 0.0005746409296989441, relativediff: 0.014891190454363823, cosine=0.9999997019767761
gpt_neox.layers.9.attention.query_key_value.bias grad match: False Maxdiff: 9.512528777122498e-06, relativediff: inf, cosine=0.9999998807907104
gpt_neox.layers.9.attention.dense.weight grad match: False Maxdiff: 2.2362452000379562e-05, relativediff: 0.001138008083216846, cosine=1.0000001192092896
gpt_neox.layers.9.attention.dense.bias grad match: False Maxdiff: 2.1470244973897934e-05, relativediff: 0.0005539876292459667, cosine=0.9999999403953552
gpt_neox.layers.9.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.0002691000699996948, relativediff: 0.004176828544586897, cosine=1.0
gpt_neox.layers.9.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 5.860254168510437e-05, relativediff: 0.029437750577926636, cosine=0.9999997615814209
gpt_neox.layers.9.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 0.00012803077697753906, relativediff: 0.002299846848472953, cosine=1.0000001192092896
gpt_neox.layers.9.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 2.1470244973897934e-05, relativediff: 0.0005539876292459667, cosine=0.9999999403953552
gpt_neox.layers.10.input_layernorm.weight grad match: False Maxdiff: 0.00041857361793518066, relativediff: 0.013254483230412006, cosine=1.0000001192092896
gpt_neox.layers.10.input_layernorm.bias grad match: False Maxdiff: 3.1903618946671486e-06, relativediff: 0.0007093409076333046, cosine=1.0
gpt_neox.layers.10.post_attention_layernorm.weight grad match: False Maxdiff: 3.400444984436035e-05, relativediff: 0.0071408688090741634, cosine=1.0000001192092896
gpt_neox.layers.10.post_attention_layernorm.bias grad match: False Maxdiff: 1.761317253112793e-05, relativediff: 0.0008480017422698438, cosine=0.9999999403953552
gpt_neox.layers.10.attention.query_key_value.weight grad match: False Maxdiff: 0.004154682159423828, relativediff: 0.009285532869398594, cosine=0.999999463558197
gpt_neox.layers.10.attention.query_key_value.bias grad match: False Maxdiff: 5.601905286312103e-06, relativediff: 0.1024150550365448, cosine=0.9999998807907104
gpt_neox.layers.10.attention.dense.weight grad match: False Maxdiff: 2.562534064054489e-05, relativediff: 0.000457679241662845, cosine=1.0
gpt_neox.layers.10.attention.dense.bias grad match: False Maxdiff: 2.1981075406074524e-05, relativediff: 0.00017686102364677936, cosine=1.0000001192092896
gpt_neox.layers.10.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 0.00013469159603118896, relativediff: 0.0026225948240607977, cosine=0.9999997615814209
gpt_neox.layers.10.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.5781802833080292e-05, relativediff: 0.0007138229557313025, cosine=0.9999999403953552
gpt_neox.layers.10.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 0.00019641220569610596, relativediff: 0.0010898219188675284, cosine=0.9999999403953552
gpt_neox.layers.10.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 2.1981075406074524e-05, relativediff: 0.00017686102364677936, cosine=1.0000001192092896
gpt_neox.layers.11.input_layernorm.weight grad match: False Maxdiff: 1.2040138244628906e-05, relativediff: 0.0023908887524157763, cosine=1.0000001192092896
gpt_neox.layers.11.input_layernorm.bias grad match: False Maxdiff: 2.6496127247810364e-06, relativediff: 0.0017001176020130515, cosine=1.0
gpt_neox.layers.11.post_attention_layernorm.weight grad match: False Maxdiff: 4.336237907409668e-06, relativediff: 0.00013862353807780892, cosine=1.0
gpt_neox.layers.11.post_attention_layernorm.bias grad match: False Maxdiff: 1.0542571544647217e-06, relativediff: 3.621343421400525e-05, cosine=0.9999998807907104
gpt_neox.layers.11.attention.query_key_value.weight grad match: False Maxdiff: 0.02075481414794922, relativediff: 0.06359207630157471, cosine=0.9999991655349731
gpt_neox.layers.11.attention.query_key_value.bias grad match: False Maxdiff: 2.696644514799118e-06, relativediff: 0.25537019968032837, cosine=1.0000001192092896
gpt_neox.layers.11.attention.dense.weight grad match: False Maxdiff: 2.91336327791214e-05, relativediff: 0.0007787492941133678, cosine=0.9999998807907104
gpt_neox.layers.11.attention.dense.bias grad match: False Maxdiff: 1.8852879293262959e-06, relativediff: 0.00015919502766337246, cosine=0.9999998807907104
gpt_neox.layers.11.mlp.dense_h_to_4h.weight grad match: False Maxdiff: 2.5608460418879986e-05, relativediff: 0.00040599319618195295, cosine=0.9999999403953552
gpt_neox.layers.11.mlp.dense_h_to_4h.bias grad match: False Maxdiff: 2.766493707895279e-06, relativediff: 0.00014488919987343252, cosine=1.0
gpt_neox.layers.11.mlp.dense_4h_to_h.weight grad match: False Maxdiff: 2.9437243938446045e-05, relativediff: 0.0002886583679355681, cosine=1.0000001192092896
gpt_neox.layers.11.mlp.dense_4h_to_h.bias grad match: False Maxdiff: 1.8852879293262959e-06, relativediff: 0.00015919502766337246, cosine=0.9999998807907104
gpt_neox.final_layer_norm.weight grad match: False Maxdiff: 1.341104507446289e-06, relativediff: 2.6080595034727594e-06, cosine=1.0000001192092896
gpt_neox.final_layer_norm.bias grad match: True Maxdiff: 0.0, relativediff: 0.0, cosine=1.0
embed_out.weight grad match: True Maxdiff: 1.3242242857813835e-09, relativediff: 1.7900250668390072e-06, cosine=0.9999999403953552
Thank you for looking into this and for the quick response! Sorry am confused - what did you change to get the gradients to match in NeoX?
@lengstrom I changed nothing. I used a simplier script (tried to change the model on yours but got into errors) and can not reproduce the issue. Do you see anything wrong with my script? A difference is the device, I did not try on cuda (which may in turn use the flash attention or memory efficient attention kernel).
@fxmarty : Thank you for the help, and for the helpful pointer. Changing three basic things in your script I am able to replicate the behavior I showed.
Not doing any of these three things will make you miss the behavior I showed. Furthermore, I found that changing the loss function to cross entropy instead of the mean of the output logits also exacerbates the issue. Changing the mode from .train()
to .eval()
on the models does not change the behavior either. Here's a diff (and accompanying gist of the problem) showing the behavior: https://gist.github.com/lengstrom/aad52170c69a0f9502b7332ed22c4048/revisions
It looks like memory efficient SDP is the issue - when I disable it, the behavior I found goes away:
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(True)
Thank you for investigating it, I can reproduce the issue (even on EleutherAI/pythia-160m, to a lesser extent)! It indeed looks like the mem-efficient attention kernel is at fault. I can't reproduce with a minimal pytorch example yet - so haven't pinpoint yet where the issue is.
import torch
import torch.nn as nn
import math
import copy
class ModelVanilla(nn.Module):
def __init__(self):
super().__init__()
self.q = nn.Parameter(torch.rand(1, 8, 22, 64))
self.k = nn.Parameter(torch.rand(1, 8, 22, 64))
self.v = nn.Parameter(torch.rand(1, 8, 22, 64))
def forward(self, x):
L = self.q.shape[-2]
S = self.k.shape[-2]
scale_factor = 1 / math.sqrt(self.q.size(-1))
attn_mask = torch.ones(L, S).tril(diagonal=0).to(x.device)
attn_mask = attn_mask.masked_fill(attn_mask == 0, -float('inf'))
attn_weight = torch.softmax((self.q @ self.k.transpose(-2, -1) * scale_factor) + attn_mask, dim=-1)
# attn_weight = torch.dropout(attn_weight, dropout_p)
res = attn_weight @ self.v
res = res * x
return res
class ModelSDPA(nn.Module):
def __init__(self):
super().__init__()
self.q = nn.Parameter(torch.rand(1, 8, 22, 64))
self.k = nn.Parameter(torch.rand(1, 8, 22, 64))
self.v = nn.Parameter(torch.rand(1, 8, 22, 64))
def forward(self, x):
res = torch.nn.functional.scaled_dot_product_attention(self.q, self.k, self.v, is_causal=True, dropout_p=0.0) * x
return res
model_vanilla = ModelVanilla().to("cuda")
model_vanilla = model_vanilla.train()
model_sdpa = ModelSDPA().to("cuda")
model_sdpa = model_sdpa.train()
model_sdpa.q.data = copy.deepcopy(model_vanilla.q.data)
model_sdpa.k.data = copy.deepcopy(model_vanilla.k.data)
model_sdpa.v.data = copy.deepcopy(model_vanilla.v.data)
inp = torch.Tensor([3.]).to("cuda", torch.float32)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_math_sdp(False)
res_vanilla = model_vanilla(inp)
res_sdpa = model_sdpa(inp)
print("Res match:", torch.allclose(res_vanilla, res_sdpa), "Maxdiff:", torch.max(torch.abs(res_vanilla - res_sdpa)).item())
loss_vanilla = res_vanilla.mean()
loss_sdpa = res_sdpa.mean()
print("Loss vanilla:", loss_vanilla)
print("Loss SDPA:", loss_sdpa)
loss_vanilla.backward()
loss_sdpa.backward()
print("Loss match:", torch.allclose(loss_vanilla, loss_sdpa), "Maxdiff:", torch.max(torch.abs(loss_vanilla - loss_sdpa)).item())
for name, param_vanilla in model_vanilla.named_parameters():
param_sdpa = getattr(model_sdpa, name)
maxdiff = torch.max(torch.abs(param_vanilla.grad - param_sdpa.grad)).item()
relativediff = torch.mean(torch.abs(param_vanilla.grad - param_sdpa.grad) / torch.abs(param_vanilla.grad))
cosine = torch.nn.functional.cosine_similarity(param_vanilla.grad.flatten(), param_sdpa.grad.flatten(), dim=0)
print(f"{name} grad match:", torch.allclose(param_vanilla.grad, param_sdpa.grad), f"Maxdiff: {maxdiff}, relativediff: {relativediff}, cosine={cosine}")
interesting, thanks for looking into it more! if it is informative: using the Flash SDP I found the same issue with the 70m model (albeit - I had to do a lot of float/half casting to get the NeoX model to even run without errors).
I can actually only reproduce with mem-efficient kernel (flash fp16 is fine to me):
edit: did a mistake
cool! this is not what I found - but I compared the gradients of vanilla fp32 models with flash attention fp16 models so that could be why. sleeping now but will take a look tomorrow..
curiosity got the better of me, its weird that you get consistent half prec performance - when I modify both models to use half precision in the script above (i.e., from https://github.com/huggingface/optimum/issues/1091#issuecomment-1588040170) I get the same kinds of cosine similarity issues: (regardless of whether or not the vanilla model is fp16 or fp32) - https://gist.github.com/lengstrom/39267bbeb9a40ec68f21e71fb60b1b2b
My bad - was just an error in copy/paste on the plot:
System Info
Who can help?
@fxmarty @younesbelkada
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
https://gist.github.com/lengstrom/84a5d0d95f8942eb4eca82d9eb5aa272
This code should perfectly reproduce with just pytorch 2.1, optimum, datasets, numpy, and transformers installed. The output is (full log here: https://gist.github.com/lengstrom/84a5d0d95f8942eb4eca82d9eb5aa272?permalink_comment_id=4593821#gistcomment-4593821):
This log shows that the gradients are often rather different, and that in some parameter groups the compared gradients are nearly uncorrelated (i.e., close to 0)
Expected behavior
The gradients of samples for models with and without
BetterTransformer
transformation should be nearly identical, but are in practice often very different.Considering a GPTNeoX model, we fix a sample and record the gradient of the BetterTransformer-transformed model and the gradient of the original model. We then measure the cosine similarity (a vector similarity metric, 0 = uncorrelated, 1 = perfectly correlated) and find that both:
gpt_neox.embed_in.weight
Is there a reason for this behavior? Based on the documentation, the only difference between the transformed and untransformed models is different code paths being used, which should not change the underlying gradients.