tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.27k stars 1.1k forks source link

tfp.stats.histogram cannot be compiled by XLA? #1758

Open yellowdolphin opened 1 year ago

yellowdolphin commented 1 year ago

Summary of problem

Applying tfp.stats.histogram on the data in a tf keras model breaks XLA compilation. The example code (see below) works on CPU/GPU but with TPU strategy raises:

InvalidArgumentError: 9 root error(s) found.
  (0) INVALID_ARGUMENT: {{function_node __inference_train_function_3630}} Input 1 to node `sequential/lambda/histogram/count_integers/map/while/bincount/Bincount` with op Bincount must be a compile-time constant.

XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator.

     [[{{node sequential/lambda/histogram/count_integers/map/while/bincount/Bincount}}]]

     [[sequential/lambda/histogram/count_integers/map/while]]
     [[TPUReplicate/_compile/_10395939635067805685/_4]]

Reproducible example

https://colab.research.google.com/drive/1g9yHihhmcAcwEeE80wWPwyI8W6D6BGfx?usp=sharing

jonas-eschle commented 7 months ago

Hi @yellowdolphin , did you try to dig in a bit? This would be very helpful! As the error suggests, it looks to me that count_integers and inside that bincount creates some troubles, can you compile these separately?