syncdoth / RetNet

Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.
MIT License
226 stars 24 forks source link

Changelog of official implementation #10

Closed donglixp closed 1 year ago

donglixp commented 1 year ago

Thanks for the well-written package! The RetNet's official implementation had several updates at https://github.com/microsoft/unilm/blob/master/retnet/README.md#changelog .

syncdoth commented 1 year ago

It's an honor to get the acknowledgment from the author himself! I'm definitely planning to reference the official implementation more in the next release. Thanks for the pointers!

syncdoth commented 1 year ago

Update: Currently on the track to implement based on torchscale version.

Parallel and Recurrent are done, but having some issues with chunkwise.

Problem Description:

I'll leave a note of what hinders chunkwise from being equivalent to the parallel or recurrent forward. (Actually, the same problem is present in the official torchscale code too.)

The decays are normalized in RetNetRelPos, as in here. However, this leads to differently scaled decays than parallel or recurrent.

In parallel, the decay with slen=4 looks sth like this:

(A)

[1    0   0  0]
[r    1   0  0]
[r^2  r   1  0]
[r^3  r^2 r  1]

and after dividing by scale, it becomes

(B)

[1    0   0  0]   /  [1               ]
[r    1   0  0]   /  [r + 1           ]
[r^2  r   1  0]   /  [r^2 + r + 1     ]
[r^3  r^2 r  1]   /  [r^3 +r^2 + r + 1]

For chunkwise with slen=4, chunk_size=2, the decay_mask is the following after scaling:

(C)

[1    0]  /  [1    ]
[r    1]  /  [r + 1]

Since the cross_decay=r^2, in cross-chunk aggregation from the formula $KV{i} = KV{i-1} * D + KV_{i}$ ($D$ = cross_decay=r^2), the decay applied to previous chunk becomes

(D)

  [r   + 1  ]  /  [r + 1]  *   [ KV_{i-1} ] * r^2   +    [r + 1]  /  [r + 1]  *  [  KV_i  ]

= [r^3 + r^2]  /  [r + 1]  *   [ KV_{i-1} ]         +    [r + 1]  /  [r + 1]  *  [  KV_i  ]

= [r^3 + r^2, r + 1] / [r + 1]   *   [[ KV_{i-1} ]
                                      [   KV_i   ]]

= [r^3   r^2  r   1] / [r + 1]   *   [[   KV_0   ]
                                      [   KV_1   ]
                                      [   KV_2   ]
                                      [   KV_3   ]]

Compare (B) with (D) and notice that while the row vector has the correct number of exponents, the scale that divides it should be sum(r**i for i in range(4)), but it is just sum(r**i for i in range(2)).

hyunwoongko commented 11 months ago

Hi @syncdoth, Thanks for great implementation! I think the difference between chunkwise and others came from group norm eps. I checked when layernorm_eps is 0, all results are same. And author of the paper said, after training validation PPL is almost same, so it's not a big problem.

Hi @N0r9st, the incorrect results come from group norm eps. If eps=0, the chunk representation is the same as the parallel one in math. You can try that.

Another reason is that the initialization of Retention is small, which amplifies the difference. However, After training, the validation ppl will be almost the same.

You can check this issue: https://github.com/microsoft/torchscale/issues/77

hyunwoongko commented 11 months ago
if __name__ == '__main__':
    bsz = 1
    seq_len = 6
    hidden_size = 8

    inputs = torch.randn(bsz, seq_len, hidden_size)
    attention_mask = None

    config = RetNetConfig(
        vocab_size=51200,
        initializer_range=0.02,
        is_decoder=True,
        pad_token_id=1,
        eos_token_id=1,
        output_retentions=False,
        use_cache=True,
        forward_impl='parallel',
        activation_fn="swish",
        dropout=0.0,  # dropout probability
        activation_dropout=0.0,  # dropout probability after activation in FFN.
        decoder_embed_dim=hidden_size,  # decoder embedding dimension
        decoder_value_embed_dim=hidden_size,  # decoder value embedding dimension
        decoder_ffn_embed_dim=hidden_size,  # decoder embedding dimension for FFN
        decoder_layers=1,  # num decoder layers
        decoder_retention_heads=4,  # num decoder retention heads
        decoder_normalize_before=True,  # apply layer_norm before each decoder block
        embedding_layer_norm=False,  # add layer_norm to embedding
        no_scale_embedding=False,  # if True, dont scale embeddings
        recurrent_chunk_size=3,
        use_glu=True,  # use GLU instead of FFN
        z_loss_coeff=0.0,  # coefficient for z loss: TODO: 1e-4
        use_lm_decay=False,
        deepnorm=False,
        subln=False,
        layer_norm_eps=0,
        tie_word_embeddings=True,
    )

    attn = MultiScaleRetention(config)
    attn.eval()
    pos = RetNetRelPos(config)

    pos_mode = "parallel"
    pos_out = pos.forward(
        slen=seq_len, retention_mask=attention_mask, forward_impl=pos_mode, get_decay_scale=True,
    )
    parallel_outputs = attn.forward(
        hidden_states=inputs, retention_mask=attention_mask, forward_impl=pos_mode, rel_pos=pos_out, use_cache=True,
    )[0]

    pos_mode = "recurrent"
    recurrent_outputs = []
    past_key_value = None
    for i in range(seq_len):
        pos_out = pos.forward(
            slen=i, retention_mask=attention_mask, forward_impl=pos_mode, get_decay_scale=True,
        )
        if attention_mask is not None:
            attn_out = attn.forward(
                hidden_states=inputs[:, i:i + 1], retention_mask=attention_mask[:, i:i + 1],
                forward_impl=pos_mode, rel_pos=pos_out, past_key_value=past_key_value, use_cache=True
            )
        else:
            attn_out = attn.forward(
                hidden_states=inputs[:, i:i + 1], retention_mask=None,
                forward_impl=pos_mode, rel_pos=pos_out, past_key_value=past_key_value, use_cache=True
            )
        past_key_value = attn_out[1]
        recurrent_outputs.append(attn_out[0])
    recurrent_outputs = torch.cat(recurrent_outputs, dim=1)

    pos_mode = "chunkwise"
    pos_out = pos.forward(
        slen=seq_len, retention_mask=attention_mask, forward_impl=pos_mode, get_decay_scale=True,
        recurrent_chunk_size=config.recurrent_chunk_size
    )
    chunked_outputs = attn.forward(
        hidden_states=inputs, retention_mask=None,
        forward_impl=pos_mode, rel_pos=pos_out, use_cache=True,
    )[0]

    print("parallel", parallel_outputs)
    print("========================================")
    print("recurrent", recurrent_outputs)
    print("========================================")
    print("chunkwise", chunked_outputs)
parallel tensor([[[ 0.0661,  0.0118, -0.0007, -0.0371,  0.0423, -0.0271, -0.0077,
           0.0454],
         [ 0.1802, -0.0137,  0.0067,  0.0878,  0.0504,  0.0784,  0.0541,
           0.0457],
         [ 0.1028,  0.0411,  0.0235, -0.0271,  0.0276, -0.0670, -0.0984,
           0.0819],
         [ 0.0462,  0.0080,  0.1094,  0.0696, -0.0110, -0.1273, -0.0945,
           0.0626],
         [ 0.0629, -0.0184,  0.1094,  0.0507, -0.0508, -0.0019, -0.0762,
           0.0003],
         [-0.0948,  0.0072, -0.1295, -0.0596,  0.0233,  0.0751,  0.0730,
          -0.0602]]], grad_fn=<UnsafeViewBackward0>)
========================================
recurrent tensor([[[ 0.0661,  0.0118, -0.0007, -0.0371,  0.0423, -0.0271, -0.0077,
           0.0454],
         [ 0.1802, -0.0137,  0.0067,  0.0878,  0.0504,  0.0784,  0.0541,
           0.0457],
         [ 0.1028,  0.0411,  0.0235, -0.0271,  0.0276, -0.0670, -0.0984,
           0.0819],
         [ 0.0462,  0.0080,  0.1094,  0.0696, -0.0110, -0.1273, -0.0945,
           0.0626],
         [ 0.0629, -0.0184,  0.1094,  0.0507, -0.0508, -0.0019, -0.0762,
           0.0003],
         [-0.0948,  0.0072, -0.1295, -0.0596,  0.0233,  0.0751,  0.0730,
          -0.0602]]], grad_fn=<CatBackward0>)
========================================
chunkwise tensor([[[ 0.0661,  0.0118, -0.0007, -0.0371,  0.0423, -0.0271, -0.0077,
           0.0454],
         [ 0.1802, -0.0137,  0.0067,  0.0878,  0.0504,  0.0784,  0.0541,
           0.0457],
         [ 0.1028,  0.0411,  0.0235, -0.0271,  0.0276, -0.0670, -0.0984,
           0.0819],
         [ 0.0462,  0.0080,  0.1094,  0.0696, -0.0110, -0.1273, -0.0945,
           0.0626],
         [ 0.0629, -0.0184,  0.1094,  0.0507, -0.0508, -0.0019, -0.0762,
           0.0003],
         [-0.0948,  0.0072, -0.1295, -0.0596,  0.0233,  0.0751,  0.0730,
          -0.0602]]], grad_fn=<UnsafeViewBackward0>)
syncdoth commented 11 months ago

@hyunwoongko that's true! I can also confirm that setting groupnorm_eps to 0 or small number (1e-15) removes the differences in outputs.

One problem might be that the kv_cache from chunkwise forward is still not exactly the same (groupnorm doesn't affect this). I'm sure there should be a way around it, haven't thought about it enough yet 😊