JuliaML / MLUtils.jl

Utilities and abstractions for Machine Learning tasks
MIT License
109 stars 22 forks source link

`batchseq` for sequences of matrices #118

Closed etpinard closed 2 years ago

etpinard commented 2 years ago

I find batchseq quite useful to reshape data in preparation for training a recurrent model in Flux.

Unfortunately, looks like batchseq only works for sequences of vectors (i.e.1-feature inputs) not sequences of matrices (multi-feature inputs). Consider

# ex.jl
using MLUtils: batchseq

seqlen = 2

seq_a = [[1, 2, 3] for i in 1:seqlen]
@show seq_a
batched_seq_a = batchseq(seq_a)
@show batched_seq_a

seq_B = [[1 2 3; 4 5 6] for i in 1:seqlen]
@show seq_B
batched_seq_B = batchseq(seq_B)

with MLUtils@v0.2.10 this gives

click here to see results ``` julia> include("ex.jl") seq_a = [[1, 2, 3], [1, 2, 3]] batched_seq_a = [[1, 1], [2, 2], [3, 3]] seq_B = [[1 2 3; 4 5 6], [1 2 3; 4 5 6]] ERROR: LoadError: MethodError: no method matching rpad(::Matrix{Int64}, ::Int64, ::Nothing) Closest candidates are: rpad(::Any, ::Integer, ::Union{AbstractChar, AbstractString}) at ~/bin/julia-1.7.2/share/julia/base/strings/util.jl:371 rpad(::AbstractVector, ::Integer, ::Any) at ~/.julia/dev/MLUtils/src/utils.jl:412 rpad(::Any, ::Integer) at ~/bin/julia-1.7.2/share/julia/base/strings/util.jl:371 ... Stacktrace: [1] (::MLUtils.var"#130#135"{Nothing, Int64})(x::Matrix{Int64}) @ MLUtils ./none:0 [2] iterate @ ./generator.jl:47 [inlined] [3] collect @ ./array.jl:724 [inlined] [4] batchseq (repeats 2 times) @ ~/.julia/dev/MLUtils/src/utils.jl:433 [inlined] [5] top-level scope @ ~/.julia/dev/MLUtils/ex.jl:12 [6] include(fname::String) @ Base.MainInclude ./client.jl:451 [7] top-level scope @ REPL[1]:1 in expression starting at /home/etetreault/.julia/dev/MLUtils/ex.jl:12 ```

Thank you very much!

CarloLucibello commented 2 years ago

Is the desired behavior the following? Given an input list xs containing B arrays of size * x Li, where is common to all arrays and Li can vary, return a list of length Lmax where the elements have size ` x B`.