j-min / MoChA-pytorch

PyTorch Implementation of "Monotonic Chunkwise Attention" (ICLR 2018)
76 stars 20 forks source link

implementation of `safe_cumprod` #2

Open bo-son opened 5 years ago

bo-son commented 5 years ago

cumprod in the MoChA paper is defined to be exclusive, while the safe_cumprod in this repo does not. Shouldn't it be:

def safe_cumprod(self, x, exclusive=False):
    """Numerically stable cumulative product by cumulative sum in log-space"""
    bsz = x.size(0)
    logsum = torch.cumsum(torch.log(torch.clamp(x, min=1e-20, max=1)), dim=1)
    if exclusive:
        logsum = torch.cat([torch.zeros(bsz, 1).to(logsum), logsum], dim=1)[:, :-1]
    return torch.exp(logsum)

And in the function soft() of MonotonicAttention:

cumprod_1_minus_p = self.safe_cumprod(1 - p_select, exclusive=True)
Cescfangs commented 5 years ago

@bo-son I think you're right