Closed fantypants closed 2 years ago
@fantypants My personal opinion is that this would be more similar to a combinator like in trax, which would match the Keras wrapper impl:
model =
Axon.input("sequence")
|> Axon.embedding(32, 64)
|> Axon.bidirectional(&Axon.lstm/3, merge: &Axon.concatenate/2)
Or something similar to that. Basically the interface is:
def bidirectional(%Axon{} = input, layer_fun, opts \\ []) do
opts = Keyword.validate!(opts, [:name, merge: &Axon.concatenate/2, axis: 1])
end
There would have to be a contract for layer_fun
and merge_fun
accepting and returning certain arguments, but at a high-level this is what I'd envision
@fantypants My personal opinion is that this would be more similar to a combinator like in trax, which would match the Keras wrapper impl:
model = Axon.input("sequence") |> Axon.embedding(32, 64) |> Axon.bidirectional(&Axon.lstm/3, merge: &Axon.concatenate/2)
Or something similar to that. Basically the interface is:
def bidirectional(%Axon{} = input, layer_fun, opts \\ []) do opts = Keyword.validate!(opts, [:name, merge: &Axon.concatenate/2, axis: 1]) end
There would have to be a contract for
layer_fun
andmerge_fun
accepting and returning certain arguments, but at a high-level this is what I'd envision
I really like that approach, I hadn't heard of trax before, however, in this context the above sounds like the correct solution, especially considering the issue I was having is the two different implementations & the naming conventions would've been quite confusing & exhaustive (sticking to my approach would've been more work per change like this)
I'll get some things together and start on the solution.
@seanmor5 i've been working on the implementation, im getting stuck at the deep merge with the following error:
** (Protocol.UndefinedError) protocol Nx.Container not implemented for #Axon<
inputs: ["inputs"]
> of type Axon (a struct), check the docs for Nx.Container for more information. This protocol is implemented for the following type(s): Any, Axon.None, Axon.StatefulOutput, Map, Tuple
The function i'm using is:
@doc type: :bidirectional
def bidirectional(%Axon{} = input, layer_fn, opts \\ [] ) do
opts = Keyword.validate!(opts, [:name, merge: &Axon.concatenate/2])
forward_input = input
backward_input = Axon.nx(input, &Nx.reverse(&1), op_name: :reverse)
forward_result = layer_fn.(forward_input)
backward_result = layer_fn.(backward_input)
Axon.Shared.deep_merge(forward_result, backward_result, opts[:merge])
end
However, at Axon.Shared.deep_merge it throws the above error. I'm presuming it's a straight forward error, i.e Nx.Container doesn't implement the Axon struct; in the Nx documentation it says that it can work with any type that inherits the Nx.Container so i'm wondering if it's Nx/Exla dependencies being incorrect?
I've used (from the PR comments you added, i've tried both configs):
# {:exla, "~> 0.2", [github: "elixir-nx/nx", sparse: "exla"]},
#{:nx, "~> 0.2", [github: "elixir-nx/nx", sparse: "nx", override: true]},
{:exla, "~> 0.3.0-dev", github: "elixir-nx/nx", sparse: "exla"},
{:nx, "~> 0.3.0-dev", github: "elixir-nx/nx", sparse: "nx", override: true}
Using apply/2 works, however, I don't think it's the same operation as deep_merge, and deep_merge sounds like the correct method for this (considering it's reducing the output layer, a deep traverse is probably much more accurate then a simple concat/merge fn applied on the inputs)
apply(merge_fn, [forward, backward])
Updates: After running some basic tests:
inp1 = Axon.input({nil, 1}, "input_0")
inp2 = Axon.input({nil, 2}, "input_1")
Axon.container(%{a: inp1, b: inp2})
|> Nx.Container.reduce([], fn x, a -> x end)
it seems even using a plain Axon.container & running it through Nx.Container.reduce reproduces the issue. Are we missing the containers option somewhere? similar too:
@derive {Nx.Container,
containers: [:field_name, :other_field]}
@fantypants You'll need to bring this implementation of deep_merge
into axon.ex
as a private helper:
defp deep_merge(left, right, fun) do
case Nx.Container.traverse(left, leaves(right), &recur_merge(&1, &2, fun)) do
{merged, []} ->
merged
{_merged, _leftover} ->
raise ArgumentError,
"unable to merge arguments with incompatible" <>
" structure"
end
end
defp leaves(container) do
container
|> Nx.Container.reduce([], fn x, acc -> [x | acc] end)
|> Enum.reverse()
end
defp recur_merge(left, [right | right_leaves], fun) do
case {left, right} do
{%Nx.Tensor{} = left, %Nx.Tensor{} = right} ->
{fun.(left, right), right_leaves}
{%Axon{} = left, %Axon{} = right} ->
{fun.(left, right), right_leaves}
{left, right} ->
{deep_merge(left, right, fun), right_leaves}
end
end
@seanmor5 that did the trick, results are now coming in! I have to do some branch shuffling with the other tutorial PR branch first, then i'll push this PR.
What's the purpose of the Axon.Layer functions vs the Axon functions? i.e Axon.lstm vs Axon.Layers.lstm; is this something I should include now into the PR?
@fantypants I have been thinking about this and it's actually more difficult than what I outlined. The implementation requires an implementation for #169, so I might have you hold off before implemnting something :)
Closing as tracked in #119 :)
Hello,
@seanmor5 I had put up a bidirectional example in the Nx Slack channel. Currently planning out the implementation and looking at the use cases for it as follows:
Tf/Keras uses the Bidirectional Layer, and subsequent usage for LSTM's etc is contained within the LSTM layer itself, via flags/opts (
go_backwards
, for LSTM sake).My plan is to implement a simple bidirectional layer and subsequently the applicable layers would have to have be refactored to include the new feature.
What are everyone's thoughts on how this should be implemented? similar to Keras using opts within the Layers + a Bidirectional Layer itself, or do we create Bidirectional LSTM layers instead, and use the LSTM layer as it is?
PR will be up soon for it, I'm assuming it'll need some good review!