JuliaML / MLUtils.jl

Utilities and abstractions for Machine Learning tasks
MIT License
107 stars 20 forks source link

`getobs!(buffer, data::BatchView, idx)` ignores `buffer` #156

Open jondeuce opened 1 year ago

jondeuce commented 1 year ago

For example,

using MLUtils
buffer = zeros(3,2)
data = BatchView(rand(3,100); batchsize = 2)
getobs!(buffer, data, 1)
buffer == zeros(3,2) # true

Looks like it just falls back to getobs defined here. This leads to e.g. DataLoader never calling getobs! when batchsize > 0:

struct DummyData{X}
    x::X
end
MLUtils.numobs(data::DummyData) = numobs(data.x)
MLUtils.getobs(data::DummyData, idx) = getobs(data.x, idx)
MLUtils.getobs!(buffer, data::DummyData, idx) = error("getobs! is called")

data = DummyData(rand(3,100))
collect(DataLoader(data; batchsize=1, buffer=true)) # no error
collect(DataLoader(data; batchsize=0, buffer=true)) # error