Closed FHof closed 2 years ago
I've tested the tensorflow precision change and it didn't work.
I installed intel_tensorflow 2.6.0, which is tensorflow optimized for intel. The tf.keras.backend.floatx
is ignored when I create a tensor, but the dtype parameter does work.
$ python3 -c 'import tensorflow as tf; tf.keras.backend.set_floatx("float64"); print(tf.keras.backend.floatx(), tf.constant(1.0))'
2021-10-28 13:01:00.818762: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-10-28 13:01:00.819572: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.
float64 tf.Tensor(1.0, shape=(), dtype=float32)
$ python3 -c 'import tensorflow as tf; print(tf.constant(1.0, dtype=tf.float64))'
2021-10-28 13:00:06.997738: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-10-28 13:00:06.998519: I tensorflow/core/common_runtime/process_util.cc:146] Creating new thread pool with default inter op setting: 2. Tune using inter_op_parallelism_threads for best performance.
tf.Tensor(1.0, shape=(), dtype=float64)
It also doesn't work with tensorflow 2.4.1 (tensorflow-gpu from conda-forge).
I enabled Github Issues in the repository setting.
I changed the set_precision docstring: it now has links to issues about global floating point precision in Numpy and Tensorflow.
The developers of Numpy are against a global dtype and in Tensorflow global dtype configuration is not yet implemented.
tf.keras.backend.floatx
, which I tried previously, only affects keras and not tensorflow.
I indirectly tested JAX global precision with GPU via the tests from #10, so set_precision with JAX should works on CPU and GPU.
Can I merge this?
Yes, _set_precision_torch should be private.
I moved the imports into the functions so that set_precision can work if torch is not installed.
~I haven't tested if it works for the tensorflow backend and how it behaves with GPU support.~