odlgroup / odl

Operator Discretization Library https://odlgroup.github.io/odl/
Mozilla Public License 2.0
368 stars 105 forks source link

Issues with 'as_tensorflow_layer' #1599

Open jaweriaamjad opened 3 years ago

jaweriaamjad commented 3 years ago

I am implementing a CT reconstruction algorithm using a deep neural network. As part of my algorithm, I need to compute forward mode and reverse mode gradients of the forward operator (ray-transform in this case). I am using the 'as_tensorflow_layer' to make the odl operator as part of the graph.

self.true = tf.placeholder(shape=[self.batch_size, self.image_space[0], self.image_space[1], self.data_pip.colors], dtype=tf.float32)
self.measurement = self.CT.tf_operator(self.true)
self.fbp = self.CT.tf_inverseoperator(self.measurement)

On computing the reverse mode gradients

self.measure_grads = tf.gradients(self.measurement, self.true)

I get the following error:

File "/scratch/uceejam/p3tf1x/lib64/python3.6/site-packages/odl/contrib/tensorflow/layer.py", line 169, in tensorflow_layer_grad_impl
    assert dy_shape[1:] == space_shape(odl_op.range) + (1,)
AssertionError

However,

self.fbp_grads = tf.gradients(self.fbp, self.true)

works fine. I am not sure why doesn't the self.measure_grads work.

I am also unable to compute the forward mode gradients for the odl tensorflow layers. The forward mode gradients can be computed in tensorflow using the following algorithm.

def fwd_gradients(ys, xs, d_xs):
    """Forward-mode pushforward analogous to the pullback defined by tf.gradients.
    With tf.gradients, grad_ys is the tensor being pulled back, and here d_xs is
    the tensor being pushed forward."""
    v = tf.ones(shape=ys.shape, dtype=ys.dtype)
    #v = tf.placeholder_with_default(v0, shape=ys.get_shape())  # dummy variable
    #v = tf.placeholder(ys.dtype, shape=ys.get_shape())
    g = tf.gradients(ys, xs, grad_ys=v)
    return tf.gradients(g, v, grad_ys=d_xs)[0]

It works well for all the other standard tf layers. However for odl tf layers, it give None.

kkkk12123 commented 1 year ago

hi,have you solved it .I got the error too when I tried to apply the odl_op_layer_adjoint to my sinogram and I have no idea how to solve it. image