JuliaDiff / ForwardDiff.jl

Forward Mode Automatic Differentiation for Julia
Other
892 stars 145 forks source link

Derivative is wrong for this inverse quadratic form #633

Open colinfang opened 1 year ago

colinfang commented 1 year ago

The result is correct only if I tag the matrix to be symmetric.

using LinearAlgebra
using ForwardDiff

function f_backward(x1, x2, rho)
    cov = [
        1.0 rho;
        rho 1.0
    ]
    x = [x1, x2]
    return x' * (cov \ x)
end

function f_backward_symmetric(x1, x2, rho)
    cov = Symmetric([
        1.0 rho;
        rho 1.0
    ])
    x = [x1, x2]
    return x' * (cov \ x)
end

function f_inv(x1, x2, rho)
    cov = [
        1.0 rho;
        rho 1.0
    ]
    x = [x1, x2]
    return x' * inv(cov) * x
end

function f_inv_symmetric(x1, x2, rho)
    cov = Symmetric([
        1.0 rho;
        rho 1.0
    ])
    x = [x1, x2]
    return x' * inv(cov) * x
end

function test(rho)
    @show f_backward(0.1, 0.2, rho)
    @show f_backward_symmetric(0.1, 0.2, rho)
    @show f_inv(0.1, 0.2, rho)
    @show f_inv_symmetric(0.1, 0.2, rho)

    @show ForwardDiff.derivative(x -> f_backward(0.1, 0.2, x), rho)
    @show ForwardDiff.derivative(x -> f_backward_symmetric(0.1, 0.2, x), rho)
    @show ForwardDiff.derivative(x -> f_inv(0.1, 0.2, x), rho)
    @show ForwardDiff.derivative(x -> f_inv_symmetric(0.1, 0.2, x), rho)
end

test(0.0)

f_backward(0.1, 0.2, rho) = 0.05000000000000001
f_backward_symmetric(0.1, 0.2, rho) = 0.05000000000000001
f_inv(0.1, 0.2, rho) = 0.05000000000000001
f_inv_symmetric(0.1, 0.2, rho) = 0.05000000000000001
ForwardDiff.derivative((x->begin
            f_backward(0.1, 0.2, x)
        end), rho) = 0.0
ForwardDiff.derivative((x->begin
            f_backward_symmetric(0.1, 0.2, x)
        end), rho) = -0.04000000000000001
ForwardDiff.derivative((x->begin
            f_inv(0.1, 0.2, x)
        end), rho) = -0.020000000000000004
ForwardDiff.derivative((x->begin
            f_inv_symmetric(0.1, 0.2, x)
        end), rho) = -0.04000000000000001
mcabbott commented 1 year ago

Desired answer is all derivatives -0.04?

If so, this is fixed by https://github.com/JuliaDiff/ForwardDiff.jl/pull/481 which is available on master.