hidet-org / hidet

An open-source efficient deep learning framework/compiler, written in python.
https://hidet.org
Apache License 2.0
634 stars 50 forks source link

[IR] Support inplace operators #416

Closed yaoyaoding closed 6 months ago

yaoyaoding commented 6 months ago

There are some operators like "inplace_relu" or "key_cache = write_cache(key_cache, value)" that requires the output to share the memory with input. This PR adds a new field "share_map" to Task and FlowGraph to indicate whether one output shares the memory with some input. Update CompiledGraph, CompiledTask and FlowGraph forward function to reuse memory when such attribute is set.

import hidet
from hidet.ir.task import Task
from hidet.ir import primitives
from hidet.ir.dtypes import float32
from hidet.runtime import CompiledGraph
from hidet.graph.ops.utils import input_like, Operator, Tensor, compute, TensorInput

class InplaceReLUTask(Task):
    def __init__(self, x: TensorInput):
        y = compute(name='y', shape=x.shape, fcompute=lambda *indices: primitives.max(x[indices], x.type.dtype.zero))
        super().__init__(name='inplace_relu', inputs=[x], outputs=[y], share_map={0: 0})  # share y with x

class InplaceReLUOp(Operator):
    def __init__(self, x: Tensor):
        super().__init__(inputs=[x], attributes={}, task=InplaceReLUTask(input_like(x, 'x')))

def inplace_relu(x: Tensor) -> Tensor:
    return InplaceReLUOp(x).outputs[0]

def test_inplace_relu():
    x = hidet.symbol([1, 10], dtype=float32, device='cuda')
    y = inplace_relu(x)
    graph = hidet.trace_from(y)
    xx = hidet.randn_like(x)

    y1 = graph(xx)
    assert y1.storage is xx.storage

    compiled_graph: CompiledGraph = graph.build()

    compiled_graph.dispatch_table.clear()
    y2 = compiled_graph(xx)  # run in slow path
    assert y2.storage is xx.storage

    y3 = compiled_graph(xx)  # run in fast path
    assert y3.storage is xx.storage

    # y1, y2, y3 should be equal
    yy = hidet.ops.relu(xx)
    hidet.utils.assert_close(y1, yy)