chrisdonahue / wavegan

WaveGAN: Learn to synthesize raw audio with generative adversarial networks
MIT License
1.33k stars 281 forks source link

Rewriting your code in Keras #37

Closed redraven984 closed 5 years ago

redraven984 commented 5 years ago

Hi Chris, I'm trying to rewrite waveGAN in keras. Can you tell me where I might be going wrong with the dimensions here?

Generator:

def defineGen(Gin, d = 1, lr = 1e-3):

    shapes = [d*x for x in [256,16,8,4,2,1]]

    x = Dense(shapes[0])(Gin)
    x = Reshape((1,16,16))(x)
    x = Activation('relu')(x)

    x = Conv2DTranspose(25,(shapes[1],shapes[2]),padding='same')(x)
    x = Activation('relu')(x)

    x = Conv2DTranspose(25,(shapes[2],shapes[3]),padding='same')(x)
    x = Activation('relu')(x)

    x = Conv2DTranspose(25,(shapes[3],shapes[4]),padding='same')(x)
    x = Activation('relu')(x)

    x = Conv2DTranspose(25,(shapes[4],shapes[5]),padding='same')(x)
    x = Activation('relu')(x)

    x = Conv2DTranspose(25,(shapes[5],1),padding='same')(x)
    G_out = Activation('tanh')(x)

    G = Model(inputs=[Gin],outputs=G_out)
    optimizer = SGD(lr =lr)

    G.compile(loss = 'binary_crossentropy',optimizer=optimizer)

    return G, G_out

G_in1 = Input(shape=[None,100])
G, G_out = defineGen(G_in1)
G.summary()

Discriminator:

def defineDisc(Din, d = 1, lr = 1e-3):
    shapes = [d*x for x in [1,2,4,8,16]]

    x = Conv1D(25,kernel_size=(shapes[0]))(Din)
    x = LeakyReLU(alpha=0.1)(x)

    # phase shuffle - not implemented yet

    x = Conv1D(25,(shapes[1]),strides=4)(x)
    x = LeakyReLU(alpha=0.1)(x)

    # phase shuffle - not implemented yet

    x = Conv1D(25,(shapes[2]),strides=4)(x)
    x = LeakyReLU(alpha=0.1)(x)

    # phase shuffle - not implemented yet

    x = Conv1D(25,(shapes[3]),strides=4)(x)
    x = LeakyReLU(alpha=0.1)(x)

    # phase shuffle - not implemented yet

    x = Conv1D(25,(shapes[4]),strides=4)(x)
    x = LeakyReLU(alpha=0.1)(x)

    x = Reshape((256))(x)

    Dout = Dense(256)(x)

    D = Model(inputs=[Din],outputs = Dout)
    D.compile(loss="binary_crossentropy", optimizer=dopt)

    return D, Dout

Din = Input(shape=[16384])
D, D_out = defineDisc(Din)
D.summary()
chrisdonahue commented 5 years ago

Going to need more information. What's the error you're experiencing?