tallamjr / astronet

Efficient Deep Learning for Real-time Classification of Astronomical Transients and Multivariate Time-series
Apache License 2.0
14 stars 3 forks source link

Stack encoders to form Vaswani et al. Nx6 architecture #46

Closed tallamjr closed 3 years ago

tallamjr commented 3 years ago

These changes implement the repeated TransformerBlock's so that the model is akin to that described in the Vaswani et al paper.

This can be confirmed to indeed have the repeated blocks by inspecting model.summary() resulting in:

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv_embedding (ConvEmbeddin (None, 200, 32)           128
_________________________________________________________________
positional_encoding (Positio (None, 200, 32)           0
_________________________________________________________________
transformer_block (Transform (None, 200, 32)           6464
_________________________________________________________________
transformer_block_1 (Transfo (None, 200, 32)           6464
_________________________________________________________________
transformer_block_2 (Transfo (None, 200, 32)           6464
_________________________________________________________________
transformer_block_3 (Transfo (None, 200, 32)           6464
_________________________________________________________________
transformer_block_4 (Transfo (None, 200, 32)           6464
_________________________________________________________________
transformer_block_5 (Transfo (None, 200, 32)           6464
_________________________________________________________________
global_average_pooling1d (Gl (None, 32)                0
_________________________________________________________________
dropout_12 (Dropout)         multiple                  0
_________________________________________________________________
dense_36 (Dense)             (None, 20)                660
_________________________________________________________________
dropout_13 (Dropout)         multiple                  0
_________________________________________________________________
dense_37 (Dense)             (None, 6)                 126
=================================================================
Total params: 39,698
Trainable params: 39,698
Non-trainable params: 0
_________________________________________________________________

Closes #37