JuliaDiff / ChainRules.jl

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Other
434 stars 89 forks source link

svd_rev issue with orthogonal matricies #664

Open AlexRobson opened 2 years ago

AlexRobson commented 2 years ago
using Zygote
using LinearAlgebra

r = rand(8,8); Σ = r' * r
foo(X) = tr(svd(X).U)
_orthogonal(X) = svd(X).U * svd(X).V'

Zygote.gradient(foo, Σ) # Works
Zygote.gradient(foo, 1.0 * Matrix(I(5))) # Nan
Zygote.gradient(foo, _orthogonal(rand(5,7))) # NaN/Inf

This reads as an edge case, but the reason this came up because of interest in introducing an orthogonality constraint in a pipeline. As the matrix is constrainted to be orthogonal, a natural initialisation would be also orthogonal, however it leads to the above issue (this essentially does the same thing):

using ParameterHandling: value, orthogonal

bar(X) = tr(value(orthogonal(X)))
Zygote.gradient(bar, Σ) # Works

# A natural initialisation for an orthogonal constrained matrix is orthogonal
Zygote.gradient(bar, _orthogonal(Σ)) # Fails

IIUC, I believe that the for orthogonal matrix as the singular values are 1 the F in the svd rule will explode.

A workaround was to add a small noise term suggesting that somewhere perhaps this should cancel? - the UᵀŪ - ŪᵀU) and (VᵀV̄ - V̄ᵀV terms) but idk

Zygote.gradient(bar, _orthogonal(Σ) + 1e-10 * Diagonal(rand(size(Σ, 1)))) # Works
# ([3.032964106070741e-29 -6.862457650747e-16 … -8.654203384732967e-16 2.8416261314487268e-15; 6.862728701290122e-16 -2.4424906535614213e-15 … 1.1744348982906105e-15 -1.2198799099766983e-15; … ; 8.654203384732967e-16 -1.4343655928804322e-15 … -1.7763568394002505e-15 2.310651670001107e-15; -2.841639683975883e-15 1.0009591625611408e-15 … -1.7551064768195346e-15 -1.5543122344752192e-15],)

There doesn't seem to be anything wrong with the rule themselves, and the workaround is reasonable once it's understood what's occuring, however as this started impacting another internally I thought i'd raise this as an issue. If the diagnosis is correct, are there any approaches to make this nice in the scenario described?

oxinabox commented 2 years ago

@sethaxen might have thoughts

sethaxen commented 2 years ago

I'll look at this more closely later, but I think here with SVD we have the same problem we have with the eigendecomposition. Namely, when singular values are exactly non-unique, svd is actually non-differentiable, as is any function that uses svd internally and is sensitive to swapping of the singular vectors. If the function is not sensitive to this swapping, then adding a noise term might resolve the issue, as it prevents NaN propagation. But the introduction of very large and small values might introduce too much floating point error and is at best a hack.

The right solution to this when possible is to use a function with more structure and define a rule for that function. e.g. f(A::Matrix) for symmetric A and matrix function f (e.g. exp or sign) uses the eigendecomposition but is not sensitive to eigenvalue ordering or eigenvector phase, so we write generic rules that don't have this numerical stability issue when eigenvalues are non-unique or close to non-unique: https://github.com/JuliaDiff/ChainRules.jl/blob/ffbaa5fecca8da39f20aeb66cc6e4edf3e0c3f11/src/rulesets/LinearAlgebra/symmetric.jl#L432-L444

Is there a chance your use case could be handled with such a rule?

AlexRobson commented 2 years ago

Thanks for the quick response - that was helpful

iiuc, the code you've linked creates general rrules for a set of matrix operations, ones that would under the hood call eigen to be more stable in AD for degenerate matricies. This does read as the same issue, as we have similar matricies that we are calling 'svd' on. The issue with crossing over singular vectors actually led to the introduction of the orthogonality constraint in our use case - I may not be able to advance this much but I'll consider how that applies here