jayelm / gisting

Learning to Compress Prompts with Gist Tokens - https://arxiv.org/abs/2304.08467
Apache License 2.0
268 stars 24 forks source link

why there's a parameter "offset"? #10

Closed pengfeiwu1999 closed 1 year ago

pengfeiwu1999 commented 1 year ago

https://github.com/jayelm/gisting/blob/acd78b49111db30c3a24c32f625b85ae59934585/src/gist_llama.py#L206

the apply_rotary_pos_emb() function does not accept the offset argument?

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):

The first two dimensions of cos and sin are always 1, so we can squeeze them.

cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
jayelm commented 1 year ago

The gist_offset parameter is needed because, when caching an instruction, the model needs to know the length of the cached instruction to apply the position embeddings correctly. (It's not needed for T5 due to T5's relative position embedding scheme). You can see how the position embeddings are shifted here:

https://github.com/jayelm/gisting/blob/acd78b49111db30c3a24c32f625b85ae59934585/src/gist_llama.py#L114-L125

apply_rotary_pos_emb does not require the gist offset argument because it transforms the cos and sin tensors which already have the offset applied.

Note the offset parameter is not used during standard training or evaluation, because we don't actually modify any sequence lengths—the model gets the entire instruction/input/output in one go, with attention masking used to control compression, and the position embeddings are correctly applied with the original instruction length.

The offset parameter is only used in compress.py—the GistActivations class accepts a "gist offset" argument which records the length of the instruction before caching:

https://github.com/jayelm/gisting/blob/acd78b49111db30c3a24c32f625b85ae59934585/src/gist_llama.py#L633-L656

https://github.com/jayelm/gisting/blob/acd78b49111db30c3a24c32f625b85ae59934585/src/gist_caching.py#L112-L113

pengfeiwu1999 commented 1 year ago

thanks for reply, but my question is: apply_rotary_pos_emb() function doesn't need the off-set para but in the code I mentioned above, in the 206 line of /gisting/src/gist_llama.py file , the function use off_set as input parameter,Doesn't that make an error?

pengfeiwu1999 commented 1 year ago

The gist_offset parameter is needed because, when caching an instruction, the model needs to know the length of the cached instruction to apply the position embeddings correctly. (It's not needed for T5 due to T5's relative position embedding scheme). You can see how the position embeddings are shifted here:

https://github.com/jayelm/gisting/blob/acd78b49111db30c3a24c32f625b85ae59934585/src/gist_llama.py#L114-L125

apply_rotary_pos_emb does not require the gist offset argument because it transforms the cos and sin tensors which already have the offset applied.

Note the offset parameter is not used during standard training or evaluation, because we don't actually modify any sequence lengths—the model gets the entire instruction/input/output in one go, with attention masking used to control compression, and the position embeddings are correctly applied with the original instruction length.

The offset parameter is only used in compress.py—the GistActivations class accepts a "gist offset" argument which records the length of the instruction before caching:

https://github.com/jayelm/gisting/blob/acd78b49111db30c3a24c32f625b85ae59934585/src/gist_llama.py#L633-L656

https://github.com/jayelm/gisting/blob/acd78b49111db30c3a24c32f625b85ae59934585/src/gist_caching.py#L112-L113

so when I run the llama model, it occurs "apply_rotary_pos_emb() got an unexpected keyword argument 'offset'"

pengfeiwu1999 commented 1 year ago

The gist_offset parameter is needed because, when caching an instruction, the model needs to know the length of the cached instruction to apply the position embeddings correctly. (It's not needed for T5 due to T5's relative position embedding scheme). You can see how the position embeddings are shifted here:

https://github.com/jayelm/gisting/blob/acd78b49111db30c3a24c32f625b85ae59934585/src/gist_llama.py#L114-L125

apply_rotary_pos_emb does not require the gist offset argument because it transforms the cos and sin tensors which already have the offset applied.

Note the offset parameter is not used during standard training or evaluation, because we don't actually modify any sequence lengths—the model gets the entire instruction/input/output in one go, with attention masking used to control compression, and the position embeddings are correctly applied with the original instruction length.

The offset parameter is only used in compress.py—the GistActivations class accepts a "gist offset" argument which records the length of the instruction before caching:

https://github.com/jayelm/gisting/blob/acd78b49111db30c3a24c32f625b85ae59934585/src/gist_llama.py#L633-L656

https://github.com/jayelm/gisting/blob/acd78b49111db30c3a24c32f625b85ae59934585/src/gist_caching.py#L112-L113

the error occurs in my llama training stage is
File "/data/wupf/gisting/src/gist_llama.py", line 206, in forward query_states, key_states = apply_rotary_pos_emb( TypeError: apply_rotary_pos_emb() got an unexpected keyword argument 'offset'

jayelm commented 1 year ago

Are you using the version of transformers specified in requirements.txt, specifically commit fb366b9a? You'll see that the function signature is different:

https://github.com/huggingface/transformers/blob/fb366b9a/src/transformers/models/llama/modeling_llama.py#L136-L141

(and more generally if you run into other issues it could be because you're not using the package versions specified in requirements.txt)

pengfeiwu1999 commented 1 year ago

Are you using the version of transformers specified in requirements.txt, specifically commit fb366b9a? You'll see that the function signature is different:

https://github.com/huggingface/transformers/blob/fb366b9a/src/transformers/models/llama/modeling_llama.py#L136-L141

(and more generally if you run into other issues it could be because you're not using the package versions specified in requirements.txt)

I try to install the transformers version fb366b9a, but I can't install it on my server, All other packages are installed according to requirement except transformer

jayelm commented 1 year ago

Unfortunately the codebase is only verified to work with commit fb366b9a. You might be able to get around this specific issue by just pasting in the apply_rotary_pos_emb function from the link above instead of importing it from modeling_llama, but I can't guarantee you won't run into additional issues.

pengfeiwu1999 commented 1 year ago

Unfortunately the codebase is only verified to work with commit fb366b9a. You might be able to get around this specific issue by just pasting in the apply_rotary_pos_emb function from the link above instead of importing it from modeling_llama, but I can't guarantee you won't run into additional issues.

ok I fixed it, thanks!

jayelm commented 1 year ago

Great!