claCase / Attention-as-RNN

Non-official implementation of "Attention as an RNN" from https://arxiv.org/pdf/2405.13956, efficient associative parallel prefix scan and recurrent version implemented.
MIT License
20 stars 0 forks source link

Time series classification #5

Open lys15163238308 opened 3 months ago

lys15163238308 commented 3 months ago

Hello, could you please update the code on time series classification, as well as the example Thank you !

claCase commented 3 months ago

Hi @lys15163238308 ,

I've updated the README with a classification example. The main idea is to squash the entire hidden state output sequence to a fixed size vector with the sequence statistics (mean, max, min) for each output feature dimension. This is achieved by global mean, max and min pooling. Then a final fully-connected layer is used to classify the sequence.

ki = tf.keras.Input(shape=(None, 1))
scan = models.ScanRNNAttentionModel([10, 10], [10, 10])
avg_pool = tf.keras.layers.GlobalAveragePooling1D()
max_pool = tf.keras.layers.GlobalMaxPooling1D()
min_pool = tf.keras.layers.Lambda(lambda x: tf.reduce_min(x, -2))
conc = tf.keras.layers.Concatenate()
dense = tf.keras.layers.Dense(1, "sigmoid")

h = scan(ki)
avgp = avg_pool(h)
maxp = max_pool(h)
minp = min_pool(h)
mix = conc([avgp, maxp, minp])
o = dense(mix)

classification_model = tf.keras.Model(ki, o)
classification_model.compile("adam", "bce")