Trying to optimize a model that contains BatchNormalization Layers inside TensorNodes results in an error. E.g.,
import tensorflow as tf
import nengo
import nengo_dl
import numpy as np
with nengo.Network() as net:
a = nengo.Node([0])
b = nengo_dl.Layer(tf.keras.layers.BatchNormalization())(a)
p = nengo.Probe(b)
with nengo_dl.Simulator(net) as sim:
sim.compile(optimizer=tf.optimizers.SGD(0), loss=tf.losses.mse)
sim.fit(np.ones((1, 1, 1)), np.ones((1, 1, 1)))
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{node training_1/group_deps}} has inputs from different frames. The input {{node TensorGraph/while/iteration_0/SimTensorNodeBuilder/cond_2/Merge}} is in frame 'TensorGraph/while/while_context'. The input {{node loss/mul}} is in frame ''.
I'd guess that using BatchNormalization layers inside any TensorFlow while loop results in the same error, but haven't looked into making a minimal example yet.
Trying to optimize a model that contains BatchNormalization Layers inside TensorNodes results in an error. E.g.,
I'd guess that using BatchNormalization layers inside any TensorFlow while loop results in the same error, but haven't looked into making a minimal example yet.