elixir-nx / axon

Nx-powered Neural Networks
Apache License 2.0
1.55k stars 102 forks source link

Introduce a high-level layer API #2

Closed seanmor5 closed 3 years ago

seanmor5 commented 3 years ago

The next step after #1 is to implement higher-level constructs on top of the lower-level functional implementations. The goal of the higher-level API is to provide abstractions for building neural networks that:

For simplicity, we'll leave the discussion of efficient handling of network state to a later issue. This issue will only focus on the architecture/representation of a network.

Axon Struct

We will introduce an %Axon{} struct that represents a constructed network. For now, the struct will have the following attributes:

The struct would be build up with calls to high-level functions in the root Axon namespace. For example, MNIST:

model =
  Axon.input({nil, 784})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10, activation: :softmax)

Then a model would be compiled, perhaps like:

compiled_model = Axon.compile(model, options)

At a minimum Axon.compile initializers parameters and returns a compiled function using whatever backend the user specifies. It can also be abstracted away in some higher-level training logic, but that's discussion for another issue.

Structuring the network in this way makes arbitrary compilation easy, the API is simple and easy to understand, and flexible enough for complex models. For example, a GAN:

generator =
   Axon.input({nil, 128})
   |> Axon.dense(256, activation: :tanh)
   |> Axon.dense(512, activation: :tanh)
   |> Axon.dense(784, activation: :tanh)
   |> Axon.reshape({28, 28})

discriminator =
   Axon.input({nil, 784})
   |> Axon.dense(128, activation: :relu)
   |> Axon.dense(1, activation: :sigmoid)

combined =
  generator
  |> Axon.compose(discriminator)

Layers

High-level layers are tied directly to their functional implementations in Axon.Layers. Some of them have layer specific options which can be passed during layer creation.

Combinators

We will use combinators similar to: https://thinc.ai/docs/api-layers/#combinators to represent more complex relationships between layers. At a minimum we'd have:

josevalim commented 3 years ago

This looks great. Just one note: since this syntax exists in a module body which will most likely be used to define defn functions, the activation function most likely cannot be a closure. So you will probably support things like this:

activation: :relu
activation: {:relu, some_option: 123}
activation: {MyCustomActivation, :some_option, []} # the result is passed as first argument

I also hear a lot about PyTorch modules. How would they fit on this design? Is it the ability to do your own layers? Or is it precisely the ability to compile a NN to a module?

seanmor5 commented 3 years ago

If I understand correctly, PyTorch modules are just neural network building blocks. Theoretically, we could support a Module within this same scope as a behaviour, but each method would have to work on the Axon data structure rather than on the tensors themselves, which might not necessarily be ideal for trying to express computations using Nx or lower-level Axon functions.

Another option is to support modules through somewhat of a metaprogramming approach:

deflayer SomeLayerNet do
   defn forward_pass, do: ...
end

Or something similar. Another option is to treat the Axon data structure as being equivalent to a PyTorch module, and then implement some of this functionality on it: https://pytorch.org/docs/stable/generated/torch.nn.Module.html

Basically the data structure implicitly defines the forward and backward pass based on whatever building blocks you use, and then some of the other functions can be implemented to accept and work on an %Axon{} struct. My understanding is that Modules are really an OOP approach to Deep Learning, and we more or less have to rethink that paradigm.

cc @jeffreyksmithjr curious if you can give some feedback wrt PyTorch modules cc @elbow-jason for discussion on this API

seanmor5 commented 3 years ago

I have been thinking more and more about this, and I think the best way to extend this would be to introduce two more building blocks:

Axon.Transformation/Axon.Function and Axon.Parameter

Every Axon struct would consist of:

Transformations have at a minimum:

Parameters have at a minimum:

elbow-jason commented 3 years ago

I see above:

allow input layer shapes to have nil batch dimensions to represent arbitrary sized batches

With reference to shapes there exists a caveat.

The fully qualified shape of layers may not be available until the shape of the input is known.

In Annex, I referred to these two different kinds of shapes "concrete" and "abstract". A "concrete" meant all of a shape's dimensions were positive integers. An "abstract" shape was a shape that had a placeholder (I chose :any due to the ambiguity of nil) as the batch dimension. A better name than "abstract" may have been "partial".

If the input layer has a nil as its batch dim then all layers in the same neural network will also have nil in batch dimension of their shape.

The exception is a network that has at least one layer with all of it's shape information. The layers after this one would have full shapes. This would be something like a reduce_sum or reduce_mean or some batch-size agnostic loss function after which the batch dimension of the output is either contracted away or changed to a known size. It's quite uncommon to see. I've never actually seen it in a working neural network's layer definitions - only used for troubleshooting.

Not knowing the full input or output shapes of a neural network is quite common among NN frameworks. For example in tensorflow keras:

In [1]: import tensorflow as tf

In [2]: model = tf.keras.Sequential([
   ...:     tf.keras.layers.Flatten(),
   ...:     tf.keras.layers.Dense(128, activation='relu'),
   ...:     tf.keras.layers.Dense(10)
   ...: ])

In [3]: model.compute_output_shape()
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-f4e0e849bae2> in <module>
----> 1 model.compute_output_shape()

TypeError: compute_output_shape() missing 1 required positional argument: 'input_shape'

In [4]: model.compute_output_shape((400, 28, 28))
Out[4]: TensorShape([400, 10])

I think having partial shapes is useful.

As long as we are strict about where partial shapes are used and where they are not used I think we'll be okay.

elbow-jason commented 3 years ago

It may also be useful to allow (or require?) the unknown batch dim to be named. Just mulling it over.

elbow-jason commented 3 years ago

Every Axon struct would consist of:

  • input - another axon struct

You are creating a singly-linked-list-like structure here. I think you would need:

  • input - another axon struct or nil

Where nil is your recursive base case.

I am not convinced that a recursive struct is necessary or better than a struct that wraps a list though.

josevalim commented 3 years ago

I wouldn't worry too much about the struct right now. It is internal representation and we can always refactor it. My suggestion would be to start with the simplest thing that allows you to model the examples above. And then as the examples grow in size and complexity (different kinds of backprop, custom modules, etc), revisit the internal structure accordingly.

One question regarding the modules discussion: can we always mode a layer with a single function? Or more complex layers will have multiple functions depending on how it is plugged into the NN (forward pass, back pass, etc)?

elbow-jason commented 3 years ago

An idea

I was reading about the PyTorch modules above and was reminded of Plug. Specifically it brought to mind the module-and-function duality of the Plug architecture.

In the Plug architecture, "plugs" are used to compose a series of transformations to a Plug.Conn struct during a web request. In the Plug architecture, the implementation for a "plug" can be either a module that implements the 2 very simple callbacks or a function.

https://github.com/elixir-plug/plug

A Plug module has 2 callbacks: init/1 and call/2. In a Plug module, the init/1 callback prepares/initializes the second arg to call/2 and the call/2 callbacks takes a Plug.Conn struct as the first arg and anything as the second arg. The call/2 function must return a Plug.Conn struct (hopefully the same struct, but modified - or not modified in a side-effect only plug).

A Plug function is quite simple, it must take 2 args - once again - a Plug.Conn struct and anything as the second arg and - once again - it must return a Plug.Conn struct.

The cognitive overhead for how to use a Plug effectively is very low, but the abstraction is super extensible. Chris McCord used Plug and some very good metaprogramming to make the Phoenix framework a highly extensible easy-to-understand web framework.

As if Axon were Plug

In my imagination... if we attempt to extend this dual-form architecture to Axon what we end up with is axon-plug...

The axon-plug function:

The "closure or MFA" must:

The axon-plug module is very similar. The init/1 is used to prepare the options (just like Plug), and call/2 has the exact same description as the axon-plug function above.

This would let us implement an axon-plug module like this:

defmodule BackwardRelu do
  import Nx.Defn

  defn backward_relu(x, coeff) do
    # This github syntax highlighter already correctly handles defn ?!?!?
    if Nx.less(x, 0) do
      x * coeff
    else
      0
    end
  end

  def init(coeff \\ nil), do: coeff || 1.0

  def call(axon, coeff) do
    Axon.add_defn(axon, fn x -> backward_relu(x, coeff) end)
  end
end

and compose layers like this:

Axon.input({nil, 128})
|> Axon.dense(256)
|> Axon.tanh()
|> Axon.dense(size: 512, activation: :tanh)
|> Axon.dense(784)
|> Axon.activation(&Axon.Activations.tanh/1)
|> Axon.reshape({28, 28})
|> Axon.plug({BackWardRelu, 0.05})

Each of these Axon function calls matches the axon-plug architecture - Axon.tanh/1 has an ignored default 2nd arg.

Though I really like Axon.plug/2 as Axon.plug/3 better:

|> Axon.plug(BackWardRelu, 0.05)

And, of course, the function name plug is a head nod to Plug but could be anything.

I think this axon-plug idea is pretty cool and was quite fun to theory craft with.

Thoughts?

elbow-jason commented 3 years ago

After thinking about the plug-axon idea some more it is missing at least one glaringly obvious thing: state.

So we probably can't just copy Plug exactly one-for-one.

josevalim commented 3 years ago

I think there are a couple approaches we can take here. For example, one option is for us to say the pluggable layers receive Nx expressions and return Nx expressions (i.e. they are transforms). This way they can do whatever they want but with some sharp edges.

The other option is to say they need to adhere to some behaviour and plug the layer based on the functions implemented on the behaviour.

The best approach is most likely to generalize whatever we use to implement our built-in layers - so I would focus on that instead, and then look at the patterns we want to generalize and make extensible.

In other words, I would suggest to stitch some networks together with built-in layers to solve different problems and then change the internal structure and the main DSL accordingly. :)

seanmor5 commented 3 years ago

After taking some time to stitch some networks together and come up with an initial API, here are my takeaways for what this problem entails and goals of the API moving forward.

Model Abstraction

In order to avoid tying too much to the model abstraction initially, I think it's best to frame the model as two functions: init and predict/apply. init initializes the parameters and variables of the model and returns a tuple of parameters and a tuple of variables. For simplicity, we can assume predict takes parameters, variables, and input and returns a prediction and updated variables; however, it may be best to actually include two predict functions: one for training and one for inference. This is because you probably won't want to update the model variables during inference. Additionally, we will need the ability to turn off dropout layers during inference - a challenge for another day.

Given this framework, the layer API basically just needs to be a series of combinators that works on init and predict. This is the same approach Jax takes.

Problem

Given the simplified abstraction, we're now faced with 3 problems:

1) Tracking trainable parameters for each layer. We need to be able to determine the shape at each layer so we can determine the shape of each parameter. We also need a way to specify an initializer and some additional options for each parameter.

2) Tracking running variables for each layer. Again, we need to determine shapes and possibly initializers.

3) Building the actual expression for each layer.

Considerations

Keeping in mind the problems defined above, we also need to keep in mind the following goals or requirements:

1) We need to retain high-level information about the model. This is necessary for importing and exporting models to other formats. We need to be able to say this is a :dense layer with these weights/biases, this is a :conv, etc.

2) Layers should be arbitrarily composable with regular Nx functions and other numerical definitions. I believe this is absolutely necessary and could really set us apart.

3) Models should be arbitrarily composable as well. This is important for building reusable parts of the model (see the resnet example where building blocks are defined with def in a separate module). With this users should be able to mark trainable parameters and variables as well. One benefit here is we can actually implement some high-level layers in terms of existing primitives. It may be that using regular Elixir functions as the building block is okay - the context switching can be strange, but it might fine.

4) Multi-input, Multi-output models. I'm just including these because I haven't tried to implement one yet, so I'm not sure how difficult they will be to implement. You end with multiple graphs, but parts of the graph are actually shared, so we'll need some additional metadata probably on composite types.

Current Approach

The current approach is to build the Axon struct from within a model do..end block. The Axon struct tracks an id, name, output_shape, op, params, vars, and additional opts. We use __jit_params and __jit_predict__ to build an expression for both init_random_params and predict. The Axon struct starts at the bottom-most node, so we traverse up the graph until we hit the root (input nodes), and build the expression by applying existing numerical definitions along the way.

From an API perspective we currently support:

model do
 input() |> dense() |> dense()
end```

which yields `init_random_params` and `predict` as functions, as well as:

```elixir
model <name> do
  input() |> dense() |> dense()
end

which yields init_<name> and <name> as functions. As @josevalim mentioned, an alternative is:

  input() |> dense() |> dense() |> model()

My only preference for using do...end blocks and treating models like function calls would be the ability to do something like

model generator(latent) do
  latent |> dense() |> dense()
end

IOT eliminate the need for input calls, although we would completely lose shape information. I'm also not really sold on using the term model, maybe defaxon, defmodel, defpredict, defmlp...not sure.

From an architectural standpoint, a completely different approach is to use Nx.Defn.Expr metadata nodes to retain the necessary information. So we could wrap all of the layer implementations, activation implementations, etc. in a transform that adds the necessary metadata and then use a jit_expr in Nx to compile the expressions. I haven't experimented with this yet, so I'm not sure exactly how it would work, but it would make solving 2 and 3 in the above considerations much easier.

I think this might frame the problem a little better.

seanmor5 commented 3 years ago

Closing with new issues.