csarofeen / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
http://pytorch.org
Other
26 stars 7 forks source link

Add `FusionDefinition` methods to print the `fusion_ir`, `scheduled_fusion_ir`, and `cuda_code` #2573

Closed kevinstephano closed 1 year ago

kevinstephano commented 1 year ago

Resolves issue #2387 for @mruberry .

The following APIs are added to the Python FusionDefinition. They all return strings instead of printing to stdout so the python user can do what they want with the string output.

fd.fusion_ir()
fd.last_scheduled_fusion_ir(tensor_transforms=False)
fd.scheduled_fusion_ir_for(inputs, tensor_transforms=False)
fd.last_cuda_code(intrinsic_code=False)
fd.cuda_code_for(inputs, intrinsic_code=False)

Example test case:

import torch
from nvfuser import FusionDefinition

inputs = [
    torch.randn(4, 4, device='cuda'),
]

with FusionDefinition() as fd:
    t0 = fd.from_pytorch(inputs[0])
    t1 = fd.ops.relu(t0)
    fd.add_output(t1)

out = fd.execute(inputs)

Fusion IR Gives a string for the unscheduled Fusion IR for a given definition. Example Usage:

print(fd.fusion_ir())

Output:

%kernel {
T1_g[ iS2{i0}, iS3{i1} ]
   = relu(T0_g[ iS0{i0}, iS1{i1} ]);
}

Scheduled Fusion IR Gives a string for the scheduled Fusion IR with or without transforms for either the last scheduled fusion or for a specific set of inputs. Note, segmented fusions also print the segmented groupings to help the user understand what the scheduled Fusion IR represents.

Example Usages:

print(fd.last_scheduled_fusion_ir(tensor_transforms=True))
print(fd.scheduled_fusion_ir_for(inputs, tensor_transforms=True))

Output:

%kernel {
T2_l[ iblockIdx.x27{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) ), 1) )}, iUS28{1}, iS26{1}, ithreadIdx.x24{128} ] ca_pos( 2 )
   = T0_g[ iS34{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) ), 1) )}, iS35{1}, iS33{1}, iS31{128} ];
T3_l[ iblockIdx.x20{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) ), 1) )}, iUS21{1}, iS19{1}, ithreadIdx.x17{128} ] ca_pos( 4 ) produce_pos( 2 )
   = relu(T2_l[ iblockIdx.x27{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) ), 1) )}, iUS28{1}, iS26{1}, ithreadIdx.x24{128} ] ca_pos( 2 ));
T1_g[ iblockIdx.x13{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) ), 1) )}, iUS14{1}, iS12{1}, ithreadIdx.x10{128} ] ca_pos( 2 ) produce_pos( 4 )
   = T3_l[ iblockIdx.x20{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) ), 1) )}, iUS21{1}, iS19{1}, ithreadIdx.x17{128} ] ca_pos( 4 ) produce_pos( 2 );

TransformPrinter : 
T0_g[ iS34{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) ), 1) )}, iS35{1}, iS33{1}, iS31{128} ]
 root domain : (iS36{T0.size[0]},iS37{T0.size[1]})
 contiguity: t t
  Merge: iS36{T0.size[0]} and iS37{T0.size[1]} -> iS29{( T0.size[0] * T0.size[1] )}
  Split: iS29{( T0.size[0] * T0.size[1] )} by factor 128 -> iS30{( ceilDiv(( T0.size[0] * T0.size[1] ), 128) )}, iS31{128}, start offset: 0, stop offset: 0
  Split: iS30{( ceilDiv(( T0.size[0] * T0.size[1] ), 128) )} by factor 1 -> iS32{( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) )}, iS33{1}, start offset: 0, stop offset: 0
  Split: iS32{( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) )} by factor 1 -> iS34{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) ), 1) )}, iS35{1}, start offset: 0, stop offset: 0
T2_l[ iblockIdx.x27{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) ), 1) )}, iUS28{1}, iS26{1}, ithreadIdx.x24{128} ] ca_pos( 2 )
 root domain : (iS38{T0.size[0]},iS39{T0.size[1]})
 contiguity: t t
  Merge: iS38{T0.size[0]} and iS39{T0.size[1]} -> iS22{( T0.size[0] * T0.size[1] )}
  Split: iS22{( T0.size[0] * T0.size[1] )} by factor 128 -> iS23{( ceilDiv(( T0.size[0] * T0.size[1] ), 128) )}, ithreadIdx.x24{128}, start offset: 0, stop offset: 0
  Split: iS23{( ceilDiv(( T0.size[0] * T0.size[1] ), 128) )} by factor 1 -> iS25{( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) )}, iS26{1}, start offset: 0, stop offset: 0
  Split: iS25{( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) )} by factor 1 -> iblockIdx.x27{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) ), 1) )}, iUS28{1}, start offset: 0, stop offset: 0
T3_l[ iblockIdx.x20{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) ), 1) )}, iUS21{1}, iS19{1}, ithreadIdx.x17{128} ] ca_pos( 4 ) produce_pos( 2 )
 root domain : (iS40{T0.size[0]},iS41{T0.size[1]})
 contiguity: t t
  Merge: iS40{T0.size[0]} and iS41{T0.size[1]} -> iS15{( T0.size[0] * T0.size[1] )}
  Split: iS15{( T0.size[0] * T0.size[1] )} by factor 128 -> iS16{( ceilDiv(( T0.size[0] * T0.size[1] ), 128) )}, ithreadIdx.x17{128}, start offset: 0, stop offset: 0
  Split: iS16{( ceilDiv(( T0.size[0] * T0.size[1] ), 128) )} by factor 1 -> iS18{( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) )}, iS19{1}, start offset: 0, stop offset: 0
  Split: iS18{( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) )} by factor 1 -> iblockIdx.x20{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) ), 1) )}, iUS21{1}, start offset: 0, stop offset: 0
T1_g[ iblockIdx.x13{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) ), 1) )}, iUS14{1}, iS12{1}, ithreadIdx.x10{128} ] ca_pos( 2 ) produce_pos( 4 )
 root domain : (iS42{T0.size[0]},iS43{T0.size[1]})
 contiguity: t t
  Merge: iS42{T0.size[0]} and iS43{T0.size[1]} -> iS8{( T0.size[0] * T0.size[1] )}
  Split: iS8{( T0.size[0] * T0.size[1] )} by factor 128 -> iS9{( ceilDiv(( T0.size[0] * T0.size[1] ), 128) )}, ithreadIdx.x10{128}, start offset: 0, stop offset: 0
  Split: iS9{( ceilDiv(( T0.size[0] * T0.size[1] ), 128) )} by factor 1 -> iS11{( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) )}, iS12{1}, start offset: 0, stop offset: 0
  Split: iS11{( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) )} by factor 1 -> iblockIdx.x13{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] * T0.size[1] ), 128) ), 1) ), 1) )}, iUS14{1}, start offset: 0, stop offset: 0
}

Cuda Code Gives a string for the Cuda Code for a fusion with or without the full intrinsic code required to execute the kernel for either the last scheduled fusion or for a specific set of inputs.

Example Usages:

print(fd.last_cuda_code(intrinsic_code=False))
print(fd.cuda_code_for(inputs, intrinsic_code=False))

Output:

__global__ void kernel1(Tensor<float, 2> T0, Tensor<float, 2> T1) {
  int i50;
  i50 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)blockIdx.x));
  if ((i50 < (T0.size[0] * T0.size[1]))) {
    float T2[1];
    T2[0] = 0;
    T2[0]
       = T0[i50];
    float T3[1];
    T3[0]
       = relu(T2[0]);
    T1[i50]
       = T3[0];
  }
}
kevinstephano commented 1 year ago

Closing in favor of PR in new repo: https://github.com/NVIDIA/Fuser/pull/6