tallamjr / astronet

Efficient Deep Learning for Real-time Classification of Astronomical Transients and Multivariate Time-series
Apache License 2.0
14 stars 3 forks source link

[MR/35] Implement `PositionalEncoding` class #36

Closed tallamjr closed 3 years ago

tallamjr commented 3 years ago

As discussed in #35 , it seems a PositionalEncoding class is required to restore temporal information to the input sequence.

From Hands On ML book:

The positional embeddings are simply dense vectors (much like word embeddings) that represent the position of a word in the sentence. The nth positional embedding is added to the word embedding of the nth word in each sentence. This gives the model access to each word’s position, which is needed because the Multi-Head Attention layers do not consider the order or the position of the words; they only look at their relationships. Since all the other layers are time-distributed, they have no way of knowing the position of each word (either relative or absolute). Obviously, the relative and absolute word positions are important, so we need to give this information to the Transformer somehow, and positional embeddings are a good way to do this.

Examples can be found at https://www.tensorflow.org/tutorials/text/transformer which uses functions akin to:

def get_angles(pos, i, d_model):
  angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
  return pos * angle_rates

def positional_encoding(position, d_model):
  angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)

  # apply sin to even indices in the array; 2i
  angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

  # apply cos to odd indices in the array; 2i+1
  angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

  pos_encoding = angle_rads[np.newaxis, ...]

  return tf.cast(pos_encoding, dtype=tf.float32)

Which is then used later in an EncodingLayer like so:

class Encoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size,
               maximum_position_encoding, rate=0.1):
    super(Encoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
    self.pos_encoding = positional_encoding(maximum_position_encoding, 
                                            self.d_model)

    self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]

    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask):

    seq_len = tf.shape(x)[1]

    # adding embedding and position encoding.
    x = self.embedding(x)  # (batch_size, input_seq_len, d_model)
    x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x = self.enc_layers[i](x, training, mask)

    return x  # (batch_size, input_seq_len, d_model)

OR, in Hands-On ML book (pg 558), a PositionalEncoding class is defined like:

import numpy as np

class PositionalEncoding(keras.layers.Layer):
    def __init__(self, max_steps, max_dims, dtype=tf.float32, **kwargs):
        super(PositionalEncoding).__init__(dtype=dtype, **kwargs)
        if max_dims % 2 == 1:
            max_dims += 1  # max_dims must be even
        p, i = np.meshgrid(np.arange(max_steps), np.arange(max_dims // 2))
        pos_emb = np.empty((1, max_steps, max_dims))
        pos_emb[0, :, ::2] = np.sin(p / 10000 ** (2 * i / max_dims)).T
        pos_emb[0, :, 1::2] = np.cos(p / 10000 ** (2 * i / max_dims)).T
        self.positional_embedding = tf.constant(pos_emb.astype(self.dtype))

    def call(self, inputs):
        shape = tf.shape(inputs)
        return inputs + self.positional_embedding[:, : shape[-2], : shape[-1]]

To be used elsewhere as:

positional_encoding = PositionalEncoding(max_steps, max_dims=embed_size) 
encoder_in = positional_encoding(conv_embeddings)

With the diff of model.py perhaps something like:

diff --git a/astronet/t2/model.py b/astronet/t2/model.py
index 50eb080..89351a4 100644
--- a/astronet/t2/model.py
+++ b/astronet/t2/model.py
@@ -2,7 +2,7 @@ import tensorflow as tf
 from tensorflow import keras
 from tensorflow.keras import layers

-from astronet.t2.transformer import ConvEmbedding, TransformerBlock
+from astronet.t2.transformer import ConvEmbedding, PositionalEncoding, TransformerBlock

 class T2Model(keras.Model):
@@ -22,6 +22,7 @@ class T2Model(keras.Model):
         self.num_classes    = num_classes

         self.embedding      = ConvEmbedding(num_filters=self.num_filters, input_shape=input_dim)
+        self.pos_encoding   = PositionalEncoding(max_steps=input_dim[1], max_dims=embed_dim)
         self.encoder        = TransformerBlock(self.embed_dim, self.num_heads, self.ff_dim)
         self.pooling        = layers.GlobalAveragePooling1D()
         self.dropout1       = layers.Dropout(0.1)
@@ -32,6 +33,7 @@ class T2Model(keras.Model):
     def call(self, inputs, training=None):

         x = self.embedding(inputs)
+        x = self.pos_encoding(x)
         x = self.encoder(x)
         x = self.pooling(x)
         if training:
tallamjr commented 3 years ago

Useful blog post regarding Positional Encoding: https://kazemnejad.com/blog/transformer_architecture_positional_encoding/

tallamjr commented 3 years ago

It seems Geron's implementation used as is throws the following error:

  File "train.py", line 181, in <module>
    training()
  File "train.py", line 95, in __call__
    num_classes=num_classes,
  File "/Users/tallamjr/github/tallamjr/origin/astronet/astronet/t2/model.py", line 26, in __init__
    self.pos_encoding   = PositionalEncoding(max_steps=self.sequence_length, max_dims=self.embed_dim)
  File "/Users/tallamjr/github/tallamjr/origin/astronet/astronet/t2/transformer.py", line 24, in __init__
    super(PositionalEncoding).__init__(dtype=dtype, **kwargs)
TypeError: super() takes no keyword arguments

It may be easier to use the other proposed implementation, or, there is anotherone defined in: /Users/tallamjr/github/tallamjr/forks/tfmodels/official/nlp/modeling/layers/position_embedding.py which defines a class RelativePositionEmbedding(tf.keras.layers.Layer): as follows:

class RelativePositionEmbedding(tf.keras.layers.Layer):
  """Creates a positional embedding.

  This layer calculates the position encoding as a mix of sine and cosine
  functions with geometrically increasing wavelengths. Defined and formulized in
   "Attention is All You Need", section 3.5.
  (https://arxiv.org/abs/1706.03762).

  Arguments:
    hidden_size: Size of the hidden layer.
    min_timescale: Minimum scale that will be applied at each position
    max_timescale: Maximum scale that will be applied at each position.
  """

  def __init__(self,
               hidden_size,
               min_timescale=1.0,
               max_timescale=1.0e4,
               **kwargs):
    # We need to have a default dtype of float32, since the inputs (which Keras
    # usually uses to infer the dtype) will always be int32.
    # We compute the positional encoding in float32 even if the model uses
    # float16, as many of the ops used, like log and exp, are numerically
    # unstable in float16.
    if "dtype" not in kwargs:
      kwargs["dtype"] = "float32"

    super(RelativePositionEmbedding, self).__init__(**kwargs)
    self._hidden_size = hidden_size
    self._min_timescale = min_timescale
    self._max_timescale = max_timescale

  def get_config(self):
    config = {
        "hidden_size": self._hidden_size,
        "min_timescale": self._min_timescale,
        "max_timescale": self._max_timescale,
    }
    base_config = super(RelativePositionEmbedding, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, inputs, length=None):
    """Implements call() for the layer.

    Args:
      inputs: An tensor whose second dimension will be used as `length`. If
        `None`, the other `length` argument must be specified.
      length: An optional integer specifying the number of positions. If both
        `inputs` and `length` are spcified, `length` must be equal to the second
        dimension of `inputs`.

    Returns:
      A tensor in shape of [length, hidden_size].
    """
    if inputs is None and length is None:
      raise ValueError("If inputs is None, `length` must be set in "
                       "RelativePositionEmbedding().")
    if inputs is not None:
      input_shape = tf_utils.get_shape_list(inputs)
      if length is not None and length != input_shape[1]:
        raise ValueError(
            "If inputs is not None, `length` must equal to input_shape[1].")
      length = input_shape[1]
    position = tf.cast(tf.range(length), tf.float32)
    num_timescales = self._hidden_size // 2
    min_timescale, max_timescale = self._min_timescale, self._max_timescale
    log_timescale_increment = (
        math.log(float(max_timescale) / float(min_timescale)) /
        (tf.cast(num_timescales, tf.float32) - 1))
    inv_timescales = min_timescale * tf.exp(
        tf.cast(tf.range(num_timescales), tf.float32) *
        -log_timescale_increment)
    scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(
        inv_timescales, 0)
    position_embeddings = tf.concat(
        [tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)
    return position_embeddings

Where hidden_size seems to simply be the dimension of the model

tallamjr commented 3 years ago

It seems Geron's implementation used as is throws the following error:

It turns out, this was happening because of a mistake on my end, i.e

diff --git a/astronet/t2/transformer.py b/astronet/t2/transformer.py
index d5be210..2d126e8 100644
--- a/astronet/t2/transformer.py
+++ b/astronet/t2/transformer.py
@@ -22,18 +22,18 @@ class ConvEmbedding(layers.Layer):

 class PositionalEncoding(keras.layers.Layer):
     def __init__(self, max_steps, max_dims, dtype=tf.float32, **kwargs):
-        super(PositionalEncoding).__init__(dtype=dtype, **kwargs)
+        super(PositionalEncoding, self).__init__(dtype=dtype, **kwargs)

This was wrecking the __mro__ that was expected.

Refs: