FluxML / Zygote.jl

21st century AD
https://fluxml.ai/Zygote.jl/
Other
1.49k stars 211 forks source link

gradient of SVD not working for complex input #1481

Open GibbsJR opened 11 months ago

GibbsJR commented 11 months ago

Hello,

I am looking to create a function that uses a truncated SVD and take the gradient for a complex input (but has a real output). I find that with Zygote this works fine for real inputs, but does not work for complex inputs (e.g. if you uncomment the line in the below example), even though I believe Zygote supports gradients of functions on complex inputs.

Please could you suggest how to resolve this.

Thanks, Joe

using LinearAlgebra, Zygote

X = kron(rand(Float64, 4,4), rand(Float64, 4,4)) + kron(rand(Float64, 4,4), rand(Float64, 4,4))
#X = kron(rand(ComplexF64, 4,4), rand(ComplexF64, 4,4)) + kron(rand(ComplexF64, 4,4), rand(ComplexF64, 4,4))

function foo(X)

    F = svd(X)

    return abs(sum(F.S[1:2]))
end

G = foo'(X)
ToucheSir commented 11 months ago

Support for complex inputs is very much a function-by-function thing in Zygote. Some may work without dedicated AD rules, but in the case of svd that likely isn't true. Zygote uses the SVD rule from https://github.com/JuliaDiff/ChainRules.jl and I think that only handles real-numbered inputs, so I'd recommend opening a feature request there.