ElArkk / jax-unirep

Reimplementation of the UniRep protein featurization model.
GNU General Public License v3.0
104 stars 31 forks source link

Output hidden state along sequence #103

Closed golin010 closed 3 years ago

golin010 commented 3 years ago

Hello,

Thank you for the well documented code for UniRep and evotuning!

I'm trying to find the most robust way to output the hidden state vector along the sequence. For example, if my protein is 50 AA long, the output size would be (50,1900) where the h_avg output would be the average along axis=0.

Thank you in advance!

ericmjl commented 3 years ago

Hello @golin010! This is something not officially part of the API, but it is possible to assemble from the source. Going off from memory (i.e. I haven't yet tested the following ideas), I think you'd probably want to try the following:

from jax_unirep.layers import mLSTM
from jax import vmap
from functools import partial
from jax_unirep.utils import (
    batch_sequences,
    get_embeddings,
    load_params,
    validate_mLSTM_params,
)

seqs = ... # put in a list of your own sequences, make sure they are all of the same length
params = load_params()[1]
 _, apply_fun = mLSTM(output_dim=1900)
embedded_seqs = get_embeddings(seqs)
h_final, c_final, h = vmap(partial(apply_fun, params))(embedded_seqs)

# `h` is the thing you'll want!

I mostly pieced those together by looking at featurize.py, there might be a missing input or variable or so as I haven't tried the code myself, but you probably could figure out what's missing at that point. If you need to deal with variable lengths, you could see what's happening in the rep_arbitrary_lengths function to make that happen!

Can I ask, what's your use case for returning the hidden state along the sequence? Is it something you see as pretty common? If so, I'm sure we'd be open to hashing out the sanest way of handling the return shapes properly and then working through it in a PR.

golin010 commented 3 years ago

Thank you for the quick reply!

So the code runs: but as stated your h_final and c_final have shape [seq_len, 1900] and h has shape [seq_len, 2, 1900].

When I try to compare the seq_length axis average of h_final, c_final, h[:, 0, :], or h[:, 1, :] to the havg, , _=jax_unirep.get_reps() output, none of them appear to match the h_avg for the same sequence. Could I be missing something?

Also- we currently are studying a protein library where only a few positions of the sequence are mutated. I'm not sure how common it is, but we wanted to see if the embedding averaged over just those mutated positions will increase our predictive accuracy versus the embedding averaged over the whole protein.

ElArkk commented 3 years ago

@golin010 , if you use @ericmjl code above, make sure you put your sequences into a list, even if it is a single sequence (seqs = ["EXAMPLE"]). Else, the mLSTM will process each amino-acid of your sequence individually, hence why you get as many final hidden & cell states (h_final, c_final) as there are amino-acids in your sequence.

h will always have the shape (n_sequences, n_aa + 1, mlstm_size), with the extra hidden state being the initial state of the mLSTM (IIRC).

golin010 commented 3 years ago

Thank you both for the help! The outputs now match.