Closed roholazandie closed 4 years ago
The stored patterns have to be static so it will operate as just a pattern retrieval task. This blog post written by Johannes Brandstetter goes into more detail.
from modules import Hopfield
import torchvision
import matplotlib.pyplot as plt
train_mnist = torchvision.datasets.MNIST('/hopfield-layers/dataset', train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
]))
beta = 1.0
Y = train_mnist.data[0:5, :, :].reshape(-1, 28*28).unsqueeze(0).float() #sored patterns
R = train_mnist.data[3, :, :].reshape(-1, 28*28).unsqueeze(0).float() # state patterns
hopfield = Hopfield(
scaling=beta,
# do not project layer input
state_pattern_as_static=True,
stored_pattern_as_static=True,
pattern_projection_as_static=True,
# do not pre-process layer input
normalize_stored_pattern=False,
normalize_stored_pattern_affine=False,
normalize_state_pattern=False,
normalize_state_pattern_affine=False,
normalize_pattern_projection=False,
normalize_pattern_projection_affine=False,
# do not post-process layer output
disable_out_projection=True
)
# tuple of stored_pattern, state_pattern, pattern_projection
result = hopfield((Y, R, Y))
result = result.unsqueeze(0).unsqueeze(0).reshape(28, 28).detach().numpy()
plt.imshow(result, cmap='gray', interpolation='none')
plt.show()
The pattern projections as well as normalisations need to be static/off, as @a-kore pointed out. For more information please have a look at the accompanying blog post and the enclosed code samples.
Closing issue.
I tried to use your layer Hopfield based on the example in the blog to recover the original picture in mnist. But it doesn't work. What am I doing wrong here? I followed all steps