mlc-ai / relax

Apache License 2.0
137 stars 69 forks source link

[Op] Enhanced Einsum #283

Closed LeshengJin closed 1 month ago

LeshengJin commented 10 months ago

Typically, Einsum only performs element-wise multiplication and summation across indices. This pr expands Einsum's capabilities:

  1. Customize element-wise computation and index combination.
  2. Einsum can now produce Tuple(tensor) outputs.

This enhanced Einsum could represents more complex computations within a few lines of code. For example,

def fcombine(tensor1, tensor2):
    mi = tensor1[0]
    di = tensor1[1]
    mj = tensor2[0]
    dj = tensor2[1]
    r0 = tvm.tir.max(mi, mj)
    r1 = di * tvm.tir.exp(mi - r0) + dj * tvm.tir.exp(mj - r0)
    return r0, r1

def fidentity(dtype1, dtype2):
    return tvm.te.min_value(dtype1), tvm.tir.const(0, dtype2)

mv, dv = einsum(
    "ij -> i, i",
    x,
    fcompute=lambda x_ij: (x_ij, 1.0),
    fcombine=fcombine,
    fidentity=fidentity,
)

softmax_x = einsum(
    "ij, i, i -> ij",
    (x, mv, dv),
    fcompute=lambda x_ij, mv_i, dv_i: (tvm.tir.exp(x_ij) - mv_i) / dv_i,
)
junrushao commented 10 months ago

You may take a look at this: https://einops.rocks/