FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.48k stars 604 forks source link

Issue with CRF loss function #1087

Open opus111 opened 4 years ago

opus111 commented 4 years ago

Here is a file that reproduces the problem. This code is copied from the TextAnalysis package and slightly altered for Flux 10. The version in GitHub works with Flux 9

`#= This code is copied from CRF of TextAnalysis

The current version in GitHub works with Flux 0.9

https://github.com/JuliaText/TextAnalysis.jl/tree/master/src/CRF =#

using Flux

log_sum_exp(z) = log_sum_exp(z, maximum(z, dims = 1)) log_sum_exp(z, m) = log.(sum(exp.(z .- m), dims = 1)) .+ m

mutable struct CRF{S} W::S # Transition Scores n::Int # Num Labels end

function CRF(n::Integer) W = rand(Float32, n + 2, n + 2) W[:, n + 1] .= -10000 W[n + 2, :] .= -10000 return CRF(W, n) end

Flux.@functor CRF (W,)

preds_first(c::CRF, y) = c.W[c.n + 1, Flux.onecold(y, 1:length(y))] preds_last(c::CRF, y) = c.W[Flux.onecold(y, 1:length(y)), c.n + 2] preds_single(c::CRF, y, y_prev) = c.W[Flux.onecold(y_prev, 1:length(y_prev)), Flux.onecold(y, 1:length(y))]

function forward_score(c::CRF, x, init_α) forward_var = log_sum_exp((c.W .+ transpose(x[1])) .+ init_α) for i in 2:length(x) forward_var = log_sum_exp((c.W .+ transpose(x[i])) .+ transpose(forward_var)) end fs = log_sum_exp(c.W[:, c.n + 2] + transpose(forward_var)) return fs[1] end

function score_sequence(c::CRF, x, label_seq) score = preds_first(c, label_seq[1]) + Flux.onecold(label_seq[1], x[1]) for i in 2:length(label_seq) score += preds_single(c, label_seq[i], label_seq[i-1]) + Flux.onecold(label_seq[i], x[i]) end return score + preds_last(c, label_seq[end]) end

crf_loss(c::CRF, x, label_seq, init_α) = forward_score(c, x, init_α) - score_sequence(c, x, label_seq)

label_count = 10 seq_length = 5 crf = CRF(label_count-2) init_α = fill(-10000.0,label_count) init_α[label_count-1] = 0.0 label_seq = [Flux.onehot(i,1:label_count) for i in 1:seq_length] x = [rand(labelcount) for in 1:seq_length] print("crf_loss=$(crf_loss(crf,x,label_seq,init_α))") print("gradient(crf_loss)=$(gradient(() -> crf_loss(crf,x,label_seq,init_α)))") `

DhairyaLGandhi commented 4 years ago

Oops.

Could you please add in the stacktrace you see as well? Itd make it easier to spot where the issue is

opus111 commented 4 years ago

`crf_loss=11.085941941312553ERROR: LoadError: DimensionMismatch("cannot broadcast array to have fewer dimensions") Stacktrace: [1] check_broadcast_shape(::Tuple{}, ::Tuple{Base.OneTo{Int64}}) at ./broadcast.jl:506 [2] check_broadcast_shape(::Tuple{Base.OneTo{Int64}}, ::Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}) at ./broadcast.jl:509 [3] check_broadcast_axes at ./broadcast.jl:511 [inlined] [4] check_broadcast_axes at ./broadcast.jl:515 [inlined] [5] instantiate at ./broadcast.jl:259 [inlined] [6] materialize! at ./broadcast.jl:822 [inlined] [7] (::Zygote.var"#1023#1025"{Array{Float32,2},Tuple{Colon,Int64}})(::Array{Float64,2}) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/lib/array.jl:47 [8] (::Zygote.var"#2707#back#1019"{Zygote.var"#1023#1025"{Array{Float32,2},Tuple{Colon,Int64}}})(::Array{Float64,2}) at /Users/peter.wolf/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [9] forward_score at /Users/peter.wolf/dev/inquiries-proto/v4/src/zygote_bug.jl:29 [inlined] [10] (::typeof(∂(forward_score)))(::Float64) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/compiler/interface2.jl:0 [11] crf_loss at /Users/peter.wolf/dev/inquiries-proto/v4/src/zygote_bug.jl:42 [inlined] [12] (::typeof(∂(crf_loss)))(::Float64) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/compiler/interface2.jl:0 [13] #65 at /Users/peter.wolf/dev/inquiries-proto/v4/src/zygote_bug.jl:53 [inlined] [14] (::typeof(∂(#65)))(::Float64) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/compiler/interface2.jl:0 [15] (::Zygote.var"#38#39"{typeof(∂(#65))})(::Float64) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/compiler/interface.jl:36 [16] gradient(::Function) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/compiler/interface.jl:45 [17] top-level scope at /Users/peter.wolf/dev/inquiries-proto/v4/src/zygote_bug.jl:53 [18] include at ./boot.jl:328 [inlined] [19] include_relative(::Module, ::String) at ./loading.jl:1105 [20] include(::Module, ::String) at ./Base.jl:31 [21] include(::String) at ./client.jl:424 [22] top-level scope at none:0 in expression starting at /Users/peter.wolf/dev/inquiries-proto/v4/src/zygote_bug.jl:53

`

mcabbott commented 4 years ago

So /src/lib/array.jl:47 is ∇getindex, and adding some @show statements to that, here's what it gets before crashing:

typeof(x) = Array{Float32,2}
size(x) = (10, 10)
typeof(dy) = Array{Float64,2}
size(dy) = (10, 1)
inds = (Colon(), 10)
size(dxv) = (10,)

You can make the error go away by defining Zygote._droplike(dy::AbstractMatrix, dxv::AbstractVector) = vec(dy), this function is from https://github.com/FluxML/Zygote.jl/pull/499. But I'm not certain that's a good idea in general, and it might be worth understanding what causes this.

(Notice also, aside, that the gradient has a different element type, which is a common performance bug, ref #1031.)

opus111 commented 4 years ago

Thank you for responding so quickly :-D

Yes, adding _droplike does seem to make the issue go away, at least in my short test. Working on a bigger test now.

Please let me know if I can be helpful in any way

opus111 commented 4 years ago

@mcabbott sorry, but I don't think the fix works. It does run without complaint (and quickly). However, in my application I do not get models with the same performance. In fact, they are terrible, so I don't think Zygote is producing the correct answer. Is there a good way to compare the results of Flux 9 and Flux 10?

opus111 commented 4 years ago

I am going to create some known trivial input for the CRF training, and compare the output for Flux9 and 10. With the same hyper parameters the result should be pretty similar. Will report back

mcabbott commented 4 years ago

No smarter ideas, that sounds like the right course.

opus111 commented 4 years ago

Good morning Michael. I am pleased to report that when trained with identical starting conditions, the CRF exactly matches the results to Flux 9/Tracker. Unfortunately, I am building a CRF/LSTM model, and when placed on top of another layer it is not training properly. It runs, but does not produce good models. Since I am new to Flux 10, I assume it is my bug and Tomas Pevny has offered to look at my code. However, is it possible that your suggested fix could affect lower layers in a DNN?

opus111 commented 4 years ago

Hello Michael, happy April Fool. Unfortunately, not a fool here. I now have 2 versions of CRF on top of other layers that behave the same way-- the CRF works on its own, but the lower layers do not train properly. As in the first example, the weights change, but the loss does not decrease. The second example is the CRF test from TextAnalysis.jl ported to Flux 10. Here is the code. Let me know if you want the full branch of TextAnalysis port to Flux 10

`
LSTM_STATE_SIZE = 5 d_out = Dense(LSTM_STATE_SIZE, num_labels + 2) lstm = LSTM(num_features, LSTM_STATE_SIZE) m(x) = d_out.(lstm.(x))

    c = CRF(num_labels)
    init_α = fill(-10000, (stop(c), 1))
    init_α[start(c)] = 0

    loss(xs, ys) = crf_loss(c, m(xs), ys, init_α)

    opt = Descent(0.01)
    data = zip(X, Y)
    ps = params(lstm,d_out,c)

    function train()
        for d in data
            reset!(lstm)
            grads = gradient(() -> loss(d[1], d[2]), ps)
            Flux.Optimise.update!(opt, ps, grads)
        end
    end

    function find_loss(d)
        reset!(lstm)
        loss(d[1], d[2])
    end

    l1 = sum([find_loss(d) for d in data])
    dense_param_1 = deepcopy(d_out.W)
    lstm_param_1 = deepcopy(lstm.cell.Wh)
    crf_param_1 = deepcopy(c.W)

    for i in 1:10
        train()
    end

    dense_param_2 = deepcopy(d_out.W)
    lstm_param_2 = deepcopy(lstm.cell.Wh)
    crf_param_2 = deepcopy(c.W)
    l2 = sum([find_loss(d) for d in data])

    @test l1 > l2
    @test dense_param_1 != dense_param_2
    @test lstm_param_1 != lstm_param_2
    @test crf_param_1 != crf_param_2`
opus111 commented 4 years ago

Reopening...

darsnack commented 3 years ago

@opus111 can you check if this issue has been resolved on master?

darsnack commented 3 years ago

Bump @opus111 I think you said this can be closed?

opus111 commented 3 years ago

Yes. I think RNNs work now. Haven't checked running a CRF on top of an RNN yet

On Mon, Dec 7, 2020 at 5:24 PM Kyle Daruwalla notifications@github.com wrote:

Bump @opus111 https://github.com/opus111 I think you said this can be closed?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/FluxML/Flux.jl/issues/1087#issuecomment-740217026, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAANIBA4ZRGVTWXPH5R666LSTVIYPANCNFSM4LNNRZJA .