NifTK / NiftyNet

[unmaintained] An open-source convolutional neural networks platform for research in medical image analysis and image-guided therapy
http://niftynet.io
Apache License 2.0
1.36k stars 403 forks source link

NiftyNet layers are not compatible with tf.cond #413

Closed danieltudosiu closed 5 years ago

danieltudosiu commented 5 years ago

If you post your question on Stack Overflow, please explain:

  1. What you were trying to do (and why) I am trying to create a VAE-GAN where the BatchNormalization layers of the Discriminator are not shared between the Real and Fake images since the batch statistics would get mislead.

  2. What happened (include command output) If I try to give to tf.cond's true_fn and false_fn arguments a BNLayer which is callable I get the following error: "_TypeError: layer_op() missing 2 required positional arguments: 'inputs' and 'istraining'" which is actually correct since tf.cond states that the return should be "_Tensors returned by the call to either true_fn or falsefn. If the callables return a singleton list, the element is extracted from the list." which means that all parameters for the callable part of the function should be passed at the creation of the layer which is not the case in NiftyNet since the creation and calling of the layer are being splint in two.

  3. What you expected to happen I would expect that tf.cond to work with BNLayer. But the required syntax would be something similar to this:

    input_tensor = tf.cond(
    pred=<tf.bool tensor>,
    true_fn=tf.layers.batch_normalization(...)
    false_fn=tf.layer.batch_normalization(...)
    )
  4. Step-by-step reproduction instructions Any of the following will throw an error

    input_tensor = tf.cond(
    pred=<tf.bool tensor>,
    true_fn=niftynet.layer.bn.BNLayer(...)
    false_fn=niftynet.layer.bn.BNLayer(...)
    )

    Expected error: TypeError: layer_op() missing 2 required positional arguments: 'inputs' and 'is_training'

input_tensor = tf.cond(
    pred=<tf.bool tensor>,
    true_fn=niftynet.layer.bn.BNLayer(...)(...)
    false_fn=niftynet.layer.bn.BNLayer(...)(...)
)

Expected error: TypeError: true_fn must be callable.

danieltudosiu commented 5 years ago

My bad, please delete.