JuliaML / MLUtils.jl

Utilities and abstractions for Machine Learning tasks
MIT License
107 stars 20 forks source link

`mapobs(f, DataLoader)` should just work? #153

Open mcabbott opened 1 year ago

mcabbott commented 1 year ago

I think the last case here should work like the 2nd-last, i.e. put the mapobs inside:

julia> using MLUtils: mapobs, DataLoader

julia> pr(x) = (@show(x); x);

julia> data = (a=[1 2 3; 4 5 6], b=[7,8,9]);

julia> mapobs(pr, data) |> collect
x = (a = [1, 4], b = 7)
x = (a = [2, 5], b = 8)
x = (a = [3, 6], b = 9)
3-element Vector{Any}:
 (a = [1, 4], b = 7)
 (a = [2, 5], b = 8)
 (a = [3, 6], b = 9)

julia> DataLoader(mapobs(pr, data); batchsize=2) |> collect
x = (a = [1 2; 4 5], b = [7, 8])
x = (a = [3; 6;;], b = [9])
2-element Vector{@NamedTuple{a::Matrix{Int64}, b::Vector{Int64}}}:
 (a = [1 2; 4 5], b = [7, 8])
 (a = [3; 6;;], b = [9])

julia> mapobs(pr, DataLoader(data; batchsize=2))
mapobs(pr, DataLoader(::@NamedTuple{a::Matrix{Int64}, b::Vector{Int64}}, batchsize=2); batched=:auto)

julia> collect(ans)
ERROR: MethodError: no method matching getindex(::DataLoader{@NamedTuple{a::Matrix{Int64}, b::Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}, ::Int64)
Stacktrace:
  [1] getobs(::Type{SimpleTraits.Not{MLUtils.IsTable{DataLoader{@NamedTuple{a::Matrix{Int64}, b::Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}}}}, data::DataLoader{@NamedTuple{a::Matrix{Int64}, b::Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}, idx::Int64)
    @ MLUtils ~/.julia/dev/MLUtils/src/observation.jl:110

Are there any downsides to defining something like this? Or should it be done at some other level?

mapobs(f, d::DataLoader) = DataLoader(
    mapobs(f, d.data),
    d.batchsize,
    d.buffer,
    d.partial,
    d.shuffle,
    d.parallel,
    d.collate,
    d.rng,
)