huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.63k stars 26.2k forks source link

[flax_llama] Why is the return value of the `create_sinusoidal_positions` truncated by `num_pos`? #29590

Open mantle2048 opened 5 months ago

mantle2048 commented 5 months ago

System Info

Who can help?

No response

Information

Tasks

Reproduction

I noticed that in the implementation of flax-llama, the return value of create_sinusoidal_positions was truncated by num_pos.

def create_sinusoidal_positions(num_pos, dim):
    inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
    freqs = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")

    emb = np.concatenate((freqs, freqs), axis=-1)
    out = np.concatenate((np.sin(emb)[:, None, :], np.cos(emb)[:, None, :]), axis=-1)
    return jnp.array(out[:, :, :num_pos]) # Why?

Why is this? Is it to prevent the sequence length from being too short?

Expected behavior

See above.

amyeroberts commented 5 months ago

cc @sanchit-gandhi

amyeroberts commented 4 months ago

Gentle ping @sanchit-gandhi

amyeroberts commented 3 months ago

Another ping @sanchit-gandhi

amyeroberts commented 1 month ago

Another another ping @sanchit-gandhi. Could you nominate someone to take over this whilst you're away?