JuliaAI / MLJModelInterface.jl

Lightweight package to interface with MLJ
MIT License
37 stars 8 forks source link

Fix `@mlj_model` parsing bug #175

Closed MilesCranmer closed 1 year ago

MilesCranmer commented 1 year ago

Fixes #174.

The bug was due to length(default.args) > 1 being used to check whether a constraint was present in the default expression. What should actually be done is default.head == :(::).

Using length(default.args) is incorrect because some defaults can have multiple args, such as

julia> :([0, 1, 2]).args
3-element Vector{Any}:
 0
 1
 2

Likewise, the user might not have used :: at all to define the constraint.

This PR checks for the :: operator to identify a constraint which fixes this issue.

Note that this fix is backwards compatible, so long as people have not been exploiting the bug. See my note at the end of this post for more details.

With this change, we can now do:

julia> using MLJModelInterface

julia> @mlj_model mutable struct Foo
           x::Vector{Int} = [0, 1, 2]
           y::Int = -1
       end

julia> Foo()
Foo([0, 1, 2], -1)

whereas before you would get an error due to the incorrectly parsed default, which would parse 0 as the default and 1 as the constraint:

julia> Foo()
ERROR: MethodError: Cannot `convert` an object of type Int64 to an object of type Vector{Int64}

Closest candidates are:
  convert(::Type{T}, ::LinearAlgebra.Factorization) where T<:AbstractArray
   @ LinearAlgebra ~/.julia/juliaup/julia-1.9.1+0.aarch64.apple.darwin14/share/julia/stdlib/v1.9/LinearAlgebra/src/factorization.jl:59
  convert(::Type{T}, ::AbstractArray) where T<:Array
   @ Base array.jl:613
  convert(::Type{T}, ::T) where T<:AbstractArray
   @ Base abstractarray.jl:16
  ...

Stacktrace:
 [1] Foo(; x::Int64, y::Int64)
   @ Main ./none:0
 [2] Foo()
   @ Main ./none:0
 [3] top-level scope
   @ REPL[5]:1

cc @ablaom


What changes is that a user can no longer exploit the bug. For example, before you actually do:

julia> @mlj_model mutable struct Foo
           x::Int = [1, (_ >= 0)]
       end

julia> Foo()
Foo(1)

to give a constraint. Any expression which has the constraint in the second argument would be a valid way of defining this model. Unless you are aware of anybody doing this, I wouldn't worry about it though.

Now if you try to use that syntax, you would get:

julia> @mlj_model mutable struct Foo
           x::Int = [1, (_ >= 0)]
       end
ERROR: syntax: all-underscore identifier used as rvalue
Stacktrace:
 [1] top-level scope
   @ REPL[3]:1

which is what we want.

ablaom commented 1 year ago

@MilesCranmer This looks like a great contribution, thank you. Appreciate the work and excellent presentation.

Because this package is sits so low in the MLJ ecosystem, let's run some tests from MLJTestIntegration.jl, just to minimize the chance of surprises. @OkonSamuel do you have some time to verify the following?

(Sometimes when I run these very extensive tests, I find new unrelated issues. In those cases I just remove appropriate models from tests in the notebook, add links in the notebook to the issues blocking inclusion, and update the notebook on the MLJTestIntegration#dev branch. We should automate this, I know, but haven't got around to it yet.)

OkonSamuel commented 1 year ago

nice catch @MilesCranmer

ablaom commented 1 year ago

@OkonSamuel Any progress on integration tests?

ablaom commented 1 year ago

The integration tests are good.