Closed LeshengJin closed 1 month ago
Typically, Einsum only performs element-wise multiplication and summation across indices. This pr expands Einsum's capabilities:
Tuple(tensor)
This enhanced Einsum could represents more complex computations within a few lines of code. For example,
sum(x), sum(x^2)
einsum("ij -> i, i", x, fcompute=lambda x_ij: (x_ij, x_ij * x_ij))
sum(x), prod(x)
einsum( "ij -> i, i", x, fcombine=lambda x, y: (x[0] + y[0], x[1] * y[1]), fidentity=lambda dtype1, dtype2: (tvm.tir.const(0, dtype1), tvm.tir.const(1, dtype2)), )
Online Softmax
def fcombine(tensor1, tensor2): mi = tensor1[0] di = tensor1[1] mj = tensor2[0] dj = tensor2[1] r0 = tvm.tir.max(mi, mj) r1 = di * tvm.tir.exp(mi - r0) + dj * tvm.tir.exp(mj - r0) return r0, r1 def fidentity(dtype1, dtype2): return tvm.te.min_value(dtype1), tvm.tir.const(0, dtype2) mv, dv = einsum( "ij -> i, i", x, fcompute=lambda x_ij: (x_ij, 1.0), fcombine=fcombine, fidentity=fidentity, ) softmax_x = einsum( "ij, i, i -> ij", (x, mv, dv), fcompute=lambda x_ij, mv_i, dv_i: (tvm.tir.exp(x_ij) - mv_i) / dv_i, )
You may take a look at this: https://einops.rocks/
Typically, Einsum only performs element-wise multiplication and summation across indices. This pr expands Einsum's capabilities:
Tuple(tensor)
outputs.This enhanced Einsum could represents more complex computations within a few lines of code. For example,
sum(x), sum(x^2)
sum(x), prod(x)
Online Softmax