FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.46k stars 603 forks source link

generic_matmul! hit in `back!` because type-promotion in activation function #613

Open oxinabox opened 5 years ago

oxinabox commented 5 years ago

Sometimes generic_matmul! is hit in back! For examopole adding a leak too unit can be done by writing an activation function like

    leaky_relu6(x) = 0.01x + clamp(x, 0, 6)

And this is well and good, of x is a Float64. But if x is a Float32 this will trigger a type-promotion. Which is bad, because the user almost certainly did not intend the type promotion. But worse, it means rather than hitting fast BLAS, we fall back to slow generic_matmul!.

Here is a MWE:

function flux_model()
    return Chain(
#        Dense(1280, 64, x->0.1f0x),    # Uncomment one of these lines
#        Dense(1280, 64, x->0.1e0x),   # Uncomment one these lines
        Dense(64, 1),
    )
end

function demo_flux()
    mdl = flux_model()
    features = rand(Float32, (1280, 1000))

    Flux.train!(
        params(mdl),
        [(features,)],
        Flux.ADAM()
    ) do xs
        sum(mdl(xs))
    end
end

Time if it has to promote: @time demo_flux()

0.143774 seconds (607 allocations: 19.635 MiB)

Time normally: @time demo_flux()

0.016475 seconds (568 allocations: 13.218 MiB, 47.67% gc time)

That is a 10x time diifference, and it scales up as your matrix sizes scale up.

KristofferC commented 5 years ago

Isn't this expected though? Why aren't you using the eltype of the input to determine the type of the float constants?

oxinabox commented 5 years ago

Depends what you mean by "expected". It follows the rules of the language, yes. But the fact that a mistake like this can trash your performance is not great.

As a rule when looking for code that might be causing a slow down one doesn't immediately go looking for constants not type matching. This is fully type stable, it doesn't allocate etc etc. Even looking at the profiler output it took @staticfloat and I a while to work this out.

It is of course obvious in retrospect, but not in from inspection of the code base that this is causing a order of magnitude slow down.

We could certainly consider giving a surpressable warning if the return type of the activation does not match it's inputs. (Or even just if it switched between floating point types).

Or we could do other things.

oxinabox commented 5 years ago

If nothing else we can have a "performance tips" page, saying to be careful of the types of literals in your activation functions.

I also think we can probably make this faster than it is. If nothing else we can promote the matrix and use BLAS rather than the generic matmul.

staticfloat commented 5 years ago

I think likely what we should do is trigger a warning on our fallback operations; by default I don’t think any user ever wants to use the generic matmul, and so while I like that we support using it, we should have a (silencable) warning that spits out the first time it is invoked, along with a backtrace to figure out where it’s coming from.

MikeInnes commented 5 years ago

At the risk of flagrant type piracy, we could just override the behaviour of x::Array{T} * y::Array{S} from Base.

It would be worth some notes in the activation functions section of the docs though. NNlib's ones are all set up to preserve input type and there's testing infrastructure for this as well; it's really a matter of following standard Julia style.

darsnack commented 3 years ago

Given that #615 added this to the docs, do we still want to address this with a warning somehow?

oxinabox commented 3 years ago

Yeah, I think we should, it can wreck your performance.

darsnack commented 3 years ago

I believe changing the types to hit BLAS makes things troublesome for mixed precision. I'm not an expert on the topic, but I've heard that mentioned quite a few times on orthogonal issues/PRs.

A warning would be good though. Only concern is the type-piracy. Either way I added this to the triage project so it gets discussed during the next ML community call.

KristofferC commented 3 years ago

Please no type piracy.

oxinabox commented 3 years ago

You don't need type piracy. You put the warning in Dense (or a helper called by Dense etc, maybe even shadowing *) that checks the types before it the calls Base.*

DhairyaLGandhi commented 3 years ago

Doing these things generically in a manner that doesn't touch ad and runtime performance in forward or backwards pass can be tough.

oxinabox commented 3 years ago

You can have a disable-able safety rails mode that compiles away.


# Safety rails default on
has_safety_rails() = true

# Function to let advanced user turn them off.
#Triggers recompilation
safety_rails!(enable) = @eval has_safety_rails() = $enable

macro safety_rail(cond, msg::String, logargs...)
    disable = " To disable this warning, run `Flux.safety_rails!(false)`."
    #TODO this doesn't quite display logargs right.
    warning = :(@warn($msg*$disable, $(esc.(logargs)...)))
    return quote
        has_safety_rails() && Zygote.ignore() do
            $(esc(cond)) && $warning
        end
    end
end

function *(a::T, b::S) where {T, S}
    @safety_rail(
        T!==S,
        "Mixed type multiplication encountered. This probably means you ...",
        T, S
    )

    return Base.:(*)(a, b)
end

1.0 * 2

I think I stole this trick from TimerOutputs.jl, to have a debug mode that compiles away when not in use. ChainRulesCore uses it.

darsnack commented 3 years ago

Summarizing what was discussed during triage today:

The appropriate place for a @safety_rail style check is probably NNlib instead of Flux (if we want it). People were generally uncomfortable with shadowing *, and it was suggested that if generic_matmul! is almost never wanted, then maybe the warning mechanism for it belongs in Base. Admittedly, the penalty is much greater when used in AD like Zygote, but Flux seems like the wrong place in the hierarchy to address this issue generically.

What was suggested instead is to package the promotion check into a utility function. Something like performance_check (could work on a similar mechanism to outputsize) that runs a forwards and backwards pass on the model and throws warnings for any performance issues. Could have a forward pointer to the performance tips docs in the warning string. I don't think there is anything like this w.r.t. performance, but other frameworks do have such utilities as sanity checks.

KristofferC commented 3 years ago

and it was suggested that if generic_matmul! is almost never wanted, then maybe the warning mechanism for it belongs in Base

It is wanted in Base all the time though so I think you will have a hard time putting such a warning there.

darsnack commented 3 years ago

Would a warning via a utility function be acceptable then @oxinabox?

oxinabox commented 3 years ago

I honestly don't care how it is done. The point is to protect people who are first learning julia and first learning flux from footguns. People who haven't yet read all the docs.

We should understand the context.

The thing that matters here is that it is very easy to get Float64's in your network. Because floating point literals are Float64 . and float(::Int) returns a Float64. So if you are not careful, you can end up with one coming out of a helper function (e.g. had to fix a few things in Distributions.jl for this recently) or even by just making a mistake and using a literal youself. Now i normal julia code promoting to Float64 is fine. It isn't much slower, it only uses 2x as much memory, and it is more accurate. It is the safe and good bet.

But in NN code you often intentionally want Float32 because lower precision cost you nothing, and even is said to act as a regularizer. Further the whole process of training a NN boils down to a loop of matmuls. We need to hit the fast matmul.

One day we might be able to do |> f32 like we do |>gpu. That also would solve this

mcabbott commented 2 years ago

Times with https://github.com/FluxML/Zygote.jl/pull/1044 : now hardly any slowdown, but a few more allocations than the all-Float32 version:

julia> function flux_model()  # Float32
           return Chain(
               Dense(1280, 64, x->0.1f0x),    # Uncomment one of these lines
       #        Dense(1280, 64, x->0.1e0x),   # Uncomment one these lines
               Dense(64, 1),
           )
       end
flux_model (generic function with 1 method)

julia> @time demo_flux()
  0.571037 seconds (1.88 M allocations: 110.176 MiB, 3.09% gc time, 0.97% compilation time)

julia> @time demo_flux()
  0.011878 seconds (388 allocations: 13.074 MiB)
  0.007133 seconds (388 allocations: 13.074 MiB)  # another run

julia> function flux_model()  # Float64
           return Chain(
       #        Dense(1280, 64, x->0.1f0x),    # Uncomment one of these lines
               Dense(1280, 64, x->0.1e0x),   # Uncomment one these lines
               Dense(64, 1),
           )
       end
flux_model (generic function with 1 method)

julia> @time demo_flux()
  0.583360 seconds (1.91 M allocations: 113.019 MiB, 3.05% gc time, 0.88% compilation time)

julia> @time demo_flux()
  0.010863 seconds (389 allocations: 14.543 MiB)
  0.011858 seconds (389 allocations: 14.543 MiB)  # another run

Compared to tagged version without that PR, just the slow case -- 10x slower than Float32, as above:

julia> @time demo_flux()  # Float64, first run after re-defining demo_flux()
  0.655448 seconds (1.91 M allocations: 117.824 MiB, 2.61% gc time, 0.79% compilation time)

julia> @time demo_flux()
  0.097526 seconds (388 allocations: 19.495 MiB, 17.16% gc time)
  0.107293 seconds (388 allocations: 19.495 MiB)  # another run

(@v1.7) pkg> st Zygote
      Status `~/.julia/environments/v1.7/Project.toml`
  [e88e6eb3] Zygote v0.6.20

This version without the PR has some ProjectTo stuff in place, e.g. in the rule for *, but not in broadcasting, so it catches the problem a little later.