microsoft / torchscale

Foundation Architecture for (M)LLMs
https://aka.ms/GeneralAI
MIT License
2.98k stars 201 forks source link

Introducing padding_mask to RetNet #85

Open xtwigs opened 7 months ago

xtwigs commented 7 months ago

As opposed to the other architectures in this package, RetNet doesn't have support for padding as far as I'm aware. I was thinking the best place to introduce it was along with the positional mask. Here we don't have the luxury of the softmax, so we can't simply mask with infinity in the relevant positions.

From my attempt, the parallel code would be something along the following (assuming left padding and a padding_mask shape of (bsz, seq_len):

sin = torch.sin(index[:, None] * self.angle[None, :])
cos = torch.cos(index[:, None] * self.angle[None, :])
mask = torch.tril(torch.ones(slen, slen).to(self.decay))
mask = torch.masked_fill(
    index[:, None] - index[None, :], ~mask.bool(), float("inf")
)
mask = torch.masked_fill(mask.unsqueeze(0), padding_mask.unsqueeze(-1), float("inf"))
mask = torch.exp(mask.unsqueeze(1) * self.decay[:, None, None])
mask = mask / mask.sum(dim=-1, keepdim=True).sqrt()
mask = torch.nan_to_num(mask)
retention_rel_pos = ((sin, cos), mask)

This would imply expanding the mask here instead of broadcasting it in the forward method.

In the recurrent formulation, perhaps masking the scaling factor accordingly works?

def recurrent_forward(
        self, qr, kr, v, decay, padding_mask=None, incremental_state=None
    ):
        bsz = v.size(0)

        v = v.view(bsz,  self.num_heads, self.head_dim, 1)
        kv = kr * v

        if "prev_key_value" in incremental_state:
            prev_kv = incremental_state["prev_key_value"]
            prev_scale = incremental_state["scale"]
            scale = prev_scale * decay + 1
            kv = prev_kv * (prev_scale.sqrt() * decay / scale.sqrt()).view(
                bsz, self.num_heads, 1, 1
            ) + kv / scale.sqrt().view(bsz, self.num_heads, 1, 1)
            # kv = prev_kv * decay.view(self.num_heads, 1, 1) + kv

        else:
            scale = torch.ones_like(decay)

        incremental_state["prev_key_value"] = kv
        scale = scale.unsqueeze(0).masked_fill(padding_mask.unsqueeze(1), 0)
        incremental_state["scale"] = scale

        output = torch.sum(qr * kv, dim=3)
        return output

I would like some help on this, perhaps the authors have a better approach? @donglixp @sunyt32

sunyt32 commented 7 months ago

In parallel_forward, you can try setting the padding as 0 after mask = mask / mask.sum(dim=-1, keepdim=True).sqrt(). Your implementation also looks fine to me.

In inference, the padding token doesn't influence the subsequent encoding, maybe just skipping it is enough?

xtwigs commented 7 months ago

Thank you for the quick reply. My reasoning for the parallel code was so that the decay would start from the first non-pad token instead of an arbitrary decay**idx. I'll test the two variants and see if there's any meaningful difference.

For the forward_recurrent code, I believe I'm ignoring the previous pad tokens as the prev_scale will be 0 and scale 1. Thus ignoring the previous kv entry.

Would you be interested in merging this code to the torchscale package? I will fork the repo with the changes if that's the case. Thank you for the help nonetheless :)