Closed campsd closed 2 years ago
There is also some code duplication in the initialization
interpreter that would be nice to avoid.
Additional code snippet that factorizes 3 x 3 matrix:
G01a = ff.parametrized.planar_rotation(0, 1, 3, initializer=ff.initializers.Normal(std=np.pi))
G12 = ff.parametrized.planar_rotation(1, 2, 3, initializer=ff.initializers.Normal(std=np.pi))
G01b = ff.parametrized.planar_rotation(0, 1, 3, initializer=ff.initializers.Normal(std=np.pi))
R = ff.tensor('R', 3, 3, prefer=ff.conditions.UpperTriangular(), initializer=ff.initializers.Normal(std=1.0))
target = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]])
tsrex = G01a @ G12 @ G01b @ R
fac = ff.factorize(tsrex, target, vec_size=16, lr=0.05)
I found out that the failing tests can be fixed by upgrading to mkdocs-jupyter==0.20.1
. Will take care of that in a separate PR.
Nice work! Merging now.
This PR adds a framework to add parametrized tensors to FunFact and closes #220 upon completion. An example implementation for parametrized Givens rotations is provided.
Example usage to compute a Givens QR factorization of a 2 x 2 matrix:
There is an outstanding issue with the generation of the rotation matrices with the JAX backend that is related to the immutability property of
DeviceArray
s we encountered before. The commented out code infunfact.parametrized.givens_rotation
uses the idea from: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates @yhtang do you see a better way to construct these parametrized tensors?The code is compatible with PyTorch.