Currently, the generate() method doesn't seem to allow disabling sampling. The forward() method in the Sampler class performs greedy search if the temperatures argument is None but the GemmaForCausalLM's generate() method doesn't allow for setting the temperature argument to None because of this line -> https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L508. Also, setting the temperature to 0 fails with the following error RuntimeError: probability tensor contains either inf, nan or element < 0.
Currently, the
generate()
method doesn't seem to allow disabling sampling. Theforward()
method in theSampler
class performs greedy search if thetemperatures
argument isNone
but the GemmaForCausalLM'sgenerate()
method doesn't allow for setting thetemperature
argument toNone
because of this line -> https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L508. Also, setting thetemperature
to0
fails with the following errorRuntimeError: probability tensor contains either inf, nan or element < 0
.