Open wuguangshuo opened 7 months ago
def __call__(self, batch): lengths = [len(instance["input_ids"]) for instance in batch] batch_max_len = max(lengths) input_ids_batch, labels_batch = [], [] for instance in batch: input_ids = instance["input_ids"] labels = instance["labels"] padding_len = batch_max_len - len(input_ids) input_ids = input_ids + [self.pad_token_id] * padding_len labels = labels + [-100] * padding_len input_ids_batch.append(input_ids) labels_batch.append(labels) input_ids不应该在左面填充pad吗