loudinthecloud / pytorch-ntm

Neural Turing Machines (NTM) - PyTorch Implementation
BSD 3-Clause "New" or "Revised" License
582 stars 128 forks source link

Corrected the width of the circular convolution adjustment #4

Closed JulesGM closed 6 years ago

JulesGM commented 6 years ago

The code concats 2 elements to each side but only needs to concat 1 to each side.

Tested with the following code

import torch
import torch.nn
from torch.nn import functional as F
from torch.autograd import Variable
from random import randint

def _convolve_original(w, s):
    """Circular convolution implementation."""
    assert s.size(0) == 3
    t = torch.cat([w[-2:], w, w[:2]])
    c = F.conv1d(t.view(1, 1, -1), s.view(1, 1, -1)).view(-1)
    return c[1:-1]

def _convolve_new(w, s):
    """Circular convolution implementation."""
    assert s.size(0) == 3
    t = torch.cat([w[-1:], w, w[:1]])
    c = F.conv1d(t.view(1, 1, -1), s.view(1, 1, -1)).view(-1)
    return c

for i in range(10000):
    N = randint(10, 1000)
    w = Variable(torch.zeros([N]))
    torch.nn.init.uniform(w)

    s = Variable(torch.zeros([3]))
    torch.nn.init.uniform(s)

    assert (_convolve_original(w, s) == _convolve_new(w, s)).all()
JulesGM commented 6 years ago

Ok, I fixed the title. I will do the same with the other PR.