Gradient dimension mismatch error when training rnns

Open chwons opened 2 years ago

chwons commented 2 years ago
using Flux

xs = [[1f0 2f0 3f0], [2f0 3f0 4f0]]
ys = [[2f0 3f0 4f0], [3f0 4f0 5f0]]
m = GRU(1, 1)

function loss(xs, ys)
    sum(Flux.mse.([m(x) for x in xs], ys))

opt = ADAM()
ps = params(m)
grads = gradient(ps) do
    loss(xs, ys)

julia> Flux.update!(opt, ps, grads)
ERROR: DimensionMismatch("new dimensions (1, 1) must be consistent with array size 3")
 [1] (::Base.var"#throw_dmrsa#272")(dims::Tuple{Int64, Int64}, len::Int64)
   @ Base ./reshapedarray.jl:41
 [2] reshape
   @ ./reshapedarray.jl:45 [inlined]
 [3] reshape
   @ ./reshapedarray.jl:116 [inlined]
 [4] restructure
   @ ~/.julia/packages/ArrayInterface/mJodK/src/ArrayInterface.jl:400 [inlined]
 [5] update!(opt::ADAM, x::Matrix{Float32}, x̄::Matrix{Float32})
   @ Flux.Optimise ~/.julia/packages/Flux/qAdFM/src/optimise/train.jl:24
 [6] update!(opt::ADAM, xs::Zygote.Params, gs::Zygote.Grads)
   @ Flux.Optimise ~/.julia/packages/Flux/qAdFM/src/optimise/train.jl:32
 [7] top-level scope
   @ REPL[9]:1
 [8] top-level scope
   @ ~/.julia/packages/CUDA/Axzxe/src/initialization.jl:52

julia> [size(p) for p in ps]
4-element Vector{Tuple{Int64, Vararg{Int64}}}:
 (3, 1)
 (3, 1)
 (1, 1)

julia> [size(grads[p]) for p in ps]
4-element Vector{Tuple{Int64, Vararg{Int64}}}:
 (3, 1)
 (3, 1)
 (1, 3)

This happens when using RNN or GRU but doesn't when using LSTM

ToucheSir commented 2 years ago

I suspect this has something to do with LSTMs not accepting 1D inputs, but have not dug into the code. To be on the safe side, you should always make sure to have a batch dimension for inputs. Currently each element of xs has size (3,) instead of (1, 3) (assuming each element is separate) or (3, 1) (assuming each element is a feature).

chwons commented 2 years ago

The size of each element of xs is already (1, 3) and changing it to (3, 1) doesn't work because GRU(1, 1) expects the size of the first dimension to be 1. I also tried changing the size to (3, 3) but the same thing happens.

ToucheSir commented 2 years ago

You are correct and I definitely hallucinated in some non-existent commas. If you don't need to have a learnable initial state, the easiest way to avoid this is to call reset! outside of your loss function and the gradient callback. Alternatively, you could use https://fluxml.ai/Zygote.jl/latest/utils/#Zygote.ignore.

Interestingly, using explicit params results in the correct shape:

function loss2(m, xs, ys)
    sum(map((x, y) -> Flux.mse(m(x), y), xs, ys))

julia> gradient(m -> loss2(m, xs, ys), m)
((cell = (σ = nothing, Wi = Float32[-14.8857;;], Wh = Float32[-4.192739;;], b = Float32[-2.7642984], st
ate0 = Float32[-7.955205;;]), state = nothing),)

It's not immediately clear to me what the issue with implicit params could be since I know people are using reset! inside gradient in the wild, but we'll look into it.

wpeguero commented 1 year ago

Hello is there any update regarding this issue? I am getting a similar error on the following heating oil prices.csv

using Plots
using DataFrames
using DelimitedFiles
using ProgressMeter
using Statistics
using Dates
using Flux
using BSON: @save

Load CSV data into a DataFrame.
function read_csv(filename::String, delimiter::Char=',')
    data, headers = readdlm(filename, delimiter, header=true)
    df = DataFrame(data, vec(headers))
    return df

df = read_csv("B:/Documents/Github/grove-cost-predictors/data/heating oil prices.csv")

# for the weekly price of dollars per gallon columns
df[!, :"weekly__dollars_per_gallon"] = convert.(Float32, df[!, :"weekly__dollars_per_gallon"])

# Convert string columns to the date types
df[!, :"\ufeffdate"] = Date.(df[!, :"\ufeffdate"], Dates.DateFormat("u d, yyyy"))
df[!, :"years"] = year.(df[!, :"\ufeffdate"])
df[!, :"months"] = month.(df[!, :"\ufeffdate"])
df[!, :"weeks"] = week.(df[!, :"\ufeffdate"])
df[!, :"days"] = day.(df[!, :"\ufeffdate"])

df[!, :weekly__dollars_per_gallon] = convert.(Float32, df[!, :weekly__dollars_per_gallon])
df[!, :years] = convert.(Float32, df[!, :years])
df[!, :months] = convert.(Float32, df[!, :months])
df[!, :weeks] = convert.(Float32, df[!, :weeks])
df[!, :days] = convert.(Float32, df[!, :days])

df[!, :"pprice"] = missings(Float32, nrow(df))
for (i, row) in enumerate(eachrow(df))
    if i == 1
        row["pprice"] = df[i-1, :weekly__dollars_per_gallon]
delete!(df, 1)
df[!, :pprice] = convert.(Float32, df[!, :pprice])
first(df, 10)
X_train = Array(first(df[!, [:years, :months, :weeks, :days, :pprice]], Int(round(nrow(df) * 0.75))))'
X_test = Array(last(df[!, [:years, :months, :weeks, :days, :pprice]], Int(round(nrow(df) * 0.25))))'
y_train = Array(first(df[!, :weekly__dollars_per_gallon], Int(round(nrow(df) * 0.75))))'
y_test = Array(last(df[!, :weekly__dollars_per_gallon], Int(round(nrow(df) * 0.25))))'
train_loader = Flux.Data.DataLoader((X_train, y_train), batchsize=32, shuffle=true);
test_loader = Flux.Data.DataLoader((data=X_test, label=y_test), batchsize=64, shuffle=true);
#first(train_loader, 10)
model = Chain(
    GRU(5 => 20),
    GRU(20 => 10),
    GRU(10 => 2),
    Dense(2 => 1)
optim = Flux.setup(Adam(0.01), model)
losses = []
mean_losses = []
@showprogress for epoch in 1:300
    for (x, y) in train_loader
        loss, grads = Flux.withgradient(model) do m
            # Evaluate model and loss inside gradient context:
            y_hat = m(x)
            Flux.mae(y_hat, y)
        Flux.update!(optim, model, grads[1])
        push!(losses, loss)
    push!(mean_losses, mean(losses))

@save "mymodel.bson" model
preds = model(X_test)
acc = []
for i in eachindex(preds)
    a = 100 - ((abs(preds[i] - y_test[i])/y_test[i])*100)
    push!(acc, a)

Here is the error that I get:

ERROR: DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 9 and 32 Stacktrace: [1] _bcs1 @ .\broadcast.jl:516 [inlined] [2] _bcs (repeats 2 times) @ .\broadcast.jl:510 [inlined] [3] broadcast_shape @ .\broadcast.jl:504 [inlined] [4] combine_axes @ .\broadcast.jl:498 [inlined] [5] instantiate @ .\broadcast.jl:281 [inlined] [6] materialize @ .\broadcast.jl:860 [inlined] [7] broadcast(::typeof(+), ::SubArray{Float32, 2, Matrix{Float32}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, ::SubArray{Float32, 2, Matrix{Float32}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}, ::SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}) @ Base.Broadcast .\broadcast.jl:798 [8] adjoint @ C:\Users\wpegu.julia\packages\Zygote\SmJK6\src\lib\broadcast.jl:78 [inlined] [9] _pullback @ C:\Users\wpegu.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:65 [inlined] [10] _pullback @ C:\Users\wpegu.julia\packages\Flux\kq9Et\src\layers\recurrent.jl:364 [inlined] [11] _pullback @ C:\Users\wpegu.julia\packages\Flux\kq9Et\src\layers\recurrent.jl:382 [inlined] [12] _pullback(::Zygote.Context{false}, ::Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, ::Matrix{Float32}, ::Matrix{Float32}) @ Zygote C:\Users\wpegu.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0 [13] _pullback @ C:\Users\wpegu.julia\packages\Flux\kq9Et\src\layers\recurrent.jl:134 [inlined] [14] _pullback(ctx::Zygote.Context{false}, f::Flux.Recur{Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, args::Matrix{Float32}) @ Zygote C:\Users\wpegu.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0 [15] macro expansion @ C:\Users\wpegu.julia\packages\Flux\kq9Et\src\layers\basic.jl:53 [inlined] [16] _pullback @ C:\Users\wpegu.julia\packages\Flux\kq9Et\src\layers\basic.jl:53 [inlined] [17] _pullback(::Zygote.Context{false}, ::typeof(Flux._applychain), ::Tuple{Flux.Recur{Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Flux.Recur{Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dropout{Float64, Colon, Random.TaskLocalRNG}, Flux.Recur{Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, ::Matrix{Float32}) @ Zygote C:\Users\wpegu.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0 [18] _pullback @ C:\Users\wpegu.julia\packages\Flux\kq9Et\src\layers\basic.jl:51 [inlined] [19] _pullback(ctx::Zygote.Context{false}, f::Chain{Tuple{Flux.Recur{Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Flux.Recur{Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dropout{Float64, Colon, Random.TaskLocalRNG}, Flux.Recur{Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, args::Matrix{Float32}) @ Zygote C:\Users\wpegu.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0 [20] _pullback @ b:\Documents\Github\grove-cost-predictors\case studies\sample.jl:70 [inlined] [21] _pullback(ctx::Zygote.Context{false}, f::var"#27#28"{LinearAlgebra.Adjoint{Float32, Vector{Float32}}, Matrix{Float32}}, args::Chain{Tuple{Flux.Recur{Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Flux.Recur{Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dropout{Float64, Colon, Random.TaskLocalRNG}, Flux.Recur{Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}) @ Zygote C:\Users\wpegu.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0 [22] pullback(f::Function, cx::Zygote.Context{false}, args::Chain{Tuple{Flux.Recur{Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Flux.Recur{Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dropout{Float64, Colon, Random.TaskLocalRNG}, Flux.Recur{Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}) @ Zygote C:\Users\wpegu.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:44 [23] pullback @ C:\Users\wpegu.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:42 [inlined] [24] withgradient(f::Function, args::Chain{Tuple{Flux.Recur{Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Flux.Recur{Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dropout{Float64, Colon, Random.TaskLocalRNG}, Flux.Recur{Flux.GRUCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Matrix{Float32}}, Matrix{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}) @ Zygote C:\Users\wpegu.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:132 [25] macro expansion @ b:\Documents\Github\grove-cost-predictors\case studies\sample.jl:68 [inlined] [26] top-level scope @ C:\Users\wpegu.julia\packages\ProgressMeter\sN2xr\src\ProgressMeter.jl:938