google / gemma_pytorch

The official PyTorch implementation of Google's Gemma models
https://ai.google.dev/gemma
Apache License 2.0
5.26k stars 503 forks source link

early stop when all sequence reach EOS #57

Open je1lee opened 6 months ago

je1lee commented 6 months ago

With model.generate() it takes too long even sequence generation have done earlier with EOS token. Because now, it generates til it reached to output_len

fix the generate method to stop when every sequence has generated EOS token

je1lee commented 5 months ago

@pengchongjin any idea for this?

pengchongjin commented 4 months ago

Thanks for the change. Could you please paste a few example outputs before and after this change?

Also please make sure to test both run.py and run_xla.py. Thanks!

je1lee commented 4 months ago

@pengchongjin test done with both scripts

BEFORE

스크린샷 2024-06-03 오후 2 59 25

model generates token regardless of eos token, so time spent in generation increases quadratically as output_len increases

AFTER

스크린샷 2024-06-03 오후 2 58 04

model stop generate when model samples out eos token time spent in generation remain still as output_len increases