keras-team / keras-io

Keras documentation, hosted live at keras.io
Apache License 2.0
2.78k stars 2.04k forks source link

Possible parameter error in time series Transformer model example #845

Open stallam-unb opened 2 years ago

stallam-unb commented 2 years ago

The the example for time series classification with transformer, the the function build_model() is defined as:

def build_model(
    input_shape,
    head_size,
    num_heads,
    ff_dim,
    num_transformer_blocks,
    mlp_units,
    dropout=0,
    mlp_dropout=0,
):
    inputs = keras.Input(shape=input_shape)
    x = inputs
    for _ in range(num_transformer_blocks):
        x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)

    x = layers.GlobalAveragePooling1D(data_format="channels_first")(x) #<- This line appears to be wrong.
    for dim in mlp_units:
        x = layers.Dense(dim, activation="relu")(x)
        x = layers.Dropout(mlp_dropout)(x)
    outputs = layers.Dense(n_classes, activation="softmax")(x)
    return keras.Model(inputs, outputs)

The pooling layer is initialised as x = layers.GlobalAveragePooling1D(data_format="channels_first")(x). Isn't the data format however channels_last?

SuryanarayanaY commented 1 year ago

Hi @stallam-unb ,

The tutorial time series classification with transformer used FordA dataset which as per my understanding is actually time series data collected in 500 timestamps for one feature (i.e Noise). So the dataset which is initially of shape (x,500) converted into (x,500,1) and theoritically it should be of form (batch_size, steps, features).

Even in the Tutorial page under model summary it is mentioned like below.

Our model processes a tensor of shape (batch size, sequence length, features), where sequence length is the number of time steps and features is each input timeseries.

But as per documentation of GlobalAveragePooling1D, data_format="channels_first" means the data should be of shape (batch_size, features, steps)

As per the data we are passing it should be data_format="channels_last".

I may raise a PR for same.