brandondube / prysm

physical optics: integrated modeling, phase retrieval, segmented systems, polynomials and fitting, sequential raytracing...
https://prysm.readthedocs.io/en/stable/
MIT License
269 stars 46 forks source link

Consider using Tensorflow or PyTorch as a gpu backend? #37

Closed usryokousha closed 3 years ago

usryokousha commented 3 years ago

I am doing some phase retrieval work for systems with large fields of view and typically I end up wanting to recover the phase for multiple local regions. In terms of use for astronomy and photography the use of tensors as opposed to 2D arrays could be very useful.

Tensorflow now includes conventional optimizers such as CG and LBFGS (tensorflow-probability). There are some inherent advantages to using the auto differentiation features offered by the library as well.

brandondube commented 3 years ago

Could you provide an example of what is broken when you just replace prysm.mathops.engine.numpy with tensorflow or pytorch?

For what it's worth, in my opinion there are not inherent advantages to autodiff from tf/pytorch/etc. These libraries, generally, lack critical features for image based wavefront sensing. Pytorch lacked complex number support (a hard requirement) for a long time, and tensorflow takes longer to set up its tape machine than some reverse calculations take to run in their entirety. Often array insertion and slicing ops are forbidden as "nondifferentiable" which is true, but not strictly so (with valid priors for these problems, they are differentiable) Some "batch updates" in the adjoint pass cannot utilize work sharing from the forward pass because those libs are blind to some of the "oracle knowledge" a skilled implementer has. For example, this tiny amount of code for a monochromatic, single plane image based wavefront sensing( "phase retrieval") problem runs more than 5x faster than the equivalent from any of jax, pytorch, or tf:

class NLOptPRModel:
    def __init__(self, amp, wvl, basis, data):
        self.amp = amp
        self.wvl = wvl
        self.basis = basis
        self.D = data

    def update(self, x):
        phs = np.tensordot(basis, x, axes=(0,0))
        W = (2 * np.pi / self.wvl) * phs
        g = self.amp * np.exp(1j * W)
        G = ft_fwd(g)
        I = np.abs(G)**2
        E = np.sum((I-self.D)**2)
        self.W = W
        self.g = g
        self.G = G
        self.I = I
        self.E = E
        return

    def fwd(self, x):
        self.update(x)
        return self.E

    def rev(self, x):
        self.update(x)
        Ibar = 2 * (self.I - self.D)
        Gbar = 2 * Ibar * self.G
        gbar = ft_rev(Gbar)
        Wbar = 2 * np.pi / self.wvl * np.imag(gbar * np.conj(self.g))
        abar = np.tensordot(self.basis, Wbar)
        self.Ibar = Ibar
        self.Gbar = Gbar
        self.gbar = gbar
        self.Wbar = Wbar
        self.abar = abar
        return abar

(This code is (C) me, you should attribute as such if you copy it)

You can convert it from a low dimensional modal estimator to an extremely high dimensional elementwise or "pixelwise" estimator by replacing phs = ... with phs = x.reshape(self.amp.shape), and return Wbar instead of abar. Obviously comment out the projection in the reverse pass, as well. For a high dimensional problem (~262k variables for 512x512) you will really want to prune all of the dead variables -- those outside the pupil's support --. 512x512 = 262,144, but at Q=2 there are only 51,445 pixels in the pupil -- removing 80% of the parameter space is a huge benefit. It's a numpy one liner (two liner if you value clarity) to do that.

Making the BFGS updates on a GPU will be faster for a large problem, but you may be surprised how big the array size is to make that worthwhile. L-BFGS-B from scipy does 512x512 single plane image based wavefront sensing in about 1.2s on a commodity desktop computer, with 65 iterations in that time. 18ms/iter is not that slow, and with one fwd and one rev eval per iter, <9ms each does not particularly motivate the GPU. (It would be a different story were it, say, 1s). It takes ~54 seconds to do 1000 iterations of the 262k variable problem (52ms/iter). I would posit you need something like a four, maybe 16 million variables to motivate the GPU. And then you must have some extremely high frequency phase features to solve, which will require a special diversity that will probably occupy more of your time to identify than the computation.

Anyway - in short, manual autograd, if you care about performance, trounces any library doing it for you.

usryokousha commented 3 years ago

First, thank you for such a complete response to my question. I just happened to come upon it after running into your blog.

I'm working with images from a 27 megapixel CCD which I break up into 256 x 256 ROIs and pad to 512 x 512 arrays.

I was interested in solving directly for Zernike coefficients in the same way you had in your blog using Tensorflow as a backend. Unfortunately Tensorflow wasn't able to handle the inherent nonlinearity of the problem and the gradient was far off from finite differences. Without a known (modeled) PSF or measured one, I was propagating the pupil plane to the PSF plane and convolving with a calibration pattern (latent image). I then minimized the square error between the camera image and my approximation. This convolution adds two more FFTs to the problem which when combined with the backward pass is quite demanding especially when using L-BFGS.

I agree that an analytical gradient is faster any day when carefully constructed and has served me well in the past. In a corporate environment, the convenience of auto-differention has been nice for exploring new ideas.

brandondube commented 3 years ago

So that I understand correctly,

Unfortunately Tensorflow wasn't able to handle the inherent nonlinearity of the problem and the gradient was far off from finite differences.

Is this a tacit assertion that if prysm's backend is tensorflow, tensorflow is able to turn the crank on its tape machine? I'm still not sure what the problem is from the title of your issue...

For the rest of the conversation (particulars of fwd/rev modeling), please send me an email -- it's off-topic for this issue (even if I opened the conversation!)

usryokousha commented 3 years ago

By no means was I suggesting that Prysm with a Tensorflow backend would be any more able to handle my non-linearity issues than Tensorflow auto-differentiation alone. As you mentioned and from my own current experience finding analytical gradients suited to the particular problem are in many cases the best solution.

I'm not an optical engineer by trade, but have been dabbling in optics-related problems for a bit now. My current project involves wavefront sensing and machine learning and while browsing your library it seemed like a good opportunity to point out that machine libraries bring a certain degree of out-of-the-box performance (gpu support and multi-threading), auto-differentiation, and a plethora of Numpy-compatible functions.

You are correct about issues related to Tensorflow's gradient tape which makes implementing custom gradients a pain! It can be accomplished with a function wrapper, but is not the most elegant solution. In addition, there are issues with gradients when applying non-differentiable functions that can be sorted out easily with a hard-coded solution. I know you have put a lot of work into eeking out performance from your library, and I wasn't suggesting that a trivial modification would bring substantial performance differences.

I think we can conclude that dropping in a machine learning backend isn't going to automatically solve a new set of problems and the performance is not necessary better in this space.

brandondube commented 3 years ago

Machine learning libraries are very powerful, but also very general purpose. There are many performance optimizations that you can make given knowledge of the whole program that are "unsafe," and so are not supported by ML libraries. There's some newer work to allow definition of adjoint functions for user functions, which allows you to write your own "unsafe" forward and reverse ops. That will help ameliorate that.

However, a lot of the libraries implemented all of their own algebra in C or C++ or whatever, which discards so much of the past 50+ years of optimization of the BLAS behind numpy, etc.

Additionally, I just find that cupy is such an incredibly superior implementation of GPU computing in python (if you try it, I think you'll find it knocks the socks off pytorch and tf -- way faster, easier to use, more flexible... no competition) that you really need your own algodiff implementation that allows you to sub out numpy for cupy. But then you need a cupy-aware optimization routine. Those used to exist in chainer, but chainer is dead now :\

The heavily optimized parts of prysm are really just the polynomial computations. The pending version (v0.20) greatly simplifies prysm, and makes it a lot easier to use it for adjoint calculations as well as forward ones. The old API is really hard to properly implement adjoint calculations with due to a bunch of magic behind the scenes, but the new one is really trivial.

I may revive iris and implement some state of the art phase retrieval and adjacent algorithms with autodiff in there, based on prysm. I haven't really decided.

One minor comment -- nonlinear and not differentiable are different. Invertible isn't even a requirement for reverse mode calculations (the adjoint model, "analytic gradients"). y=x^2 is nonlinear, and is both differentiable and invertable. y = bindown(x) is neither linear nor invertable, but it does have an adjoint operation that allows you to do reverse calculations to get analytic gradients. It is an important distinction to make, else we may decide some operations that we could otherwise differentiate are non differentiable!