ElArkk / jax-unirep

Reimplementation of the UniRep protein featurization model.
GNU General Public License v3.0
104 stars 31 forks source link

Why is the embedding sequence one longer than the protein sequence? #85

Closed konstin closed 3 years ago

konstin commented 4 years ago

I'm trying to use the per-residue embeddings of unirep and have obeserved that the hidden state sequence is one longer than the number of amino acids in the sequence. Judging from the test_mLSTM1900 test, this is intended. Could you tell me why there is one extra hidden state?

For reference, this is the minimized code I'm using (full source):

from jax import vmap, partial
from jax_unirep.featurize import apply_fun
from jax_unirep.utils import get_embeddings, load_params_1900
embedded_seqs = get_embeddings([sequence])
h_final, c_final, h = vmap(partial(apply_fun, load_params_1900()))(embedded_seqs)
return h[0] # len(h[0]) == len(sequence) + 1
ElArkk commented 4 years ago

Hi @konstin ,

I'd say it's because we add a "start" amino-acid (aa) when integer-encoding the protein sequence before repping (actually, true to the original implementation, a "start" and a "stop" aa get added to either end of the sequence, but when using get_reps the "stop" aa gets removed again).

The exact function in jax-unirep where this happens is get_embedding in utils.py. In the original implementation it's in get_rep in unirep.py.

ElArkk commented 4 years ago

See also https://github.com/ElArkk/jax-unirep/issues/55

konstin commented 3 years ago

Thanks!