Open opus111 opened 4 years ago
Oops.
Could you please add in the stacktrace you see as well? Itd make it easier to spot where the issue is
`crf_loss=11.085941941312553ERROR: LoadError: DimensionMismatch("cannot broadcast array to have fewer dimensions") Stacktrace: [1] check_broadcast_shape(::Tuple{}, ::Tuple{Base.OneTo{Int64}}) at ./broadcast.jl:506 [2] check_broadcast_shape(::Tuple{Base.OneTo{Int64}}, ::Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}) at ./broadcast.jl:509 [3] check_broadcast_axes at ./broadcast.jl:511 [inlined] [4] check_broadcast_axes at ./broadcast.jl:515 [inlined] [5] instantiate at ./broadcast.jl:259 [inlined] [6] materialize! at ./broadcast.jl:822 [inlined] [7] (::Zygote.var"#1023#1025"{Array{Float32,2},Tuple{Colon,Int64}})(::Array{Float64,2}) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/lib/array.jl:47 [8] (::Zygote.var"#2707#back#1019"{Zygote.var"#1023#1025"{Array{Float32,2},Tuple{Colon,Int64}}})(::Array{Float64,2}) at /Users/peter.wolf/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49 [9] forward_score at /Users/peter.wolf/dev/inquiries-proto/v4/src/zygote_bug.jl:29 [inlined] [10] (::typeof(∂(forward_score)))(::Float64) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/compiler/interface2.jl:0 [11] crf_loss at /Users/peter.wolf/dev/inquiries-proto/v4/src/zygote_bug.jl:42 [inlined] [12] (::typeof(∂(crf_loss)))(::Float64) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/compiler/interface2.jl:0 [13] #65 at /Users/peter.wolf/dev/inquiries-proto/v4/src/zygote_bug.jl:53 [inlined] [14] (::typeof(∂(#65)))(::Float64) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/compiler/interface2.jl:0 [15] (::Zygote.var"#38#39"{typeof(∂(#65))})(::Float64) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/compiler/interface.jl:36 [16] gradient(::Function) at /Users/peter.wolf/.julia/packages/Zygote/ApBXe/src/compiler/interface.jl:45 [17] top-level scope at /Users/peter.wolf/dev/inquiries-proto/v4/src/zygote_bug.jl:53 [18] include at ./boot.jl:328 [inlined] [19] include_relative(::Module, ::String) at ./loading.jl:1105 [20] include(::Module, ::String) at ./Base.jl:31 [21] include(::String) at ./client.jl:424 [22] top-level scope at none:0 in expression starting at /Users/peter.wolf/dev/inquiries-proto/v4/src/zygote_bug.jl:53
`
So /src/lib/array.jl:47 is ∇getindex
, and adding some @show
statements to that, here's what it gets before crashing:
typeof(x) = Array{Float32,2}
size(x) = (10, 10)
typeof(dy) = Array{Float64,2}
size(dy) = (10, 1)
inds = (Colon(), 10)
size(dxv) = (10,)
You can make the error go away by defining Zygote._droplike(dy::AbstractMatrix, dxv::AbstractVector) = vec(dy)
, this function is from https://github.com/FluxML/Zygote.jl/pull/499. But I'm not certain that's a good idea in general, and it might be worth understanding what causes this.
(Notice also, aside, that the gradient has a different element type, which is a common performance bug, ref #1031.)
Thank you for responding so quickly :-D
Yes, adding _droplike does seem to make the issue go away, at least in my short test. Working on a bigger test now.
Please let me know if I can be helpful in any way
@mcabbott sorry, but I don't think the fix works. It does run without complaint (and quickly). However, in my application I do not get models with the same performance. In fact, they are terrible, so I don't think Zygote is producing the correct answer. Is there a good way to compare the results of Flux 9 and Flux 10?
I am going to create some known trivial input for the CRF training, and compare the output for Flux9 and 10. With the same hyper parameters the result should be pretty similar. Will report back
No smarter ideas, that sounds like the right course.
Good morning Michael. I am pleased to report that when trained with identical starting conditions, the CRF exactly matches the results to Flux 9/Tracker. Unfortunately, I am building a CRF/LSTM model, and when placed on top of another layer it is not training properly. It runs, but does not produce good models. Since I am new to Flux 10, I assume it is my bug and Tomas Pevny has offered to look at my code. However, is it possible that your suggested fix could affect lower layers in a DNN?
Hello Michael, happy April Fool. Unfortunately, not a fool here. I now have 2 versions of CRF on top of other layers that behave the same way-- the CRF works on its own, but the lower layers do not train properly. As in the first example, the weights change, but the loss does not decrease. The second example is the CRF test from TextAnalysis.jl ported to Flux 10. Here is the code. Let me know if you want the full branch of TextAnalysis port to Flux 10
`
LSTM_STATE_SIZE = 5
d_out = Dense(LSTM_STATE_SIZE, num_labels + 2)
lstm = LSTM(num_features, LSTM_STATE_SIZE)
m(x) = d_out.(lstm.(x))
c = CRF(num_labels)
init_α = fill(-10000, (stop(c), 1))
init_α[start(c)] = 0
loss(xs, ys) = crf_loss(c, m(xs), ys, init_α)
opt = Descent(0.01)
data = zip(X, Y)
ps = params(lstm,d_out,c)
function train()
for d in data
reset!(lstm)
grads = gradient(() -> loss(d[1], d[2]), ps)
Flux.Optimise.update!(opt, ps, grads)
end
end
function find_loss(d)
reset!(lstm)
loss(d[1], d[2])
end
l1 = sum([find_loss(d) for d in data])
dense_param_1 = deepcopy(d_out.W)
lstm_param_1 = deepcopy(lstm.cell.Wh)
crf_param_1 = deepcopy(c.W)
for i in 1:10
train()
end
dense_param_2 = deepcopy(d_out.W)
lstm_param_2 = deepcopy(lstm.cell.Wh)
crf_param_2 = deepcopy(c.W)
l2 = sum([find_loss(d) for d in data])
@test l1 > l2
@test dense_param_1 != dense_param_2
@test lstm_param_1 != lstm_param_2
@test crf_param_1 != crf_param_2`
Reopening...
@opus111 can you check if this issue has been resolved on master?
Bump @opus111 I think you said this can be closed?
Yes. I think RNNs work now. Haven't checked running a CRF on top of an RNN yet
On Mon, Dec 7, 2020 at 5:24 PM Kyle Daruwalla notifications@github.com wrote:
Bump @opus111 https://github.com/opus111 I think you said this can be closed?
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/FluxML/Flux.jl/issues/1087#issuecomment-740217026, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAANIBA4ZRGVTWXPH5R666LSTVIYPANCNFSM4LNNRZJA .
Here is a file that reproduces the problem. This code is copied from the TextAnalysis package and slightly altered for Flux 10. The version in GitHub works with Flux 9
`#= This code is copied from CRF of TextAnalysis
The current version in GitHub works with Flux 0.9
https://github.com/JuliaText/TextAnalysis.jl/tree/master/src/CRF =#
using Flux
log_sum_exp(z) = log_sum_exp(z, maximum(z, dims = 1)) log_sum_exp(z, m) = log.(sum(exp.(z .- m), dims = 1)) .+ m
mutable struct CRF{S} W::S # Transition Scores n::Int # Num Labels end
function CRF(n::Integer) W = rand(Float32, n + 2, n + 2) W[:, n + 1] .= -10000 W[n + 2, :] .= -10000 return CRF(W, n) end
Flux.@functor CRF (W,)
preds_first(c::CRF, y) = c.W[c.n + 1, Flux.onecold(y, 1:length(y))] preds_last(c::CRF, y) = c.W[Flux.onecold(y, 1:length(y)), c.n + 2] preds_single(c::CRF, y, y_prev) = c.W[Flux.onecold(y_prev, 1:length(y_prev)), Flux.onecold(y, 1:length(y))]
function forward_score(c::CRF, x, init_α) forward_var = log_sum_exp((c.W .+ transpose(x[1])) .+ init_α) for i in 2:length(x) forward_var = log_sum_exp((c.W .+ transpose(x[i])) .+ transpose(forward_var)) end fs = log_sum_exp(c.W[:, c.n + 2] + transpose(forward_var)) return fs[1] end
function score_sequence(c::CRF, x, label_seq) score = preds_first(c, label_seq[1]) + Flux.onecold(label_seq[1], x[1]) for i in 2:length(label_seq) score += preds_single(c, label_seq[i], label_seq[i-1]) + Flux.onecold(label_seq[i], x[i]) end return score + preds_last(c, label_seq[end]) end
crf_loss(c::CRF, x, label_seq, init_α) = forward_score(c, x, init_α) - score_sequence(c, x, label_seq)
label_count = 10 seq_length = 5 crf = CRF(label_count-2) init_α = fill(-10000.0,label_count) init_α[label_count-1] = 0.0 label_seq = [Flux.onehot(i,1:label_count) for i in 1:seq_length] x = [rand(labelcount) for in 1:seq_length] print("crf_loss=$(crf_loss(crf,x,label_seq,init_α))") print("gradient(crf_loss)=$(gradient(() -> crf_loss(crf,x,label_seq,init_α)))") `