kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

top-k sampling off by 1 bug #164

Closed mar-muel closed 2 years ago

mar-muel commented 2 years ago

Hi there

I think top-k sampling is currently incorrect. Using the test block at the end of sampling.py with top_k=1:

if __name__ == "__main__":
    import numpy as np
    logits = np.array([[-2, -1, 0, 0.8, 0, 0.1, 0.3, 0.4, 0.5, 0.6, 0.7, -3]])
    print(nucleaus_filter(logits, top_k=1))

gives me [[-1.00000000e+10 -2.00000000e+10 -2.00000000e+10 8.00000012e-01 -4.00000000e+10 -4.99999990e+10 -6.00000020e+10 -7.00000010e+10 -8.00000000e+10 -8.99999990e+10 6.99999988e-01 -1.19999996e+11]]

(note the 2 non-negative values).

A simple fix would be

-        sorted_indices_to_remove = jnp.where(indices_range > top_k, sorted_indices, 0)
+        sorted_indices_to_remove = jnp.where(indices_range >= top_k, sorted_indices, 0)
kingoflolz commented 2 years ago

Thanks! fixed in d2c2f59