JuliaML / MLUtils.jl

Utilities and abstractions for Machine Learning tasks
MIT License
107 stars 20 forks source link

`unstack` not passing the inferred test #149

Closed gabrevaya closed 1 year ago

gabrevaya commented 1 year ago

Bellow you can find a MWE. Is there any workaround?

julia> using MLUtils, Test

julia> x =  reshape(1:24, 3,4,2) |> collect
3×4×2 Array{Int64, 3}:
[:, :, 1] =
 1  4  7  10
 2  5  8  11
 3  6  9  12

[:, :, 2] =
 13  16  19  22
 14  17  20  23
 15  18  21  24

julia> @inferred unstack(x; dims=2)
ERROR: return type Vector{Matrix{Int64}} does not match inferred return type Vector
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] top-level scope
   @ REPL[42]:1
(MLUtils_unstack) pkg> st
Status `~/Documents/issues/MLUtils_unstack/Project.toml`
  [f1d291b0] MLUtils v0.4.1
julia> versioninfo()
Julia Version 1.8.5
Commit 17cfb8e65ea (2023-01-08 06:45 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin21.5.0)
  CPU: 8 × Apple M1 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, apple-m1)
  Threads: 1 on 6 virtual cores
ToucheSir commented 1 year ago

Adding a method which takes a Val dims arg would work, but someone would have to do it.

gabrevaya commented 1 year ago

Thanks! Would this be OK?

unstack(xs; dims::Val{D}) where D = [copy(selectdim(xs, D, i)) for i in 1:size(xs, D)]
ToucheSir commented 1 year ago

If it infers for you and passes the existing tests (with Val instead of Int dims), then drop a PR :)

gabrevaya commented 1 year ago

I don't understand why when I add the new method

unstack(xs; dims::Int) = [copy(selectdim(xs, dims, i)) for i in 1:size(xs, dims)]
unstack(xs; dims::Val{D}) where {D} = [copy(selectdim(xs, D, i)) for i in 1:size(xs, D)]

when I run the tests, the previous method with dims::Int is not found:

Test threw exception
  Expression: stack(unstack(stacked_array, dims = 1), dims = 1) == stacked_array
  MethodError: no method matching var"#unstack#102"(::Int64, ::typeof(unstack), ::Matrix{Int64})
  Closest candidates are:
    var"#unstack#102"(::Val{D}, ::typeof(unstack), ::Any) where D at ~/Documents/issues/MLUtils.jl/src/utils.jl:80

It's like the new method would be overwriting the other one. But that doesn't make sense because they should be considered as two different methods, with different signatures. For example, the following works well:

julia> f(x::Int) = "int"
julia> f(x::Val{d}) where d = "val"
julia> f(2)
"int"
julia> f(Val(2))
"val"

Do you know what is the problem in the case of unstack?

gabrevaya commented 1 year ago

I've just realized that it's due to the keyword argument. I guess I'll have to make that method with dims not being a keyword argument.

I'll make the PR now :)