CTUAvastLab / Mill.jl

Build flexible hierarchical multi-instance learning models.
https://ctuavastlab.github.io/Mill.jl/stable/
MIT License
86 stars 8 forks source link

MaybeHotMatrix does not support `Flux.onecold` #83

Closed racinmat closed 2 years ago

racinmat commented 2 years ago

We use Flux.onecold as an inversion to onehot encoding. This works for OneHotMatrix, but not for MaybeHotMatrix. See

julia> t = Flux.onehotbatch(1:3, 1:10)
10×3 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
 1  ⋅  ⋅
 ⋅  1  ⋅
 ⋅  ⋅  1
 ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅
 ⋅  ⋅  ⋅

julia> t2 = maybehotbatch(1:3, 1:10)
10×3 MaybeHotMatrix{UInt32, Int64, Bool}:
 1  0  0
 0  1  0
 0  0  1
 0  0  0
 0  0  0
 0  0  0
 0  0  0
 0  0  0
 0  0  0
 0  0  0

julia> Flux.onecold(t)
3-element Vector{Int64}:
 1
 2
 3

julia> Flux.onecold(t2)
ERROR: LoadError: MethodError: no method matching _getindex(::MaybeHotMatrix{UInt32, Int64, Bool}, ::Int64, ::CartesianIndex{1})
Closest candidates are:
  _getindex(::MaybeHotMatrix, ::Union{Integer, AbstractVector{T} where T}, ::Integer) at C:\Users\racinsky\.julia\packages\Mill\f48u2\src\special_arrays\maybe_hot_matrix.jl:32
  _getindex(::MaybeHotMatrix, ::Integer, ::Colon) at C:\Users\racinsky\.julia\packages\Mill\f48u2\src\special_arrays\maybe_hot_matrix.jl:33
  _getindex(::MaybeHotMatrix, ::CartesianIndex{2}) at C:\Users\racinsky\.julia\packages\Mill\f48u2\src\special_arrays\maybe_hot_matrix.jl:34
  ...
Stacktrace:
 [1] getindex(::MaybeHotMatrix{UInt32, Int64, Bool}, ::Int64, ::CartesianIndex{1})
   @ Mill C:\Users\racinsky\.julia\packages\Mill\f48u2\src\special_arrays\maybe_hot_matrix.jl:31
 [2] findminmax!(f::typeof(Base.isgreater), Rval::Matrix{Bool}, Rind::Matrix{CartesianIndex{2}}, A::MaybeHotMatrix{UInt32, Int64, Bool})
   @ Base .\reducedim.jl:928
 [3] _findmax(A::MaybeHotMatrix{UInt32, Int64, Bool}, region::Int64)
   @ Base .\reducedim.jl:1048
 [4] #findmax#726
   @ .\reducedim.jl:1038 [inlined]
 [5] #argmax#728
   @ .\reducedim.jl:1103 [inlined]
 [6] _fast_argmax
   @ C:\Users\racinsky\.julia\packages\Flux\ZnXxS\src\onehot.jl:211 [inlined]
 [7] onecold(y::MaybeHotMatrix{UInt32, Int64, Bool}, labels::UnitRange{Int64}) (repeats 2 times)
   @ Flux C:\Users\racinsky\.julia\packages\Flux\ZnXxS\src\onehot.jl:205
 [8] top-level scope
   @ c:\Projects\others\JsonGrinder.jl\examples\recipes.jl:70
in expression starting at c:\Projects\others\JsonGrinder.jl\examples\recipes.jl:70

It works as

julia> Flux.onecold(Flux.onehotbatch(t2))
3-element Vector{Int64}:
 1
 2
 3

but that feels cumbersome.

simonmandlik commented 2 years ago

Implementing this certainly wouldn't hurt.. Out of curiosity, where do you need this?

racinmat commented 2 years ago

It's useful in some examples in JsonGrinder where I use extractor to extract labels in task where I predict one field of json from other fields.