def _get_D(self, sequence_length):
# D[n,m] = gamma ** (n - m) if n >= m else 0
D = torch.zeros((sequence_length, sequence_length), requires_grad=False)
for n in range(sequence_length):
for m in range(sequence_length):
if n >= m:
D[n, m] = self.gamma ** (n - m)
return D
gets really slow for long sequence lengths, resulting in very low GPU utility.
by changing to the style below it gets better. Not sure if it's perfectly correct but for gamma < 1 it seems all good.
def _get_D(self, sequence_length):
n = torch.arange(sequence_length).unsqueeze(1)
m = torch.arange(sequence_length).unsqueeze(0)
# Broadcast self.gamma ** (n - m) with appropriate masking to set values where n < m to 0
D = (self.gamma ** (n - m)) * (n >= m).float() #this results in some NaN when n is much larger than m
# fill the NaN with 0
D[D != D] = 0
return D
First, many thanks for your implementation!
It seems that the _get_D function
gets really slow for long sequence lengths, resulting in very low GPU utility.
by changing to the style below it gets better. Not sure if it's perfectly correct but for gamma < 1 it seems all good.