elixir-nx / axon

Nx-powered Neural Networks
Apache License 2.0
1.55k stars 102 forks source link

Add model import/export API #26

Closed seanmor5 closed 2 years ago

seanmor5 commented 3 years ago

Need ability to serialize models to/from external formats. Model serialization is serialization of the actual computation graph. We should also have the ability to save and load model parameters, but I believe part of that discussion needs to happen upstream with a common Nx tensor serialization format. See e.g. https://github.com/elixir-nx/nx/issues/354

seanmor5 commented 3 years ago

I have thought a bit about this issue and have the following proposal.

First, the training API needs to be updated to accept either a model initialization function or already existing parameters or even a combination of both. This was brought up by @arpieb when doing some RL applications. This will definitely also be necessary for transfer learning where we import part of the model and initialize another part for training.

Second, I would introduce the following methods for now into the Axon API:

load(file_or_url, opts \\ []) :: {:ok, {model, params}} | {:ok, model} | {:error, msg}

export(model, opts \\ []) :: :ok | {:error, msg}
export(model, params, opts \\ []) :: :ok | {:error, msg}

The reason for there not being any methods for loading/saving weights specifically is because I believe that is mostly related to https://github.com/elixir-nx/nx/issues/354 and a responsibility that should be on Nx.

An example of working with this API would be something like:

file = ...get file from somewhere (like tfhub)...

{model, params} = Axon.load(file, format: TFLite)

Axon.predict(model, params, inp)

TFLite is a module that implements some behaviour which converts .tflite models to Axon structs.

In order to include part of a model in something like transfer learning, we'd need to indicate the layers are frozen. It's an Axon struct, so I would suggest something similar to:

{feature_extractor, fe_params} = Axon.load(file, format: TFLite)

model =
  feature_extractor
  |> Axon.freeze(fe_params, filter: :all)
  |> Axon.flatten()
  |> Axon.dense(10, activation: :softmax)

The reason behind the filter option is to conditionally freeze some layers in a large model but not all. We can also add a :frozen option to each high-level layer so layers can be declared as frozen as they are used.

arpieb commented 3 years ago

Might I suggest defining an import/export protocol to standardize the workflow for other possible formats? Axon should definitely roll its own implementation for completeness, expediency, and tight integration with Nx. That being said, a well-defined protocol would also allow external modules to im-/export Axon models into formats like ONNX, HDF5, XML (don't laugh too hard - you can encode binary data, and I've seen it in the wild!), etc ad nauseum - pretty much any format that can support graph architecture as well as stateful data (weights, biases, recurrent components). Some day not far in the future someone will train a model in Axon leveraging distributed, streaming training and want to export it for production, possibly on a non-BEAM system... ;)

Freezing layers and the ability to cherry-pick layers from an existing model as input or output for an Axon model is definitely going to be needed for transfer learning, but I'd argue that is another separate issue from import/export. Glad to see it's on the roadmap though!