churchlab / UniRep

UniRep model, usage, and examples.
338 stars 96 forks source link

Making get_rep() faster? #4

Closed spark157 closed 3 years ago

spark157 commented 5 years ago

Hello,

The get_rep() function in unirep.py has a comment:

        Unfortunately, this method accepts one sequence at a time and is as such quite
        slow.

I'm thinking about how to improve the speed of this as it would be nice to be able to generate representations for larger datasets to then incorporate into other top models. [Note: I see there is actually a _top_final_hidden variable that can feed into a top model. It would be nice to also have access to other representations to build other top models not bolted on top of the TF implementation.]

Initially I thought looping through a list of sequences but that didn't appear to provide any savings (I guessed that most of the time was being spent in initilizing the tensorflow session each time but I don't think this is the case).

Now I'm thinking maybe something to do with running batches but I'm not sure.

Any ideas/suggestions on how to make this run faster? It is a very useful function to have.

Thanks.

Scott

spark157 commented 5 years ago

I might have an idea gathered essentially from the tutorial: usebucket_batch_pad on a file of formatted strings (the ones you would like the representations for), and then pass that in as your batch to sess.run and ask for the _final_state back as in get.rep().

So kind of an amalgamation of:

bucket_op = b.bucket_batch_pad("formatted.txt", interval=1000)

and something like:

        with tf.Session() as sess:
            initialize_uninitialized(sess)

            batch = sess.run(bucket_op)
            lengths = _nonpad_len(batch)
            final_state_, hs, = sess.run([self._final_state, self._output], feed_dict={
                    self._batch_size_placeholder: batch_size,
                    self._minibatch_x_placeholder: batch,
                    self._seq_length_placeholder: lengths,
                    self._initial_state_placeholder: self._zero_state
                    }
            )

I'll try to give that a go to see if I can get something running locally. Comments welcome.

Scott

sandias42 commented 5 years ago

Hi Scott,

You should see some speedup switching to a GPU if you are still doing CPU-based inference. Otherwise, a faster implementation is non-trivial and something we are actively working on.

PR's for faster rep inference would be welcomed. Ethan

spark157 commented 5 years ago

One observation I did make when trying to understand line 639 of unirep.py:

self._top_final_hidden = tf.gather_nd(self._output, tf.stack([tf.range(tf_get_shape(self._output)[0], dtype=tf.int32), indices], axis=1))

is that if you could make the reference to indices a slice instead of simply the length of the sequence -1 you would be able to sum along that axis (if that makes any sense).

For example as I was playing around in Jupyter notebook to understand the line above I had:

np.stack([range(14),lengths],axis=1)

array([[  0, 248],
       [  1, 176],
       [  2, 148],
       [  3, 120],
       [  4, 265],
       [  5, 232],
       [  6,  87],
       [  7, 195],
       [  8, 239],
       [  9, 261],
       [ 10, 251],
       [ 11, 230],
       [ 12, 205],
       [ 13, 124]])

with lengths being the lengths of sequences in my batch of 14 using the seqs.txt/formatted.txt data. In focusing on length 120 I can generate the appropriate avg hidden with:

np.mean(hs[3,0:120,],axis=0) # == seq_rep[0] when putting through the seq for 120

With hs being the hidden state as per usual, 3 being the index for length 120 and then (here was the observation) taking the slice 0:120 gave the part of the tensor needed to properly calculate the mean.

From that observation I don't yet have a suggestion as to how to tweak something like line 639 to give give back the slice instead of just the value at indices. However, if you could figure that out then I think it would lead to batching in a similar way that self._top_final_hidden is derived.

Thought I would share if helpful.

Scott

smsaladi commented 4 years ago

By batching sequences of the same length, I've put together a faster implementation of get_rep. In my testing it matches having batch sizes of 1 using the original get_rep.

@sandias42 Do you see any issues with this implementation?

https://github.com/smsaladi/UniRep/blob/ccae3ed6a60a4651d700cd7c931a965d97472de5/run_inference.py#L68-L101

seq 8 | parallel "echo {}; cat test.faa" > test8.faa

time python run_inference.py --batch_size 1 --reverse test8.faa test-out.pkl
1997.99s user 387.24s system 3230% cpu 1:13.84 total

time python run_inference.py --batch_size 2 --reverse test8.faa test-out.pkl
1189.24s user 263.06s system 3013% cpu 48.200 total

time python run_inference.py --batch_size 4 --reverse test8.faa test-out.pkl
720.00s user 162.37s system 2682% cpu 32.894 total

time python run_inference.py --batch_size 8 --reverse test8.faa test-out.pkl
495.06s user 111.00s system 2387% cpu 25.384 total
sandias42 commented 4 years ago

Hi smsaladi,

Were these timings on CPU or GPU? Also, to confirm, you are timing full inference completion, rather than per-batch completion (hence the time with higher batch size)?

Can you also please provide hardware, python env and os details? Thanks, Ethan

smsaladi commented 4 years ago

These are CPU only (i7-6700, 32 gb ram), python 3.5, tensorflow 1.13.1. Right, full inference completion.

Anyhow, I timed it on Colab and have the notebook here, so you can play with it yourself: https://github.com/smsaladi/UniRep/blob/master/unirep_inference.ipynb

with GPU batch_size, total time (seconds) [[16, 26.899244273000022], [32, 13.71079222100002], [64, 8.502014164999991], [128, 6.426210394000009], [156, 4.381791538000016]]

CPU-only: batch_size, total time (seconds)

[[16, 546.6658579950001], [32, 311.33970561499973], [64, 215.62030503999995], [128, 162.09657374400012], [156, 135.92662289999998]]

Again, it's much faster with larger batch sizes.

edit: There are 156 sequences in the set, so the last is actually just that many.

mengqvist commented 4 years ago

By batching sequences of the same length, I've put together a faster implementation of get_rep. In my testing it matches having batch sizes of 1 using the original get_rep.

@sandias42 Do you see any issues with this implementation?

https://github.com/smsaladi/UniRep/blob/ccae3ed6a60a4651d700cd7c931a965d97472de5/run_inference.py#L68-L101

seq 8 | parallel "echo {}; cat test.faa" > test8.faa

time python run_inference.py --batch_size 1 --reverse test8.faa test-out.pkl
1997.99s user 387.24s system 3230% cpu 1:13.84 total

time python run_inference.py --batch_size 2 --reverse test8.faa test-out.pkl
1189.24s user 263.06s system 3013% cpu 48.200 total

time python run_inference.py --batch_size 4 --reverse test8.faa test-out.pkl
720.00s user 162.37s system 2682% cpu 32.894 total

time python run_inference.py --batch_size 8 --reverse test8.faa test-out.pkl
495.06s user 111.00s system 2387% cpu 25.384 total

This works great!

pkinn commented 4 years ago

smsaladi and mengqvist,

Did you check to see if your outputs from the batched version of get_rep are identical to the original version? I tweaked your code slightly and made it a function in the unirep class, and then wrote a separate function modeled off of your run_inference code. I opted for this option instead of overwriting the get_rep function so I could compare the two.

When I compare outputs from ~10 sequences acquired with the original get_rep and my new get_rep_batch they're different. Did you encounter this at all?

Code is below:

def get_rep_batch(self, seqs, sess):
        """
        Input a dataframe with sequence strings in a column labeled "seqs"
        """

        if isinstance(seqs, str):
            seqs = pd.Series([seqs])

        coded_seqs = [aa_seq_to_int(s) for s in seqs]
        n_seqs = len(coded_seqs)

        if n_seqs == self._batch_size:
            zero_batch = self._zero_state
        else:
            zero = self._zero_state[0]
            zero_batch = [zero[:n_seqs,:], zero[:n_seqs, :]]

        final_state_, hs = sess.run(
                [self._final_state, self._output], feed_dict={
                    self._batch_size_placeholder: n_seqs,
                    self._minibatch_x_placeholder: coded_seqs,
                    self._initial_state_placeholder: zero_batch
                })

        final_cell_all, final_hidden_all = final_state_
        avg_hidden = np.mean(hs, axis=1)
        #turn hs_list into a better formatted output for filling the dataframe
        hs_list = [hs[ii,:,:] for ii in range(hs.shape[0])]
        df = seqs.to_frame()
        df['seq'] = seqs
        df['final_hs'] = np_to_list(final_hidden_all)[:n_seqs]
        df['final_cell'] = np_to_list(final_cell_all)[:n_seqs]
        df['avg_hs'] = np_to_list(avg_hidden)[:n_seqs]
        df['hs'] = hs_list[:n_seqs]
        return df

And then wrote this function to format the dataframe properly based on smsaladi's run_inference function:

def run_inference_pck_1900(seq_series, out_fn, batch_size, reverse):
    # set up babbler object
    print('starting run_inference')
    b = unirep.babbler1900(batch_size=batch_size, model_path="./1900_weights")
    seqs = seq_series

    # read sequences
    # seqs = series_from_seqio(seq_fn, 'fasta')
    # seqs = seqs.str.rstrip('*')
    df_seqs = seqs.to_frame()

    # sort by length
    df_seqs['len'] = df_seqs['seq'].str.len()
    df_seqs.sort_values('len', inplace=True)
    df_seqs.reset_index(drop=True, inplace=True)

    df_seqs['grp'] = df_seqs.groupby('len')['len'].transform(lambda x: np.arange(np.size(x))) // batch_size
    print('length df_seqs = ' + str(len(df_seqs)))
    print('df_seqs.grp = ' + str(df_seqs['grp']))
    if reverse:
        seqs = seqs.str[::-1]

    # set up tf session, then run inference
    print('starting tf session')
    with tf.Session() as sess:
        unirep.initialize_uninitialized(sess)
        df_calc = df_seqs.groupby(['grp', 'len'], as_index=False, sort=False).apply(lambda d: b.get_rep_batch(d['seq'], sess=sess))

    df_calc.to_pickle(out_fn)
    return df_seqs, df_calc
mengqvist commented 4 years ago

@pkinn

The results are the same to a certain level of precision, about 1e-6 for the sequences I tested. I don't know why they are not exactly the same.

I've made an installable implementation here: https://github.com/EngqvistLab/UniRep50

Check out validate_script.ipynb for the comparison.