westlake-repl / SaProt

[ICLR'24 spotlight] Saprot: Protein Language Model with Structural Alphabet
MIT License
271 stars 25 forks source link

Additional input values generated by the tokenizer #15

Closed tpritsky closed 4 months ago

tpritsky commented 4 months ago

Thanks for this impressive solution!

When I run the tokenizer on an input sequence, there are always two additional elements added to the tokenizer output. Why is this and what do the values represent?

For example: print(len(sequence)) -> 5 inputs = tokenizer(sequence, return_tensors="pt") print(inputs_1['input_ids'].size()) -> torch.Size([1, 7])

Additionally, I'm trying to generate a fixed length sequence embedding. I saw you answered how to do this with the ESM model, but is there a way to do so with the huggingface model?

Thanks for your help!

LTEnjoy commented 4 months ago

Hi!

For the first question, the tokenizer will add the start token and the end token by default when you tokenize a sequence. If you convert the input ids back into tokens, the list would be like: [" \ ", "xxx", "xxx", "\ "]

In esm models, these tokens usually represent nothing but align with the training strategy of bert. Also, the embedding of the "\" token is usually further fine-tuned for downstream tasks.

For the second question, actually the way I use for generating fixed length embeddings is based on huggingface model. If you check the function, you will find I just add a keyword argument to the inputs: image

So you can adjust the function to obtain embeddings from other models.

Hope the answer above could resolve you questions and let me know if you have any other questions!

tpritsky commented 4 months ago

Thanks for your answer! This makes sense. As a sanity check, the embedding vector dimension is 480?

LTEnjoy commented 4 months ago

Thanks for your answer! This makes sense. As a sanity check, the embedding vector dimension is 480?

For 35M model, the embedding vector dimension is 480 and for 650M model, the dimension is 1280 :)

tpritsky commented 4 months ago

Thanks! One other point of feedback, I was facing out of memory errors (on a A100 GPU) until I added 'with torch.no_grad():' to the get_hidden_states function. I'm sure you considered this, but adding it here in case it helps anyone :)

image
LTEnjoy commented 4 months ago

Thank you for pointing that out! We did consider this problem. Actually we suggest you use 'with torch.no_grad()' manually instead of writing it in the function, e.g., like:

with torch.no_grad():
    model.get_hidden_states()

This enables a more flexible way to handle with different situations, i.e. whether or not you need the resulted embeddings to be used to calculate gradients.