Closed seanmor5 closed 3 years ago
This looks great. Just one note: since this syntax exists in a module body which will most likely be used to define defn
functions, the activation function most likely cannot be a closure. So you will probably support things like this:
activation: :relu
activation: {:relu, some_option: 123}
activation: {MyCustomActivation, :some_option, []} # the result is passed as first argument
I also hear a lot about PyTorch modules. How would they fit on this design? Is it the ability to do your own layers? Or is it precisely the ability to compile a NN to a module?
If I understand correctly, PyTorch modules are just neural network building blocks. Theoretically, we could support a Module
within this same scope as a behaviour, but each method would have to work on the Axon data structure rather than on the tensors themselves, which might not necessarily be ideal for trying to express computations using Nx
or lower-level Axon
functions.
Another option is to support modules through somewhat of a metaprogramming approach:
deflayer SomeLayerNet do
defn forward_pass, do: ...
end
Or something similar. Another option is to treat the Axon data structure as being equivalent to a PyTorch module, and then implement some of this functionality on it: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
Basically the data structure implicitly defines the forward and backward pass based on whatever building blocks you use, and then some of the other functions can be implemented to accept and work on an %Axon{}
struct. My understanding is that Modules are really an OOP approach to Deep Learning, and we more or less have to rethink that paradigm.
cc @jeffreyksmithjr curious if you can give some feedback wrt PyTorch modules cc @elbow-jason for discussion on this API
I have been thinking more and more about this, and I think the best way to extend this would be to introduce two more building blocks:
Axon.Transformation
/Axon.Function
and Axon.Parameter
Every Axon
struct would consist of:
input
- another axon structoutput_shape
- output shape of this layer, for inferring between inputs/outputstransformation
/function
- either an atom that maps to a known transformation or a custom transformationTransformations have at a minimum:
function
- the actual function that does the transform, probably a nesting of other transforms or a primitiveactivation
- optional, we can also build "Activation layers" with an identity function plus activationparameters
- a list of trainable parameters attached to the transformation, these are Axon.Parameter
structsoptions
- other layer specific optionsParameters have at a minimum:
name
- for attaching to pretrained modelsshape
- shape of the parameterinitializer
- parameter initializationoptions
- parameter specific optionsI see above:
allow input layer shapes to have
nil
batch dimensions to represent arbitrary sized batches
With reference to shapes there exists a caveat.
The fully qualified shape of layers may not be available until the shape of the input is known.
In Annex, I referred to these two different kinds of shapes "concrete" and "abstract". A "concrete" meant all of a shape's dimensions were positive integers. An "abstract" shape was a shape that had a placeholder (I chose :any
due to the ambiguity of nil
) as the batch dimension. A better name than "abstract" may have been "partial".
If the input layer has a nil
as its batch dim then all layers in the same neural network will also have nil
in batch dimension of their shape.
The exception is a network that has at least one layer with all of it's shape information. The layers after this one would have full shapes. This would be something like a reduce_sum
or reduce_mean
or some batch-size agnostic loss function after which the batch dimension of the output is either contracted away or changed to a known size. It's quite uncommon to see. I've never actually seen it in a working neural network's layer definitions - only used for troubleshooting.
Not knowing the full input or output shapes of a neural network is quite common among NN frameworks. For example in tensorflow keras:
In [1]: import tensorflow as tf
In [2]: model = tf.keras.Sequential([
...: tf.keras.layers.Flatten(),
...: tf.keras.layers.Dense(128, activation='relu'),
...: tf.keras.layers.Dense(10)
...: ])
In [3]: model.compute_output_shape()
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-3-f4e0e849bae2> in <module>
----> 1 model.compute_output_shape()
TypeError: compute_output_shape() missing 1 required positional argument: 'input_shape'
In [4]: model.compute_output_shape((400, 28, 28))
Out[4]: TensorShape([400, 10])
I think having partial shapes is useful.
As long as we are strict about where partial shapes are used and where they are not used I think we'll be okay.
It may also be useful to allow (or require?) the unknown batch dim to be named. Just mulling it over.
Every Axon struct would consist of:
input
- another axon struct
You are creating a singly-linked-list-like structure here. I think you would need:
input
- another axon struct ornil
Where nil
is your recursive base case.
I am not convinced that a recursive struct is necessary or better than a struct that wraps a list though.
I wouldn't worry too much about the struct right now. It is internal representation and we can always refactor it. My suggestion would be to start with the simplest thing that allows you to model the examples above. And then as the examples grow in size and complexity (different kinds of backprop, custom modules, etc), revisit the internal structure accordingly.
One question regarding the modules discussion: can we always mode a layer with a single function? Or more complex layers will have multiple functions depending on how it is plugged into the NN (forward pass, back pass, etc)?
I was reading about the PyTorch modules above and was reminded of Plug. Specifically it brought to mind the module-and-function duality of the Plug
architecture.
In the Plug
architecture, "plugs" are used to compose a series of transformations to a Plug.Conn
struct during a web request. In the Plug
architecture, the implementation for a "plug" can be either a module that implements the 2 very simple callbacks or a function.
https://github.com/elixir-plug/plug
A Plug module has 2 callbacks: init/1
and call/2
. In a Plug module, the init/1
callback prepares/initializes the second arg to call/2
and the call/2
callbacks takes a Plug.Conn
struct as the first arg and anything as the second arg. The call/2
function must return a Plug.Conn
struct (hopefully the same struct, but modified - or not modified in a side-effect only plug).
A Plug function is quite simple, it must take 2 args - once again - a Plug.Conn
struct and anything as the second arg and - once again - it must return a Plug.Conn
struct.
The cognitive overhead for how to use a Plug effectively is very low, but the abstraction is super extensible. Chris McCord used Plug and some very good metaprogramming to make the Phoenix framework a highly extensible easy-to-understand web framework.
In my imagination... if we attempt to extend this dual-form architecture to Axon what we end up with is axon-plug...
The axon-plug function:
must take an Axon struct as the first arg and anything as a second arg
must return an Axon struct (either the given %Axon{}
or the new head %Axon{}
in the recursive case)
in order to "add an operation" the function must append/concat/add a closure or MFA to the Axon struct's list of closures.
could update or add metadata, shape, etc.
The "closure or MFA" must:
return a {backend, operation}
tuple with a backend and backend-appropriate representation of operation - e.g. a Nx.Defn.Expr
struct (just like a defn
function) when called/applied when targeting a Defn backend.
have arity 1 for the closure
have the arity of length(args) + 1
for the MFA (the first arg is passed just like an arity 1 closure)
The axon-plug module is very similar. The init/1
is used to prepare the options (just like Plug), and call/2
has the exact same description as the axon-plug function above.
This would let us implement an axon-plug module like this:
defmodule BackwardRelu do
import Nx.Defn
defn backward_relu(x, coeff) do
# This github syntax highlighter already correctly handles defn ?!?!?
if Nx.less(x, 0) do
x * coeff
else
0
end
end
def init(coeff \\ nil), do: coeff || 1.0
def call(axon, coeff) do
Axon.add_defn(axon, fn x -> backward_relu(x, coeff) end)
end
end
and compose layers like this:
Axon.input({nil, 128})
|> Axon.dense(256)
|> Axon.tanh()
|> Axon.dense(size: 512, activation: :tanh)
|> Axon.dense(784)
|> Axon.activation(&Axon.Activations.tanh/1)
|> Axon.reshape({28, 28})
|> Axon.plug({BackWardRelu, 0.05})
Each of these Axon
function calls matches the axon-plug architecture - Axon.tanh/1
has an ignored default 2nd arg.
Though I really like Axon.plug/2
as Axon.plug/3
better:
|> Axon.plug(BackWardRelu, 0.05)
And, of course, the function name plug
is a head nod to Plug but could be anything.
I think this axon-plug idea is pretty cool and was quite fun to theory craft with.
Thoughts?
After thinking about the plug-axon idea some more it is missing at least one glaringly obvious thing: state.
So we probably can't just copy Plug exactly one-for-one.
I think there are a couple approaches we can take here. For example, one option is for us to say the pluggable layers receive Nx expressions and return Nx expressions (i.e. they are transforms). This way they can do whatever they want but with some sharp edges.
The other option is to say they need to adhere to some behaviour and plug the layer based on the functions implemented on the behaviour.
The best approach is most likely to generalize whatever we use to implement our built-in layers - so I would focus on that instead, and then look at the patterns we want to generalize and make extensible.
In other words, I would suggest to stitch some networks together with built-in layers to solve different problems and then change the internal structure and the main DSL accordingly. :)
After taking some time to stitch some networks together and come up with an initial API, here are my takeaways for what this problem entails and goals of the API moving forward.
In order to avoid tying too much to the model abstraction initially, I think it's best to frame the model as two functions: init
and predict
/apply
. init
initializes the parameters and variables of the model and returns a tuple of parameters and a tuple of variables. For simplicity, we can assume predict
takes parameters, variables, and input and returns a prediction and updated variables; however, it may be best to actually include two predict
functions: one for training and one for inference. This is because you probably won't want to update the model variables during inference. Additionally, we will need the ability to turn off dropout
layers during inference - a challenge for another day.
Given this framework, the layer API basically just needs to be a series of combinators that works on init
and predict
. This is the same approach Jax takes.
Given the simplified abstraction, we're now faced with 3 problems:
1) Tracking trainable parameters for each layer. We need to be able to determine the shape at each layer so we can determine the shape of each parameter. We also need a way to specify an initializer and some additional options for each parameter.
2) Tracking running variables for each layer. Again, we need to determine shapes and possibly initializers.
3) Building the actual expression for each layer.
Keeping in mind the problems defined above, we also need to keep in mind the following goals or requirements:
1) We need to retain high-level information about the model. This is necessary for importing and exporting models to other formats. We need to be able to say this is a :dense
layer with these weights/biases, this is a :conv
, etc.
2) Layers should be arbitrarily composable with regular Nx
functions and other numerical definitions. I believe this is absolutely necessary and could really set us apart.
3) Models should be arbitrarily composable as well. This is important for building reusable parts of the model (see the resnet example where building blocks are defined with def
in a separate module). With this users should be able to mark trainable parameters and variables as well. One benefit here is we can actually implement some high-level layers in terms of existing primitives. It may be that using regular Elixir functions as the building block is okay - the context switching can be strange, but it might fine.
4) Multi-input, Multi-output models. I'm just including these because I haven't tried to implement one yet, so I'm not sure how difficult they will be to implement. You end with multiple graphs, but parts of the graph are actually shared, so we'll need some additional metadata probably on composite types.
The current approach is to build the Axon
struct from within a model do..end
block. The Axon
struct tracks an id
, name
, output_shape
, op
, params
, vars
, and additional opts
. We use __jit_params
and __jit_predict__
to build an expression for both init_random_params
and predict
. The Axon
struct starts at the bottom-most node, so we traverse up the graph until we hit the root (input nodes), and build the expression by applying existing numerical definitions along the way.
From an API perspective we currently support:
model do
input() |> dense() |> dense()
end```
which yields `init_random_params` and `predict` as functions, as well as:
```elixir
model <name> do
input() |> dense() |> dense()
end
which yields init_<name>
and <name>
as functions. As @josevalim mentioned, an alternative is:
input() |> dense() |> dense() |> model()
My only preference for using do...end
blocks and treating models like function calls would be the ability to do something like
model generator(latent) do
latent |> dense() |> dense()
end
IOT eliminate the need for input
calls, although we would completely lose shape information. I'm also not really sold on using the term model
, maybe defaxon
, defmodel
, defpredict
, defmlp
...not sure.
From an architectural standpoint, a completely different approach is to use Nx.Defn.Expr
metadata nodes to retain the necessary information. So we could wrap all of the layer implementations, activation implementations, etc. in a transform
that adds the necessary metadata and then use a jit_expr
in Nx to compile the expressions. I haven't experimented with this yet, so I'm not sure exactly how it would work, but it would make solving 2 and 3 in the above considerations much easier.
I think this might frame the problem a little better.
Closing with new issues.
The next step after #1 is to implement higher-level constructs on top of the lower-level functional implementations. The goal of the higher-level API is to provide abstractions for building neural networks that:
For simplicity, we'll leave the discussion of efficient handling of network state to a later issue. This issue will only focus on the architecture/representation of a network.
Axon Struct
We will introduce an
%Axon{}
struct that represents a constructed network. For now, the struct will have the following attributes::input
- input to this layer/model/etc., or in the case of a literalinput
layer some metadata:shape
- shape of the layer's parameters, can be inferred from:input
, we can also allow input layer shapes to havenil
batch dimensions to represent arbitrary sized batches:transformation
- how does this layer transform the input, can be an atom (like:dense
) which resolves to an already implemented layer, or a numerical definition for arbitrarily complex transformations:initializer
and initializer options,:activation
and activation options, layer specific options, possibly constraints, callbacks, etc.The struct would be build up with calls to high-level functions in the root Axon namespace. For example, MNIST:
Then a model would be compiled, perhaps like:
At a minimum
Axon.compile
initializers parameters and returns a compiled function using whatever backend the user specifies. It can also be abstracted away in some higher-level training logic, but that's discussion for another issue.Structuring the network in this way makes arbitrary compilation easy, the API is simple and easy to understand, and flexible enough for complex models. For example, a GAN:
Layers
High-level layers are tied directly to their functional implementations in
Axon.Layers
. Some of them have layer specific options which can be passed during layer creation.Combinators
We will use combinators similar to: https://thinc.ai/docs/api-layers/#combinators to represent more complex relationships between layers. At a minimum we'd have:
compose
- function compositionadd
- adds layersconcat
- concats layersresidual
- residual outputparallel
/split
- something to represent multiple-model outputs