FluxML / Mjolnir.jl

A little less conversation, a little more abstraction
Other
87 stars 13 forks source link

Applications #1

Open AriMKatz opened 4 years ago

AriMKatz commented 4 years ago

Hello Mike,

In the spirit of your readme, I'm wondering to what extent this package can or is intended to address some common pain points aside from speeding up flux/zygote:

  1. Lightweight static compilation of a useful subset of non ML julia programs, for deployment. (including general binary deployment and also minimal runtime targets- I'm interested in lightweight drone vision and control systems))
  2. non-runtime semantic or correctness analysis. At the least, finding method errors.
  3. non-runtime shape inference/checking for numerical code
  4. ONNX read and write
  5. Dispatching array ops (even on complex nested arrays) to their correct GPU implementations
  6. As a bonus, transpilation to things like WASM

For those that apply, are they planned roadmap items, and if not, how much additional work would they required?

Thanks

MikeInnes commented 4 years ago
  1. Yes, definitely. The XLA work is initially targeting training, but assuming there are deployment tools that we can easily compile to that's an easy corollary. Mjolnir is a much more powerful superset of the DataFlow.jl approach used by FluxJS and ONNX.jl (and should replace it in both cases).
  2. Yes and no: Not only are missing methods a compile-time error but, at least as far as XLA and co are concerned, anything that would result in slow code is. That said, Julia's type system is more geared towards performance than correctness, so the former is all we're guaranteeing. OTOH, a more advanced user could hijack Mjolnir's abstract interpreter to infer whatever they want, not just Julia's type tags. But this isn't something we're initially supporting.
  3. Shape inference comes under type/constant prop so that's easy, though we're not currently doing it for the XLA work (instead letting XLA do the checking). Checking things like numerical stability take the form of a custom analysis (see 2) or just inserting runtime checks to help debug where NaNs appear from, which is pretty easy.
  4. Mostly see 1. Reading models gives us a choice of spitting out Julia code (easy to edit) or just eval'ing an IR fragment (easier to implement). There was some initial work on IR -> Julia AST but it isn't complete; ONNX's needs are simple though.
  5. If we compile Julia code via Mjolnir we can see moving things to the GPU as an optimisation, and then we'd have total control over what kernels are executed etc. This obviously wouldn't help regular CuArrays users though; it'd be more like working with XLA, with CUDA hidden underneath. I do think something like this is the right eventual interface though (https://github.com/JuliaGPU/CuArrays.jl/pull/406).
  6. Mjolnir is effectively good at taking code fragments (ML models, numerical kernels) that don't use too much dynamism, and removing the need for the Julia runtime. This could be helpful for compiling small code snippets to WASM. It wouldn't help with e.g. compiling large projects or packages to WASM; at some point the assumptions Mjolnir makes to do its job are going to be violated. Compiling numerical kernels to wasm could have its uses though.

Happy to help give pointers if you want to hack on any of these things. Something in the ONNX/FluxJS/deployment bucket would be easy to get started with. WebAssembly.jl is solid and would probably make Mjolnir->WASM quite easy.

AriMKatz commented 4 years ago

That all sounds quite excellent, I'm very excited about this work and what it bodes for me being able to use Julia at work :)

To start, I'd like to explore working on emitting code for resource constrained systems. My initial inclination is that it would be initially easiest to go for targeting tensorflow lite which does things like quantization etc, potentially even TF lite for microcontrollers to get at even lighter targets like https://www.youtube.com/watch?v=HzCRZsGJLbI . Another possible target is : https://github.com/google/iree

Though, to what extent it would be a good idea to skip all that just work on emitting slim c code? Especially because I'm not sure yet if TF lite for microcontrollers allows use of custom ops.

I'm going to have to do a bit more digging to sharpen this, but these are my initial thoughts.

Edit: I don't want to get ahead of myself though. Perhaps just focusing on basic TF lite for now would be best, though I'd need to be able to integrate custom ops.

Another question I need to explore is at what point in the stack does quantization need to happen: https://blog.tensorflow.org/2020/04/quantization-aware-training-with-tensorflow-model-optimization-toolkit.html

MikeInnes commented 4 years ago

It's a little clumsy right now, but here's how you can get a graph for the forward pass of a simple model, ready to deploy:

```julia (xla-test) pkg> add Flux https://github.com/MikeInnes/Mjolnir.jl https://github.com/MikeInnes/XLATools.jl#next julia> using Flux, XLA julia> m = Chain(Dense(10, 5, relu), Dense(5, 2)); julia> XLA.@trace XLA.Primitives() m(Vector{Float32}) 1: (%1 :: const(Chain(Dense(10, 5, relu), Dense(5, 2))), %2 :: Array{Float32,1}) %3 = Float32[-0.50585645 -0.20598492 … -0.1412567 0.15082987; -0.0841699 -0.57924235 … -0.3025245 -0.27678147; … ; -0.16991931 -0.6295842 … -0.13748969 -0.32836327; 0.018975155 -0.22297584 … 0.1435846 0.5270162] :: const(Float32[-0.50585645 -0.20598492 … -0.1412567 0.15082987; -0.0841699 -0.57924235 … -0.3025245 -0.27678147; … ; -0.16991931 -0.6295842 … -0.13748969 -0.32836327; 0.018975155 -0.22297584 … 0.1435846 0.5270162]) %4 = Float32[0.0, 0.0, 0.0, 0.0, 0.0] :: const(Float32[0.0, 0.0, 0.0, 0.0, 0.0]) %5 = (*)(%3, %2) :: Array{Float32,1} %6 = (Base.Broadcast.broadcasted)(+, %5, %4) :: Array{Float32,1} %7 = (Base.Broadcast.broadcasted)(NNlib.relu, %6) :: Array{Float32,1} %8 = Float32[-0.078295 0.9035908 … 0.76721174 0.37824208; 0.08101376 0.5027532 … 0.39849186 0.39398715] :: const(Float32[-0.078295 0.9035908 … 0.76721174 0.37824208; 0.08101376 0.5027532 … 0.39849186 0.39398715]) %9 = Float32[0.0, 0.0] :: const(Float32[0.0, 0.0]) %10 = (*)(%8, %7) :: Array{Float32,1} %11 = (Base.Broadcast.broadcasted)(+, %10, %9) :: Array{Float32,1} return %11 ```

Turning this into a graph for whatever framework, or even C code, should be pretty straightforward.

would be a good idea to skip all that just work on emitting slim c code?

I think this could be a nice approach; the main potential problem is that we support broadcasting/mapping arbitrary functions. That's hard to do in C but might be possible in a templated C++ library like Eigen. XLA can do it too, so perhaps TF lite can. The other option is to only support built-in activation functions.

Theoretically, I think you may even be able to just get XLA to dump object code, but I've no idea how hard that is in practice.

at what point in the stack does quantization need to happen

This is a good question that I'm not sure of either. AIUI you can potentially do quantisation (and similar things like weight pruning) before training or after it, as a deployment optimisation. It feels like that could be a fairly straightforward API in Flux (basically an fmap to convert the weights and make sure we support low precision in the AD), but I'm not sure if these techniques ever take advantage of the network structure somehow.