rafiepour / CTran

Complete code for the proposed CNN-Transformer model for natural language understanding.
https://github.com/rafiepour/CTran
Apache License 2.0
23 stars 2 forks source link

Training Saturates before Epoch ~10 #5

Closed EthanEpp closed 1 month ago

EthanEpp commented 1 month ago

When training on either ATIS or SNIPs, my model seems to have no improvements beyond Epoch ~10, with training SlotFilling F1 abd IntentDet Prec reaching ~1.0. The test set hits ~.98 which I understand is quite close to the training, but the fast convergence makes me wonder if this is expected of the model, especially since the paper mentions that improvements are not observed beyond 50.

I believe I have set my learning rates and optimizers appropriately according to the paper with BERT large so I am just wondering if there is something I am missing, if I should add more methods of preventing overfitting, or if this behavior is expected.

Here are my training parameters in case:

BATCH_SIZE=16
LENGTH=60
STEP_SIZE=50
loss_function_1 = nn.CrossEntropyLoss(ignore_index=0)
loss_function_2 = nn.CrossEntropyLoss()
dec_optim = optim.AdamW(decoder.parameters(),lr=0.0001)
enc_optim = optim.AdamW(encoder.parameters(),lr=0.001)
ber_optim = optim.AdamW(bert_layer.parameters(),lr=0.00001)
mid_optim = optim.AdamW(middle.parameters(), lr=0.0001)
enc_scheduler = torch.optim.lr_scheduler.StepLR(enc_optim, 1, gamma=0.96)
dec_scheduler = torch.optim.lr_scheduler.StepLR(dec_optim, 1, gamma=0.96)
mid_scheduler = torch.optim.lr_scheduler.StepLR(mid_optim, 1, gamma=0.96)
ber_scheduler = torch.optim.lr_scheduler.StepLR(ber_optim, 1, gamma=0.96)

Thank you!

rafiepour commented 1 month ago

Hi Ethan, Yes this is expected of the model. Our model benefited from a phenomenon called "Deep Double Descent". The training accuracy remains almost the same throughout, but the test and development accuracy fluctuate. First, you'll see it rise to a local maximum, then it steadily decreases until around 20 or 30 epochs, and then subtly increases until it peaks. You could let it train for longer if you have the time. See what kind of results you get.

Very good question btw. Sadly we did not address this in our original paper. We did see a number of phenomenons in our tests but were not confident enough in their validity, hence we did not mention them.