lucidrains / rotary-embedding-torch

Implementation of Rotary Embeddings, from the Roformer paper, in Pytorch
MIT License
535 stars 43 forks source link

caching frequency results in RuntimeError: Trying to backward through the graph a second time #18

Closed wren93 closed 4 months ago

wren93 commented 9 months ago

Hi, thank you very much for this handy rotary embedding library. I encountered this runtime error when the rotary embedding was trying to read cached frequency at the second loss.backward() of the training (second iteration). I'm using huggingface accelerate.

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

I tried to comment out the caching part and always compute the frequency in forward() and it worked fine:

  @autocast(enabled = False)
  def forward(
      self,
      t: Tensor,
      seq_len = None,
      offset = 0
  ):
      # should_cache = (
      #     not self.learned_freq and \
      #     exists(seq_len) and \
      #     self.freqs_for != 'pixel'
      # )

      # if (
      #     should_cache and \
      #     exists(self.cached_freqs) and \
      #     (offset + seq_len) <= self.cached_freqs.shape[0]
      # ):
      #     return self.cached_freqs[offset:(offset + seq_len)]

      freqs = self.freqs

      freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
      freqs = repeat(freqs, '... n -> ... (n r)', r = 2)

      # if should_cache:
      #     self.tmp_store('cached_freqs', freqs)

      return freqs

Did I use the library incorrectly or if this could be a bug? thanks in advance.

lucidrains commented 9 months ago

@wren93 hey Weiming

thanks for reporting

i cannot seem to repro this, but what would happen if you changed the last commented out line to self.tmp_store('cached_freqs', freqs.detach()) in your training context. also try adding it to return self.cached_freqs[offset:(offset + seq_len)].detach()

lucidrains commented 9 months ago

@wren93 a small reproducible script would help

lucidrains commented 4 months ago

hmm, i think you are the only one with this issue, but please reopen if you have more information