agemagician / ProtTrans

ProtTrans is providing state of the art pretrained language models for proteins. ProtTrans was trained on thousands of GPUs from Summit and hundreds of Google TPUs using Transformers Models.
Academic Free License v3.0
1.13k stars 153 forks source link

Sanity check of inputs #115

Closed BSharmi closed 1 year ago

BSharmi commented 1 year ago

Hi @mheinzinger!

Continuing from https://github.com/agemagician/ProtTrans/issues/113 I think I am almost there with the preprocessing with re-using https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py#L337 with slight modifications (span mask set to 1.0 being one of them)

Just have a few remaining questions I want to clarify

  1. In group_texts function https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py#L688 the first key in examples is input_text which has a different length (by 2) from the input_ids key due to special tokens. This is expected? We want the total_length to be same as input text rather than input ids?
  2. Where do they actually drop a sample from a batch which does not have the max length? https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py#L694. There is always a possibility that the last sample will have different length. In that case at what stage does it gets dropped in batch? For e.g in my example 3 of my samples had 512 and the last one 181. after group_text is applied. Unless I specify a batch size of 3 it is there. Seems hand wavy
  3. Related to above https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py#L348 inputs always has </s> at the end so should it not be
  4. if batch["input_ids"].shape[-1] != self.input_length + 1: rather than if batch["input_ids"].shape[-1] != self.input_length:?

For e.g. my final inputs look like

raw sequence:

MGKGVGRDKYEPAAVSEQGDKKGKKGKKDRDMDELKKEVSMDDHKLSLDELHRKYGTDLSRMEEIQRYLQPDRSQQHNFLYPLIFQEYIYALAHDHGLNRNRSILLENPGYNNKLSFLIVKRLITRMYQQNHFLISTNDSNKNSFLGCNKSLYSQMISEGFAFIVEIPFSLRLISSLSSFEGKKIFKSYNLRSIHSTFPFLEDNFSHLNYVLDILIPYPVHLEILVQTLRYWVKDASSLHLLRFFLHEFWNLNSLITSKKPGYSFSKKNQRFFFFLYNSYVYECESTFVFLRNQSSHLRSTSFGALLERIYFYGKIERLVEVFAKDFQVTLWLFKDPFMHYVRYQGKSILASKGTFLLMNKWKFYLVNFWQCHCSLCFHTGRIHINQLSNHSRDFMGYLSSVRLNPSMVRSQMLENSFLINNAIKKFDTLVPIIPLIGSLAKANFCTVLGHPISKPVWSDLSDSDIIDRFGRICRNLFHYYSGSSKKKTLYRIKYILRLSCARTLARKHKST

inputs:

M G K G V<extra_id_0> R D<extra_id_1> Y E P A A V S E Q<extra_id_2> D K K G K K<extra_id_3> K K D R D M D E<extra_id_4> K K E<extra_id_5> S<extra_id_6> D D H K L S<extra_id_7> D E L H R K Y G T D L S R</s> <extra_id_8> E<extra_id_9> I Q R Y L Q P D<extra_id_10> S Q Q<extra_id_11> N F L Y<extra_id_12> L<extra_id_13> F Q E Y I Y A L A H D H<extra_id_14> L N R N R S I L L E N P G Y N N K L S F<extra_id_15> I V K R L I T R<extra_id_16> Y Q<extra_id_17> N H F L I S T N D<extra_id_18> N K N S F L<extra_id_19> C N K S L<extra_id_20> S Q<extra_id_21> I S E G F<extra_id_22> F<extra_id_23> V E I P<extra_id_24> S L R L I S S L S S F E G K<extra_id_25> I<extra_id_26> K S Y<extra_id_27> L R<extra_id_28> I H S T F P F L E D N F S H L N Y V L<extra_id_29> I<extra_id_30> I<extra_id_31> Y P V H L<extra_id_32> I L V Q T L<extra_id_33> Y W V K D<extra_id_34> S S L H L L<extra_id_35> F F L H E F W N L N S L I<extra_id_36> S K K P<extra_id_37> Y S F<extra_id_38> K K N Q R<extra_id_39> F F F<extra_id_40> Y N S Y V<extra_id_41> E C E S T F V F L R N Q S S H L R S T S F G A L L E<extra_id_42> I<extra_id_43> F Y G K<extra_id_44> E R<extra_id_45> V E V F A K<extra_id_46> F Q V<extra_id_47> L W L F<extra_id_48> D P F<extra_id_49> H<extra_id_50> V R<extra_id_51> Q G K S<extra_id_52> L A S<extra_id_53> G T F L L M<extra_id_54> K W K F Y L V N F W Q C H C S<extra_id_55> C F<extra_id_56> T G R I<extra_id_57> I N Q L S<extra_id_58> H<extra_id_59> R D F M G Y L<extra_id_60> S V R L N P<extra_id_61> M V R S Q M L E N S F L I N N A I K K<extra_id_62> D T L V P I I<extra_id_63> L I G S L<extra_id_64> K A N F C<extra_id_65> V L G H P I S<extra_id_66> P<extra_id_67> W S<extra_id_68> L S<extra_id_69> S<extra_id_70> I I D R F G R I C<extra_id_71> N L F H Y Y S G S S K<extra_id_72> K T L Y R I K Y I<extra_id_73> R<extra_id_74> S C A R T<extra_id_75> A R K H K<extra_id_76> </s>

and labels

<extra_id_0> G<extra_id_1> K<extra_id_2> G<extra_id_3> G<extra_id_4> L<extra_id_5> V<extra_id_6> M<extra_id_7> L<extra_id_8> M<extra_id_9> E<extra_id_10> R<extra_id_11> H<extra_id_12> P<extra_id_13> I<extra_id_14> G<extra_id_15> L<extra_id_16> M<extra_id_17> Q<extra_id_18> S<extra_id_19> G<extra_id_20> Y<extra_id_21> M<extra_id_22> A<extra_id_23> I<extra_id_24> F<extra_id_25> K<extra_id_26> F<extra_id_27> N<extra_id_28> S<extra_id_29> D<extra_id_30> L<extra_id_31> P<extra_id_32> E<extra_id_33> R<extra_id_34> A<extra_id_35> R<extra_id_36> T<extra_id_37> G<extra_id_38> S<extra_id_39> F<extra_id_40> L<extra_id_41> Y<extra_id_42> R<extra_id_43> Y<extra_id_44> I<extra_id_45> L<extra_id_46> D<extra_id_47> T<extra_id_48> K<extra_id_49> M<extra_id_50> Y<extra_id_51> Y<extra_id_52> I<extra_id_53> K<extra_id_54> N<extra_id_55> L<extra_id_56> H<extra_id_57> H<extra_id_58> N<extra_id_59> S<extra_id_60> S<extra_id_61> S<extra_id_62> F<extra_id_63> P<extra_id_64> A<extra_id_65> T<extra_id_66> K<extra_id_67> V<extra_id_68> D<extra_id_69> D<extra_id_70> D<extra_id_71> R<extra_id_72> K<extra_id_73> L<extra_id_74> L<extra_id_75> L<extra_id_76> S</s>

Assertion on targets length match.

Thanks so much for your help! I am just hoping to start training soon!

mheinzinger commented 1 year ago

Hi!

  1. Yes, I expect the input_ids to be longer than the input_text. For each sequence, a special token gets appended to the end. So if you observe a difference of 2 in length, I guess you grouped two sequences together so that there are two additional special-tokens in the input_ids. Total_length should give you the number of samples after concatenation. So this is less than the number of sequences in your FASTA because multiple sequences got concatenated to a single input.
  2. The dropping of the last, potentially shorter sample happens here: https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py#L695 . This works because in the lines before, total_length got adjusted such that it is rounded down: https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py#L692 . Batch-size should not affect this. Or I am misunderstanding the logic here; let me know if that's the case.
  3. The if-statement compares the input IDs after the pre-processing (so after tokenization, concatenation, splitting, span-denoising, etc) against the parameter you set for the upper bound of sample-length. This already factors in appended special tokens.

I hope this clarified some things. If something above is wrong, feel free to correct me (also not 100% sure on some of the things)

BSharmi commented 1 year ago

Thank you tons for the detailed response! At this point my code runs fine if I want to drop part of sequence and model training is good. But I was getting ambitious and tried padding, attention masks, and it works fine until it gets to collator but then gets some error due to the assertions. I think I might just go ahead without padding and drop the fragments unless I figure it out.

Thanks again for your help

mheinzinger commented 1 year ago

Update: https://github.com/agemagician/ProtTrans/issues/137#issuecomment-1817576165