Jamie-Stirling / RetNet

An implementation of "Retentive Network: A Successor to Transformer for Large Language Models"
MIT License
1.14k stars 99 forks source link

_get_D function very slow for long sequence #7

Closed ZuowenWang0000 closed 11 months ago

ZuowenWang0000 commented 11 months ago

First, many thanks for your implementation!

It seems that the _get_D function

def _get_D(self, sequence_length):
    # D[n,m] = gamma ** (n - m) if n >= m else 0
    D = torch.zeros((sequence_length, sequence_length), requires_grad=False)
    for n in range(sequence_length):
        for m in range(sequence_length):
            if n >= m:
                D[n, m] = self.gamma ** (n - m)
    return D

gets really slow for long sequence lengths, resulting in very low GPU utility.

by changing to the style below it gets better. Not sure if it's perfectly correct but for gamma < 1 it seems all good.

def _get_D(self, sequence_length):
    n = torch.arange(sequence_length).unsqueeze(1)
    m = torch.arange(sequence_length).unsqueeze(0)

    # Broadcast self.gamma ** (n - m) with appropriate masking to set values where n < m to 0
    D = (self.gamma ** (n - m)) * (n >= m).float()  #this results in some NaN when n is much larger than m
    # fill the NaN with 0
    D[D != D] = 0

    return D
Jamie-Stirling commented 11 months ago

Thanks for this. I've tested this and the tests in /src/tests.py all pass.

I've committed this change now.