Closed nzhang89 closed 3 months ago
Hello @nzhang89,
The indexing in github is 1-based whereas the indexing in the hidden_states
array of HuggingFace is 0-based. This means that in order to get the same embedding than github, you need to extract torch_outs['hidden_states'][19]
. That is why the embeddings are slightly different, they are extracted 1 layer apart !
Best regards, Hugo
Hi @dallatt,
Thank you for your quick response. I tried with torch_outs['hidden_states'][19]
but got completely different results.
>>> torch_outs['hidden_states'][20][0][0:4]
tensor([[ 0.5936, 7.0626, 8.0543, ..., -12.8166, -8.5917, 3.5761],
[ -3.8114, -3.4322, -3.4315, ..., 8.1701, -3.5505, -1.0878],
[ -4.4262, -3.1080, -2.7096, ..., 8.8625, -3.9117, -1.6279],
[ -1.6622, -4.6036, -2.1546, ..., 8.9295, -2.1016, -0.5010]])
>>> torch_outs['hidden_states'][19][0][0:4]
tensor([[ 2.0145, 6.4966, 7.4038, ..., -14.8354, -7.2971, 1.8059],
[ -3.0252, 0.5187, -0.5540, ..., 1.1396, -3.8698, 2.3325],
[ -3.7960, 0.3949, -0.9983, ..., 2.2104, -4.1310, 1.4081],
[ -1.7485, -1.2718, 0.1990, ..., 2.0984, -2.3466, 2.3011]])
The torch_outs['hidden_states']
has length of 25. I guess the first one might be the output after the embedding layer, which is consistent with original ESM model setting.
Thank you for your excellent work on releasing the nucleotide transformer model. I am trying to use it for a downstream regression task. Since my downstream model is written in pytorch, I am considering the huggingface models.
However, I found that the embeddings from the jax model and the huggingface model seem to be a little bit different.
Here is an example.
The jax model
The embedding for the first sequence (first 4 tokens)
The pytorch model
The embedding for the first sequence (first 4 tokens)
The embeddings are a little bit different. Is this expected? Do you know the possible reason that might cause the difference? Thank you so much for your help!