lucidrains / rotary-embedding-torch

Implementation of Rotary Embeddings, from the Roformer paper, in Pytorch
MIT License
570 stars 44 forks source link

Model hangs on eval #15

Open GarrettMerz opened 11 months ago

GarrettMerz commented 11 months ago

Hi! I'm running an enc-dec transformer with ROPE in the first self-attention layer of the encoder and decoder. I'm noticing that in the eval stage of my model, it hangs until my job times out after about 7 epochs; when running without this package, i.e. standard learnable positional encoding or shaw-style relative postion encoding, I do not notice this behavior. Are there any obvious places in this package that might lead to a memory leak or similar (i.e. does the register_buffer play nice with with_no_grad()?)

GarrettMerz commented 11 months ago

I'm wondering if this issue may be related: https://github.com/pytorch/pytorch/issues/20275

lucidrains commented 11 months ago

hey Garrett at Madison! beautiful city, still have fond memories of it (worked at Epic Systems for a year right out of college)

yup, i think i may have an idea of what's wrong, will throw in a fix in an hour

lucidrains commented 11 months ago

@GarrettMerz basically i'm incorrectly caching by the sequence length, but it should cache the longest sequence length and slice out any subsequent calls with shorter ones

lucidrains commented 11 months ago

@GarrettMerz want to give 0.5.0 a try and see if it still hangs?

GarrettMerz commented 10 months ago

Updating this with results: this largely seems to fix things. I still see hanging behavior in cases where max output length is large and the model does not produce an EOS token before hitting the max length- i.e., if it gets stuck outputting something nonsensical like "++++++++++...", which may happen in early epochs: the length of this bad output is then cached, which causes ROPE to slow down a lot- but capping the max output length at a reasonable size generally seems to mitigate this, which is a good enough fix for now.

lucidrains commented 10 months ago

@GarrettMerz sounds good, as long as it does not hang anymore

best with your research and life out in the midwest

GarrettMerz commented 10 months ago

May need to reopen this, it seems that things are still hanging! I'm going to try to investigate more to figure out when specifically it might be happening- I'm going to use the relative position encoding in the encoder only (not the decoder) and see if that helps at all.

lucidrains commented 10 months ago

hmm, yea, i'll wait for more info from your end

you are the only one reporting this

lucidrains commented 10 months ago

@GarrettMerz could you try turning off cache altogether? https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py#L82 just to confirm that it is indeed caused by the freqs caching and not something on your end?

jozhang97 commented 8 months ago

Hi, I was also encountering this issue with the latest commit (v0.5.3). On my end, I can confirm that this is caused by caching. Setting cache_if_possible=False worked for me.

GarrettMerz commented 7 months ago

I still seem to be encountering this issue in some cases- I'm investigating more to confirm it's not an implementation problem.

lunixbochs commented 2 months ago

I'm seeing consistent hangs with 8-gpu accelerate based DDP. Hangs on NCCL comms after the first step if I'm using 3+ gpus. My sequence lengths are both different per step and per accelerator. No issue if cache is disabled, or with 2 gpus, or if I use a fixed sequence length.

lucidrains commented 2 months ago

@lunixbochs I see! thank you for this info

I'll try a way of standardizing the cache to same tensor shape across devices and ping you to give it a try when it is done

lunixbochs commented 2 months ago

the rope cache in fairseq seems to work fine: https://github.com/facebookresearch/fairseq/blob/920a548ca770fb1a951f7f4289b4d3a0c1bc226f/fairseq/modules/rotary_positional_embedding.py#L28

lucidrains commented 2 months ago

@lunixbochs sounds good, I'll take a look only if I can't figure it out

lucidrains commented 2 months ago

@lunixbochs want to try 0.8.2 and see if it resolves your issue?

lunixbochs commented 2 months ago

just updated, not hanging for me anymore. thanks!

lucidrains commented 2 months ago

if you see anything wrong with the loss curve, let me know, made some risky changes