Firstly, I want to extend my appreciation for the assistance provided thus far.
Following recent tests on both the training and test datasets, I noted an accuracy rate of 8%. While this is a step forward, it appears there might be underlying issues impacting the results.
Checkpoint Choosing: I choose the checkpoint located in LSTM_kmeam/experiment/bst_cpt/ (comprising 2420 epochs), rather than the one in the main folder's experiment/bst_cpt/ (with only 420 epochs). Can someone confirm if this is the correct approach?
Accuracy Calculation: I uncommented your code segment for calculating accuracy. However, the obtained results were not satisfactory. I would appreciate any insights on potential adjustments that could improve accuracy.
Thank you for your ongoing support and assistance in resolving these matters. (I checked other issues but it didn't help unfortunately)
Best regards,
Nady
here is my code (I deleted all the unnecessary comments):
`
if name == 'main':
n_channels = 14
n_feat = 128
batch_size = 256
test_batch_size = 1
n_classes = 10
with open('data/eeg/image/data.pkl', 'rb') as file:
data = pickle.load(file, encoding='latin1')
train_X = data['x_train']
train_Y = data['y_train']
test_X = data['x_test']
test_Y = data['y_test']
train_batch = load_complete_data(train_X, train_Y, batch_size=batch_size)
val_batch = load_complete_data(test_X, test_Y, batch_size=batch_size)
test_batch = load_complete_data(test_X, test_Y, batch_size=test_batch_size)
X, Y = next(iter(train_batch))
triplenet = TripleNet(n_classes=n_classes)
opt = tf.keras.optimizers.legacy.Adam(learning_rate=3e-4)
triplenet_ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=triplenet, optimizer=opt)
triplenet_ckptman = tf.train.CheckpointManager(triplenet_ckpt, directory='lstm_kmean/experiments/best_ckpt', max_to_keep=500)
triplenet_ckpt.restore(triplenet_ckptman.latest_checkpoint)
START = int(triplenet_ckpt.step) // len(train_batch)
if triplenet_ckptman.latest_checkpoint:
print('Restored from the latest checkpoint, epoch: {}'.format(START))
EPOCHS = 3000
cfreq = 5 # Checkpoint frequency
for epoch in range(START, EPOCHS):
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss = tf.keras.metrics.Mean()
test_acc = tf.keras.metrics.SparseCategoricalAccuracy()
test_loss = tf.keras.metrics.Mean()
tq = tqdm(train_batch)
for idx, (X, Y) in enumerate(tq, start=1):
loss = train_step(triplenet, opt, X, Y)
train_loss.update_state(loss)
Y_cap = triplenet(X, training=False)
train_acc.update_state(Y, Y_cap)
tq.set_description('Train Epoch: {}, Loss: {}, Acc: {}'.format(epoch, train_loss.result(), train_acc.result()))
# break
tq = tqdm(val_batch)
for idx, (X, Y) in enumerate(tq, start=1):
loss = test_step(triplenet, X, Y)
test_loss.update_state(loss)
Y_cap = triplenet(X, training=False)
test_acc.update_state(Y, Y_cap)
tq.set_description('Test Epoch: {}, Loss: {}, Test Acc:: {}'.format(epoch, test_loss.result(), test_acc.result()))
# break
triplenet_ckpt.step.assign_add(1)
if (epoch%cfreq) == 0:
triplenet_ckptman.save()
Hello,
Firstly, I want to extend my appreciation for the assistance provided thus far.
Following recent tests on both the training and test datasets, I noted an accuracy rate of 8%. While this is a step forward, it appears there might be underlying issues impacting the results.
Checkpoint Choosing: I choose the checkpoint located in LSTM_kmeam/experiment/bst_cpt/ (comprising 2420 epochs), rather than the one in the main folder's experiment/bst_cpt/ (with only 420 epochs). Can someone confirm if this is the correct approach?
Accuracy Calculation: I uncommented your code segment for calculating accuracy. However, the obtained results were not satisfactory. I would appreciate any insights on potential adjustments that could improve accuracy.
Thank you for your ongoing support and assistance in resolving these matters. (I checked other issues but it didn't help unfortunately)
Best regards, Nady
here is my code (I deleted all the unnecessary comments):
` if name == 'main': n_channels = 14 n_feat = 128 batch_size = 256 test_batch_size = 1 n_classes = 10
`