Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.18k stars 78 forks source link

Difference in backward output compared to eager #969

Open kshitij12345 opened 2 months ago

kshitij12345 commented 2 months ago
import torch
import thunder
from thunder.benchmarks.targets import LitGPTConfig, LitGPTBenchmark

name = "thunder"

def get_model_and_args(name):
    with torch.device("cuda"):
        cfg: LitGPTConfig = LitGPTConfig.from_name("Llama-2-7b-hf")
        cfg.n_layer = 2
        cfg.block_size = 512
        b = LitGPTBenchmark(cfg, batchdims=(2,), dtype=thunder.float64)
        model = b.fn()
        args, kwargs = b.make_batch()

    if name == "thunder":
        jmodel = thunder.jit(model)
    elif name == "tcompile":
        jmodel = torch.compile(model)
    elif name == "eager":
        jmodel = model

    return jmodel, model, args, kwargs, cfg

jmodel, model, args, kwargs, cfg = get_model_and_args(name)

a = jmodel(*args, **kwargs)

g = torch.rand_like(a)
actual_grads = torch.autograd.grad(a, model.parameters(), g)

# Sanity Check values vs Eager
e = model(*args, **kwargs)
expected_grads = torch.autograd.grad(e, model.parameters(), g)
try:
    torch.testing.assert_close(a, e)
except Exception as exception:
    print("Difference in forward")
    print(exception)

try:
    torch.testing.assert_close(actual_grads, expected_grads)
except Exception as exception:
    print("Difference in backward")
    print(exception)

Output

Difference in forward
Tensor-likes are not close!

Mismatched elements: 701 / 32768000 (0.0%)
Greatest absolute difference: 3.4200238951953565e-07 at index (0, 196, 27841) (up to 1e-07 allowed)
Greatest relative difference: 0.0003868698751363598 at index (0, 232, 16553) (up to 1e-07 allowed)
Difference in backward
Tensor-likes are not close!

Mismatched elements: 20801070 / 131072000 (15.9%)
Greatest absolute difference: 5.836758326438485e-06 at index (28794, 902) (up to 1e-07 allowed)
Greatest relative difference: 2.512434422647769 at index (22700, 3441) (up to 1e-07 allowed)

The failure occurred for item [0]

Note that in backward - Greatest relative difference: 2.512434422647769 and this is at float64.

This also happens with PyTorch native torch.compile(model) - so it seems that it happens in general with compilation. Filing this issue to document this.

cc: @IvanYashchuk

mruberry commented 2 months ago

triage review: