# need to pad attention_mask and input_ids to be full seq_len + n_learned_tokens
# even though it does not matter what you pad input_ids with, it's just to make HF happy
inputs['input_ids'] = torch.cat([torch.full((1,n_tokens), 50256), inputs['input_ids']], 1)
inputs['attention_mask'] = torch.cat([torch.full((1,n_tokens), 1), inputs['attention_mask']], 1)
Everything seems to be working fine without adding those for transformers==4.23.1 in the sense that I am not getting any error. Are they needed for something else though?
I wonder if the following lines are redundant.
Everything seems to be working fine without adding those for
transformers==4.23.1
in the sense that I am not getting any error. Are they needed for something else though?