yhtang / FunFact

Tensor decomposition with arbitrary expressions: inner, outer, elementwise operators; nonlinear transformations; and more.
Other
56 stars 4 forks source link

Parametrized Tensors #222

Closed campsd closed 2 years ago

campsd commented 2 years ago

This PR adds a framework to add parametrized tensors to FunFact and closes #220 upon completion. An example implementation for parametrized Givens rotations is provided.

Example usage to compute a Givens QR factorization of a 2 x 2 matrix:

G = ff.parametrized.planar_rotation(0, 1, 2)
R = ff.tensor('R', 2, 2, prefer=ff.conditions.UpperTriangular())
tsrex = G @ R
target = np.array([[1.0, 2.0], [3.0, 4.0]])
fac = ff.factorize(tsrex, target, vec_size=8)
fac()

There is an outstanding issue with the generation of the rotation matrices with the JAX backend that is related to the immutability property of DeviceArrays we encountered before. The commented out code in funfact.parametrized.givens_rotation uses the idea from: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates @yhtang do you see a better way to construct these parametrized tensors?

The code is compatible with PyTorch.

campsd commented 2 years ago

There is also some code duplication in the initialization interpreter that would be nice to avoid.

campsd commented 2 years ago

Additional code snippet that factorizes 3 x 3 matrix:

G01a = ff.parametrized.planar_rotation(0, 1, 3, initializer=ff.initializers.Normal(std=np.pi))
G12 = ff.parametrized.planar_rotation(1, 2, 3, initializer=ff.initializers.Normal(std=np.pi))
G01b = ff.parametrized.planar_rotation(0, 1, 3, initializer=ff.initializers.Normal(std=np.pi))
R = ff.tensor('R', 3, 3, prefer=ff.conditions.UpperTriangular(), initializer=ff.initializers.Normal(std=1.0))
target = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
tsrex = G01a @ G12 @ G01b @ R
fac = ff.factorize(tsrex, target, vec_size=16, lr=0.05)
yhtang commented 2 years ago

I found out that the failing tests can be fixed by upgrading to mkdocs-jupyter==0.20.1. Will take care of that in a separate PR.

yhtang commented 2 years ago

Nice work! Merging now.