oarriaga / STN.keras

Implementation of spatial transformer networks (STNs) in keras 2 with tensorflow as backend.
MIT License
278 stars 75 forks source link

How to use this code to recognize mnist of size 28 * 28? #12

Open ghost opened 5 years ago

ghost commented 5 years ago

Hi, thanks for providing the code. I have run the code and the result worked out well ,but when I try to transform the code to recognize mnist of size 28 * 28 I encounter some problems. Really appreciate if anyone could help. Here is my code.

import keras.backend as K
from keras.datasets import mnist
from keras.optimizers import Adam
from src.models import STN
import matplotlib.pyplot as plt
import keras as k

num_classes = 10
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train[:50000]
y_train = y_train[:50000]
x_train = x_train.reshape((x_train.shape[0], 28, 28, 1))
x_test = x_test.reshape((x_test.shape[0], 28, 28, 1))
y_test = k.utils.to_categorical(y_test, num_classes)
y_train = k.utils.to_categorical(y_train, num_classes)

model = STN(input_shape=(28, 28, 1), sampling_size=(14, 14))
model.compile(loss='categorical_crossentropy', optimizer=Adam())
input_image = model.input
output_STN = model.get_layer('bilinear_interpolation_1').output
STN_function = K.function([input_image], [output_STN])

num_epochs = 3
batch_size = 10
model.fit(x_train, y_train, batch_size=batch_size, epochs=num_epochs)
image_result = STN_function([x_train[:10]])
for i in range(2):
    plt.imshow(x_train[i].reshape(28, 28), cmap='gray')
    plt.show()
    image = K.np.squeeze(image_result[0][i])
    plt.imshow(image, cmap='gray')
    plt.show()

And here is the result, I couldn't get the transformed image but the whole black. image image

What do I need to do when I change to other types of datasets? The loss stuck at around 2.3000 after 3 epochs of training.