mlfoundations / open_clip

An open source implementation of CLIP.
Other
9.15k stars 908 forks source link

Fix stopping_criteria result check in coca_model #860

Closed MengqingCao closed 2 weeks ago

MengqingCao commented 2 months ago

fix #847

The stopping criteria is updated in the latest transformers(V4.39.3 now). The return result is modified to a tensor (torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool) ) instead of a bool value, which causes the bug in #847

the related code in transformers https://github.com/huggingface/transformers/blob/0bd58f1ce0573c0e3269de4215a17d318add49b9/src/transformers/generation/stopping_criteria.py#L76

MengqingCao commented 2 months ago

cc @rwightman @rom1504 @gabrielilharco @bryant1410 @mitchellnw

rwightman commented 2 months ago

@MengqingCao hmm, yeah this needs fixing. One q though as I'm not intimately familiar with the gen code. Does the implementation here support multiple sentences in a batch? If yes, using any() will stop too early as the bool tensor is done state for each row (sentence) no? If it's only supporting one sentence at a time this is fine for now.. @gpucce ?

MengqingCao commented 2 months ago

@MengqingCao hmm, yeah this needs fixing. One q though as I'm not intimately familiar with the gen code. Does the implementation here support multiple sentences in a batch? If yes, using any() will stop too early as the bool tensor is done state for each row (sentence) no? If it's only supporting one sentence at a time this is fine for now.. @gpucce ?

Your concerns are right in the cases when users using StopStringCriteria and EosTokenCriteria, which I ignored before. I only noticed the default StoppingCriteria method MaxLengthCriteria before, which returns a boolTensor filled with one single bool value is_done. Thus, I think use any() brings bigger operating efficiency than all().

The related code in Transformers: image

To adapt to the situation of StopStringCriteria and EosTokenCriteria at the same time, I think we have two choices:

  1. change to use all() here
  2. checking if there is StopStringCriteria and EosTokenCriteria in stopping_criteria, if no, use any(), otherwise, use all(). This may run faster but bring more changes than 1
MengqingCao commented 1 month ago

@rwightman @gpucce , I have implemented option 2 and updated the code, give me some suggestions plz, thanks!