Open marcinkoziej opened 1 year ago
Staring some more into the Axon code I noticed that Axon.Layers.lstm
and friends are defined using an arcane macro. Is something like this always necessary to create a recurrent network with Axon, or is it just means to not duplicate code?
Does this example help? https://github.com/elixir-nx/axon/blob/main/examples/generative/text_generator.exs
@polvalente thanks for a quick reply!
I have seen this guide – the problem with it is that it uses lstm as a black box, as it encapsulates and abstracts away the iteration over input data.
I wanted to write a simple RNN myself, in which I specify how the "cell" function looks like, and so I tried to use the dynamic_unroll(cell_fn, input_sequence, carry, input_kernel, recurrent_kernel, bias)
to build something similar myself.
However, I cannot figure out how to use it properly, as it has little documentation and its arguments nor return values are not described, and even it's argument list seems to be quite arbitrary (eh why. carry
, input_kernel
, recurrent_kernel
, bias
, and not just paramters
; or on the other hand, why just one bias
instead of input_bias
and recurrent_bias
)...
Right now I figured out maybe i should avoid the Axon's unroll_*
functions, and just use Nx.while
myself in a custom layer?
In general, I am not sure if Axon intends to provide a re-usable building block (like Axon.Loop
) to scan/unroll inputs, or are these functions tailor made for the 3 implemented recurrent NNs in axon package?
If the former, then it would be great to see a guide on how to use such abstraction to implement an custom RNN model (not lstm, not gru, etc).
I can't speak too much about the intention behind the design, but since we don't have an explicit @doc false
I'd expect this to be a public interface.
Some things I concluded (pending any corrections by @seanmor5) that might help you:
dynamic_unroll
and static_unroll
are, from the way I see it, ways to apply cell_fn
over the "sequence" axis of your input - that is, axis 1, the first axis after the batch dimension - carrying the results forward to the next entry in that dimension, effectively doing the equivalent of Enum.scan
over those entries. The only difference is that dynamic_
and static_
refer to whether you're building this scan onto your computational graph unrolled (static) or as a while loop (dynamic).
So if you have an input which is of the shape {batch, sequence, m, n, ...}
cell_fn
is a function which takes {batch, m, n, ...}
tensors, as well as a carry
, which is whichever state you need to carry over to the next entry in the sequence, and outputs a batch of output
and a batch of carry
.
Also you want to receive input_kernel
, recurrent_kernel
, bias
, which are the trainable params, and which shape will depend on your actual cell_fn
definition. For instance, in the code for lstm_cell, below, we can see that input_kernel
and recurrent_kernel
(which is the hidden_kernel
referred above) are weights that, together with bias, compose linear transformations in the form of an Axon.dense
layer.
Your cell_fn
could very well just have those as empty maps or constant values if you didn't want to apply any transformations whatsoever to your input and carry values.
{cell, hidden} = carry
{wii, wif, wig, wio} = input_kernel
{whi, whf, whg, who} = hidden_kernel
{bi, bf, bg, bo} = bias
i = gate_fn.(dense(input, wii, bi) + dense(hidden, whi, 0))
f = gate_fn.(dense(input, wif, bf) + dense(hidden, whf, 0))
g = activation_fn.(dense(input, wig, bg) + dense(hidden, whg, 0))
o = gate_fn.(dense(input, wio, bo) + dense(hidden, who, 0))
new_c = f * cell + i * g
new_h = o * activation_fn.(new_c)
output
and carry
are the conceptual equivalent of the {x, acc}
result in Enum.scan
or Enum.map_reduce
Thanks for some clarifications.
I am trying to go forward without using the _unroll
APIs but now I stumbled on another problem in RNN and sequential data: uneven length of input data.
I work on an example livebook where I rewrite a pytorch example to Axon.
In pytorch example, it was possible to work on uneven data (last names which have variable number of letters), but it seems Axon prefers a fixed length input (in LSTM example, there is a fixed sequence_size
).
I learned that I should pad the data with 0 to make all inputs same length, but there should be some way to tell Axon to ignore this padding (something like masking and padding described here for TensorFlow). Does Axon has a concept like this?
There is a new API Axon.mask
which does this that you can pass to Axon.lstm
and other RNNs. Something like this should work:
input = Axon.input("seq")
# pad token is 0
mask = Axon.mask(input, 0)
embed = Axon.embedding(input, ...)
{seq, state} = Axon.lstm(embed, 32, mask: mask)
Thanks! I saw it was just committed few days ago! When will it be released?
I would like to reiterate that an example of a custom RNN using all these features (unroll, masking, how to implement a "cell", can we call other layers from a RNN "cell") would be awesome to see in Axon guide!
Hi! Axon beginner here.
I struggle to figure out how to write a very simple RNN network. Basically I want to rewrite this pytorch example from a tutorial.
However, the Axon API makes it a bit convoluted to create networks that "scan" or "unroll" the input. After some digging I realized that I need to create something similar to
lstm_cell
andlstm
, but these APIs are not well documented (what dodynamic_unroll
arguments mean?). I am also not sure how to handle parameters in that case so the training mechanism (Axon.Loop.trainer
with standard optimizer and loss functions) can do it's job.