philipperemy / keras-tcn

Keras Temporal Convolutional Network.
MIT License
1.86k stars 451 forks source link

How to freeze TCN layer? #212

Closed ktakanopy closed 2 years ago

ktakanopy commented 3 years ago

I realized that TCN layer does not has 'trainable' property, how can we freeze it's layers?

philipperemy commented 2 years ago

Great question. The TCN is a Temporal Convolutional Network. It's not technically a layer although it inherits from the Keras Layer object. It's closer to a Model. The choice was motivated to be able to easily swap the LSTM and TCN classes for comparison.

Now, yeah we can add a trainable property. In theory, we have the list of all the layers when we build the TCN layer. So we can just loop over each of them.

philipperemy commented 2 years ago

I checked and you can actually just set .trainable = False to your TCN layer. All the subsequent layers will be recursively set to trainable=False. They mention about it here: https://keras.io/guides/transfer_learning/ Look for: Recursive setting of the trainable attribute

philipperemy commented 2 years ago

Here is an example.

"""
Trains a TCN on the IMDB sentiment classification task.
Output after 1 epochs on CPU: ~0.8611
Time per epoch on CPU (Core i7): ~64s.
Based on: https://github.com/keras-team/keras/blob/master/examples/imdb_bidirectional_lstm.py
"""
import numpy as np
from tensorflow.keras import Sequential
from tensorflow.keras.datasets import imdb
from tensorflow.keras.layers import Dense, Dropout, Embedding
from tensorflow.keras.preprocessing import sequence

from tcn import TCN

max_features = 20000
# cut texts after this number of words
# (among top max_features most common words)
maxlen = 100
batch_size = 32

print('Loading data...')
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
print(len(x_train), 'train sequences')
print(len(x_test), 'test sequences')

print('Pad sequences (samples x time)')
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)
y_train = np.array(y_train)
y_test = np.array(y_test)

tcn2 = TCN(
    nb_filters=64,
    kernel_size=6,
    dilations=[1, 2, 4, 8, 16, 32, 64]
)

tcn2.trainable = False

model = Sequential()
model.add(Embedding(max_features, 128, input_shape=(maxlen,)))
model.add(tcn2)
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))

model.summary()

model.compile('adam', 'binary_crossentropy', metrics=['accuracy'])

print(tcn2.residual_blocks[0].get_weights()[0])

print('Train...')
model.fit(
    x_train, y_train,
    batch_size=batch_size,
    validation_data=(x_test, y_test),
    epochs=1
)

print(tcn2.residual_blocks[0].get_weights()[0])