haidark / WaveletDeconv

Neural network layer code to implement wavelet deconvolutions
34 stars 11 forks source link

Test result on artificial data in Section 5.1 in original paper #3

Open qideng7 opened 4 years ago

qideng7 commented 4 years ago

Hi,

This is a wonderful work! I was exploring the testing on artificial data part, as 5.1 in your original paper. But I couldn't achieve the result as shown in Figure 3, especially last plot. My naive thought is the vanishing gradients on the learnable filter width in the 1st layer. May I have your suggestions on training on this test data? Based on the architecture description: "We train two networks on examples from each class and compare their performance. The baseline network is a 4 layer CNN with Max-pooling [21] ending with a single unit for classification. The other network replaces the first layer with a WD layer while maintaining the same number of parameters. Both networks are optimized with Adam [20] using a fixed learning rate of 0.001 and a batch size of 4.", I was implementing this network:

# -*- coding: utf-8 -*-

import scipy
import scipy.signal
import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, activations, initializers, constraints, regularizers
from tensorflow.keras.models import Sequential, model_from_json, load_model
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten
from tensorflow.keras.layers import Convolution2D, MaxPool2D
from tensorflow.keras.initializers import Constant, RandomUniform, VarianceScaling
from matplotlib import pyplot as plt

# generate dummy data
N = 100
numSamps = 1000
data = np.random.random((N, 1, numSamps)).astype('float32')
labels = np.random.random((N, 1)).astype('float32')

val_data = np.random.random((N, 1, numSamps)).astype('float32')
val_labels = np.random.random((N, 1)).astype('float32')

X = np.linspace(-100, 100+1, numSamps)

for i in range(data.shape[0]):
    pure0 = np.sin(0.5*X)
    pure1 = np.sin(1*X)
    pure2 = np.sin(5*X)
    noise = np.random.normal(0, 1, numSamps)
    sig = np.zeros(X.shape)
    # pick 2 divider points
    a = np.random.randint(N/5, numSamps/2+1)
    b = np.random.randint(a+N/5, 2*numSamps/3+1)
    if i <= data.shape[0]/2:        
        sig[:a] = pure0[:a]
        sig[a:b] = pure1[a:b]
        sig[b:] = pure2[b:]
        label = 0
    else:
        sig[:a] = pure2[:a]
        sig[a:b] = pure1[a:b]
        sig[b:] = pure0[b:]      
        label = 1
    sig = sig + noise
    data[i,:,:] = sig
    labels[i] = label
# generat val data  
for i in range(val_data.shape[0]):
    pure0 = np.sin(0.5*X)
    pure1 = np.sin(1*X)
    pure2 = np.sin(5*X)
    noise = np.random.normal(0, 1, numSamps)
    sig = np.zeros(X.shape)
    # pick 2 divider points
    a = np.random.randint(0, numSamps/2)
    b = np.random.randint(a, numSamps+1)
    if i <= val_data.shape[0]/2:        
        sig[:a] = pure0[:a]
        sig[a:b] = pure1[a:b]
        sig[b:] = pure2[b:]
        label = 0
    else:
        sig[:a] = pure2[:a]
        sig[a:b] = pure1[a:b]
        sig[b:] = pure0[b:]      
        label = 1
    sig = sig + noise
    val_data[i,:,:] = sig
    val_labels[i] = label

print('data_scales = {:.2f}, {:.2f}, {:.2f}'.format(2.*np.pi/0.5, 2.*np.pi/1., 2.*np.pi/5.))

class Pos(constraints.Constraint):
    '''Constrain the weights to be strictly positive
    '''
    def __call__(self, p):
        p = p * tf.cast(p > 0., tf.float32)
        return p

class WaveletDeconvolution(layers.Layer):
    '''
    Deconvolutions of 1D signals using wavelets
    When using this layer as the first layer in a model,
    provide the keyword argument `input_shape`  as a
    (tuple of integers, e.g. (10, 128) for sequences
    of 10 vectors with dimension 128).

    # Example
    ```python
        # apply a set of 5 wavelet deconv widthss to a sequence of 32 vectors with 10 timesteps
        model = Sequential()
        model.add(WaveletDeconvolution(5, padding='same', input_shape=(32, 10)))
        # now model.output_shape == (None, 32, 10, 5)
        # add a new conv2d on top
        model.add(Convolution2D(64, 3, 3, padding='same'))
        # now model.output_shape == (None, 64, 10, 5)
# Arguments
    nb_widths: Number of wavelet kernels to use
        (dimensionality of the output).
    kernel_length: The length of the wavelet kernels            
    init: Locked to didactic set of widths ([1, 2, 4, 8, 16, ...]) 
        name of initialization function for the weights of the layer
        (see [initializers](../initializers.md)),
        or alternatively, a function to use for weights initialization.
        This parameter is only relevant if you don't pass a `weights` argument.
    activation: name of activation function to use
        ( or alternatively, an elementwise function.)
        If you don't specify anything, no activation is applied
        (ie. "linear" activation: a(x) = x).
    weights: list of numpy arrays to set as initial weights.
    padding: one of `"valid"` or `"same"` (case-insensitive).
    strides: An integer or tuple/list of 2 integers,
        specifying the strides of the convolution
        along the height and width.
        Can be a single integer to specify the same value for
        all spatial dimensions.
    data_format: A string,
        one of `"channels_last"` or `"channels_first"`.
        The ordering of the dimensions in the inputs.
        `"channels_last"` corresponds to inputs with shape
        `(batch, height, width, channels)` while `"channels_first"`
        corresponds to inputs with shape
        `(batch, channels, height, width)`.
        It defaults to the `image_data_format` value found in your
        Keras config file at `~/.keras/keras.json`.
        If you never set it, then it will be "channels_last".
    use_bias: Boolean, whether the layer uses a bias vector.
    kernel_regularizer: Regularizer function applied to
        the `kernel` weights matrix
    bias_regularizer: Regularizer function applied to the bias vector
    activity_regularizer: Regularizer function applied to
        the output of the layer (its "activation").
    kernel_constraint: Constraint function applied to the kernel matrix
    bias_constraint: Constraint function applied to the bias vector

# Input shape
    if data_format is 'channels_first' then
        3D tensor with shape: `(batch_samples, input_dim, steps)`.
    if data_format is 'channels_last' then
        3D tensor with shape: `(batch_samples, steps, input_dim)`.

# Output shape
    if data_format is 'channels_first' then
        4D tensor with shape: `(batch_samples, input_dim, new_steps, nb_widths)`.
        `steps` value might have changed due to padding.
    if data_format is 'channels_last' then
        4D tensor with shape: `(batch_samples, new_steps, nb_widths, input_dim)`.
        `steps` value might have changed due to padding.
'''

def __init__(self, nb_widths, kernel_length=100,
             init='uniform', activation='linear', weights=None,
             padding='same', strides=1, data_format='channels_last', use_bias=True,
             kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None,
             kernel_constraint=None, bias_constraint=None,
             input_shape=None, **kwargs):

    if padding.lower() not in {'valid', 'same'}:
        raise Exception('Invalid border mode for WaveletDeconvolution:', padding)
    if data_format.lower() not in {'channels_first', 'channels_last'}:
        raise Exception('Invalid data format for WaveletDeconvolution:', data_format)
    self.nb_widths = nb_widths
    self.kernel_length = kernel_length
    self.init = self.didactic #initializers.get(init, data_format='channels_first')
    self.activation = activations.get(activation)
    self.padding = padding
    self.strides = strides

    self.subsample = (strides, 1)

    self.data_format = data_format.lower()

    self.kernel_regularizer = regularizers.get(kernel_regularizer)
    self.bias_regularizer = regularizers.get(bias_regularizer)
    self.activity_regularizer = regularizers.get(activity_regularizer)

    self.kernel_constraint = Pos()
    self.bias_constraint = constraints.get(bias_constraint)

    self.use_bias = use_bias
    self.initial_weights = weights
    super(WaveletDeconvolution, self).__init__(**kwargs)

def build(self, input_shape):
    # get dimension and length of input
    if self.data_format == 'channels_first':
        self.input_dim = input_shape[1]
        self.input_length = input_shape[2]
    else:
        self.input_dim = input_shape[2]
        self.input_length = input_shape[1]
    # initialize and define wavelet widths
    self.W_shape = (self.nb_widths)
    # self.W = self.init(self.W_shape, name='{}_W'.format(self.name))
    # self.trainable_weights = [self.W]?
    # Constant(2.**np.arange(self.nb_widths)
    # Constant([1., 5., 12.]
    self.W = self.add_weight(shape = self.W_shape, 
                             name = 'W',
                             initializer = Constant([1., 4., 10.]),
                             constraint = Pos())

    super(WaveletDeconvolution, self).build(input_shape)

def call(self, x, mask=None):
    # shape of x is (batches, input_dim, input_len) if 'channels_first'
    # shape of x is (batches, input_len, input_dim) if 'channels_last'
    # we reshape x to channels first for computation
    if self.data_format == 'channels_last':
        x = tf.transpose(x, (0, 2, 1))

    #x = K.expand_dims(x, 2)  # add a dummy dimension for # rows in "image", now shape = (batches, input_dim, input_len, 1)

    # build the kernels to convolve each input signal with
    kernel_length = self.kernel_length
    T = (np.arange(0,kernel_length) - (kernel_length-1.0)/2).astype('float32')
    T2 = T**2
    # helper function to generate wavelet kernel for a given width
    # this generates the Mexican hat or Ricker wavelet. Can be replaced with other wavelet functions.
    def gen_kernel(w):
        w2 = w**2
        B = (3 * w)**0.5
        A = (2 / (B * (np.pi**0.25)))
        mod = (1 - (T2)/(w2))
        gauss = tf.exp(-(T2) / (2 * (w2)))
        kern = A * mod * gauss
        kern = tf.reshape(kern, [kernel_length, 1])
        return kern
    wav_kernels = []
    for i in range(self.nb_widths):
        kernel = gen_kernel(self.W[i])
        wav_kernels.append(kernel)
    wav_kernels = tf.stack(wav_kernels, axis=0)
    # kernel, _ = tf.map_fn(fn=gen_kernel, elems=self.W)
    wav_kernels = tf.expand_dims(wav_kernels, 0)
    wav_kernels = tf.transpose(wav_kernels,(0, 2, 3, 1))               

    # reshape input so number of dimensions is first (before batch dim)
    x = tf.transpose(x, (1, 0, 2))
    def gen_conv(x_slice):
        x_slice = tf.expand_dims(x_slice,1) # shape (num_batches, 1, input_length)
        x_slice = tf.expand_dims(x_slice,2) # shape (num_batches, 1, 1, input_length)
        return tf.nn.conv2d(x_slice, wav_kernels, strides=self.subsample, padding=self.padding, data_format='NCHW')
    outputs = []
    for i in range(self.input_dim):
        output = gen_conv(x[i,:,:])
        outputs.append(output)
    outputs = tf.stack(outputs, axis=0)
    # output, _ = tf.map_fn(fn=gen_conv, elems=x)
    outputs = tf.squeeze(outputs, 3)
    outputs = tf.transpose(outputs, (1, 0, 3, 2))
    if self.data_format == 'channels_last':
        outputs = tf.transpose(outputs,(0, 2, 3, 1))
    return outputs

# def compute_output_shape(self, input_shape):
#     out_length = conv_utils.conv_output_length(input_shape[2], 
#                                                self.kernel_length, 
#                                                self.padding, 
#                                                self.strides)        
#     return (input_shape[0], self.input_dim, out_length, self.nb_widths)

def get_config(self):
    config = {'nb_widths': self.nb_widths,
              'kernel_length': self.kernel_length,
              'init': self.init.__name__,
              'activation': self.activation.__name__,
              'padding': self.padding,
              'strides': self.strides,
              'data_format': self.data_format,
              'kernel_regularizer': self.kernel_regularizer.get_config() if self.kernel_regularizer else None,
              'bias_regularizer': self.bias_regularizer.get_config() if self.bias_regularizer else None,
              'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None,
              'kernel_constraint': self.kernel_constraint.get_config() if self.kernel_constraint else None,
              'bias_constraint': self.bias_constraint.get_config() if self.bias_constraint else None,
              'use_bias': self.use_bias}
    base_config = super(WaveletDeconvolution, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))   

def didactic(self, shape, name=None):
    x = 2**np.arange(shape).astype('float32')
    return tf.Variable(initial_value=x, name=name)

inp_shape = data.shape[1:] model = Sequential() model.add(WaveletDeconvolution(3, kernel_length=500, input_shape=inp_shape, padding='SAME', data_format='channels_first')) model.add(Activation('tanh')) # (batch, 1, len=1000, 5) model.add(MaxPool2D((1,2)))

model.add(Convolution2D(3, (3, 3), padding='same')) model.add(Activation('relu')) model.add(MaxPool2D((1,2)))

model.add(Convolution2D(3, (3, 3), padding='same')) model.add(Activation('relu')) model.add(MaxPool2D((1,2)))

model.add(Convolution2D(3, (3, 3), padding='same')) model.add(Activation('relu')) model.add(MaxPool2D((1,2)))

end convolutional layers

model.add(Flatten()) model.add(Dense(25, kernel_initializer=VarianceScaling(mode='fan_avg', distribution='uniform'))) model.add(Activation('relu'))

model.add(Dense(1, kernel_initializer=VarianceScaling(mode='fan_avg', distribution='uniform'))) model.add(Activation('sigmoid'))

optimizer_0 = tf.keras.optimizers.Adam(learning_rate=10.**-3) model.compile(optimizer=optimizer_0, loss='binary_crossentropy')

num_epochs = 25 plt.figure(figsize=(6,6)) Widths = np.zeros((num_epochs, 3)).astype('float32') for i in range(num_epochs): hWD = model.fit(data, labels, epochs=1, batch_size=4, validation_data=(val_data, val_labels), verbose=0)

print('Epoch %3d | train_loss: %.4f | val_loss: %.4f' % (i+1, hWD.history['loss'][0], hWD.history['val_loss'][0]))

Widths[i,:] = model.layers[0].weights[0].numpy()
plt.plot(i, hWD.history['loss'][0], 'k.')
plt.plot(i, hWD.history['val_loss'][0], 'r.')

plt.figure(figsize=(6,6)) for i in range(Widths.shape[1]): plt.plot(range(num_epochs), Widths[:,i])

plt.show()



![image](https://user-images.githubusercontent.com/35273269/91406467-2ad8a180-e7f7-11ea-9b8e-9c64032744b9.png)
![image](https://user-images.githubusercontent.com/35273269/91406534-2f9d5580-e7f7-11ea-925e-ed3c17cf8718.png)
haidark commented 4 years ago

Hi @qideng7, it doesn't look like that is the same architecture I was using, can you try with the NN architecture in https://github.com/haidark/WaveletDeconv/blob/master/testWD.py?

thanks for reaching out.

qideng7 commented 4 years ago

Hi @haidark, thanks for replying fast. Here is implementation of the architecture you mentioned. Also I think there might be 2 typos in https://github.com/haidark/WaveletDeconv/blob/master/testWD.py ? Correct me if I make a mistake here:

  1. In generating data part, scales used for training and validation data are different, (0.5x, 1x, 2x) vs (0.5x, 1x, 5x).
  2. In network architecture, number of filters in deconv layer is set to 5 but only 3 scales are used to generate data.

So when I rerun it, I changed scales to be same in training and validation generating : (0.5x, 1x, 5*x), which indicates the scales should be

print('data_scales = {:.2f}, {:.2f}, {:.2f}'.format(2.*np.pi/0.5, 2.*np.pi/1., 2.*np.pi/5.))
data_scales = 12.57, 6.28, 1.26

Also, I changed the number of filters to 3 in NN. NN implemented: (filter width initialized as (1., 4., 10.))

inp_shape = data.shape[1:]
model = Sequential()
model.add(WaveletDeconvolution(3, kernel_length=500, input_shape=inp_shape, padding='SAME', data_format='channels_first'))
model.add(Activation('tanh'))
model.add(Convolution2D(3, (3, 3), padding='same'))
model.add(Activation('relu'))
#end convolutional layers
model.add(Flatten())
model.add(Dense(25))
model.add(Activation('relu'))
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(optimizer='sgd', loss='binary_crossentropy')

Results: image image

learned scales:

print(model.layers[0].W.numpy())
[ 1.033309   4.0004635 10.006781 ]