tatsy / keras-generative

Deep generative networks, coded with Keras.
122 stars 28 forks source link

what is ConditionalDataset()? #11

Closed LiangqunLu closed 6 years ago

LiangqunLu commented 6 years ago

in mnist data load, it requires ConditionalDataset, which is not in the same folder. I am wondering what that is?

from .datasets import ConditionalDataset

def load_data(): (x_train, ytrain), = keras.datasets.mnist.load_data()

x_train = np.pad(x_train, ((0, 0), (2, 2), (2, 2)), 'constant', constant_values=0)
x_train = (x_train[:, :, :, np.newaxis] / 255.0).astype('float32')
y_train = keras.utils.to_categorical(y_train)
y_train = y_train.astype('float32')

**datasets = ConditionalDataset()**
datasets.images = x_train
datasets.attrs = y_train
datasets.attr_names = [str(i) for i in range(10)]

return datasets
tatsy commented 6 years ago

See here. https://github.com/tatsy/keras-generative/blob/master/datasets/datasets.py#L16