Closed lylyjy closed 3 years ago
@tf.function() def train(self, data, log_images=False):
self._strategy.experimental_run_v2(self._train, args=(data, log_images))
You're using a newer version of TF than is listed in the readme. The method is called run() now.
run()
That said, I highly recommend giving DreamerV2 a try. The code base also supports Gaussian latents with a flag.
@tf.function() def train(self, data, log_images=False):
strategy.run