hidet-org / hidet

An open-source efficient deep learning framework/compiler, written in python.
Apache License 2.0
634 stars 50 forks source link

[Operator] Add a opaque operator base class #414

Closed yaoyaoding closed 6 months ago

yaoyaoding commented 6 months ago

When we do not want to gives the computation for some operator because its too tedious or can not expressed using our computation defintion DSL, we can define an opaque operator that only gives

  1. the dtype and shape inference function that infer the output dtype and shape given the inputs'
  2. the implement function that implements the operator given the input/output dtype and shape

An example to define an opaque operator to perform matrix multiplication.

from typing import List, Union

import hidet
from hidet import Tensor
from hidet.graph.ops.opaque import OpaqueOperator
from hidet.ir.dtypes import float32
from hidet.ir import IRModule


class OpaqueMatmul(OpaqueOperator):
    def __init__(self, x: Tensor, y: Tensor):
                'x': x,
                'y': y

    def symbolic_forward(self, x: Tensor, y: Tensor):
        assert x.dtype == y.dtype == float32
        assert x.device.is_cuda()
        m, k = x.shape
        k, n = y.shape
        return {
            'z': self.symbol(
                shape=[m, n],

    def implement_cuda(self, inputs: List[Tensor], outputs: List[Tensor]) -> Union[IRModule, List[IRModule]]:
        import hidet
        from hidet.lang import attrs
        from hidet.lang.types import f32
        from hidet.lang.cuda import threadIdx, blockIdx

        m_size, k_size = inputs[0].shape
        k_size, n_size = inputs[1].shape

        with hidet.script_module() as script_module:
            def matmul(x: f32[m_size, k_size], y: f32[k_size, n_size], z: f32[m_size, n_size]):
                attrs.func_kind = 'cuda_kernel'
                attrs.cuda.block_dim = (32, 32)
                attrs.cuda.grid_dim = ((n_size + 31) // 32, (m_size + 31) // 32)

                i = threadIdx.x + blockIdx.x * 32
                j = threadIdx.y + blockIdx.y * 32
                if i < n_size and j < m_size:
                    z[j, i] = 0.0
                    for k in range(k_size):
                        z[j, i] += x[j, k] * y[k, i]

        return script_module.ir_module()

def opaque_matmul(x: Tensor, y: Tensor) -> Tensor:
    return OpaqueMatmul(x, y).outputs[0]

def test_opaque_operator():
    a = hidet.randn([128, 128], dtype='float32', device='cuda')
    b = hidet.randn([128, 128], dtype='float32', device='cuda')
    c1 = opaque_matmul(a, b)
    c2 = a @ b

    print(hidet.ops.max(hidet.ops.abs(c1 - c2), dims=[0, 1]))