Open shino16 opened 1 week ago
Apparently, fd.ops.set
creates an intermediate no-op node that binds what it receives to another TensorView
.
Val* set(Val* v) {
Val* out = ops::newValLike(v, v->getDataType().value());
IrBuilder::create<LoadStoreOp>(LoadStoreOpType::Set, out, v);
return out;
}
IrBuilder::create<LoadStoreOp>(LoadStoreOpType::Set, out, v);
is used in other places, e.g. when fd.ops.expand
does not have to expand anything. link So fd.ops.set
does not necessarily make a copy, but just bind a separate node to the output of another.
@jjsjann123 could you take a look?
fd.add_output(A, B)
fd.add_output(C, D)
Consider the above fusion definition. If I understand correctly, #1110 intends to avoid one tensor appearing twice, i.e. (A or B) and (C or D) aliasing each other.
I suspect that, when there is no direct call of Tensor.copy_
and all the other in-place ops are functionalized, this aliasing never happens even without #1110. For example, consider
def f(a, b):
a.add_(b)
Its trace is
def computation(a, b):
# a: "cuda:0 f32[]"
# b: "cuda:0 f32[]"
# /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:78: a.add_(b)
t1 = ltorch.add_(a, b, alpha=None) # t1: "cuda:0 f32[]"
# t0 = ltorch.add(a, b, alpha=None) # t0: "cuda:0 f32[]"
# t0 = prims.add(a, b) # t0: "cuda:0 f32[]"
# t1 = prims.copy_(t0, a) # t1: "cuda:0 f32[]"
return {'output': None, 'flat_args': [a, b]}
# Constructed by Dead Code Elimination (took 0 milliseconds)
def computation(a, b):
# a: "cuda:0 f32[]"
# b: "cuda:0 f32[]"
# Functionalized from `t1 = add_(a,b,None)`
t0 = ltorch.add(a, b, alpha=None) # t0: "cuda:0 f32[]"
# t0 = prims.add(a, b) # t0: "cuda:0 f32[]"
# /opt/pytorch/lightning-thunder/mshinokawa/sandbox/debug.py:78: a.add_(b)
t1 = prims.copy_(t0, a) # t1: "cuda:0 f32[]"
return {'output': None, 'flat_args': [t1, b]}
The LHS of copy_
, t0
, is the output of the arithmetic op just before prims.copy_
. As long as the preceding arithmetic op does not produce an alias, and the dst of copies do not alias (which is guaranteed by #798), we never have unwanted aliases.
Nonetheless, we must pay attention to the direct call of Tensor.copy_
, without a preceding arithmetic op. We could write
def f(a, b, c):
c.copy_(a)
c.copy_(b)
So I suggest applying #1110's fix only to plain Tensor.copy_
. As Adam.step
does not involve direct copy_
, this will speed up the Thunder-jitted Adam.step
to the original efficiency.
@crcrpar Would you check if this is correct?
This could be a good opportunity to prepare a better construct for Tensor.copy_
. Currently, it is directly translated into prims.copy_
, but Tensor.copy_
allows broadcast, dtype cast and device transfer.
but
Tensor.copy_
allows broadcast, dtype cast and device transfer.
rel: https://github.com/Lightning-AI/lightning-thunder/issues/1084
Another solution is to make nvFuserExecutor track all the arguments to add_output
and place fd.ops.set
if needed. This is more work but more general and robust.
compilation (s) execution (ms) eager 0.0 10.87
torch.compile(adam.step)
21.0 5.83torch.compile(adam.step, backend=thunder)
95.0 11.59torch.compile(adam.step, backend=thunder)
, #1110 reverted 46.0 6.241110 adds
fd.ops.set
on everyprims.copy_
(diff). What isfd.ops.set
? Can we avoid using this op?
That's a surprising regression. Wondering is you can give some repro scripts so I can investigate what's causing the regression. cc'ing @shino16
For the question re: fd.ops.set
: Out of curiosity, does #798 also detect cases when we are returning t0? vvv
t0 = ltorch.add(a, b, alpha=None) # t0: "cuda:0 f32[]"
t1 = prims.copy_(t0, a) # t1: "cuda:0 f32[]"
The PR adding fd.ops.set
is trying to avoid returning aliases. That seems to be the simple solution to avoid issues caused by such side effects when we have multiple aliases sharing storage.
I tend to agree with @t-vi on his suggestion in #1177, we should have a precise way to specify inplace update vs copying. That maybe would make it easier on nvfuser executor to construct a simpler fusion.
Wondering is you can give some repro scripts so I can investigate what's causing the regression. cc'ing @shino16
I should have done so. The benchmark script is on this gist. You can try reverting 7c9cd8c02841b14996902ae17ad81eb9cdaa9839.
And thank you for your opinion about implementation!
I took a benchmark of
torch.compile
'dAdam.step
with Thunder backend. Surprisingly, the compiledAdam.step
was even slower than the eager mode.torch.compile(adam.step)
torch.compile(adam.step, backend=thunder)
torch.compile(adam.step, backend=thunder)
, #1110 reverted1110 adds
fd.ops.set
on everyprims.copy_
(diff). What isfd.ops.set
? Can we avoid using this op?cc @tfogal