Closed thanasisT closed 9 months ago
Thanks for reporting, I'll have a look at it.
I'm sorry for the long radio silence. When I run the following example code, I don't get any errors.
using FluxArchitectures
using Plots
poollength = 10
horizon = 15
datalength = 2000
input, target = get_data(:solar, poollength, datalength, horizon) |> gpu
inputsize = size(input, 1)
hiddensize = 10
layers = 2
filternum = 32
filtersize = 1
model = TPALSTM(inputsize, hiddensize, poollength, layers, filternum, filtersize) |> gpu
function loss(x, y)
Flux.ChainRulesCore.ignore_derivatives() do
Flux.reset!(model)
end
return Flux.mse(model(x), y')
end
cb = function ()
Flux.reset!(model)
pred = model(input)' |> cpu
Flux.reset!(model)
p1 = plot(pred, label="Predict")
p1 = plot!(cpu(target), label="Data", title="Loss $(loss(input, target))")
display(plot(p1))
end
@info "Start loss" loss = loss(input, target)
@info "Starting training"
Flux.train!(loss, Flux.params(model),Iterators.repeated((input, target), 20), Adam(0.01), cb=cb)
@info "Final loss" loss = loss(input, target)
Are you using the package versions defined in the Project.toml
file? Unfortunately, the package is incompatible with newer versions of Flux
and Zygote
.
I am running into the same issue, and I believe it centers on that above incompatibility. Will there be support for it in the future?
Well, unfortunately the underlying issue (see https://github.com/FluxML/Zygote.jl/issues/1304 and the issues therein) is somewhere deep in the Zygote ecosystem. I don't understand enough of Zygote to fix it myself, and I don't really have the resources right now to work around it. So the honest answer is that it will probably stay this way short-term.
I looked arround and i saw, that a) You run it on GPU, so it doesnt use the CPU arrays, so you dont find the same issue b) The Base.Slices and JuliennedArrays.Slices are practicly the same, based on a discussion in JuliennedArrays.Slices github I changed the JuliennedArrays.Slices to Base.Slices, and it seemed to work. Maybe someone can confirm
Thank you for investigating that - I will give Base.Slices a try. See https://github.com/bramtayl/JuliennedArrays.jl/issues/34
I try to run TPALSTM example on julia 1.9 but it throws error that 'Slices' not defined
exception = │ UndefVarError:
Slices
not defined │ Stacktrace: │ [1] (::FluxArchitectures.Seq{FluxArchitectures.StackedLSTMCell{Chain{Tuple{Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}}}, Vector{Float32}}})(#unused#::Vector{Float32}, #unused#::FluxArchitectures.StackedLSTMCell{Chain{Tuple{Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Flux.Recur{Flux.LSTMCell{Matrix{Float32}, Matrix{Float32}, Vector{Float32}, Tuple{Matrix{Float32}, Matrix{Float32}}}, Tuple{Matrix{Float32}, Matrix{Float32}}}}}, Vector{Float32}}, x::Matrix{Float32})