osh / KerasGAN

A couple of simple GANs in Keras
501 stars 177 forks source link

Upgrade to Keras2? #12

Open EladNoy opened 7 years ago

EladNoy commented 7 years ago

With Keras1 now being deprecated, a Keras2 version would be greatly appreciated. I tried converting it myself but Keras2 does not support the batchnorm mode=2 option, so it will probably require some sort of a workaround.

engharat commented 7 years ago

I was stuck with your same problem. I ended up developing a batchnorm version that uses always batchnorm mode = 2. you can easily edit the keras file where bn is defined, and you can modify it so it will never use batchnorm training accumulated statistics.

frnk99 commented 7 years ago

Can you share the code. @engharat Please.

engharat commented 7 years ago

Sure. Here is a link to the code: https://drive.google.com/open?id=0B0E8DCU-EnYRR2l3aV9oTkJORHc . The file needs to be put in the same folder of your script and it needs to be imported of course, then you can substitute any occurrence of BatchNormalization layer in the generator / discriminator code with the layer BatchNormGAN.

Or if you prefer the code:

`# -- coding: utf-8 -- from future import absolute_import

from keras.engine import Layer, InputSpec from keras import initializers from keras import regularizers from keras import constraints from keras import backend as K from keras.legacy import interfaces

class BatchNormGAN(Layer): """Batch normalization layer (Ioffe and Szegedy, 2014).

Normalize the activations of the previous layer at each batch,
i.e. applies a transformation that maintains the mean activation
close to 0 and the activation standard deviation close to 1.

# Arguments
    axis: Integer, the axis that should be normalized
        (typically the features axis).
        For instance, after a `Conv2D` layer with
        `data_format="channels_first"`,
        set `axis=1` in `BatchNormGAN`.
    momentum: Momentum for the moving average.
    epsilon: Small float added to variance to avoid dividing by zero.
    center: If True, add offset of `beta` to normalized tensor.
        If False, `beta` is ignored.
    scale: If True, multiply by `gamma`.
        If False, `gamma` is not used.
        When the next layer is linear (also e.g. `nn.relu`),
        this can be disabled since the scaling
        will be done by the next layer.
    beta_initializer: Initializer for the beta weight.
    gamma_initializer: Initializer for the gamma weight.
    moving_mean_initializer: Initializer for the moving mean.
    moving_variance_initializer: Initializer for the moving variance.
    beta_regularizer: Optional regularizer for the beta weight.
    gamma_regularizer: Optional regularizer for the gamma weight.
    beta_constraint: Optional constraint for the beta weight.
    gamma_constraint: Optional constraint for the gamma weight.

# Input shape
    Arbitrary. Use the keyword argument `input_shape`
    (tuple of integers, does not include the samples axis)
    when using this layer as the first layer in a model.

# Output shape
    Same shape as input.

# References
    - [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/abs/1502.03167)
"""

@interfaces.legacy_batchnorm_support
def __init__(self,
             axis=-1,
             momentum=0.99,
             epsilon=1e-3,
             center=True,
             scale=True,
             beta_initializer='zeros',
             gamma_initializer='ones',
             moving_mean_initializer='zeros',
             moving_variance_initializer='ones',
             beta_regularizer=None,
             gamma_regularizer=None,
             beta_constraint=None,
             gamma_constraint=None,
             **kwargs):
    super(BatchNormGAN, self).__init__(**kwargs)
    self.supports_masking = True
    self.axis = axis
    self.momentum = momentum
    self.epsilon = epsilon
    self.center = center
    self.scale = scale
    self.beta_initializer = initializers.get(beta_initializer)
    self.gamma_initializer = initializers.get(gamma_initializer)
    self.moving_mean_initializer = initializers.get(moving_mean_initializer)
    self.moving_variance_initializer = initializers.get(moving_variance_initializer)
    self.beta_regularizer = regularizers.get(beta_regularizer)
    self.gamma_regularizer = regularizers.get(gamma_regularizer)
    self.beta_constraint = constraints.get(beta_constraint)
    self.gamma_constraint = constraints.get(gamma_constraint)

def build(self, input_shape):
    dim = input_shape[self.axis]
    if dim is None:
        raise ValueError('Axis ' + str(self.axis) + ' of '
                         'input tensor should have a defined dimension '
                         'but the layer received an input with shape ' +
                         str(input_shape) + '.')
    self.input_spec = InputSpec(ndim=len(input_shape),
                                axes={self.axis: dim})
    shape = (dim,)

    if self.scale:
        self.gamma = self.add_weight(shape,
                                     name='gamma',
                                     initializer=self.gamma_initializer,
                                     regularizer=self.gamma_regularizer,
                                     constraint=self.gamma_constraint)
    else:
        self.gamma = None
    if self.center:
        self.beta = self.add_weight(shape,
                                    name='beta',
                                    initializer=self.beta_initializer,
                                    regularizer=self.beta_regularizer,
                                    constraint=self.beta_constraint)
    else:
        self.beta = None
    self.moving_mean = self.add_weight(
        shape,
        name='moving_mean',
        initializer=self.moving_mean_initializer,
        trainable=False)
    self.moving_variance = self.add_weight(
        shape,
        name='moving_variance',
        initializer=self.moving_variance_initializer,
        trainable=False)
    self.built = True

def call(self, inputs, training=None):
    input_shape = K.int_shape(inputs)
    # Prepare broadcasting shape.
    ndim = len(input_shape)
    reduction_axes = list(range(len(input_shape)))
    del reduction_axes[self.axis]

    normed, mean, variance = K.normalize_batch_in_training(
        inputs, self.gamma, self.beta, reduction_axes,
        epsilon=self.epsilon)

    return normed #K.in_train_phase(normed,
                   #         normalize_inference,
                   #         training=True)

def get_config(self):
    config = {
        'axis': self.axis,
        'momentum': self.momentum,
        'epsilon': self.epsilon,
        'center': self.center,
        'scale': self.scale,
        'beta_initializer': initializers.serialize(self.beta_initializer),
        'gamma_initializer': initializers.serialize(self.gamma_initializer),
        'moving_mean_initializer': initializers.serialize(self.moving_mean_initializer),
        'moving_variance_initializer': initializers.serialize(self.moving_variance_initializer),
        'beta_regularizer': regularizers.serialize(self.beta_regularizer),
        'gamma_regularizer': regularizers.serialize(self.gamma_regularizer),
        'beta_constraint': constraints.serialize(self.beta_constraint),
        'gamma_constraint': constraints.serialize(self.gamma_constraint)
    }
    base_config = super(BatchNormGAN, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))``
frnk99 commented 7 years ago

thank you! @engharat