Closed MengqingCao closed 2 weeks ago
cc @rwightman @rom1504 @gabrielilharco @bryant1410 @mitchellnw
@MengqingCao hmm, yeah this needs fixing. One q though as I'm not intimately familiar with the gen code. Does the implementation here support multiple sentences in a batch? If yes, using any()
will stop too early as the bool tensor is done state for each row (sentence) no? If it's only supporting one sentence at a time this is fine for now..
@gpucce ?
@MengqingCao hmm, yeah this needs fixing. One q though as I'm not intimately familiar with the gen code. Does the implementation here support multiple sentences in a batch? If yes, using
any()
will stop too early as the bool tensor is done state for each row (sentence) no? If it's only supporting one sentence at a time this is fine for now.. @gpucce ?
Your concerns are right in the cases when users using StopStringCriteria
and EosTokenCriteria
, which I ignored before. I only noticed the default StoppingCriteria
method MaxLengthCriteria
before, which returns a boolTensor filled with one single bool value is_done
. Thus, I think use any()
brings bigger operating efficiency than all()
.
The related code in Transformers:
To adapt to the situation of StopStringCriteria
and EosTokenCriteria
at the same time, I think we have two choices:
all()
hereStopStringCriteria
and EosTokenCriteria
in stopping_criteria
, if no, use any()
, otherwise, use all()
. This may run faster but bring more changes than 1@rwightman @gpucce , I have implemented option 2 and updated the code, give me some suggestions plz, thanks!
fix #847
The stopping criteria is updated in the latest
transformers
(V4.39.3 now). The return result is modified to a tensor (torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
) instead of a bool value, which causes the bug in #847the related code in
transformers
https://github.com/huggingface/transformers/blob/0bd58f1ce0573c0e3269de4215a17d318add49b9/src/transformers/generation/stopping_criteria.py#L76