you use the train_max_seq_length to pad all the sentences to a fixed length, not accoding to every batch's max_len, In every batch:
you use
def collate_fn(batch):
"""
batch should be a list of (sequence, target, length) tuples...
Returns a padded tensor of sequences sorted from longest to shortest,
"""
all_input_ids, all_attention_mask, all_token_type_ids, all_lens, all_labels = map(torch.stack, zip(*batch))
max_len = max(all_lens).item()
all_input_ids = all_input_ids[:, :max_len]
all_attention_mask = all_attention_mask[:, :max_len]
all_token_type_ids = all_token_type_ids[:, :max_len]
all_labels = all_labels[:, :max_len]
# here only truncate, no padding.
return all_input_ids, all_attention_mask, all_token_type_ids, all_labels, all_lens
here, only truncate, not padding.
In my opinion, we should not padding all the sentence to a fixed length, but we can pad every batch's sentences to a fixed length accoding the batch's max_length, this doing can avoid truncate operation, of course, the batch's max_lengtch maybe not longer than 512.
you use the train_max_seq_length to pad all the sentences to a fixed length, not accoding to every batch's max_len, In every batch: you use
here, only truncate, not padding.
In my opinion, we should not padding all the sentence to a fixed length, but we can pad every batch's sentences to a fixed length accoding the batch's max_length, this doing can avoid truncate operation, of course, the batch's max_lengtch maybe not longer than 512.