microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
3k stars 201 forks source link

RetNet : Check consistency of each forward mode #54

Closed mmorinag127 closed 1 year ago

mmorinag127 commented 1 year ago

Hello authors,

I'm really happy to see this great work! I have one question or request about the consistency of output from each forward mode. I have been comparing three outputs by using below simple code.

import torch
from ret_net import MultiScaleRetention, RetNetRelPos

def test_msr():
    seq_len = 16
    dim = 16
    B = 3
    n_heads = 4
    chunk_size = 4
    x = torch.rand(B, seq_len, dim)
    x = torch.arange(B*seq_len*dim).view(B, seq_len, dim)
    x = x/x.max()

    xpos = RetNetRelPos(dim, n_heads, chunk_size)
    layer = MultiScaleRetention(dim, n_heads)

    # parallel
    output_p, _ = layer(x, xpos(seq_len, False, False), False, None)

    # recurrent
    output_r = []
    incremental_state = {}
    for idx in range(seq_len):
        rpos = xpos(idx+1, True, False)
        xi = x[:, idx, :].unsqueeze(1)
        out_r, incremental_state = layer(xi, rpos, False, incremental_state)
        output_r.append(out_r)
    output_r = torch.concat(output_r, dim=1)

    # chunkwise
    output_c, _ = layer(x, xpos(seq_len, False, True), True, None)

    check_diff('parallel  - recurrent', output_p, output_r)
    check_diff('parallel  - chunkwise', output_p, output_c)
    check_diff('recurrent - chunkwise', output_r, output_c)

def check_diff(name, A, B, eps=1e-6):
    D = A - B
    C = A/B
    print(name, torch.sum(torch.abs(D)))
    idx = torch.abs(D) < eps
    print(idx[0, :, 0])
    print()

And I got below result.

parallel  - recurrent tensor(6.4814e-07, grad_fn=<SumBackward0>)
tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True])

parallel  - chunkwise tensor(5.6081, grad_fn=<SumBackward0>)
tensor([ True,  True,  True,  True, False, False, False, False, False, False,
        False, False, False, False, False, False])

recurrent - chunkwise tensor(5.6081, grad_fn=<SumBackward0>)
tensor([ True,  True,  True,  True, False, False, False, False, False, False,
        False, False, False, False, False, False])

As you might see good agreement between parallel and recurrent results. But the chunkwise output doesn't agree with both parallel and recurrent, after the 2nd chunk. Could you give me a hint to understand this?

(I have already pulled the latest main branch)

Thanks a lot, Masahiro Morinaga

mmorinag127 commented 1 year ago

Note that: To simplify the problem, I changed the original code(MSR) below

class MultiScaleRetention(nn.Module):
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
    def forward(
        self,
        x,
        rel_pos,
        chunkwise_recurrent=False,
        incremental_state=None
    ):
        bsz, tgt_len, _ = x.size()
        (sin, cos), inner_mask = rel_pos

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        g = self.g_proj(x)

        k *= self.scaling
        q = q.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)
        k = k.view(bsz, tgt_len, self.num_heads, self.key_dim).transpose(1, 2)

        qr = theta_shift(q, sin, cos)
        kr = theta_shift(k, sin, cos)

        if incremental_state is not None:
            output, incremental_state = self.recurrent_forward(qr, kr, v, inner_mask, incremental_state)
        elif chunkwise_recurrent:
            output = self.chunk_recurrent_forward(qr, kr, v, inner_mask)
        else:
            output = self.parallel_forward(qr, kr, v, inner_mask)

        #output = self.group_norm(output)
        output = output.reshape(bsz, tgt_len, self.head_dim * self.num_heads)
        # output = self.gate_fn(g) * output
        # output = self.out_proj(output)
        return output, incremental_state
sunyt32 commented 1 year ago

I notice that you comment on the following line:

#output = self.group_norm(output)

It is essential to keep identity among different forms. Besides, to check the consistency, we usually set eps=0 in group_norm. The inconsistency comes from small weight values and eps, where the small initialization will be fixed after training.

mmorinag127 commented 1 year ago

Thanks a lot for your comment. I see. Now I understand what's going on.

XintianHan commented 11 months ago

Thanks a lot for your comment. I see. Now I understand what's going on.

Hi. Did you check the consistency again? I ran your code and fixed the group_norm thing. There is still a large diff between parallel forward and chunkwise forward.

mmorinag127 commented 11 months ago

Hi, I confirmed three outputs are exactly the same after the group_norm.

bin123apple commented 11 months ago

Thanks a lot for your comment. I see. Now I understand what's going on.

Hi. Did you check the consistency again? I ran your code and fixed the group_norm thing. There is still a large diff between parallel forward and chunkwise forward.

You should use self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=0, elementwise_affine=False)) # Check consistancy instead of self.group_norm = MultiwayWrapper(args, RMSNorm(self.head_dim, eps=args.layernorm_eps, elementwise_affine=False)). Then all the output will be True

sunyt32 commented 11 months ago

Same question. When eps=0, the consistency will be great, but the training stability will be a problem. Luckily, RMSNorm is much more stable than LayerNorm, then the eps could be much smaller.

Dao007forever commented 10 months ago

In the paper, it said GroupNorm replaces LayerNorm, but in the code, it uses RMSNorm. Looks like the goal is to normalize each head separately? Could you clarify on that part?

donglixp commented 10 months ago

In the paper, it said GroupNorm replaces LayerNorm, but in the code, it uses RMSNorm. Looks like the goal is to normalize each head separately? Could you clarify on that part?

@Dao007forever We used LN in the experiments of our arxiv paper. In our latest experiments, we found that RMSNorm is more stable, especially to the LN.eps. So we switch the default one to RMSNorm.