odlgroup / odl

Operator Discretization Library https://odlgroup.github.io/odl/
Mozilla Public License 2.0
368 stars 105 forks source link

Add as_theano_layer #1098

Closed adler-j closed 7 years ago

adler-j commented 7 years ago

Similar to the new functionality as_tensorflow_layer we should do the same for at least theano.

Some short tests show that this works:


class ODLTheanoOp(theano.Op):
    __props__ = ()

    def __init__(self, operator):
        self.operator = operator

    def make_node(self, x):
        x = theano.tensor.as_tensor_variable(x)
        out_type = theano.tensor.TensorType(
            self.operator.range.dtype,
            [False] * len(self.operator.range.shape))
        return theano.Apply(self, [x], [out_type()])

    def perform(self, node, inputs, output_storage):
        x = inputs[0]
        z = output_storage[0]
        z[0] = np.asarray(self.operator(x))

    def infer_shape(self, node, i0_shapes):
        return [self.operator.range.shape]

    def grad(self, inputs, output_grads):
        try:
            dom_weight = self.operator.domain.weighting.const
        except AttributeError:
            dom_weight = 1.0

        try:
            ran_weight = self.operator.range.weighting.const
        except AttributeError:
            ran_weight = 1.0

        scale = dom_weight / ran_weight
        return [ODLTheanoOp(scale * self.operator.adjoint)(output_grads[0])]

    def R_op(self, inputs, eval_points):
        if eval_points[0] is None:
            return eval_points
        return self.grad(inputs, eval_points)

EDIT: Updated code

kohr-h commented 7 years ago

Agreed. I have this lying around here. This works, similarly to your code.

The main problem is that this thing is of limited use, at least as far as I can see currently, in the context of Deep Learning. Theano itself is only responsible for the low-level symbolic/numeric computation stuff and the graph building and optimization. It has no high-level notions and relies on another library blocks that builds on top to deal with layers, inputs, outputs, and so on. Perhaps there's a simpler way to do it, but it seems you need to wrap the Theano op again so blocks understands what it does.

adler-j commented 7 years ago

While I agree that we need to support "more stuff", simply doing this helps for lots of stuff.

Myself, I'm currently looking into Bayesian inverse problems and pymc3 was recommended to me by @olivierverdier. There, in order to do hybrid monte-carlo we need full integration with theano.

With that, we can do bayesian inverse problems quite easily. For example, this "snippet" performs fullybayesian inversion with a TV regularizer.

import pymc3 as pm
import numpy as np
import odl
import tqdm
from util import ODLTheanoOp, MovingAverage, MovingStd
np.random.seed(0)

# Create ODL data structures
size = 128
space = odl.uniform_discr([-64, -64], [64, 64], [size, size],
                          dtype='float64')

geometry = odl.tomo.parallel_beam_geometry(space, num_angles=30)
operator = odl.tomo.RayTransform(space, geometry)
fbp_op = odl.tomo.fbp_op(operator)

true_phantom = odl.phantom.shepp_logan(space, modified=True)
data = operator(true_phantom)
sigma = np.mean(np.abs(data)) * 0.05
noisy_data = data + odl.phantom.white_noise(operator.range) * sigma
observed = noisy_data.asarray()
fbp_recon = fbp_op(noisy_data).asarray()

W = odl.Gradient(space)

callback = (odl.solvers.CallbackShow('current', step=10) &
            odl.solvers.CallbackShow('mean', step=10) * MovingAverage() &
            odl.solvers.CallbackShow('std', step=10) * MovingStd()) * space.element

with pm.Model() as model:
    phantom = pm.Normal('phantom', mu=fbp_recon, sd=1.0,
                        shape=space.shape)

    mu_reg = ODLTheanoOp(W)(phantom)
    regularizer = pm.Laplace('regularizer', mu=mu_reg, b=0.05,
                             observed=0)

    mu = ODLTheanoOp(operator)(phantom)
    obs = pm.Normal('obs', mu=mu, sd=sigma,
                    observed=observed)

    step = pm.NUTS(early_max_treedepth=3, max_treedepth=3)
    start = {'phantom': fbp_recon}
    for trace in tqdm.tqdm(pm.iter_sample(10**5, step, start=start)):
        callback(trace['phantom', -1:][0])