JuliaML / MLUtils.jl

Utilities and abstractions for Machine Learning tasks
MIT License
107 stars 22 forks source link

Using DataLoader for multi input #151

Closed MaAl13 closed 1 year ago

MaAl13 commented 1 year ago

Hello, basically i have a neural network that has multiple heads as an input. The problem is when i use the DataLoader from MLUtils, I get an error message, wherever when i use the DatLoader from DataLoader it runs without any problem. Ideally i want to use MLUtils, since the DataLoaders package seems abandoned and hinders me from using the latest versions of quite a few packages. The problem is illustrated below, thanks for the help in advance!

The following package versions are used: Flux v0.13.14 DataLoaders v0.1.3 MLUtils v0.4.1

using Flux
using MLUtils
using DataLoaders

# Setting up the model
struct Join{T, F}
    combine::F
    paths::T
end

Flux.@functor Join
Join(combine, paths) = Parallel(combine, paths)
Join(combine, paths...) = Join(combine, paths)

model = Chain(
              Join(vcat,
                   Chain(Dense(2 => 5, relu), Dense(5 => 1)), # branch 1
                   Dense(2 => 2),                             # branch 2
                   Dense(2 => 1)                              # branch 3
                  ),
              Dense(4 => 1)
             )

# Testing the two different DataLoader approaches
train_data = [tuple([rand(2), rand(2), rand(2)]...) for i in 1:10]
train_label = rand(10)
data_loader = DataLoaders.DataLoader((train_data,train_label), 2)
for (x,y) in data_loader
    println(x)
    println(model(x))
end

data_loader = MLUtils.DataLoader((train_data,train_label), batchsize = 2)
for (x,y) in data_loader
    println(x)
    println(model(x))
end

However, i get the following error:

ERROR: LoadError: MethodError: no method matching (::Float32, ::Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}) Closest candidates are: (::Any, ::Any, ::Any, ::Any...) at operators.jl:591 (::T, ::T) where T<:Union{Float16, Float32, Float64} at float.jl:385 (::Union{Float16, Float32, Float64}, ::BigFloat) at mpfr.jl:414 ... Stacktrace: [1] generic_matvecmul!(C::Vector{Union{}}, tA::Char, A::Matrix{Float32}, B::Vector{Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool}) @ LinearAlgebra ~/julia-1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:805 [2] mul! @ ~/julia-1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:81 [inlined] [3] mul! @ ~/julia-1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:276 [inlined] [4] * @ ~/julia-1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/matmul.jl:56 [inlined] [5] (::Dense{typeof(relu), Matrix{Float32}, Vector{Float32}})(x::Vector{Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}}) @ Flux ~/.julia/packages/Flux/Nzh8J/src/layers/basic.jl:174 [6] macro expansion @ ~/.julia/packages/Flux/Nzh8J/src/layers/basic.jl:53 [inlined] [7] _applychain @ ~/.julia/packages/Flux/Nzh8J/src/layers/basic.jl:53 [inlined] [8] Chain @ ~/.julia/packages/Flux/Nzh8J/src/layers/basic.jl:51 [inlined] [9] #227 @ ~/.julia/packages/Flux/Nzh8J/src/layers/basic.jl:527 [inlined] [10] map @ ./tuple.jl:223 [inlined] [11] (::Parallel{typeof(vcat), Tuple{Chain{Tuple{Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})(x::Vector{Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}}) @ Flux ~/.julia/packages/Flux/Nzh8J/src/layers/basic.jl:527 [12] macro expansion @ ~/.julia/packages/Flux/Nzh8J/src/layers/basic.jl:53 [inlined] [13] _applychain @ ~/.julia/packages/Flux/Nzh8J/src/layers/basic.jl:53 [inlined] [14] (::Chain{Tuple{Parallel{typeof(vcat), Tuple{Chain{Tuple{Dense{typeof(relu), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})(x::Vector{Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}}) @ Flux ~/.julia/packages/Flux/Nzh8J/src/layers/basic.jl:51 [15] top-level scope @ /cluster/work/cobi/Marius/Tissue_sims/Test_9.jl:40 in expression starting at /cluster/work/cobi/Marius/Tissue_sims/Test_9.jl:38

lorenzoh commented 1 year ago

I think passing collate = true, i.e. MLUtils.DataLoader(...; collate = true) may fix your issue. Let me know if it doesn't :)

darsnack commented 1 year ago

The model code doesn't make sense to me here. I know it worked before, but to simplify the issue, do you mind removing the model related code from the MWE? Just have a loop over the data loader and print size(x) and size(y) (and post that output).

MaAl13 commented 1 year ago

Thanks, so collate = true fixed my issue! Awesome!