JuliaAI / MLJLinearModels.jl

Generalized Linear Regressions Models (penalized regressions, robust regressions, ...)
MIT License
80 stars 13 forks source link

Problem with logistic regression #104

Closed goerch closed 3 years ago

goerch commented 3 years ago

I wanted to check test set "GH> LogitL2" with AD. Unfortunately ForwardDiff seems to have problems with the definition of logsigmoid which I therefore simplified to

#= function logsigmoid(x::T) where T <: AbstractFloat
    τ = SIGMOID_THRESH(T)
    x > τ  && return zero(T)
    x < -τ && return x
    return -log1p(exp(-x))
end
logsigmoid(x) = logsigmoid(float(x)) =#
function logsigmoid(x)
    return -log1p(exp(-x))
end

Afterwards the modified test case

@testset "GH> LogitL2" begin
    rng = StableRNG(551551)
    # fgh! without fit_intercept
    s = R.scratch(X; i=false)
    λ = 0.5
    lr = LogisticRegression(λ; fit_intercept=false)
    fgh! = R.fgh!(lr, X, y, s)
    θ = randn(rng, p)
    J = objective(lr, X, y)
    f = 0.0
    g = similar(θ)
    H = zeros(p, p)
    f = fgh!(f, g, H, θ)
    @test f == J(θ)
    @test g ≈               -X' * (y .* R.σ.(-y .* (X * θ))) .+ λ .* θ
    @test H ≈                X' * (Diagonal(R.σ.(y .* (X * θ))) * X) + λ * I
    @test g ≈               ForwardDiff.gradient(θ -> J(θ), θ)
    @test H ≈               ForwardDiff.hessian(θ -> J(θ), θ)

    # Hv! without  fit_intercept
    s = R.scratch(X; i=false)
    Hv! = R.Hv!(lr, X, y, s)
    v   = randn(rng, p)
    Hv  = similar(v)
    Hv!(Hv, θ, v)
    @test Hv ≈               H * v

    # fgh! with fit_intercept
    s = R.scratch(X; i=true)
    λ = 0.5
    lr1 = LogisticRegression(λ; penalize_intercept=true)
    fgh! = R.fgh!(lr1, X, y, s)
    θ1 = randn(rng, p+1)
    J  = objective(lr1, X, y)
    f1 = 0.0
    g1 = similar(θ1)
    H1 = zeros(p+1, p+1)
    f1 = fgh!(f1, g1, H1, θ1)
    @test f1 ≈ J(θ1)
    @test g1 ≈              -X1' * (y .* R.σ.(-y .* (X1 * θ1))) .+ λ .* θ1
    @test H1 ≈               X1' * (Diagonal(R.σ.(y .* (X1 * θ1))) * X1) + λ * I
    @test g1 ≈               ForwardDiff.gradient(θ -> J(θ), θ1)
    @test H1 ≈               ForwardDiff.hessian(θ -> J(θ), θ1)

    # Hv! with fit_intercept
    Hv! = R.Hv!(lr1, X, y, s)
    v   = randn(rng, p+1)
    Hv  = similar(v)
    Hv!(Hv, θ1, v)
    @test Hv ≈               H1 * v

    # fgh! with fit intercept and no penalty on intercept
    lr1 = LogisticRegression(λ)
    fgh! = R.fgh!(lr1, X, y, s)
    θ1 = randn(rng, p+1)
    J  = objective(lr1, X, y)
    f1 = 0.0
    g1 = similar(θ1)
    H1 = zeros(p+1, p+1)
    f1 = fgh!(f1, g1, H1, θ1)
    @test f1 ≈ J(θ1)
    @test g1 ≈              -X1' * (y .* R.σ.(-y .* (X1 * θ1))) .+ λ .* θ1 .* maskint
    @test H1 ≈               X1' * (Diagonal(R.σ.(y .* (X1 * θ1))) * X1) + λ * Diagonal(maskint)
    @test g1 ≈               ForwardDiff.gradient(θ -> J(θ), θ1)
    @test H1 ≈               ForwardDiff.hessian(θ -> J(θ), θ1)
    Hv! = R.Hv!(lr1, X, y, s)
    v   = randn(rng, p+1)
    Hv  = similar(v)
    Hv!(Hv, θ1, v)
    @test Hv ≈               H1 * v
end

resulted in

GH> LogitL2: Test Failed at C:\Users\Win10\source\repos\julia\MLJLinearModels.jl-0.5.5\test\glr\grad-hess-prox.jl:119
  Expression: H1 ≈ ForwardDiff.hessian((θ->begin
                J(θ)
            end), θ1)
   Evaluated: [24.397827452783897 -2.8953009691182245 … 6.051655547790141 -10.127380336191182; -2.8953009691182245 26.910348017983537 … -2.3587690657784335 -1.157501101551889; … ; 6.051655547790139 -2.3587690657784357 … 31.49926106433512 -8.152842107478648; -10.127380336191182 -1.157501101551889 … -8.152842107478648 31.579672388719718] ≈ [5.315743473424368 -0.9877692597110505 … 1.857040773115254 -4.4508880076679285; -0.9877692597110451 4.347729494124755 … -3.18922740206638 1.0524042106260683; … ; 1.8570407731152607 -3.1892274020663844 … 9.052300688144538 -4.606378400038858; -4.45088800766793 1.0524042106260734 … -4.606378400038857 9.432163635398592]
tlienart commented 3 years ago

I'm not sure what your question is?

The evaluations above your ForwardDiff call are the analytical form (*) so if ForwardDiff doesn't return that, that's an issue with ForwardDiff.

(*) unless I made a mistake there in which case it would be very sneaky as the results are identical to those of sklearn as far as I know

Could I suggest you

goerch commented 3 years ago

I'd expect to see the derivative of sigmoid in the analytical form of the hessian.

tlienart commented 3 years ago

I'm sorry but so far I'm not clear on what the issue is and what you'd like me to do.

If you think there's an error with the way the Gradient/Hessian is computed and that FD should be the ground truth, then please provide the norms as indicated above.

goerch commented 3 years ago

Adding the following lines

    println("norm(H - Analytical) ",
        norm(H - (X' * (Diagonal(R.σ.(y .* (X * θ))) * X) + λ * I)))
    println("norm(H - FD.hessian) ",
        norm(H - ForwardDiff.hessian(θ -> J(θ), θ)))
    println("norm(Analytical - FD.hessian) ",
        norm(X' * (Diagonal(R.σ.(y .* (X * θ))) * X) + λ * I -
             ForwardDiff.hessian(θ -> J(θ), θ)))

results in

norm(H - Analytical) 0.0
norm(H - FD.hessian) 78.76258276024724
norm(Analytical - FD.hessian) 78.76258276024724
goerch commented 3 years ago

Adding another line

    println("norm(Alternative - FD.hessian) ",
        norm(X' * (Diagonal(y .* R.σ.(y .* (X * θ)) .* (1 .- R.σ.(y .* (X * θ))) .* y) * X) + λ * I -
          ForwardDiff.hessian(θ -> J(θ), θ)))

it seems AD believes the hessian to be more like this:

norm(Alternative - FD.hessian) 5.273714211591725e-14
tlienart commented 3 years ago

thanks, I'll check this in details and get back to you

tlienart commented 3 years ago

ok after re-doing this by hand carefully, I see the mistake in

https://github.com/JuliaAI/MLJLinearModels.jl/blob/430a6293eb669e307873d37ae47e6555a015615a/src/glr/d_logistic.jl#L6

since the gradient of σ(x) is σ(x)(σ(x)-1) and so indeed the formula as you wrote it fixes this and matches what you get with AD.

Thanks for reporting this, I'll fix this asap

tlienart commented 3 years ago

done, thanks again for reporting this mistake