syncdoth / RetNet

Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.
MIT License
226 stars 24 forks source link

1. Bug fix. 2. add fast long retention implement #25

Open veya2ztn opened 1 year ago

veya2ztn commented 1 year ago
  1. Fix bug for the mode with inputs_embedding rather than inputs_ids

    if inputs_embeds is None:
    inputs_embeds = self.forward_embedding(input_ids, forward_impl, inputs_embeds,past_key_values)
    else:
    if forward_impl == 'recurrent':
        inputs_embeds = inputs_embeds[:, -1:]
  2. Add fix length seq arguement when the inputs is (addtional_token, pask_kv)

    if fixed_seq_len:slen=fixed_seq_len
  3. Cached the fixed retnet_rel_pos ( thus does not need generate runtimely)

  4. add fast retention implement when the sequence length >> D**2. See https://github.com/veya2ztn/fast_retention

5.1 I set use_glu defaut to false, thus consistancy to old code. 5.2 The layer norm setting in FFN seem wrong, the self.embed_dim should be ffn_dim

if subln:
    if use_rms_norm:
        self.ffn_layernorm = RMSNorm(self.embed_dim, eps=layernorm_eps)
    else:
        self.ffn_layernorm = LayerNorm(self.embed_dim, eps=layernorm_eps)
else:
    self.ffn_layernorm = None

Anyway, I roll back to self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None

syncdoth commented 1 year ago

Other than some formatting and refactoring issues, I love the fast-retention implementation! I was hoping to get into that. Thanks for your work!

pkpro commented 1 year ago

Will this be merged?

syncdoth commented 1 year ago

There are some code styling issues and some things I don't understand fully. I think it's great to have its own branch for now.