Shivanandroy / simpleT5

simpleT5 is built on top of PyTorch-lightning⚡️ and Transformers🤗 that lets you quickly train your T5 models.
MIT License
382 stars 61 forks source link

about labels and decoder_attention_mask #59

Open gitfor20 opened 8 months ago

gitfor20 commented 8 months ago

target_text_encoding = self.tokenizer( data_row["target_text"], max_length=self.target_max_token_len, padding="max_length", truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", )

    labels = target_text_encoding["input_ids"]
    labels[
        labels == 0
    ] = -100  # to make sure we have correct labels for T5 text generation

    return dict(
        source_text_input_ids=source_text_encoding["input_ids"].flatten(),
        source_text_attention_mask=source_text_encoding["attention_mask"].flatten(),
        labels=labels.flatten(),
        labels_attention_mask=target_text_encoding["attention_mask"].flatten(),
    )

as i know, the decoder_input_ids is default to be got by shifting labels, but at these codes, the decoder_attention_mask is matched to labels. so i think the decoder_input_ids prepared by models will not be matched to decoder_attention_mask.
is it a bug or my understanding is wrong?