huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.11k stars 26.32k forks source link

TorchDynamo graph needlessly fragmented for GPTNeoX due to baddbmm type mistake #24940

Closed norabelrose closed 1 year ago

norabelrose commented 1 year ago

System Info

Who can help?

@ArthurZucker @younesbelkada

Information

Tasks

Reproduction

from transformers import AutoModelForCausalLM
import torch

def debug_backend(gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]):
    print("debug_backend() called with FX graph:")
    gm.graph.print_tabular()
    return gm.forward  # return a python callable

model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-160m")
jitted = torch.compile(model, backend=debug_backend)

jitted(**model.dummy_inputs)

The output is too long to fit in a comment, so you'll have to run the code yourself. It features "debug_backend() called with FX graph:" being printed several times, each time followed with a fragment of the whole computation graph. This is not expected since NeoX has no data-dependent control flow.

Expected behavior

The torch.compile backend should only be called once, and therefore "debug_backend() called with FX graph:" should only appear once, because GPT NeoX does not actually require any data-dependent control flow.

I've already checked that this can be fixed by turning GPTNeoXAttention.norm_factor into a Python scalar instead of a tensor. This is actually what torch.baddbmm expects for its alpha parameter; it's supposed to be a scalar. But it seems to silently convert tensors into scalars, so this doesn't cause a crash in normal use.

Captura de pantalla 2023-07-19 a la(s) 5 27 42 p m

The exact fix is, in modeling_gpt_neox.py, replace lines 103-107 with:

self.norm_factor = self.head_size ** -0.5

and replace the baddbmm call inside _attn with:

attn_scores = torch.baddbmm(
    attn_scores,
    query,
    key.transpose(1, 2),
    beta=1.0,
    alpha=self.norm_factor,
)
ydshieh commented 1 year ago

@fxmarty I think you are more familiar with this topic? If so, could you take a look, thanks!

fxmarty commented 1 year ago

Hi @norabelrose, would you like to submit a PR?

norabelrose commented 1 year ago

Hi @norabelrose, would you like to submit a PR?

I already did! 😊 See #24941.

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.