rwth-i6 / returnn

The RWTH extensible training framework for universal recurrent neural networks
http://returnn.readthedocs.io/
Other
348 stars 130 forks source link

GAN based training #229

Closed sachin-singh-12 closed 4 years ago

sachin-singh-12 commented 4 years ago

Hi, I wanted to train a network in GAN fashion in RETURNN i.e. there will be two optimizers discriminator_opt and generator_opt and run them over alternating steps. However, after digging through the complete pipeline in RETURNN, I found it will require some major modifications in TFEngine and other files. I was wondering is there any neater way for achieving such training.

albertz commented 4 years ago

Why do you think there are major modifications needed? All you usually need is to define the losses as usual, and flip the gradient at some point, or not? There is not a specific layer for that currently, but you could just use EvalLayer and set "eval": "flip_gradient(source(0))". What else do you need?

albertz commented 4 years ago

To be a bit more specific: You define a discriminator network D, and use it two times (e.g. you can define it two times and use param sharing, or use the extra-net mechanism, or so). Once it gets the input from the generator, and there you put the flip gradient in between. And the second time, it gets the real input in.

sachin-singh-12 commented 4 years ago

Thanks for the quick response @albertz, however, I think what you are suggesting is along the lines of Domain-Adversarial Neural Network (DANN) by Ganin et al. where both the generator and the domain classifier are updated simultaneously in each training step, whereas in GANs the two updates are mutually exclusive i.e. GANs are generally implemented with two separate optimizers one for each network

albertz commented 4 years ago

Whether they are trained simultaneously, or always alternating between one loss and the other loss, or any other scheme, this is just a detail. You can adapt to any scheme you want. This is up to you. There is also the question if you want to alternate between a mini-batch (train step), or alternate for individual sequences (but then in a single mini-batch, you could still have both losses together). If this scheduling/alternating should be on sequence level, then you just have to define the dataset accordingly. If this should be a scheduling per mini-batch/step, then you could do that as part of the network logic, e.g. using CondLayer, or so. (There is always only a single session.run call, via TFEngine in Returnn. But this is not really a restriction. That just means that the logic has to be part of the TF computation graph.)

sachin-singh-12 commented 4 years ago

@albertz like you suggested there can be two ways to update - One where we update the generator and the discriminator in alternate training steps. Second, where we update both networks in the same training step however each network is updated keeping the other constant like this one - https://www.tensorflow.org/tutorials/generative/dcgan

  1. For the first case, one approach can be to make the generator and the discriminator networks conditionally trainable and untrainable i.e. in even step make generator(trainable = true), discriminator(trainable = False) and vice versa. another approach can be to use CondLayer with true creating a subnetwork where generator(trainable = True), discriminator(trainable = False) and false creating subnetwork with reverse values. What's your comment on these two approaches?
  2. For the second case generally, we use two separate optimizers one for each network, so I couldn't get how modifying dataset will help.
albertz commented 4 years ago
  1. Trainable is a bool which is fixed at static graph construction time, thus you cannot later change that, thus that is not an option. But there are many ways how you could put the logic into the computation graph, like using tf.cond (CondLayer in Returnn), or dynamically stopping the gradient (extend tf.stop_gradient or TFUtil.flip_gradient) and switching the loss, etc...

  2. There is still just a single optimizer, but you just define the loss in the usual way, and have a flip_gradient in there between discriminator and generator. On the dataset side, you define how a single sequence/instance should look like. The batch is build from sequences. E.g. a sequence could be a triple (noise_input (for generator), real_input), but it also could be just noise_input or just real_input. (Add some classes if you want to make it conditional.)

I don't think it matters technically that you have only a single optimizer. Or why do you think so? Because of shared momentum? If you update them separately, i.e. only the generator or only the discriminator, there is no difference. Or if you update them together (and using flip_gradient), would that not actually be a good thing, to share the momentum? I could imagine it would be much more unstable otherwise (with two separate optimizers). But in any case, is that so important?

Maybe, to make our discussion a bit more specific, can you point out a specific paper which you want to implement, and maybe also an existing example implementation?

sachin-singh-12 commented 4 years ago

I am trying to implement a variant of a vanilla DCGAN like this one https://www.tensorflow.org/tutorials/generative/dcgan Though, with a little modification, I was able to implement it in RETURNN, below is a sample code for the same (not exact). Thanks for all the pointers.

from __future__ import division
import logging
logging.getLogger('tensorflow').disabled = True
import tensorflow as tf
import sys
sys.path += ["."]  # Python 3 hack
from nose.tools import assert_equal, assert_is_instance
import contextlib
import unittest
import numpy.testing
from pprint import pprint
import better_exchook
better_exchook.replace_traceback_format_tb()

from Config import Config
from TFNetwork import *
from TFNetworkLayer import *
from Log import log
import TFUtil
TFUtil.debug_register_better_repr()
log.initialize(verbosity=[5])

n_batch, n_time, n_in, n_out = 3, 7, 11, 13
config = Config({
"extern_data": {
  "data": {"dim": n_in},
  "classes": {"dim": n_out, "sparse": True},
},
"debug_print_layer_output_template": True,
})
rnd = numpy.random.RandomState(42)

net = TFNetwork(config=config, train_flag=True)
net.construct_from_dict({
      "src": {"class": "linear", "activation": "tanh", "n_out": 13, "loss":"ce", "target": "classes"},
      "output": {"class": "softmax", "from": "src", "loss": "ce", "target": "classes"}
})

src_vars = list(net.layers['src'].params.values())
output_vars = list(net.layers['output'].params.values())

from TFUpdater import Updater
net.get_total_loss()
src_updater = Updater(config=config, network=net, initial_learning_rate=0.1,loss_layer='src')
output_updater = Updater(config=config, network=net, initial_learning_rate=0.1,loss_layer='output')

session = tf.Session()
src_updater.set_trainable_vars(src_vars)
src_updater.init_optimizer_vars(session)
src_updater.set_learning_rate(value=src_updater.initial_learning_rate, session=session)
output_updater.set_trainable_vars(output_vars)
output_updater.init_optimizer_vars(session)
output_updater.set_learning_rate(value=output_updater.initial_learning_rate, session=session)
net.initialize_params(session)

in_v = rnd.normal(size=(n_batch, n_time, n_in)).astype("float32")
targets_v = rnd.randint(0, n_out, size=(n_batch, n_time)).astype("int32")
seq_lens_v = numpy.array([n_time, n_time - 1, n_time - 2])

feed_dict = {
  net.extern_data.data["data"].placeholder: in_v,
  net.extern_data.data["data"].size_placeholder[0]: seq_lens_v,
  net.extern_data.data["classes"].placeholder: targets_v,
  net.extern_data.data["classes"].size_placeholder[0]: seq_lens_v,
}
fetches = net.get_fetches_dict(with_summary=True, with_size=True)
src_fetches = fetches.copy()
output_fetches = fetches.copy()
src_fetches["optim_op"] = src_updater.get_optim_op()
output_fetches["optim_op"] = output_updater.get_optim_op()

#training
for i in range(10):
    step = session.run(net.global_train_step)
    src_results = session.run(feed_dict=feed_dict, fetches=src_fetches)
    output_results = session.run(feed_dict=feed_dict, fetches=output_fetches)