denizyuret / Knet.jl

Koç University deep learning framework.
https://denizyuret.github.io/Knet.jl/latest
Other
1.43k stars 230 forks source link

Modular Interface #144

Open MikeInnes opened 7 years ago

MikeInnes commented 7 years ago

Knet's current API requires that all model parameters are collected together and passed into the model at once. This design has several issues: it makes code reuse more difficult, leading to hand-coded implementations of standard layers like LSTM in each model. It also makes it impossible to abstract over structure, so that (for example) replacing LSTM with GRU in a model requires significant refactoring rather than being a one-line change. Furthermore, it hampers performance: standard layers like LSTM cannot take advantage of the advanced static optimisations that are possible in other frameworks.

I want to discuss and get feedback on a new design that attempts to solve these problems without compromising Knet's simplicity and flexibility.

AutoGrad Variables

Using a layer shouldn't require knowledge of its internals, so layers need to be able to create and manage their own parameters. In other words, tracked variables need to be part of the autograd API, rather than being hidden inside the grad function. Here's a sketch of what this could look like:

x = AutoGrad.Variable([1,2,3])
y = sum(2*x) # y::AutoGrad.Variable
back!(y, 1)
x.grad == [1, 1, 1]

A Variable is essentially the current Rec; it stores at least a value, a gradient, and the operation / inputs that created it. The back! function performs backpropagation and updates the gradients of input variables recursively, so x now stores a gradient that we can use for training.

Layers

A layer is now just a function with some differentiable parameters.

struct Affine
  W
  b
end

Affine(in, out) = Affine(Variable(randn(in, out)),
                         Variable(randn( 1, out)))

(a::Affine)(x) = x*W + b

We can construct an Affine layer and then call it like a function on a normal array. Because the params are Variables, the output will also be a Variable which we can use for backpropagation. This design makes Keras-like abstractions pretty trivial:

struct Chain
  layers::Vector{Any}
end

(c::Chain)(x) = foldl((c, x) -> c(x), x, c.layers)

Chain(Affine(10, 20), Affine(20, 5))

This method of defining layers is not at all incompatible with Knet's current approach (via grad); one can freely mix and match plain Julia functions with layers defined in this way.

Training

Training does not need to change significantly; the only real difference is that in addition to the parameters and gradients supplied by grad, you'll have ones internal to the model. We can make it easy to pull these out by defining a common interface like params (e.g. params(a::Affine) = [a.W, a.b]); then you can train these variables in the obvious way.

Having a common interface for collecting parameters also makes it easier to abstract out the training process itself; hopefully in future we'll have generic optimisers which hide this whole process.

Longer Term Plans

If you've followed Flux at all you may notice that the design is very similar. We've been working on things like generic optimisers, and a design like this would enable us to share those between Flux and Knet. More ambitiously, my hope is that we can use static, define-before-run models like Affine inside of Knet models. The static models can then be heavily optimised – e.g. with custom JITed GPU kernels or parallelism – but still seamlessly be used inside very dynamic Knet models.

Hopefully the fact that this only adds to Knet's current interface should be enough evidence that it doesn't sacrifice flexibility. If you're still concerned I recommend checking out PyTorch, which has been very successful at achieving performance, flexibility and modularity with a very similar API. Overall I think it will make it much easier to create, modify and reuse Knet models.

Any thoughts?

AStupidBear commented 7 years ago

Cannot agree with you more

denizyuret commented 7 years ago

Summarizing some points from this week's discussion:

  1. We need a way to run in "test mode" without overhead of AD, tapes and that allows us to see the real loss not wrapped in a Variable.
  2. In the future we may want a more functional interface (possibly built on this one) where either (i) back returns an iterable of gradients that parallel params(model), or (ii) we have some interface like grad(model) that does the same.
  3. We need to discuss the "update" function for optimization, in particular (i) how it can be easily applied to a composite model with many subparts, (ii) where it stores the optimization state (like first and second moments of Adam, learning rates etc.), (iii) the possibility of having a totally different "solver" interface.
  4. We need to think about the initialization of (i) model parameters, (ii) optimization parameters, (iii) data (minibatching etc.)
jbrea commented 6 years ago

I really like the idea of the modular interface. A simple, but not so elegant way to have this interface right now, is to add functions like (a::Affine)(w, x) = x*w[1] + w[2].

Here is a more complete example:

using Knet

# type defs

struct Dense{T,Tu}
    w::Array{T, 2}
    b::Array{T, 1}
    unit::Tu
end
Dense(dimin, dimout; T = Float32, unit = relu, initfun = xavier) = 
    Dense(initfun(T, dimout, dimin), zeros(T, dimout), unit)

params(l::Dense) = Any[l.w, l.b]
(l::Dense)(x) = l.unit.(l.w * mat(x) .+ l.b)
(l::Dense)(w, x) = l.unit.(w[1] * mat(x) .+ w[2]) # needed for grad

struct Conv{T,Tu}
    w::Array{T, 4}
    b::Array{T, 4}
    unit::Tu
    convkargs
end
Conv(xdim, ydim, inc, outc; T = Float32, unit = relu, initfun = xavier, convkargs...) =
    Conv(initfun(T, xdim, ydim, inc, outc), zeros(T, 1, 1, outc, 1), unit, convkargs)

params(l::Conv) = Any[l.w, l.b]
(l::Conv)(x) = l.unit.(conv4(l.w, x; l.convkargs...) .+ l.b)
(l::Conv)(w, x) = l.unit.(conv4(w[1], x; l.convkargs...) .+ w[2])

struct Pooling
    poolkargs
end
Pooling(; kargs...) = Pooling(kargs)

params(l::Pooling) = Any[]
(l::Pooling)(x) = pool(x; l.poolkargs...)
(l::Pooling)(w, x) = l(x)

struct Chain
    layers::Array{Any, 1}
end
Chain(l...) = Chain(collect(l))
params(c::Chain) = [params(l) for l in c.layers]
(c::Chain)(x) = foldl((x, c) -> c(x), x, c.layers)
(c::Chain)(w, x) = foldl((x, c) -> c[2](c[1], x), x, zip(w, c.layers))

struct Model
    chain::Chain
    w::Array{Any, 1}
end
Model(c) = Model(c, params(c))

# utility functions

loss(w, c, x, y) = nll(c(w, x), y)
gradfun = grad(loss)
function trainepoch!(model, data, gradfun, opt)
    for (x, y) in data
        update!(model.w, gradfun(model.w, model.chain, x, y), opt)
    end
end

import Knet.optimizers
optimizers(m::Model, opt) = optimizers(model.w, opt)

import Knet.accuracy
accuracy(m::Model, data) = accuracy(m.w, data, m.chain)

# example

include(Knet.dir("data","mnist.jl"))
xtrn, ytrn, xtst, ytst = mnist()
dtrn = minibatch(xtrn, ytrn, 100)
dtst = minibatch(xtst, ytst, 100)

model = Model(Chain(Conv(5, 5, 1, 20), Pooling(), Conv(5, 5, 20, 50), Pooling(), 
              Dense(800, 500), Dense(500, 10, unit = identity)))
opt = optimizers(model, Adam)
@time for _ in 1:5
    trainepoch!(model, dtrn, gradfun, opt)
    println(accuracy(model, dtst))
end

Are you interested in turning this into a pull request, or are there anyway many changes planned for julia v0.7, that would allow a more elegant way to implement this inferface?

edit: After @CarloLucibello's comment I wrapped the chain and its parameters into the struct Model to avoid passing around w and the chain. This does not, however, solve the issue of method duplication.

CarloLucibello commented 6 years ago

@denizyuret would autograd allow for a pytorch-style interface like:

x = Rec(rand(10)) # the parameters of our network
y = sum(x)  # our loss' output, still a Rec
backprop!(y) 
x.grad # now contains the gradient

If so, we could take pytorch's (and flux's) approach and initialize the parameters in the modules as Rec. This would allow to avoid this kind of methods' duplication:

(l::Dense)(x) = l.unit.(l.w * mat(x) .+ l.b)
(l::Dense)(w, x) = l.unit.(w[1] * mat(x) .+ w[2]) # needed for grad

and passing both w and the model around

Evizero commented 6 years ago

whatever ends up happening, please don't replace the current low-level interface with this member-variable approach. I really prefer working with Knet the way it currently works over the alternatives that are around, and would be sad to see that go away

ngphuoc commented 6 years ago

It seems the proposed approach doesn't replace the current low-level interface but provide a high level wrapper.

Another straight forward way to avoid the method duplication is to define a helper macro w

@w (l::Dense)(x) = unit.(w * mat(x) .+ b)

which translates to

(l::Dense)(ws,x) = begin 
w,b = ws[l.name]
unit.(w * mat(x) .+ b)
end
jbrea commented 6 years ago

@CarloLucibello Yes, the method duplication is the non-elegant part. Not having to pass around w and chain is easy to fix (see updated comment above). @ngphuoc Some macros may help, but it could be a bit tricky to come up with a general solution (see e.g. the method duplication for the pooling layer or imagine a similar thing for LSTM layers). @Evizero My proposition would stricktly add a high level wrapper, leaving the low-level interface untouched.

MikeInnes commented 6 years ago

@evizero You call the grad approach low-level but I actually think the variable version is more primitive. You can implement grad in Flux in about three lines (it's included in the library if you want a Knet-style API), but you can't do modular variables with only grad. In either case, there's certainly no replacement going on.

MikeInnes commented 6 years ago

In case that's not totally clear:

julia> using Flux.Tracker: gradient

julia> gradient((x,y) -> sum(x.*y), [1,2,3], [4,5,6])
([4.0, 5.0, 6.0], [1.0, 2.0, 3.0])

In future this will become the primary interface to Flux's AD (it's the only sensible way to get nested derivatives), but I want to do it in a way that supports both a grad operator and modularity at the same time.

davidssmith commented 6 years ago

I'm a bit late to the party, but without knowing about this thread I made my own modular interface for Knet. I cleaned it up and posted it in a repo in case it is of use:

https://github.com/davidssmith/KnetLayers/blob/master/src/KnetLayers.jl

Here's a preview of my still very dirty implementation:

abstract type Layer end

mutable struct NeuralNet 
    layers::Array{Layer}
end
NeuralNet() = NeuralNet(Layer[LinearLayer(1)])

depth(N::NeuralNet) = length(N.layers)

mutable struct LinearLayer <: Layer
    size::Int
end
operator(L::LinearLayer) = (x, w, b) -> w*mat(x) .+ b   # Fully connected layer
nparams(L::LinearLayer) = 2
nparams(L::LinearLayer, indims) =  L.size * prod(indims[1:end-1])
string(L::LinearLayer) = "Dense($(L.size))"
outdims(L::LinearLayer, indims) = (L.size, indims[end])
weights(L::LinearLayer, indims) = (xavier(L.size, prod(indims[1:end-1])),zeros(L.size))
flops(L::LinearLayer, indims) = prod(indims[1:end-1]).^2*L.size + L.size
==(a::LinearLayer, b::LinearLayer) = a.size == b.size
const Linear = LinearLayer
const Dense = LinearLayer

etc...

operator(A::NeuralNet) = function (w, x)
    i = 1
    for L in A.layers
        n = nparams(L)
        x = operator(L)(x, w[i:i+n-1]...)
        i += n
    end
    return x
end
function weights(net::NeuralNet, indims; atype=Array{Float32})
    W = Any[]
    D = indims  # running value of data size as it passes through net
    for L in net.layers
        push!(W, weights(L, D)...)
        D = outdims(L, D)
    end
    np = sum([length(w) for w in W])
    return (map(a -> convert(atype, a), W), np)
end

etc...

predict = operator(net)

I've found no significant speed penalty for running Knet using my Layers interface.

One perk of the way I wrote it is that you can define the network without knowing anything about the input data. For example you can just write Dense(100) for a linear layer with 100 hidden units, rather than the PyTorch-style Dense(768,100), where you have to know the size of the input data to each layer.

Also, my objects don't carry around their weights. Weight generation and use is done entirely in the train method, so the objects end up very lightweight and can probably all be immutable.

ngphuoc commented 6 years ago

@denizyuret, is the struct methods, i.e. (a::Affine)(w, x) = x*w[1] + w[2], fully supported by AutoGrad? I mean will AutoGrad lost track of some global weights and cause gradcheck to fail? If so we should only use normal functions and wait for when this is fully supported.

f(w, x) = x*w[1] + w[2] f::Function but a::Affine

denizyuret commented 6 years ago

OK, I got Knet and AutoGrad finally catch up with the times (see the latest master). The problem was the old AutoGrad could not see into or create structs, and g=grad(f) was a very rigid interface for specifying what the differentiation should be with respect to. Inspired by Flux and after some thinking I fit the new AutoGrad interface into four functions (Param, differentiate, gradient, and value described below) which allows Flux-like models if you have some familiarity with function-like objects. Everything should be backward compatible, so all examples with the old interface should still work (regular functions, parameters in first arg, using g=grad(f) etc.). I am really happy with the memory management, all differentiated related allocation is stored under the return value of differentiate, as soon as it goes out of scope all that memory can be garbage collected (whereas before I had two way pointers between parameters and temporary storage which sometimes caused issues if you went around the grad interface.). Also Ekin made some useful suggestions on memory clearing during the backward pass which allows much larger recurrent models to be processed now. The interfaces for RNNs and optimization parameters are still work in progress, comments welcome.

### new AutoGrad interface: Param, differentiate, value, gradient.
x = Param(rand(3,3))      # user can declare parameters
sum(x)  => 4.5            # they act like regular values outside of differentiation
y = differentiate(sum,x)  # however if you want the gradients
y => T(4.5)               # you get a struct
value(y) => 4.5           # which represents the same value
gradient(y,x) => Array{Float32}(3,3)  # but also contains gradients for any Params

# This allows a Flux-like interface:
struct Linear <: Model; w; b; end       # user declares a subtype of Model                                                                             
(f::Linear)() = (f.w, f.b)              # 0-arg call returns iterator over parameters
(f::Linear)(x) = (f.w * mat(x) .+ f.b)  # 1-arg call returns a prediction
(f::Linear)(x,y) = nll(f(x), y)         # 2-arg call returns loss

# Now we can simply train a model like this:
for (x,y) in data
    J = differentiate(f,x,y)
    for w in f()
        g = gradient(J,w)
        update!(w,g)
    end
end

Here is an example notebook in the new style.

ekinakyurek commented 6 years ago

This interface lacks deactivating params(i.e. differentiable weights) for specific parts of a model which is a combination of pre-exist models. We want to have a deactivate(::MyStruct) or detach(::MyStruct) which basically makes Params temporally normal parameter.

denizyuret commented 6 years ago

Please check out issue #347.