google / gemma_pytorch

The official PyTorch implementation of Google's Gemma models
https://ai.google.dev/gemma
Apache License 2.0
5.19k stars 492 forks source link

Can't disable sampling #42

Closed joselpart closed 4 months ago

joselpart commented 4 months ago

Currently, the generate() method doesn't seem to allow disabling sampling. The forward() method in the Sampler class performs greedy search if the temperatures argument is None but the GemmaForCausalLM's generate() method doesn't allow for setting the temperature argument to None because of this line -> https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L508. Also, setting the temperature to 0 fails with the following error RuntimeError: probability tensor contains either inf, nan or element < 0.