CTUAvastLab / Mill.jl

Build flexible hierarchical multi-instance learning models.
https://ctuavastlab.github.io/Mill.jl/stable/
MIT License
86 stars 8 forks source link

skipping derivation of reduce catobs during inference and adding helper functions #91

Closed racinmat closed 2 years ago

racinmat commented 2 years ago

Fixes #90 . Adds

(m::AbstractMillModel)(x::AbstractVector{<:AbstractMillNode}) = m(Zygote.@ignore(reduce(catobs, x)))
(m::AbstractMillModel)(x::DataSubset) = m(Zygote.@ignore(getobs(x)))

which simplifies usage of RandomBatches from MLDataPattern for minibatching.

Adds convenience method

catobs(as::AbstractVector{<:AbstractMillNode}) = reduce(catobs, as)