Closed kshitij12345 closed 2 months ago
Running the dynamo graph independently works -
Minimal Repro -
import torch
import thunder
# Dynamo Generated Graph
def forward(self, L_logits_ : torch.Tensor, L_labels_ : torch.Tensor):
l_logits_ = L_logits_
l_labels_ = L_labels_
view = l_logits_.view(-1, 5); l_logits_ = None
view_1 = l_labels_.view(-1); l_labels_ = None
loss = torch.nn.functional.cross_entropy(view, view_1, None, None, -100, None, 'mean', 0.0); view = view_1 = None
return (loss,)
t = torch.randn(8, 5, requires_grad=True, device='cuda:0')
labels = torch.tensor([2, 4, 2, 3, 1, 0, 4, 4], device='cuda:0')
thunder.jit(forward)(None, t, labels)
I think you will be able to reproduce the error by setting requires_grad of t
to True
. In the current code snippet no augmented forward nor backward are generated.
Nice catch, thanks @IvanYashchuk - I can repro the error by setting requires_grad=True
on t
. Updating the above example.
looks like the renaming has a problem t0 should be Llogits
So the problematic update is happening in _transform_for_operator_executor_execution
.
The input trace which it receives is (removed sub-symbols to simplify the trace)
# Constructed by Dead Code Elimination (took 0 milliseconds)
import thunder
import thunder.core.dtypes as dtypes
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def augmented_forward_fn(L_logits_, L_labels_):
# L_logits_: "cpu f32[8, 5]"
# L_labels_: "cpu i64[8]"
t0 = prims.reshape(L_logits_, (8, 5)) # t0: "cpu f32[8, 5]"
t1 = prims.reshape(L_labels_, (8,)) # t1: "cpu i64[8]"
t15 = ltorch.log_softmax(t0, 1, dtype=None) # t15: "cpu f32[8, 5]"
t16 = ltorch.neg(t15) # t16: "cpu f32[8, 5]"
t17 = ltorch.unsqueeze(t1, 1) # t17: "cpu i64[8, 1]"
t18 = ltorch.take_along_dim(t16, t17, 1) # t18: "cpu f32[8, 1]"
t19 = ltorch.ne(t17, -100) # t19: "cpu b8[8, 1]"
t20 = ltorch.where(t19, t18, 0) # t20: "cpu f32[8, 1]"
t21 = ltorch.sum(t20, None, False, dtype=None) # t21: "cpu f32[]"
t23 = ltorch.sum(t19, None, False, dtype=None) # t23: "cpu i64[]"
t25 = ltorch.true_divide(t21, t23) # t25: "cpu f32[]"
return {'output': t25, 'flat_args': [L_logits_, L_labels_], 'flat_output': (t25,)}, ((t1, t15, t23), ())
_transform_for_operator_executor_execution
modifies this trace to
import thunder.core.dtypes as dtypes
import torch
import torch.nn.functional
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def augmented_forward_fn(L_logits_, L_labels_):
# t0: "cpu f32[8, 5]"
# t1: "cpu i64[8]"
t0 = torch.reshape(t0, (8, 5)) # t0: "cpu f32[8, 5]"
t1 = torch.reshape(t1, (8,)) # t1: "cpu i64[8]"
t15 = torch.nn.functional.log_softmax(t0, 1) # t15: "cpu f32[8, 5]"
t16 = torch.neg(t15) # t16: "cpu f32[8, 5]"
t17 = torch.unsqueeze(t1, 1) # t17: "cpu i64[8, 1]"
t18 = torch.take_along_dim(t16, t17, 1) # t18: "cpu f32[8, 1]"
t19 = torch.ne(t17, -100) # t19: "cpu b8[8, 1]"
t20 = torch.where(t19, t18, 0) # t20: "cpu f32[8, 1]"
t21 = torch.sum(t20, None, False, dtype=None) # t21: "cpu f32[]"
t23 = torch.sum(t19, None, False, dtype=None) # t23: "cpu i64[]"
t25 = torch.true_divide(t21, t23) # t25: "cpu f32[]"
return {'output': t25, 'flat_args': [t0, t1], 'flat_output': (t25,)}, ((t1, t15, t23), ())
NOTE the
t0 = torch.reshape(t0, (8, 5)) # t0: "cpu f32[8, 5]"
t1 = torch.reshape(t1, (8,)) # t1: "cpu i64[8]"
This happens as the we hit the special case for reshape (where the size to reshape is same as original size of the tensor) and we return the same proxy.
What has confused me is _transform_for_operator_executor_execution
takes this updated output from the transformed
symbol and adds it to swapmap using update_swapmap
(ref 1). update_swapmap
maps the new output proxy to the old one (ref 2). This seems confusing based on the documentation of from_bsym_swap_proxies
-
This replaces :class:``Proxy``s, e.g. :class:`TensorProxy`, of inputs and output
with another ones already seen recorded in ``swap_map`` (``swap_map`` maps variableified
:class:``Proxy`` to an existing one generated by the same expression), and do the same to subsymbols.
so, I assume - update_swapmap
should have actually mapped old proxy to the new one (other usages of swapmap and from_bsym_swap_proxies
do this). Is that correct or am I missing something?
The patch mapping old proxy to new one fixes the above issue.
diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py
index 8040fb6..50129b7 100644
+++ b/thunder/executors/passes.py
@@ -42,7 +45,7 @@ def _transform_for_operator_executor_execution(trace: TraceCtx, executors_list:
vno = variableify(no)
if vo == vno:
return
- swapmap[vno] = o
+ swapmap[vo] = no
But it creates an invalid backward graph. See that t47
is input to ltorch.nll_loss_backward
and t47
is output of ltorch.log_softmax_backward
. This probably occurs due to all the renaming that happens in forward graph and we end up saving a proxy named t47
for backward and backward graph already had a t47
as output of something else.
import thunder
import thunder.core.dtypes as dtypes
import thunder.core.prims as prims
import thunder.torch as ltorch
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def backward_fn(saved_for_backward, cotangents):
# saved_for_backward: "Collection"
# cotangents: "Collection"
C0, _, = saved_for_backward
t26, = cotangents
L_labels_, t39, _, = C0
t41 = ltorch.nll_loss_backward(t26, t39, L_labels_, None, 'mean', -100, t47) # t41: "cpu f32[8, 5]"
t47 = ltorch.log_softmax_backward(t41, t39, 1, dtypes.float32) # t47: "cpu f32[8, 5]"
t50 = prims.reshape(t47, (8, 5)) # t50: "cpu f32[8, 5]"
return (t50, None)
t0 = prims.reshape(L_logits_, (8, 5)) # t0: "cpu f32[8, 5]"
is transformed into t0 = torch.reshape(t0, (8, 5)) # t0: "cpu f32[8, 5]"
. How does clang.reshape
get used here? Why did the input argument change from L_logits_
to t0
?
It seems to me that the trace produced by _transform_for_operator_executor_execution
is mostly correct the only missing thing is the renaming of function arguments, they need to be renamed L_logits_
-> t0
and L_labels_
-> t1
.
See comment below for minimal repro -
Repro - (requires
transformers==4.42.4
as present inrequirements/test.txt
)Error
Failing Dynamo Generated Graph
Failing Trace -
cc: @IvanYashchuk