Skip to content
This repository has been archived by the owner on Mar 17, 2021. It is now read-only.

NiftyNet layers are not compatible with tf.cond #413

Closed
danieltudosiu opened this issue Jun 27, 2019 · 1 comment
Closed

NiftyNet layers are not compatible with tf.cond #413

danieltudosiu opened this issue Jun 27, 2019 · 1 comment

Comments

@danieltudosiu
Copy link
Contributor

If you post your question on Stack Overflow, please explain:

  1. What you were trying to do (and why)
    I am trying to create a VAE-GAN where the BatchNormalization layers of the Discriminator are not shared between the Real and Fake images since the batch statistics would get mislead.

  2. What happened (include command output)
    If I try to give to tf.cond's true_fn and false_fn arguments a BNLayer which is callable I get the following error: "TypeError: layer_op() missing 2 required positional arguments: 'inputs' and 'is_training'" which is actually correct since tf.cond states that the return should be "Tensors returned by the call to either true_fn or false_fn. If the callables return a singleton list, the element is extracted from the list." which means that all parameters for the callable part of the function should be passed at the creation of the layer which is not the case in NiftyNet since the creation and calling of the layer are being splint in two.

  3. What you expected to happen
    I would expect that tf.cond to work with BNLayer. But the required syntax would be something similar to this:

input_tensor = tf.cond(
    pred=<tf.bool tensor>,
    true_fn=tf.layers.batch_normalization(...)
    false_fn=tf.layer.batch_normalization(...)
)
  1. Step-by-step reproduction instructions
    Any of the following will throw an error
input_tensor = tf.cond(
    pred=<tf.bool tensor>,
    true_fn=niftynet.layer.bn.BNLayer(...)
    false_fn=niftynet.layer.bn.BNLayer(...)
)

Expected error: TypeError: layer_op() missing 2 required positional arguments: 'inputs' and 'is_training'

input_tensor = tf.cond(
    pred=<tf.bool tensor>,
    true_fn=niftynet.layer.bn.BNLayer(...)(...)
    false_fn=niftynet.layer.bn.BNLayer(...)(...)
)

Expected error: TypeError: true_fn must be callable.

@danieltudosiu
Copy link
Contributor Author

My bad, please delete.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant