hammerlab / flowdec

TensorFlow Deconvolution for Microscopy Data
Apache License 2.0
89 stars 26 forks source link

windowing and type conversion possbile using `input_prep_fn`, `output_prep_fn` ? #16

Open VolkerH opened 5 years ago

VolkerH commented 5 years ago

Hi,

I'm just trying to get into tensorflow to be able to modify flowdec to my needs. There are two things I am trying to achieve:

  1. do input and output type conversions on the GPU (currently I do this in numpy on the CPU and cProfile shows that the code is spending quite a bit of time on astype(np.*)).
  2. integrate a windowing/apodization function to reduce artifacts caused by discontinuities at the boundary.

While looking through the flowdec source code to see where I could add these things I noticed the input_prep_fn and output_prep_fn stubs and I am wondering whether I could somehow use these for the above-mentioned purposes.

However, in both cases I somehow need to allocate additional arrays (or "tensors")

  1. For the input/output dtpye conversion I will have to create additional arrays with the input/output dtypes (typically uint16).
  2. For the windowing function I would like to pass in a pre-computed windowing function that I multiply with the array.

I notice that inputs and outputs are passed in as dictionaries. So can I achieve these objectives by initialzing the deconvolution object with some additional key/value pairs in the input/output dictionaries and passing in appropriate input_prep_ and output_prep_ functions or do I need to make modifications to the actual code in flowdec/restoration.py ?

Some guidance with how to approach this would be highly appreciated.

eric-czech commented 5 years ago

If you have the profiling results still, can you tell where the astype conversion is happening (i.e. are you sure flowdec is doing it)? I don't recall there being anything in there that tries to manage types beyond the implicit cast that comes with specifying data array inputs like tf.placeholder(dtype=tf.float32, ...). I'm not sure if TF is then doing the conversion on a CPU for a non-float32 input, but you could try to specify an input_prep_fn like:

def input_prep_fn(tensor_name, tensor):
    # tensor_name will be one of 'data' or 'kernel'
    return tf.cast(tensor, tf.float32)

But I don't think this will change anything since that should be what happens without the explicit cast.

On the output side of things though, perhaps something like this will help?

def output_prep_fn(tensor_name, tensor, inputs):
    # This is only applied to single result tensor so tensor_name can be ignored
    return tf.cast(tf.clip_by_value(tensor, tf.uint16.min, tf.uint16.max), tf.uint16)

Those would then get passed in like fd_restoration.RichardsonLucyDeconvolver(..., input_prep_fn=input_prep_fn, output_prep_fn=output_prep_fn).

When the tensors are passed to those functions, any operations that you add would get appended to the TF graph and most likely run on a GPU. So for example, saying tensor * 5 in one of those functions looks like an immediate multiplication (as it would w/ a normal numpy array) but it's a declarative addition to the graph to be run later. Multiplying by an apodization array would do something similar where the whole array would get added to the graph as a constant and its multiplication with future inputs would be on the GPU.

Does that help? The type conversion and apodization are easy to express as tensorflow operations but the window function might be trickier. If you can tell me more about that I could probably help find what tensorflow functions would make it efficient. Either way though, I think all of the above should work using [input|output]_prep_fns since it's basically like altering the source code used to build the TF graph.

beniroquai commented 5 years ago

Thanks for opening the discussion. I'm following your work for quiet a while now and I must say it's really impressive! I was curious about some feelings about Tensorflows complex-valued optimization. I'm currently working on a deconvolution for complex-valued Holograms/E-fields where I found some issues with initialization of TF-variables/constants. A simple example leads to large differences in computational time:

import numpy as np
import tensorflow as tf
import time as time

mysize = 256
mymat = np.ones((mysize,mysize,mysize))
mymat_cmplx = mymat+1j*mymat

# convertinig complex numpy to tensor
t1 = time.time()
TF_mymat_cmplx = tf.constant(mymat_cmplx)
t2 = time.time()
print('Time is: '+str(t2-t1))

# convertinig complex numpy to tensor
t1 = time.time()
TF_mymat_cmplx = tf.complex(np.real(mymat_cmplx), np.imag(mymat_cmplx))
t2 = time.time()
print('Time is: '+str(t2-t1))

which gives:

Time is: 20.639596700668335
Time is: 3.303292989730835

Both parts should do the same, but the latter one is approx. 6x faster. Any ideas? It was measured on Mac/Ubuntu/Windows with TF 1.13 GPU and CPU.

Another thing I found is, that the TF.angle() function is not (yet) implemented on GPU which causes expensive CPU<->GPU copying all the time.

Another aspect is the long time it takes to run the first Session.run(). I guess TF does a set of pre-compilations in order to simplify the graph, but for a simple 3D convolution it already takes longer than the equivalent model in e.g. Matlab/numpy. Did you face the problem sometime and if so do you have any suggestions to solve it? Somewhere I found people suggesting to switch off some env-variables: os.environ["TF_CUDNN_USE_AUTOTUNE"]="0" But unfortunately this didn't show much of an improvement.

Best Benedict

eric-czech commented 5 years ago

Thanks @beniroquai ! That is odd that the one initialization is so much slower but depending on what you were working towards, a placeholder is probably a better starting point for building a graph anyhow (vs building it with constants like mymat_cmplx). I think there are two ways that make the most sense to go about this sort of thing with TensorFlow:

  1. Do what I was doing with flowdec where you assume the spin up time for a single graph is high, hard to optimize, but kind of irrelevant as long as you hang on the graph and reuse it a lot (i.e. reusing the Deconvolver instances instead of recreating them).
  2. Use eager execution. For those examples, you could try it with a tf.enable_eager_execution() somewhere at the beginning and I bet you'll get similar times or tf.constant may even be faster

I'd imagine option #2 is good if you want to use tensorflow much like you'd use numpy, where everything happens line by line. I think you can generally get better performance overall with option #1 but it's a bit harder to use and there's always that "first run" slowness. I don't think there is a good way around that outside of eager execution, but it's usually not to hard to specify inputs using placeholders instead of constants to make a reusable graph that will execute quickly after the first time (and maybe run some dummy values through once on startup if you've got a real-time use case?).

VolkerH commented 5 years ago

Thanks for the detailed reply @eric-czech

If you have the profiling results still, can you tell where the astype conversion is happening (i.e. are you sure flowdec is doing it)?

This is a misunderstanding. The type conversions were not happening in flowdec, they were happening in my own batch code before and after each session run. As I noticed that this is quite slow I was looking for ways to also do this with tensorflow and when poring over your code I realized that this might be possible using the input_prep and output_prep functions.

Thanks for those input_prep_fn and output_prep_fn examples. I will start my experiments from there. What I did not understand when I originally asked is whether I would need to somehow declare tf.placeholders for these variables (as they have different data types) so that the graph knows about them and allocates space. and somehow add these to the feed_dict that gets passed to the session. Your examples indicate I don't have to do anything like that.

Multiplying by an apodization array would do something similar where the whole array would get added to the graph as a constant and its multiplication with future inputs would be on the GPU.

Yes, that's what I was aiming for. My original question was in the sense of how to I declare this apodization array so the tensorflow graph knows about it? At the time I was looking at the this code snippet from resoration.py...

        # Data and kernel should have shapes (z, height, width)
        datah = self._wrap_input(tf.placeholder(self.dtype, shape=[None] * self.n_dims, name='data'))
        kernh = self._wrap_input(tf.placeholder(self.dtype, shape=[None] * self.n_dims, name='kernel'))

...and wondering whether I would have to add an additional line for an apodization array. I'm still not quite clear on this. Then, do I have to add it to the feed_dict. Currently, I seem to be passing a feed_dict like this when passing an Acquisition object:

    def to_feed_dict(self):
        return {'data': self.data, 'kernel': self.kernel}

So do I need to pass a feed_dict similar to {'data' : my_data_array, 'kernel' : my_kernel_array, 'apo' : my_apo_array} in order to be able to use something like tensor * apo during the input_prep_fn ?

Also, from browsing through the tf documentation over the weekend and it appears that I alternatively could do something like tf.Variable(apod_np_array, 'my_apo_tf_array') and then the variable will be part of the graph.

Apologies if that all sounds a bit confused, but having used tensorflow mainly through convenient wrappers such as keras I still need to experiment a bit before I can forumlate these questions more cohesively.

Does that help? The type conversion and apodization are easy to express as tensorflow operations but the window function might be trickier.

With the window function I basically meant multipilcation with an apodization array (this could be a sine-window or similar which I can easily pre-compute using numpy, so nothing tricky there).

Finally there is a question about persistence. Currently I pass in the kernel each time I run a deconvolution, similar to this:

algo = fd_restoration.RichardsonLucyDeconvolver(data.ndim).initialize()
res = []
for v in volumes:
    res.append(algo.run(fd_data.Acquisition(data=v, kernel=kernel), niter=10).data)

As the kernel is the same for each algo.run, is there a way to persist the kernel on the GPU (maybe that's already happening)? Not sure whether there would be much speed-up.

beniroquai commented 5 years ago

Thank you very much for the comprehensive answer! I tried some of your suggestions, but I guess the first sess.run() is always slow as it compiles the graph. Saving and loading might be an option. Eager not really - but maybe TF 2.0 makes the difference?

Best Bene

eric-czech commented 5 years ago

Hey @VolkerH sorry for the delay, just getting back from vacation.

I see what you mean now and thanks for laying that out. I think you could safely avoid having to add new tf.Variable tensors and instead stick to adding an apodization array as a constant in the input_prep_fn since the type and shape declaration within the graph would be implicit in how you define that function (i.e. when you multiply the image tensor by the apodization matrix as a numpy array, Tensorflow would infer the type and shape of that numpy array before adding it as a tensor to the graph forever).

I hope that makes sense and if it helps further, here is an example that places both an apodization array and a PSF as constants on the graph instead of having to feed them in with placeholders:

import tensorflow as tf
import numpy as np
from flowdec import restoration as fd_restoration
from flowdec import data as fd_data

# Load Hollow Bars volume downsampled to 25%
acq = fd_data.bars_25pct()
print(acq.shape())
# > {'actual': (32, 64, 64), 'data': (32, 64, 64), 'kernel': (32, 64, 64)}

# Create a dummy apodization array
apod = np.ones_like(acq.data)

def input_prep_fn(tensor_name, tensor):
    # Multiply image tensor by apodization matrix (multiplication will convert `apod` to Tensor)
    if tensor_name.startswith('data'):
        return tensor * apod
    # Return psf as constant with explicit conversion to Tensor since the output
    # of these functions must be a Tensor 
    if tensor_name.startswith('kernel'):
        return tf.constant(acq.kernel)
    raise ValueError('Tensor %s not supported' % tensor_name)

# Initialize with above function
algo = fd_restoration.RichardsonLucyDeconvolver(3, input_prep_fn=input_prep_fn).initialize()

# Pass the kernel as a tiny array to ultimately be ignored
# * `algo` could then be reused on other images with both the PSF and apodization array as constants
res = algo.run(fd_data.Acquisition(data=acq.data, kernel=np.ones((1,1,1))), niter=25)

I was experimenting with this via nvprof as I was little unclear myself if pinned constants (a la this post) would persist across TensorFlow sessions but it looks like they don't. In other words, the above doesn't solve the "persisting PSFs in GPU memory across multiple deconvolutions" problem since (apodization aside) it results in the same number of host-to-device and device-to-host transfers as the same code that doesn't make the PSF a constant. Presumably it at least makes some of the possibilities for an input_prep_fn more clear though.

I think the real trick to keeping anything in GPU memory for multiple images will be to change the API to not create a new TensorFlow session each time (which is what happens on every algo.run call). That should overlap a good bit with https://github.com/hammerlab/flowdec/issues/17, and at a quick glance it looks like most of the FFT operations support batch dimensions so it shouldn't be too hard to add support for multiple images and PSFs at some point with broadcasting to support the many images + one PSF case.

Oh also, unfortunately I made a mistake in how I was setting references to the input placeholders that needed to be fixed for the above example to work, but I just pushed the change.

VolkerH commented 5 years ago

Thanks very much for your time and that example. This answers my original question (I guess the issue can be closed). I'll watch the discussion about batching in #17. That might bring some speed improvements. In addition to the GPU-to-host transfer for the PSF I'm curious whether the FFT plan persists across sessions. If it doesn't I would expect the batch implementation to bring quite a significant performance increase.