column_names = raw_datasets["train"].column_names
def tokenize_function(examples):
context = task.get_context(examples)
target = task.get_target(examples)
context = tokenizer(context)
target = tokenizer(target)
# if context is ending with special token, remove it
if len(context['input_ids'][0]) > 0 and context['input_ids'][0][-1] in tokenizer.all_special_ids:
print('1')
context['input_ids'] = [i[:-1] for i in context['input_ids']]
context['attention_mask'] = [a[:-1]
for a in context['attention_mask']]
# if target is starting with special token, remove it
if len(target['input_ids'][0]) > 0 and target['input_ids'][0][0] in tokenizer.all_special_ids:
print('2')
target['input_ids'] = [i[1:] for i in target['input_ids']]
target['attention_mask'] = [a[1:]
for a in target['attention_mask']]
'''
Do we need to add an outer loop for the special token removing codes, since "example" is a batch of samples?
In data.py, process_text2text_datasets
‘’‘python def process_text2text_datasets(raw_datasets, args, tokenizer, accelerator): task = task_dict[args.dataset_name]
'''
Do we need to add an outer loop for the special token removing codes, since "example" is a batch of samples?