For example, instead of writing a separate function project_matmul we can implement something like project_matmul = fuse(lambda a, b: project(a, b), lambda b, c: matmul(b, c))
To do that lets express each operation as a sequence of recurrent steps of the form
def recurrent(a, b):
res = 1.0
for a_core, b_core in zip(a.tt_cores, b.tt_cores):
res = einsum('rule', a_core, b_core, res)
and of independent steps
def independent(a, b):
res_cores = []
for a_core, b_core in zip(a.tt_cores, b.tt_cores):
res.append(einsum('rule', a_core, b_core))
Then, we can automatically concat einsum-s of individual operations into a single big einsum (per core), and by using opt_einsum guarantee that the restulting einsum will be fast.
From the top of my head we can support any combinations of
1) matmul(A, B)
2) add a + b
3) elementwise product a * b
Additionally, as the last operation of the combination, we can support
1) dot product a^t b
2) gram matrix G_ij = ai^t bj
3) projection on the tangent space P_x y
4) trace
By combining this ops we can for example automatically get fast versions of
1) x^t A y (already implemented as a separate fast operation)
2) ||A B||
3) A B x
4) P_x A y (already implemented)
5) ||a b||
6) Px A B y
7) ||A + B||
8) P_x (a b)
9) x^t A B y
10) ||(Ax) * (By)||
11) trace(A^T B A)
For example, instead of writing a separate function
project_matmul
we can implement something likeproject_matmul = fuse(lambda a, b: project(a, b), lambda b, c: matmul(b, c))
To do that lets express each operation as a sequence of recurrent steps of the form
and of independent steps
Then, we can automatically concat einsum-s of individual operations into a single big einsum (per core), and by using opt_einsum guarantee that the restulting einsum will be fast.
From the top of my head we can support any combinations of 1) matmul(A, B) 2) add a + b 3) elementwise product a * b
Additionally, as the last operation of the combination, we can support 1) dot product a^t b 2) gram matrix G_ij = ai^t bj 3) projection on the tangent space P_x y 4) trace
By combining this ops we can for example automatically get fast versions of 1) x^t A y (already implemented as a separate fast operation) 2) ||A B|| 3) A B x 4) P_x A y (already implemented) 5) ||a b|| 6) Px A B y 7) ||A + B|| 8) P_x (a b) 9) x^t A B y 10) ||(Ax) * (By)|| 11) trace(A^T B A)
Does anyone need this?