Closed adamlerer closed 2 years ago
Hey Adam, I reviewed this and I don't see your point. With the README provided example we get:
(Pdb) token_probs.shape
torch.Size([1, 400, 264, 33])
aka [batch size always 1, MSA_depth, seq_len, vocab_size]
The indexing on L163 token_probs[:, 0, i]
grabs the logprobs of the first sequence of the MSA, position i in [0, seq_len[
which is where we had put the mask token on L158.
Do you agree?
@tomsercu sorry for the delay, here's the context:
Typically, MSA FASTA files include insertions from the source sequence, so e.g. for a query sequence ABCD, the aligned fasta MSA could be
.AB..C.DEF
QArxCeD-F
...
Thus, the position of A is 0 in the query sequence but 1 in the MSA.
However, MSA transformer is expecting MSAs with no insertions in the source sequence, so IIRC it should actually be:
ABCDEF
ArCD-F
...
Therefore you're right that there's no bug here per se. It would be nice to add an assert in read_msa
that the input MSA is of the right format (i.e. same length as the input sequence, and contains only alphanumeric plus dashes, and first row equals the source sequence).
Longer term, even better would be to add utilities that produce the MSA for the user in a way that's consistent with how it was trained.
Got it thx. yes the tooling around msa reading is too minimal right now. Let me add this assert
Bug description
I was running MSA transformer for variant prediction, and it looks like there’s a bug. It’s indexing
token_probs
based on the index of the mutation within the sequence, buttoken_probs
islen(MSA)
, notlen(seq)
, so I think the indexes are incorrect.Can anyone confirm if this is a bug or my misunderstanding? https://github.com/facebookresearch/esm/blob/main/variant-prediction/predict.py#L154-L169
If it is a bug, I have a fix I can send as a PR.