Bihaqo / t3f

Tensor Train decomposition on TensorFlow
https://t3f.readthedocs.io/en/latest/index.html
MIT License
219 stars 55 forks source link

Automatic operation fusion #169

Open Bihaqo opened 5 years ago

Bihaqo commented 5 years ago

For example, instead of writing a separate function project_matmul we can implement something like project_matmul = fuse(lambda a, b: project(a, b), lambda b, c: matmul(b, c))

To do that lets express each operation as a sequence of recurrent steps of the form

def recurrent(a, b):
  res = 1.0
  for a_core, b_core in zip(a.tt_cores, b.tt_cores):
    res = einsum('rule', a_core, b_core, res)

and of independent steps

def independent(a, b):
  res_cores = []
  for a_core, b_core in zip(a.tt_cores, b.tt_cores):
    res.append(einsum('rule', a_core, b_core))

Then, we can automatically concat einsum-s of individual operations into a single big einsum (per core), and by using opt_einsum guarantee that the restulting einsum will be fast.

From the top of my head we can support any combinations of 1) matmul(A, B) 2) add a + b 3) elementwise product a * b

Additionally, as the last operation of the combination, we can support 1) dot product a^t b 2) gram matrix G_ij = ai^t bj 3) projection on the tangent space P_x y 4) trace

By combining this ops we can for example automatically get fast versions of 1) x^t A y (already implemented as a separate fast operation) 2) ||A B|| 3) A B x 4) P_x A y (already implemented) 5) ||a b|| 6) Px A B y 7) ||A + B|| 8) P_x (a b) 9) x^t A B y 10) ||(Ax) * (By)|| 11) trace(A^T B A)

Does anyone need this?

Bihaqo commented 5 years ago

A potential way to design the API:

with t3f.Fuse() as f:
  Ax = t3f.matmul(A, x)
  xAx = t3f.flat_inner(x, Ax)
  fast_xAx = f.optimize(xAx)