april-tools / cirkit

a python framework to build, learn and reason about probabilistic circuits and tensor networks
https://cirkit-docs.readthedocs.io/en/latest/
GNU General Public License v3.0
71 stars 1 forks source link

Parameter saving optimization #182

Closed IrwinChay closed 4 months ago

IrwinChay commented 9 months ago

Starting point: PR #178

In circuit product, we are expanding the parameters by param_1 \kron param_2 at the moment

However, (for example in CP), param_1 have shape (output_dim, input_dim) Hence the parameter size after doing N products (or power of N) would becomes (output_dim^N, input_dim^N)

But (param_1 \kron param_2) vec(x) = vec(param_2 \cdot x \cdot param_1)

So we don't need to compute the huge parameter matrix when going through forward pass

An example implementation is in dense.py of PR #178, line 75-94

Similar case for tucker

P.S. we have to store the input vector x, but it only have dimension input_dim^N, which is much smaller than the parameter for product of CP that we need to store at the moment ( dim^(2N) ), not to mention tucker ( dim^(3N) )

loreloc commented 4 months ago

This will be part of the optimizer that will be developed for the current torch compiler.