tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.26k stars 1.1k forks source link

NotImplementedError: The adjoint sensitivity method does not support complex dtypes. #1696

Open sriharikrishna opened 1 year ago

sriharikrishna commented 1 year ago

I am not able to differentiate code that involves a call to tfp.math.ode.DormandPrince() with complex arguments. I get a NotImplementedError. The error can be reproduced using the code below (based on #1372):

import tensorflow_probability as tfp
import tensorflow as tf

t_i = 0.
t_f = 2.

y0 = tf.constant([1.0, 9.], dtype=tf.complex128)
A = tf.constant([[0, 1.0], [- 100.0, 0]], dtype = tf.complex128)

def ode_fn(t, y):
    return tf.linalg.matvec(A, y)

with tf.GradientTape(persistent=True) as tape:
    tape.watch(y0)
    results = tfp.math.ode.DormandPrince().solve(ode_fn, t_i, y0,
                              solution_times = [t_i, t_f])
    y_out = results.states[-1]

jac = tape.jacobian(y_out, y0, experimental_use_pfor=False)

The relevant part of the error message is:

[/usr/local/lib/python3.8/dist-packages/tensorflow_probability/python/math/ode/base.py](https://localhost:8080/#) in error_if_complex(dtype)
    249         def error_if_complex(dtype):
    250           if dtype_util.is_complex(dtype):
--> 251             raise NotImplementedError('The adjoint sensitivity method does '
    252                                       'not support complex dtypes.')
    253 

NotImplementedError: The adjoint sensitivity method does not support complex dtypes.

Will this capability be added soon?

The code works properly when the inputs are:

y0 = tf.constant([1.0, 9.], dtype=tf.float64)
A = tf.constant([[0, 1.0], [- 100.0, 0]], dtype = tf.float64)
dkweiss31 commented 8 months ago

I'm wondering myself about this issue? I would be happy to take a stab at implementing this if experts think it is doable for a relative tensorflow novice