NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
240 stars 44 forks source link

an error seemingly related to matmul #2532

Open crcrpar opened 1 week ago

crcrpar commented 1 week ago

nvfuser commit: d75fc93 cuda: 12.5

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1, 1], contiguity=[True, None, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 0, 1])
    T1 = fd.define_tensor(shape=[-1, 1, -1], contiguity=[True, None, True], dtype=DataType.Float, is_cpu=False, stride_order=[2, 1, 0])
    T2 = fd.ops.sum(T1, dims=[0, 1], keepdim=False, dtype=DataType.Null)
    T3 = fd.ops.matmul(T0, T1)
    T4 = fd.ops.sum(T3, dims=[0], keepdim=False, dtype=DataType.Null)
    fd.add_output(T2)
    fd.add_output(T4)

with FusionDefinition() as fd:
    nvfuser_fusion_id1(fd)

inputs = [
    torch.randn((262400,), dtype=torch.float32, device='cuda:0').as_strided((1025, 256, 1), (256, 1, 256)),
    torch.randn((1049600,), dtype=torch.float32, device='cuda:0').as_strided((1025, 1, 1024), (1024, 1024, 1)),
]
fd.execute(inputs)
RuntimeError: Expected T4_g[ iS103{( ceilDiv(( ceilDiv(( ceilDiv(( ( (( (( getMetaData(T4) )).logical_size ))[1] ) * ( ( (( (( getMetaData(T4) )).logical_size ))[2] ) * 1 ) ), 4) ), blockDim.x) ), 1) )}, iS102{blockDim.x}, iS112{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( (( (( getMetaData(T4) )).logical_size ))[0] ), blockDim.y) ), 2) ), 1) ), gridDim.y) )}, iS104{1}, iS100{4}, iS111{gridDim.y}, iS106{blockDim.y}, iS110{1}, iS108{2} ] to be bound to a tensor of rank 4, but got a tensor of rank 3
Exception raised from validateValWithConcreteValue at /opt/pytorch/nvfuser/csrc/expr_evaluator.cpp:38 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x7f8269a088e5 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::ExpressionEvaluator::bind_(nvfuser::Val const*, dynamic_type::DynamicType<dynamic_type::Containers<std::vector>, nvfuser::StructHandle, nvfuser::Pointer, nvfuser::Opaque, at::Tensor, std::complex<double>, double, long, bool>, bool) + 0x15b5 (0x7f8269d97755 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: <unknown function> + 0x48eb82 (0x7f8269d6db82 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: nvfuser::FusionExecutor::runFusion(nvfuser::KernelArgumentHolder&, nvfuser::LaunchParams const&, nvfuser::CompileParams, std::vector<at::Tensor, std::allocator<at::Tensor> >) + 0x2b8 (0x7f8269d648d8 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0x67bbf1 (0x7f8269f5abf1 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0x681481 (0x7f8269f60481 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: nvfuser::FusionKernelRuntime::runWithInputs(nvfuser::KernelArgumentHolder&) + 0xa9 (0x7f8269f60cd9 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0x3fe (0x7f8269f6bebe in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, std::optional<signed char>, bool, bool, bool) const + 0x38b (0x7f826a15dd4b in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #9: <unknown function> + 0x1a38ee (0x7f8269a828ee in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #10: <unknown function> + 0x21a5ef (0x7f8269af95ef in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x2afd60 (0x7f8269b8ed60 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
<omitting python frames>
frame #36: torch::autograd::PyNode::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x95 (0x7f8688f2a115 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #37: <unknown function> + 0x4c6e8bb (0x7f8680ce18bb in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so)
frame #38: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0xd36 (0x7f8680cdbb46 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so)
frame #39: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x58e (0x7f8680cdcd7e in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so)
frame #40: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x2a9 (0x7f8680cd54e9 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so)
frame #41: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x75 (0x7f8688f24c15 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #42: <unknown function> + 0xdc253 (0x7f8696232253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #43: <unknown function> + 0x94ac3 (0x7f8696416ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #44: <unknown function> + 0x126850 (0x7f86964a8850 in /usr/lib/x86_64-linux-gnu/libc.so.6)

encountered this error when running

import torch
from nemo.collections.nlp.modules.common.megatron.mlp import ParallelMLP
from megatron.core.tensor_parallel.layers import LinearWithGradAccumulationAndAsyncCommunication

import thunder

def f(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor) -> tuple[torch.Tensor]:
    return LinearWithGradAccumulationAndAsyncCommunication.apply(
        x,
        weight,
        bias,
        False,  # gradient_accumulation_fusion
        False,  # allreduce_dgrad
        False,  # sequence_parallel
        None,  # grad_output_buffer
    )

def main():
    device = torch.device("cuda")
    s, b, hp = 1025, 1, 256
    with device:
        l = torch.nn.Linear(256, 1024).to(device)
        weight = l.weight
        bias = l.bias
        x = torch.randn((s, b, hp))

    jitted = thunder.jit(f, nv_enable_matmul=True, nv_enable_linear=True)
    y = jitted(x, weight, bias)
    y.backward(torch.ones_like(y))

if __name__ == "__main__":
    main()
jacobhinkle commented 1 week ago

@Priya2698 , it looks like this K=1 case with a batch dimension is creating a 4D tensorview output which should be 3D:

Inputs:                                                                                                                 
  T0_g[ iS0{i0}, iS1{i1}, bS2{1} ], float                                                                               
  T1_g[ iS18{i0}, bS4{1}, iS5{i6} ], float                                                                              
Outputs:                                                                                                                  T3_g[ rS20{i0}, iS9{i6} ], float                                                                                      
  T5_g[ rS14{i0}, iS15{i1}, iS16{i6}, bS17{1} ], float                                                                                                                                                                                          
%kernel_math {                                                                                                          
T2_l[ iS19{i0}, iS7{i6} ]                                                                                               
   = squeeze( T1_g[ iS18{i0}, bS4{1}, iS5{i6} ] )                                                                       
T3_g[ rS20{i0}, iS9{i6} ]                                                                                               
   = reduction( T2_l[ iS19{i0}, iS7{i6} ], op = add, initial value = float(0), allreduce = false )                      
T4_l[ iS10{i0}, iS11{i1}, iS12{i6}, bS13{1} ]                                                                        
   = matmul(T0_g[ iS0{i0}, iS1{i1}, bS2{1} ],                                                                           
            T1_g[ iS18{i0}, bS4{1}, iS5{i6} ])                                                                          
T5_g[ rS14{i0}, iS15{i1}, iS16{i6}, bS17{1} ]                                                                           
   = reduction( T4_l[ iS10{i0}, iS11{i1}, iS12{i6}, bS13{1} ], op = add, initial value = float(0), allreduce = false )
}     

That last axis in T4_l should be Reduction not Broadcast.

[nav] In [1]: import torch
[nav] In [2]: x = torch.randn([5, 5, 1])
[nav] In [3]: y = torch.randn([5, 1, 7])
[nav] In [4]: z = torch.matmul(x,y)
[nav] In [5]: z.shape
Out[5]: torch.Size([5, 5, 7])
jacobhinkle commented 1 week ago

That last axis in T4_l should be Reduction not Broadcast.

Actually I think we will hit errors with Reductions mapping to Iteration domains. Instead we should just not include this Reduction dimension when K=1.

There is a related issue here that ops::newOutputIterDomain is not respecting force_iter_type when all the inputs are Broadcast.

Priya2698 commented 1 week ago

That last axis in T4_l should be Reduction not Broadcast.

Actually I think we will hit errors with Reductions mapping to Iteration domains. Instead we should just not include this Reduction dimension when K=1.

I prefer keeping the reduction axis uniform across all cases. Can you elaborate more on what issues we may run into?

There is a related issue here that ops::newOutputIterDomain is not respecting force_iter_type when all the inputs are Broadcast.

That's right, will fix that. force_iter_type should be respected regardless of if it is a broadcast axis.

jacobhinkle commented 1 week ago

I prefer keeping the reduction axis uniform across all cases. Can you elaborate more on what issues we may run into?

When you do that in the K=1 case, then you will have a mapping from reduction domain to broadcast domains causing an IdModel error in my brief testing. Possibly you could do it if you don't map the K domain but that would lead to other problems possibly.

Priya2698 commented 1 week ago

I prefer keeping the reduction axis uniform across all cases. Can you elaborate more on what issues we may run into?

When you do that in the K=1 case, then you will have a mapping from reduction domain to broadcast domains causing an IdModel error in my brief testing. Possibly you could do it if you don't map the K domain but that would lead to other problems possibly.

Yes, not mapping them will probably lead to other errors downstream. But iteration domains are mapped to reduction iterdomains in the consumer and ideally, the same should be allowed for broadcast as well. I will look into the IDModel errors and if we need to revise our constraints.