Closed IvanYashchuk closed 2 years ago
Oops, this should work.
Expand to concrete size was added pretty recently, we can double check expand
as well as broadcast_in_dim
first.
I just realized that broadcast_in_dim
is Python bindings only thing, but I was able to get the same error with a normal broadcast
call. I will post the C++ variant of the code shortly.
Here's the Python script using broadcast
instead of broadcast_in_dim
:
import torch
from torch._C._nvfuser import Fusion, FusionDefinition
# Construct and Define Fusion
fusion = Fusion()
with FusionDefinition(fusion) as fd :
t0 = fd.define_tensor(2)
t1 = fd.define_tensor(1)
fd.add_input(t0)
fd.add_input(t1)
t0_b = fd.Ops.broadcast(t0, [False, False]) # using broadcast instead of broadcast_in_dim
t1_b = fd.Ops.broadcast(t1, [True, False])
t2 = fd.Ops.add(t0_b, t1_b)
fd.add_output(t2)
fusion.print_ir()
# Execute Fusion
input1 = torch.ones(3, 1, device='cuda')
input2 = torch.ones(3, device='cuda')
fusion.execute([input1, input2])
It fails with the same error RuntimeError: Attempting to bind T0.size[1] to 3but it's already set to 1
. But C++ test fails with a different error message:
C++ exception with description "false INTERNAL ASSERT FAILED at "/home/iyashchuk/dev/pytorch/master/torch/csrc/jit/codegen/cuda/executor.cpp":236, please report a bug to PyTorch. Allocations must be based on constant integers for local memory. However, found: T3_l[ iS15{T0.size[0]}, bS6{1} ], T2_l[ iS11{T0.size[0]}, iS12{T0.size[1]} ], have dynamic allocations but are placed in local memory.
Exception raised from compileFusion at /home/iyashchuk/dev/pytorch/master/torch/csrc/jit/codegen/cuda/executor.cpp:236 (most recent call first):
TEST_F(NVFuserTest, FusionBroadcastVectors_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
TensorView* t0 = makeSymbolicTensor(2);
TensorView* t1 = makeSymbolicTensor(1);
fusion.addInput(t0);
fusion.addInput(t1);
TensorView* t0_b = broadcast(t0, {false, false});
TensorView* t1_b = broadcast(t1, {true, false});
TensorView* t2 = add(t0_b, t1_b);
fusion.addOutput(t2);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto input0 = at::randn({3, 1}, options);
auto input1 = at::randn({3}, options);
auto aten_output = at::add(input0, input1);
FusionExecutor fe;
fe.compileFusion(&fusion, {input0, input1});
auto cg_outputs = fe.runFusion({input0, input1});
testValidate(&fusion, cg_outputs, {input0, input1}, {aten_output}, __LINE__, __FILE__);
}
Broadcasting on t0 doesn't look right to me... to repro the original problem, you might need to use expand
instead. Let me take a quick look there as well.
Okay, broadcasting on t0
isn't needed actually. The error is raised also with t2 = add(t0, t1_b)
.
Looks like this logic is flipped. What's the semantic for broadcast_in_dim
supposed to be? Does [1]
mean the axis 1 is broadcasted or non-broadcasted?
This worked for me on TOT for example:
import torch
from torch._C._nvfuser import Fusion, FusionDefinition
# Construct and Define Fusion
fusion = Fusion()
with FusionDefinition(fusion) as fd :
t0 = fd.define_tensor(2)
t1 = fd.define_tensor(1)
fd.add_input(t0)
fd.add_input(t1)
t1_b = fd.Ops.broadcast_in_dim(t1, [3, 3], [0]) # 1 -> 0
t2 = fd.Ops.add(t0, t1_b)
fd.add_output(t2)
fusion.print_ir()
# Execute Fusion
input1 = torch.ones(3, 1, device='cuda')
input2 = torch.ones(3, device='cuda')
fusion.execute([input1, input2])
I believe eager mode interpreted this as an implicit broadcast on axis 1,
input1 = torch.ones(3, 1, device='cuda')
and this should somehow be reflected on our fusion definition without seeing the actual input, if that was the original intention.
Something is wrong with broadcast_in_dim... The sample python code runs fine without error, but it returns the output tensor not expanded to the right size. print(out[0].shape)
Inputs:
T0_g[ iS0{i0}, iS1{i2} ], float
T1_g[ iS2{i3} ], float
Outputs:
T3_g[ iS5{i0}, iS6{i2} ], float
%kernel_math {
T2_l[ iS3{i3}, bS4{1} ]
= broadcast( T1_g[ iS2{i3} ] )
T3_g[ iS5{i0}, iS6{i2} ]
= T0_g[ iS0{i0}, iS1{i2} ]
+ T2_l[ iS3{i3}, bS4{1} ];
}
torch.Size([3, 1])
So we want output size to be [3,3]
?
In that case I believe we need to make T0[I, B] instead of T0[I,I], currently the fusion definition is saying T0 is a concrete tensor of size[3,1].
The iS1
in T0_g[ iS0{i0}, iS1{i2} ]
is symbolic shaped so there's no indication that it could be a broadcast.
Do we currently support creating tensors with broadcast axes with define_tensor
?
t0 = fd.define_tensor([3, 1], [1, 1])
(for some reason -1
for size is not supported yet).
hmmm. even that doesn't work, since I don't see expand in the kernel math...
I guess this is expected, since broadcast_in_dim only does broadcast at this moment, but not expand to the right size. https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/python_frontend/python_bindings.cpp#L606-L635
I'll patch this. Self assigned.
Just got it patched.
Maybe there should be a squeeze
operation first to remove dimensions with 1 shape and then use broadcast
. This code works as expected without using expand
:
In [10]: import torch
...:
...: from torch._C._nvfuser import Fusion, FusionDefinition
...:
...: # Construct and Define Fusion
...: fusion = Fusion()
...:
...: with FusionDefinition(fusion) as fd :
...: t0 = fd.define_tensor(1)
...: t1 = fd.define_tensor(1)
...:
...: fd.add_input(t0)
...: fd.add_input(t1)
...:
...: t0_b = fd.Ops.broadcast(t0, [False, True])
...: t1_b = fd.Ops.broadcast(t1, [True, False])
...: t2 = fd.Ops.add(t0_b, t1_b)
...:
...: fd.add_output(t2)
...:
...: fusion.print_ir()
...:
...: # Execute Fusion
...: input1 = torch.ones(3, device='cuda')
...: input2 = torch.ones(3, device='cuda')
...:
...: fusion.execute([input1, input2])
Inputs:
T0_g[ iS0{i0} ], float
T1_g[ iS1{i2} ], float
Outputs:
T4_g[ iS6{i0}, iS7{i2} ], float
%kernel_math {
T2_l[ iS2{i0}, bS3{1} ]
= broadcast( T0_g[ iS0{i0} ] )
T3_l[ bS4{1}, iS5{i2} ]
= broadcast( T1_g[ iS1{i2} ] )
T4_g[ iS6{i0}, iS7{i2} ]
= T2_l[ iS2{i0}, bS3{1} ]
+ T3_l[ bS4{1}, iS5{i2} ];
}
Out[10]:
[tensor([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]], device='cuda:0')]
We shouldn't need to add squeeze, expand can work on broadcasted dimension (size-1).
🐛 Describe the bug
Is this a bug in the nvFuser or the code below is invalid? I think the code is valid, I translated the trace of
torch._refs.add
to nvFuser Python API calls. There's no error with ATen execution.fusion.execute
call raisesVersions
.