Hi! I noticed that `tf.ForwardAccumulator` is currently issuing a warning when it is used to compute a jvp of a real-valued function for complex input. I'm not sure if this is a historic artifact since the result seems correct despite the warning as the code example below demonstrates. I would appreciate if someone could clarify why the warning exists.
The example below implements two Hessian-vector product (hvp) functions, one based on reverse-over-reverse mode using two `tf.GradientTape` instances (i.e., the gradient of the directional derivative), and one based on forward-over-reverse mode using `tf.ForwardAccumulator` and `tf.GradientTape`. The latter issues the warning `WARNING:tensorflow:The dtype of the watched primal must be floating (e.g. tf.float32), got tf.complex64` which originates here: https://github.com/tensorflow/tensorflow/blame/d8ce9f9c301d021a69953134185ab728c1c248d3/tensorflow/python/eager/forwardprop.py#L399-L402.
Standlone code to reproduce the issue
Hi! I noticed that `tf.ForwardAccumulator` is currently issuing a warning when it is used to compute a jvp of a real-valued function for complex input. I'm not sure if this is a historic artifact since the result seems correct despite the warning as the code example below demonstrates. I would appreciate if someone could clarify why the warning exists.
The example below implements two Hessian-vector product (hvp) functions, one based on reverse-over-reverse mode using two `tf.GradientTape` instances (i.e., the gradient of the directional derivative), and one based on forward-over-reverse mode using `tf.ForwardAccumulator` and `tf.GradientTape`. The latter issues the warning `WARNING:tensorflow:The dtype of the watched primal must be floating (e.g. tf.float32), got tf.complex64` which originates here: https://github.com/tensorflow/tensorflow/blame/d8ce9f9c301d021a69953134185ab728c1c248d3/tensorflow/python/eager/forwardprop.py#L399-L402.
Relevant log output
Hi! I noticed that `tf.ForwardAccumulator` is currently issuing a warning when it is used to compute a jvp of a real-valued function for complex input. I'm not sure if this is a historic artifact since the result seems correct despite the warning as the code example below demonstrates. I would appreciate if someone could clarify why the warning exists.
The example below implements two Hessian-vector product (hvp) functions, one based on reverse-over-reverse mode using two `tf.GradientTape` instances (i.e., the gradient of the directional derivative), and one based on forward-over-reverse mode using `tf.ForwardAccumulator` and `tf.GradientTape`. The latter issues the warning `WARNING:tensorflow:The dtype of the watched primal must be floating (e.g. tf.float32), got tf.complex64` which originates here: https://github.com/tensorflow/tensorflow/blame/d8ce9f9c301d021a69953134185ab728c1c248d3/tensorflow/python/eager/forwardprop.py#L399-L402.
Issue Type
Bug
Source
source
Tensorflow Version
2.8
Custom Code
Yes
OS Platform and Distribution
No response
Mobile device
No response
Python version
No response
Bazel version
No response
GCC/Compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current Behaviour?
Standlone code to reproduce the issue
Relevant log output