MikeInnes / Lazy.jl

I was gonna maintain this package, but then I got high
Other
470 stars 54 forks source link

@switch doesn't work on Array index #59

Closed AStupidBear closed 7 years ago

AStupidBear commented 7 years ago
julia> x = [1, 2, 3]
3-element Array{Int64,1}:
 1
 2
 3

julia> macroexpand(:(@switch x[1] begin
         1 ; x[1] +=2
         nothing
       end))
:(if let _ = 1 # C:\PortableSoftware\Scoop\apps\julia\pkgs-0.5.0\v0.5\Lazy\src\macros.jl, line 45:
            x[1]
        end
        x[1] += 2
    else
        nothing
    end)

How can I get around this ?

stevenjgilmore commented 7 years ago

This happens for any function call that you put after the @switch, for example:

@switch sum([5,2]) begin
    7; true
    5; false
end

This is due to the way that macros parse expressions.

Note that by adding the following line to the @switch macro (at line 42), I'm able to achieve the desired behavior (and all the examples on the README still work)

  test = !isa(test,Expr) ? test :
         in(:_,test.args) ? test : eval(test)

Disclaimer: I am still understanding how macros work and am unsure of the side-effects that this can cause. If desired, I can open a PR for this but I am guessing could cause unintended results.

ETA: This breaks when testing in anything besides global scope.

AStupidBear commented 7 years ago

@stevenjgilmore Thanks! I have also modified this macro to suit my need. According to the Julia manual, using eval in a macro is not appropriate. Here's my implementation:

using Lazy: isexpr, rmlines, splitswitch
export @switch
macro switch(args...)
  test, exprs = splitswitch(args...)

  length(exprs) == 0 && return nothing
  length(exprs) == 1 && return esc(exprs[1])

  test_expr(test, val) =
  test == :_      ? val :
  has_symbol(test, :_) ? :(let _ = $val; $test; end) :
                    :($test==$val)

  thread(val, yes, no) = :($(test_expr(test, val)) ? $yes : $no)
  thread(val, yes) = thread(val, yes, :(error($"No match for $test in @switch")))
  thread(val, yes, rest...) = thread(val, yes, thread(rest...))

  esc(thread(exprs...))
end

export has_symbol
function has_symbol(ex::Expr, s)
  for i in 1:length(ex.args)
    has_symbol(ex.args[i], s) && return true
  end
  false
end
function has_symbol(ex, s)
  ex == s
end