ml-jku / hopfield-layers

Hopfield Networks is All You Need
https://ml-jku.github.io/hopfield-layers/
Other
1.69k stars 189 forks source link

Trying out the retrieval for Hopfield #12

Closed roholazandie closed 4 years ago

roholazandie commented 4 years ago

I tried to use your layer Hopfield based on the example in the blog to recover the original picture in mnist. But it doesn't work. What am I doing wrong here? I followed all steps


from modules import Hopfield
import torchvision
import matplotlib.pyplot as plt

train_mnist = torchvision.datasets.MNIST('/hopfield-layers/dataset', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

beta = 1.0

Y = train_mnist.data[0:5, :, :].reshape(-1, 28*28).unsqueeze(0).float() #sored patterns
R = train_mnist.data[3, :, :].reshape(-1, 28*28).unsqueeze(0).float() # state patterns

hopfield = Hopfield(
    input_size=784,                          # R
    hidden_size=784,
    stored_pattern_size=784,                 # Y
    pattern_projection_size=784,             # Y
    scaling=beta,
    pattern_projection_as_connected=True)  # Eq. (32)

# tuple of stored_pattern, state_pattern, pattern_projection
result = hopfield((Y, R, Y))

result = result.unsqueeze(0).unsqueeze(0).reshape(28, 28).detach().numpy()

plt.imshow(result, cmap='gray', interpolation='none')
plt.show()
a-kore commented 4 years ago

The stored patterns have to be static so it will operate as just a pattern retrieval task. This blog post written by Johannes Brandstetter goes into more detail.

from modules import Hopfield
import torchvision
import matplotlib.pyplot as plt

train_mnist = torchvision.datasets.MNIST('/hopfield-layers/dataset', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))

beta = 1.0

Y = train_mnist.data[0:5, :, :].reshape(-1, 28*28).unsqueeze(0).float() #sored patterns
R = train_mnist.data[3, :, :].reshape(-1, 28*28).unsqueeze(0).float() # state patterns

hopfield = Hopfield(
    scaling=beta,

    # do not project layer input
    state_pattern_as_static=True,
    stored_pattern_as_static=True,
    pattern_projection_as_static=True,

    # do not pre-process layer input
    normalize_stored_pattern=False,
    normalize_stored_pattern_affine=False,
    normalize_state_pattern=False,
    normalize_state_pattern_affine=False,
    normalize_pattern_projection=False,
    normalize_pattern_projection_affine=False,

    # do not post-process layer output
    disable_out_projection=True
)

# tuple of stored_pattern, state_pattern, pattern_projection
result = hopfield((Y, R, Y))

result = result.unsqueeze(0).unsqueeze(0).reshape(28, 28).detach().numpy()

plt.imshow(result, cmap='gray', interpolation='none')
plt.show()
bschaefl commented 4 years ago

The pattern projections as well as normalisations need to be static/off, as @a-kore pointed out. For more information please have a look at the accompanying blog post and the enclosed code samples.

Closing issue.