dorarad / gansformer

Generative Adversarial Transformers
MIT License
1.33k stars 149 forks source link

Conditional training #12

Closed chokyungjin closed 3 years ago

chokyungjin commented 3 years ago

Hi, I want to give a condition when training a model, but I don't know where to start.

How Can I Solve this issue?

dorarad commented 3 years ago

Hi! :) So sorry for the large delay in my response, I hope to get back to you in about 2 days at most, will go then over all open issues!

chokyungjin commented 3 years ago

Okay, I'll wait.

dorarad commented 3 years ago

Hi! In order to provide a condition to the generator you need to add a variable here similarly to the labels_in (structure: [batch_size, label_size] so that label_size is the size of the condition) at https://github.com/dorarad/gansformer/blob/main/training/network.py#L802.

In the current flow of the code, this condition is:

  1. passed to G_mapping which concatenates that information to the normally sampled latents z to turn them into new latents w (https://github.com/dorarad/gansformer/blob/main/training/network.py#L1036)

  2. These w (called intermediate space, also called dlatents in the code) are then used in G_synthesis to guide the generation process (by changing the scales and biases of the features of the image grid as it's being upsampled from the initial 44 map up to the final resolutionresolution (https://github.com/dorarad/gansformer/blob/main/training/network.py#L1429))

In order to feed in such conditions e.g. from a file, you'll need to add a function here similarly to get_random_labels_tf at https://github.com/dorarad/gansformer/blob/main/training/dataset.py#L169.

I'm actually actively working on extending the model for conditional generation and plan to release to release it in about a month with data and pretrained models so that may be of interest to you!

Please let me know if you have any further questions!

chokyungjin commented 3 years ago

Thanks, but I was conditioning the label in a numerical method, not one-hot vector, before you commented. Here This Code. (https://github.com/dorarad/gansformer/blob/main/training/network.py#L1036) I changed like this.

if label_size:
        with tf.variable_scope("LabelConcat"):
            w = tf.get_variable("weight", shape = [label_size, latent_size], initializer = tf.initializers.random_normal())
            l = tf.tile(tf.expand_dims(tf.matmul(labels_in, w), axis = 1), (1, latents_num, 1))
            x = tf.add(x, l)

It seems to me that it works, but how do I give conditions in the generate.py? I'm waiting for your conditional generation code release!

Thanks again.

dorarad commented 3 years ago

So is the change that you add the label embeddings instead of concatenating them? Yea that should work totally fine too! In order to feed in labels_in in generate.py you need to simply feed in one more argument, changing https://github.com/dorarad/gansformer/blob/main/generate.py#L27 to:

images = Gs.run(latents, labels, truncation_psi = truncation_psi, minibatch_size = batch_size, verbose = True)[0]

where labels are your input condition. No further changes are needed for that case (the run function support variable number of inputs: if you provide less then when the model expects, the run function will feed the missing inputs with zeros; if you provide more inputs than expected, it will simply ignore the extra ones).

Hope it helps! I will update on this thread as soon as the new code is released!

chokyungjin commented 3 years ago

I've tried similar things and I got the following error:

def run(model, gpus, output_dir, images_num, truncation_psi, batch_size, ratio):
    print("Loading networks...")
    os.environ["CUDA_VISIBLE_DEVICES"] = gpus                   # Set GPUs
    tflib.init_tf()                                             # Initialize TensorFlow
    G, D, Gs = load_networks(model)                             # Load pre-trained network
    Gs.print_layers()                                           # Print network details
    print("Generate images...")
    latents = np.random.randn(images_num, *Gs.input_shape[1:])  # Sample latent vectors
    labels = '0'
    images = Gs.run(latents, labels, truncation_psi = truncation_psi, minibatch_size = batch_size, verbose = True)[0]

My Gs in model layer early layer is : Gs Params OutputShape WeightShape


ltnt_emb/emb 512 (16, 32) (16, 32) G_mapping/LabelConcat 64 (?, 17, 32) (2, 32)

And my error is this.

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 1365, in _do_call
    return fn(*args)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 1350, in _run_fn
    target_list, run_metadata)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 1443, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: -input rank(-0) <= split_dim < input rank (0), but got 0
         [[{{node Gs/_Run/split_1}}]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "generate.py", line 49, in <module>
    main()
  File "generate.py", line 46, in main
    run(**vars(args))
  File "generate.py", line 27, in run
    images = Gs.run(latents, labels, truncation_psi = truncation_psi, minibatch_size = batch_size, verbose = True)[0]
  File "/project/gansformer/dnnlib/tflib/network.py", line 490, in run
    mb_out = tf.get_default_session().run(out_expr, dict(zip(in_expr, mb_in)))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 956, in run
    run_metadata_ptr)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 1180, in _run
    feed_dict_tensor, options, run_metadata)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 1359, in _do_run
    run_metadata)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py", line 1384, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: -input rank(-0) <= split_dim < input rank (0), but got 0
         [[node Gs/_Run/split_1 (defined at /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py:1748) ]]

Original stack trace for 'Gs/_Run/split_1':
  File "generate.py", line 49, in <module>
    main()
  File "generate.py", line 46, in main
    run(**vars(args))
  File "generate.py", line 27, in run
    images = Gs.run(latents, labels, truncation_psi = truncation_psi, minibatch_size = batch_size, verbose = True)[0]
  File "/project/gansformer/dnnlib/tflib/network.py", line 447, in run
    in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
  File "/project/gansformer/dnnlib/tflib/network.py", line 447, in <listcomp>
    in_split = list(zip(*[tf.split(x, num_gpus) for x in in_expr]))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/array_ops.py", line 1684, in split
    axis=axis, num_split=num_or_size_splits, value=value, name=name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/gen_array_ops.py", line 9898, in split
    "Split", split_dim=axis, value=value, num_split=num_split, name=name)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/op_def_library.py", line 794, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py", line 3357, in create_op
    attrs, op_def, compute_device)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py", line 3426, in _create_op_internal
    op_def=op_def)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py", line 1748, in __init__
    self._traceback = tf_stack.extract_stack()
dorarad commented 3 years ago

I believe that the labels are initialized incorrectly. Try the following: labels = np.zeros((images_num, label_size)) where label_size is the value that you used in your new generator model (the pretrained models in the repository have label_size = 0 as they are unconditional).

chokyungjin commented 3 years ago

Thanks! one more question, what is difference G and Gs network? your comment is G: generator, D: discriminator, Gs: generator moving-average (higher quality images), but I can't understand.

dorarad commented 3 years ago

Let's say we trained a model, then the weights of the generator get updated. G has the most updated values of the weights. However, there is a useful technique in deep learning https://paperswithcode.com/method/polyak-averaging to store also an exponential moving average (EMA) over the weights as they get updated during the training. Gs stores these EMA values, and is thus expected to give back better results.

chokyungjin commented 3 years ago

So does the Gs always return better results? img1 = Gs.run(latents, labels, truncation_psi = truncation_psi, minibatch_size = batch_size, verbose = True)[0] img2 = G.run(latents, labels, randomize_noise = False, minibatch_size = batch_size, return_dlatents = True)[0]

I think img2 results look better than img1 in my model and dataset.

dorarad commented 3 years ago

That's not a strict rule but usually Gs images are likely to be better than G!

chokyungjin commented 3 years ago

Thanks a million :)

chokyungjin commented 3 years ago

Oh,, one more thing. How can I training duplex transformer? I can not find to setting duplex args parameter.

dorarad commented 3 years ago

duplex consists of two aspects: --kmeans and --g-img2ltnt. However in the new refactored code I still noticed some training issues with --g-img2ltnt so in the meantime I recommend using the --gansformer-default which gives you the model in a version of simplex + kmeans.

dorarad commented 3 years ago

Please let me know if you have further questions or if I can close the issue!

chokyungjin commented 3 years ago

Yes! I'll be waiting for your one-hot encoding condition code release!