Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.17k stars 77 forks source link

dynamo : HF Bert- `NameError: name 't0' is not defined.` #864

Closed kshitij12345 closed 2 months ago

kshitij12345 commented 2 months ago

See comment below for minimal repro -

Repro - (requires transformers==4.42.4 as present in requirements/test.txt)

from datasets import load_dataset
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
import torch
from tqdm.auto import tqdm
from torch.optim import AdamW
from transformers import get_scheduler
import torch
from transformers import AutoModelForSequenceClassification
from torch.utils.data import DataLoader

dataset = load_dataset("yelp_review_full")

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(100))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(100))

train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)
eval_dataloader = DataLoader(small_eval_dataset, batch_size=8)

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-5)

num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

progress_bar = tqdm(range(num_training_steps))

model.train()

import thunder
class ThunderJitBackend:
    def __init__(self, **compile_options) -> None:        
        self.thunder_jit_fns = []
        self.dynamo_graphs = []
        self.cnt = 0
        self.compile_options = compile_options

    def compile(self, gm, sample_args):
        self.dynamo_graphs.append(gm)
        gm.real_recompile()
        thunder_jit_fn = thunder.jit(gm, **self.compile_options)
        self.thunder_jit_fns.append(thunder_jit_fn)
        self.cnt += 1
        return thunder_jit_fn

tbackend = ThunderJitBackend()
# model = torch.compile(model)  # Works
model = torch.compile(model, backend=tbackend.compile)

for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

Error

  File "/home/kkalambarkar/lightning-thunder/thunder/__init__.py", line 699, in fn_
    result = cache_entry.computation_fn(*inps)
  File "/home/kkalambarkar/git/pytorch/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/kkalambarkar/git/pytorch/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "/home/kkalambarkar/git/pytorch/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "thunder.augmented_forward_fn_6", line 11, in augmented_forward_fn
NameError: name 't0' is not defined. Did you mean: 't20'?

Failing Dynamo Generated Graph

GraphModule()

def forward(self, L_logits_ : torch.Tensor, L_labels_ : torch.Tensor):
    l_logits_ = L_logits_
    l_labels_ = L_labels_
    view = l_logits_.view(-1, 5);  l_logits_ = None
    view_1 = l_labels_.view(-1);  l_labels_ = None
    loss = torch.nn.functional.cross_entropy(view, view_1, None, None, -100, None, 'mean', 0.0);  view = view_1 = None
    return (loss,)

Failing Trace -

# Constructed by Delete Last Used (took 0 milliseconds)
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(L_logits_, L_labels_):
  # t0: "cuda:0 f32[8, 5]"
  # t1: "cuda:0 i64[8]"
  t15 = torch.nn.functional.log_softmax(t0, 1)  # t15: "cuda:0 f32[8, 5]"
    # t15 = ltorch.log_softmax(t0, 1, dtype=None)  # t15: "cuda:0 f32[8, 5]"
      # t37 = ltorch.logsumexp(t0, 1, True)  # t37: "cuda:0 f32[8, 1]"
        # t27 = ltorch.amax(t0, 1, True)  # t27: "cuda:0 f32[8, 1]"
          # t26 = prims.amax(t0, (1,))  # t26: "cuda:0 f32[8]"
          # t27 = prims.broadcast_in_dim(t26, [8, 1], [0])  # t27: "cuda:0 f32[8, 1]"
        # t28 = ltorch.abs(t27)  # t28: "cuda:0 f32[8, 1]"
          # t28 = prims.abs(t27)  # t28: "cuda:0 f32[8, 1]"
        # t29 = ltorch.eq(t28, float('inf'))  # t29: "cuda:0 b8[8, 1]"
          # t29 = prims.eq(t28, float('inf'))  # t29: "cuda:0 b8[8, 1]"
        # t30 = ltorch.where(t29, 0, t27)  # t30: "cuda:0 f32[8, 1]"
          # _ = prims.convert_element_type(0, float)
          # t30 = prims.where(t29, 0.0, t27)  # t30: "cuda:0 f32[8, 1]"
        # t32 = ltorch.sub(t0, t30, alpha=None)  # t32: "cuda:0 f32[8, 5]"
          # t31 = prims.broadcast_in_dim(t30, (8, 5), (0, 1))  # t31: "cuda:0 f32[8, 5]"
          # t32 = prims.sub(t0, t31)  # t32: "cuda:0 f32[8, 5]"
        # t33 = ltorch.exp(t32)  # t33: "cuda:0 f32[8, 5]"
          # t33 = prims.exp(t32)  # t33: "cuda:0 f32[8, 5]"
        # t35 = ltorch.sum(t33, 1, True, dtype=None)  # t35: "cuda:0 f32[8, 1]"
          # t34 = prims.sum(t33, (1,))  # t34: "cuda:0 f32[8]"
          # t35 = prims.broadcast_in_dim(t34, [8, 1], [0])  # t35: "cuda:0 f32[8, 1]"
        # t36 = ltorch.log(t35)  # t36: "cuda:0 f32[8, 1]"
          # t36 = prims.log(t35)  # t36: "cuda:0 f32[8, 1]"
        # t37 = ltorch.add(t36, t30, alpha=None)  # t37: "cuda:0 f32[8, 1]"
          # t37 = prims.add(t36, t30)  # t37: "cuda:0 f32[8, 1]"
      # t15 = ltorch.sub(t0, t37, alpha=None)  # t15: "cuda:0 f32[8, 5]"
        # t38 = prims.broadcast_in_dim(t37, (8, 5), (0, 1))  # t38: "cuda:0 f32[8, 5]"
        # t15 = prims.sub(t0, t38)  # t15: "cuda:0 f32[8, 5]"
  t16 = torch.neg(t15)  # t16: "cuda:0 f32[8, 5]"
    # t16 = ltorch.neg(t15)  # t16: "cuda:0 f32[8, 5]"
      # t16 = prims.neg(t15)  # t16: "cuda:0 f32[8, 5]"
  t17 = torch.unsqueeze(t1, 1)  # t17: "cuda:0 i64[8, 1]"
    # t17 = ltorch.unsqueeze(t1, 1)  # t17: "cuda:0 i64[8, 1]"
      # t17 = prims.broadcast_in_dim(t1, [8, 1], [0])  # t17: "cuda:0 i64[8, 1]"
  t18 = torch.take_along_dim(t16, t17, 1)  # t18: "cuda:0 f32[8, 1]"
    # t18 = ltorch.take_along_dim(t16, t17, 1)  # t18: "cuda:0 f32[8, 1]"
      # t18 = prims.take_along_axis(t16, t17, 1)  # t18: "cuda:0 f32[8, 1]"
  del t16
  t19 = torch.ne(t17, -100)  # t19: "cuda:0 b8[8, 1]"
    # t19 = ltorch.ne(t17, -100)  # t19: "cuda:0 b8[8, 1]"
      # t19 = prims.ne(t17, -100)  # t19: "cuda:0 b8[8, 1]"
  del t17
  t20 = torch.where(t19, t18, 0)  # t20: "cuda:0 f32[8, 1]"
    # t20 = ltorch.where(t19, t18, 0)  # t20: "cuda:0 f32[8, 1]"
      # _ = prims.convert_element_type(0, float)
      # t20 = prims.where(t19, t18, 0.0)  # t20: "cuda:0 f32[8, 1]"
  del t18
  t21 = torch.sum(t20, None, False, dtype=None)  # t21: "cuda:0 f32[]"
    # t21 = ltorch.sum(t20, None, False, dtype=None)  # t21: "cuda:0 f32[]"
      # t21 = prims.sum(t20, (0, 1))  # t21: "cuda:0 f32[]"
  del t20
  t23 = torch.sum(t19, None, False, dtype=None)  # t23: "cuda:0 i64[]"
    # t23 = ltorch.sum(t19, None, False, dtype=None)  # t23: "cuda:0 i64[]"
      # t46 = ltorch.to(t19, dtypes.int64, None, device=None, dtype=None, copy=False, memory_format=None)  # t46: "cuda:0 i64[8, 1]"
        # t46 = prims.convert_element_type(t19, dtypes.int64)  # t46: "cuda:0 i64[8, 1]"
      # t23 = prims.sum(t46, (0, 1))  # t23: "cuda:0 i64[]"
  del t19
  t25 = torch.true_divide(t21, t23)  # t25: "cuda:0 f32[]"
    # t25 = ltorch.true_divide(t21, t23)  # t25: "cuda:0 f32[]"
      # t48 = prims.convert_element_type(t23, dtypes.float32)  # t48: "cuda:0 f32[]"
      # t25 = prims.div(t21, t48)  # t25: "cuda:0 f32[]"
  del t21
  return {'output': (t25,), 'flat_args': [t0, t1], 'flat_output': (t25,)}, ((t1, t15, t23), ())

cc: @IvanYashchuk

kshitij12345 commented 2 months ago

Running the dynamo graph independently works -

Minimal Repro -

import torch
import thunder

# Dynamo Generated Graph
def forward(self, L_logits_ : torch.Tensor, L_labels_ : torch.Tensor):
    l_logits_ = L_logits_
    l_labels_ = L_labels_
    view = l_logits_.view(-1, 5);  l_logits_ = None
    view_1 = l_labels_.view(-1);  l_labels_ = None
    loss = torch.nn.functional.cross_entropy(view, view_1, None, None, -100, None, 'mean', 0.0);  view = view_1 = None
    return (loss,)

t = torch.randn(8, 5, requires_grad=True, device='cuda:0')
labels = torch.tensor([2, 4, 2, 3, 1, 0, 4, 4], device='cuda:0')

thunder.jit(forward)(None, t, labels)
IvanYashchuk commented 2 months ago

I think you will be able to reproduce the error by setting requires_grad of t to True. In the current code snippet no augmented forward nor backward are generated.

kshitij12345 commented 2 months ago

Nice catch, thanks @IvanYashchuk - I can repro the error by setting requires_grad=True on t. Updating the above example.

t-vi commented 2 months ago

looks like the renaming has a problem t0 should be Llogits

kshitij12345 commented 2 months ago

So the problematic update is happening in _transform_for_operator_executor_execution.

https://github.com/Lightning-AI/lightning-thunder/blob/a6a304491ccee34dc025451f0667463472488d90/thunder/executors/passes.py#L29

The input trace which it receives is (removed sub-symbols to simplify the trace)

# Constructed by Dead Code Elimination (took 0 milliseconds)
import thunder
import thunder.core.dtypes as dtypes
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(L_logits_, L_labels_):
  # L_logits_: "cpu f32[8, 5]"
  # L_labels_: "cpu i64[8]"
  t0 = prims.reshape(L_logits_, (8, 5))  # t0: "cpu f32[8, 5]"
  t1 = prims.reshape(L_labels_, (8,))  # t1: "cpu i64[8]"
  t15 = ltorch.log_softmax(t0, 1, dtype=None)  # t15: "cpu f32[8, 5]"

  t16 = ltorch.neg(t15)  # t16: "cpu f32[8, 5]"

  t17 = ltorch.unsqueeze(t1, 1)  # t17: "cpu i64[8, 1]"

  t18 = ltorch.take_along_dim(t16, t17, 1)  # t18: "cpu f32[8, 1]"

  t19 = ltorch.ne(t17, -100)  # t19: "cpu b8[8, 1]"

  t20 = ltorch.where(t19, t18, 0)  # t20: "cpu f32[8, 1]"

  t21 = ltorch.sum(t20, None, False, dtype=None)  # t21: "cpu f32[]"

  t23 = ltorch.sum(t19, None, False, dtype=None)  # t23: "cpu i64[]"

  t25 = ltorch.true_divide(t21, t23)  # t25: "cpu f32[]"
  return {'output': t25, 'flat_args': [L_logits_, L_labels_], 'flat_output': (t25,)}, ((t1, t15, t23), ())

_transform_for_operator_executor_execution modifies this trace to

import thunder.core.dtypes as dtypes
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def augmented_forward_fn(L_logits_, L_labels_):
  # t0: "cpu f32[8, 5]"
  # t1: "cpu i64[8]"
  t0 = torch.reshape(t0, (8, 5))  # t0: "cpu f32[8, 5]"

  t1 = torch.reshape(t1, (8,))  # t1: "cpu i64[8]"

  t15 = torch.nn.functional.log_softmax(t0, 1)  # t15: "cpu f32[8, 5]"

  t16 = torch.neg(t15)  # t16: "cpu f32[8, 5]"

  t17 = torch.unsqueeze(t1, 1)  # t17: "cpu i64[8, 1]"

  t18 = torch.take_along_dim(t16, t17, 1)  # t18: "cpu f32[8, 1]"

  t19 = torch.ne(t17, -100)  # t19: "cpu b8[8, 1]"

  t20 = torch.where(t19, t18, 0)  # t20: "cpu f32[8, 1]"

  t21 = torch.sum(t20, None, False, dtype=None)  # t21: "cpu f32[]"

  t23 = torch.sum(t19, None, False, dtype=None)  # t23: "cpu i64[]"

  t25 = torch.true_divide(t21, t23)  # t25: "cpu f32[]"
  return {'output': t25, 'flat_args': [t0, t1], 'flat_output': (t25,)}, ((t1, t15, t23), ())

NOTE the

t0 = torch.reshape(t0, (8, 5))  # t0: "cpu f32[8, 5]"
t1 = torch.reshape(t1, (8,))  # t1: "cpu i64[8]"

This happens as the we hit the special case for reshape (where the size to reshape is same as original size of the tensor) and we return the same proxy.

https://github.com/Lightning-AI/lightning-thunder/blob/a6a304491ccee34dc025451f0667463472488d90/thunder/clang/__init__.py#L1046-L1051

What has confused me is _transform_for_operator_executor_execution takes this updated output from the transformed symbol and adds it to swapmap using update_swapmap (ref 1). update_swapmap maps the new output proxy to the old one (ref 2). This seems confusing based on the documentation of from_bsym_swap_proxies -

This replaces :class:``Proxy``s, e.g. :class:`TensorProxy`, of inputs and output
with another ones already seen recorded in ``swap_map`` (``swap_map`` maps variableified
:class:``Proxy`` to an existing one generated by the same expression), and do the same to subsymbols.

so, I assume - update_swapmap should have actually mapped old proxy to the new one (other usages of swapmap and from_bsym_swap_proxies do this). Is that correct or am I missing something?

Ref 1 https://github.com/Lightning-AI/lightning-thunder/blob/a6a304491ccee34dc025451f0667463472488d90/thunder/executors/passes.py#L88

Ref 2 https://github.com/Lightning-AI/lightning-thunder/blob/a6a304491ccee34dc025451f0667463472488d90/thunder/executors/passes.py#L34-L45


The patch mapping old proxy to new one fixes the above issue.

diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py
index 8040fb6..50129b7 100644
+++ b/thunder/executors/passes.py
@@ -42,7 +45,7 @@ def _transform_for_operator_executor_execution(trace: TraceCtx, executors_list:
             vno = variableify(no)
             if vo == vno:
                 return
-            swapmap[vno] = o
+            swapmap[vo] = no

But it creates an invalid backward graph. See that t47 is input to ltorch.nll_loss_backward and t47 is output of ltorch.log_softmax_backward. This probably occurs due to all the renaming that happens in forward graph and we end up saving a proxy named t47 for backward and backward graph already had a t47 as output of something else.

import thunder
import thunder.core.dtypes as dtypes
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
  # saved_for_backward: "Collection"
  # cotangents: "Collection"
  C0, _, = saved_for_backward
  t26, = cotangents
  L_labels_, t39, _, = C0
  t41 = ltorch.nll_loss_backward(t26, t39, L_labels_, None, 'mean', -100, t47)  # t41: "cpu f32[8, 5]"
  t47 = ltorch.log_softmax_backward(t41, t39, 1, dtypes.float32)  # t47: "cpu f32[8, 5]"

  t50 = prims.reshape(t47, (8, 5))  # t50: "cpu f32[8, 5]"
  return (t50, None)
IvanYashchuk commented 2 months ago

t0 = prims.reshape(L_logits_, (8, 5)) # t0: "cpu f32[8, 5]" is transformed into t0 = torch.reshape(t0, (8, 5)) # t0: "cpu f32[8, 5]". How does clang.reshape get used here? Why did the input argument change from L_logits_ to t0?

It seems to me that the trace produced by _transform_for_operator_executor_execution is mostly correct the only missing thing is the renaming of function arguments, they need to be renamed L_logits_ -> t0 and L_labels_ -> t1.