elixir-nx / axon

Nx-powered Neural Networks
Apache License 2.0
1.5k stars 101 forks source link

Unable to replace existing model layers #529

Closed wtedw closed 3 weeks ago

wtedw commented 11 months ago

I'm currently trying to implement the LoRA algorithm in Axon. It involves freezing the original model's weights, adding new weights inside an existing layer, and feeding the input into both original weights and new weights.

I noticed that Axon.map_nodes used to be able to replace layers, but now it only replaces Axon.Nodes. However, if I were to make a custom layer, it would return an %Axon{} struct. I figured it's possible to unravel the Axon struct to retrieve the node, but it doesn't seem right.

My skeleton draft atm:

# Define custom LoRA layer
defmodule Lora do
  import Nx.Defn

  def custom_layer(%Axon{} = input, %Axon{} = target_to_be_replaced) do
    ...
    Axon.layer(&custom_layer_impl)
  end

  defn custom_layer_impl(), do ...
end

# Import model
{:ok, unet} =
  Bumblebee.load_model({:hf, repository_id, subdir: "unet"},
    params_filename: "diffusion_pytorch_model.bin"
  )

# Get Axon model
%{spec: spec, model: unet_model, params: params} = unet

# Replace attention nodes
new_model = Axon.map_nodes(unet_model, fn
  %Axon.Node{} = node ->
    # Can't use Lora.custom_layer() because it returns %Axon{} struct

  node ->
    node
end)

Previous way of replacing layers. Documentation is outdated I believe. https://hexdocs.pm/axon/Axon.html#map_nodes/2

Another use case is to replace entire classes of layers with another. For example, you may want to replace all relu layers with tanh layers:

new_model = Axon.map_nodes(model, fn
  %Axon{op: :relu} = graph ->
    # Get nodes immediate parent
    parent = Axon.get_parent(graph)
    # Replace node with a tanh
    Axon.tanh(parent)

  graph ->
    graph
end)
seanmor5 commented 11 months ago

Thanks for bringing this up, the %Axon{} data structure used to be the only data structure that represented a model/layer/etc. but now we have %Axon.Node{} and as you pointed out that means Axon.map_nodes can no longer be used to replace layers in the way you described.

As it stands, you could probably just replace the nodes internal properties to point to the custom layer implementation rather than using Lora.custom_layer, but honestly that feels hacky and I don't like it. I will need to think a bit about what a good API for this looks like.

I haven't looked at a LoRA implementation yet, so what you need to be able to do is replace specific Axon nodes with a LoRA version of the layer?

wtedw commented 11 months ago

Gotcha, that makes sense.

I'm not well versed in the LoRA implementation, but it looks like you'd need a wrapper layer that keeps the original implementation, but also creates two new parameters (lora_A, lora_B) to learn.

image

So for the implementation function, the calculation looks something like this

defn lora_embedding_impl(x, original_layer_impl(?), lora_A(?), lora_B(?), opts \\ []) do
    original_output = original_layer_impl(x)

    # Lora has different layers for embedding / conv / linear, 
    # but they all perform the matrix operation: BAx
    after_a = Axon.Layers.embedding(x, lora_A)
    after_b = Nx.dot(after_a, lora_B)

    # Combine original output with our lora calculations
    Nx.add(original_output, after_b)
  end

#  reference: https://github.com/microsoft/LoRA/blob/main/loralib/layers.py#L79-L85

I think even if I replaced the node's internal properties to point to the custom implementation, it 's not possible to inject new lora parameters w/ Axon.param().

I'll keep hacking away at it to see what's possible. Appreciate your input!

wtedw commented 9 months ago

Hey @seanmor5, I was able to implement LoRA with a couple of tricks.

I ended up not using map_nodes. Instead I created new nodes by extracting them out of Axon.layer. Afterwards I added these new nodes into the original Axon struct, and then wired existing nodes to connect to the new nodes.

Thought I'd leave this comment for anybody who's trying to do something similar. See for more details: https://github.com/wtedw/lorax/blob/main/lib/lorax.ex#L47