nengo / nengo-dl

Deep learning integration for Nengo
https://www.nengo.ai/nengo-dl
Other
88 stars 22 forks source link

Error when using BatchNormalization Layer in TensorNode #109

Closed drasmuss closed 4 years ago

drasmuss commented 4 years ago

Trying to optimize a model that contains BatchNormalization Layers inside TensorNodes results in an error. E.g.,

import tensorflow as tf
import nengo
import nengo_dl
import numpy as np

with nengo.Network() as net:
    a = nengo.Node([0])
    b = nengo_dl.Layer(tf.keras.layers.BatchNormalization())(a)
    p = nengo.Probe(b)

with nengo_dl.Simulator(net) as sim:
    sim.compile(optimizer=tf.optimizers.SGD(0), loss=tf.losses.mse)
    sim.fit(np.ones((1, 1, 1)), np.ones((1, 1, 1)))

tensorflow.python.framework.errors_impl.InvalidArgumentError: {{node training_1/group_deps}} has inputs from different frames. The input {{node TensorGraph/while/iteration_0/SimTensorNodeBuilder/cond_2/Merge}} is in frame 'TensorGraph/while/while_context'. The input {{node loss/mul}} is in frame ''.

I'd guess that using BatchNormalization layers inside any TensorFlow while loop results in the same error, but haven't looked into making a minimal example yet.