Open ndvbd opened 6 years ago
Okay, I found the solution.
First, to export the model with:
--decode_hparams="beam_size=4,alpha=0.6,write_beam_scores=True,return_beams=True"
Then, in query.py:
shapeOfReturnedBeams = response.outputs["outputs"].tensor_shape.dim
shapeOfReturnedBeams = [int(i.size) for i in shapeOfReturnedBeams]
if len(shapeOfReturnedBeams) == 3:
beams = np.reshape(response.outputs["outputs"].int_val, shapeOfReturnedBeams[1:])
elif len(shapeOfReturnedBeams) == 2:
beams = np.reshape(response.outputs["outputs"].int_val, shapeOfReturnedBeams)
listOfBeamScores = response.outputs["scores"].float_val
It works, but for some reason when specifying return_beams=True
it is x5 slower (the prediction). I don't think it should be slower, unless I am missing something.
I have no experience with T2T serving, the following experiments are done simply with t2t-decoder
, translating a batch of 3000 sentences, each experiment is performed twice and an average time is reported.
params | decode time |
---|---|
beam_size=1 |
4m25s |
beam_size=2 |
2m40s |
beam_size=2,return_beams=True |
5m50s |
beam_size=4 |
4m10s |
beam_size=4,return_beams=True |
9m10s |
beam_size=12 |
14m50s |
beam_size=12,return_beams=True |
22m30s |
So I confirm returning beams (n-best list) is slower, but I think the slowdown could be expected (the slow down is more than double for beam 2 and 4, but less than double for beam 12 - this could be expected: the total decode time contains loading the model and running encoder, which is constant for each of the experiments).
What is strange that beam_size=1
(i.e. greedy decoding) is slower than beam 2 or even 4 (if someone wants to discuss this, please open a new issue).
Another thing that is strange, is that when you are doing prediction, when you don't specify return_beams=True
and set for example the beam_size=4
, you still get the score (log prob) of each of the 4 beams, just without the actual tokens. When you set return_beams=True
, you get back the exact same scores for the 4 beams, with the tokens. If in both cases the 4 beams had to be calculated anyhow, why there is a major time difference between return_beams=True
and return_beams=False
?
+1
I am also quite interested in this major time difference, perhaps the name of the issue could be changed to reflected this slightly different problem (or even a new issue opened?)
Can you offer a new name so I'll change it?
E.g. Multiple beam results slowdown: "return_beams=True" increases decode time
Have investigated this, may have found the cause:
return_beams=True
causes top_beams
to be set to a number greater than 1 e.g. on this line
When top_beams
is greater than 1, this turns off the stop_early
clause e.g. here and also here
With the stop_early
clause off, the beams keep computing until you hit the maximum decode length, instead of stopping early when all active beams hit an EOS token. See this line for the impact of the stop_early
clause being set to False.
Also: majorly confusing misnomer: _is_finished()
function should be called _is_not_finished()
here, it's the cond
parameter of a tf.while()
loop
@rsepassi you seem to be the original author of beam_search.py, any thoughts on this issue? I'm likely going to investigate further and submit something, but curious to know if I seem on the right track - thanks!
Hi @walmsley, thanks for investigating. Yes, you seem to be on the right track. What do you think the right fix is that wouldn't break the current logic? First thought is that you're right that you can exit once all beams have hit EOS, but it'd be good if you could justify that in the context of the stopping logic already there too. Thanks for digging in!
This is the line of code that implements that behavior in the greedy case.
@rsepassi / @walmsley : is it enough to just check we have hit EOS for all the beams? (still have to check the code in detail, so, I may be wrong)
Asking because, if we still have alive beams, I suppose they can still be 'extended' and maybe improve their score to a point where they enter in the first N? (because of the length_penalty bonus)
Or maybe when we reach EOS for all beams, it is safe to assume that there is no alive beam around? (in that case, I suppose hitting EOS is enough)
In short, it seems to me that the method https://github.com/tensorflow/tensor2tensor/blob/40758df26f92cdaa20869ef9d470da997ad89557/tensor2tensor/utils/beam_search.py#L469
(in the case stop_early = True) should be modified so that instead of stopping when the first beam is 'unreachable' by the other alive beams, it will stop when the last of the best N beams is unreachable.
Hi all, I took some time to look at the code. I think indeed we need two conditions to happen: EOS is reached and the top N beams cannot be beaten by the other alive beams. Note that the second condition basically includes the first one - given that a beam that has not reached EOS will have -INF score (thus, it is always beatable).
I changed the code and indeed I am getting the same results as before (when annotating) without the delay mentioned above.
PR is here: https://github.com/tensorflow/tensor2tensor/pull/780
Any news regarding the PR?
I just re-based my PR (someone fixed the same test I fixed, which resulted in a conflict) - it should be ok now (still waiting for the check to be performed though)..
not sure how it works with PR (it's my first one) - should I alert someone? or just wait?
Thanks, Mirko
I don't know.. Anybody home?
had some CLA problems - re-opened it here: https://github.com/tensorflow/tensor2tensor/pull/965
note that the travis tests are failing - similarly to the other recent PRs. (I do not think it is a problem with the code changed in this PR - given it was passing before)
Currently, the T2T serving (export.py + query.py) returns one result per query.
How would you go about changing it to returns all the beam results? For example, if the beam is 4, return the 4 results, together with their log probabilities, in a similar way to the HParam write_beam_scores=True which is used in t2t_decoder.py
I assume the change should be both in export.py + query.py. My question is what should be changed to support it.
Thanks.