Open mcabbott opened 1 year ago
I think the last case here should work like the 2nd-last, i.e. put the mapobs inside:
mapobs
julia> using MLUtils: mapobs, DataLoader julia> pr(x) = (@show(x); x); julia> data = (a=[1 2 3; 4 5 6], b=[7,8,9]); julia> mapobs(pr, data) |> collect x = (a = [1, 4], b = 7) x = (a = [2, 5], b = 8) x = (a = [3, 6], b = 9) 3-element Vector{Any}: (a = [1, 4], b = 7) (a = [2, 5], b = 8) (a = [3, 6], b = 9) julia> DataLoader(mapobs(pr, data); batchsize=2) |> collect x = (a = [1 2; 4 5], b = [7, 8]) x = (a = [3; 6;;], b = [9]) 2-element Vector{@NamedTuple{a::Matrix{Int64}, b::Vector{Int64}}}: (a = [1 2; 4 5], b = [7, 8]) (a = [3; 6;;], b = [9]) julia> mapobs(pr, DataLoader(data; batchsize=2)) mapobs(pr, DataLoader(::@NamedTuple{a::Matrix{Int64}, b::Vector{Int64}}, batchsize=2); batched=:auto) julia> collect(ans) ERROR: MethodError: no method matching getindex(::DataLoader{@NamedTuple{a::Matrix{Int64}, b::Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}, ::Int64) Stacktrace: [1] getobs(::Type{SimpleTraits.Not{MLUtils.IsTable{DataLoader{@NamedTuple{a::Matrix{Int64}, b::Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}}}}, data::DataLoader{@NamedTuple{a::Matrix{Int64}, b::Vector{Int64}}, Random._GLOBAL_RNG, Val{nothing}}, idx::Int64) @ MLUtils ~/.julia/dev/MLUtils/src/observation.jl:110
Are there any downsides to defining something like this? Or should it be done at some other level?
mapobs(f, d::DataLoader) = DataLoader( mapobs(f, d.data), d.batchsize, d.buffer, d.partial, d.shuffle, d.parallel, d.collate, d.rng, )
I think the last case here should work like the 2nd-last, i.e. put the
mapobs
inside:Are there any downsides to defining something like this? Or should it be done at some other level?