clij / clijpy

GPU-accelerated image processing using CLIJ via pyimagej in python
https://clij.github.io/clijpy
BSD 3-Clause "New" or "Revised" License
10 stars 1 forks source link

Multiple issues with demo #2

Open guiwitz opened 4 years ago

guiwitz commented 4 years ago

Hi,

first of all, thanks for the effort of making Clij available in Python. I tried it in Fiji and was amazed at the speed increase and thus wanted to see if I could use it in Python as well. Installation worked without problem on OSX (10.13.6) but it took me a while then to get code to run properly. All this is based on the demo examples provided. I had a series of issues but I think they all come from the fact that probably I'm not importing the right libraries. Anyway, I'll just describe my issues and explain how I solved them.

  1. The examples all import the CLIJx class. But it does not seem to exist. The only thing that works for me is CLIJPY. I also tried CLIJ2 which exists but doesn't seem to have the same methods.
  2. When attempting to copy data into an array I get that 'net.haesleinhuepf.clijpy.CLIJPY' object has no attribute 'copy' so I just used the push method instead.
  3. clijx.blur doesn't exist, so I used clijx.op.blur
  4. When calling the blur method like this clijx.op.blur(input, blurred, 5, 5, 0); it says no method matching your arguments. I somehow found out that I could replace the numbers by clijx.op.blur(input, blurred, Float(5), Float(5), Float(0));

With all these changes I get the following working code that does work and produce a blurred image:

# init pyimage to get access to jar files
import imagej
ij = imagej.init('/Applications/Fiji.app/')
ij.getVersion()

# load some image data
from skimage import io
sk_img = io.imread('https://samples.fiji.sc/blobs.png')

# init clijpy to get access to the GPU
from jnius import autoclass

CLIJpy = autoclass('net.haesleinhuepf.clijpy.CLIJPY')
clijx = CLIJpy.getInstance();

Float = autoclass('java.lang.Float')
Int = autoclass('java.lang.Integer')

# convert and array to an ImageJ2 img:
import numpy as np
np_arr = np.array(sk_img)
ij_img = ij.py.to_java(np_arr)

# push the image to the GPU
input8 = clijx.push(ij_img)

# reserve memory for output, same size and type as input
blurred = clijx.create(input8.getDimensions());

# blur, threshold and label the image
clijx.op.blur(input8, blurred, Float(10), Float(10), Float(0));

# pull image back from GPU
ij_img_result = clijx.pull(blurred);
# convert to numpy/python
np_arr_result = ij.py.rai_to_numpy(ij_img_result);

I'm just curious where I went completely wrong. When doing a few tests I noticed that while the blurring is indeed much faster with clijpy than the classic skimage, what takes a lot of time is the conversion done by rai_to_numpy. Is that expected? Unfortunately this last step is so slow that it seems to make the effective gain in time of the GPU processing vanish.

I hope there's a simple explanation for these issues and maybe there's a way to make the numpy conversion faster!

haesleinhuepf commented 4 years ago

Hey @guiwitz ,

thanks for testing and thanks for the feedback! Regarding CLIJx versus CLIJ2 and CLIJPY. Maybe your timing was just suboptimal. I removed CLIJPY recently (the day before you created this issue), because everything lives now in CLIJx. All you need to do in order to get it work on your system is updating your Fiji. I assume in Fiji, the clij2 update site is activated?

All the other API-related issues you mentioned should also go away then. Goal of the API changes were to make the interface in ImageJ/Fiji, Matlab in python the same. All methods should be accessible via CLIJx.getInstance().method();, or short clijx.method(). Let me know if it works after updating Fiji!

what takes a lot of time is the conversion done by rai_to_numpy. Is that expected? Unfortunately this last step is so slow that it seems to make the effective gain in time of the GPU processing vanish.

Regarding the speed issue with rai_to_numpy. I think it's related to internals of pyimagej. I'll start a discussion with the developers on their repo. In the meantime, feel free to try the little workaround I just implemented:

instead of calling

# pull image back from GPU
ij_img_result = clijx.pull(blurred);
# convert to numpy/python
np_arr_result = ij.py.rai_to_numpy(ij_img_result);

try this:

def clijx_pull(buffer):
    import numpy
    numpy_image = numpy.zeros([buffer.getWidth(), buffer.getHeight(), buffer.getDepth()])
    wrapped = ij.py.to_java(numpy_image);
    clijx.pullToRAI(buffer, wrapped);
    return numpy_image

np_arr_result = clijx_pull(labelled_without_edges);

This should improve performance as demonstrated [here]:(https://github.com/clij/clijpy/blob/master/python/clijpy_demo.ipynb):

image

Let me know if it works and thanks again for the support!

Cheers, Robert

guiwitz commented 4 years ago

Hi @haesleinhuepf,

yes everything works now (I still had a manually installed version of clijpy that of course wasn't getting updated)! With 3D blurring I got a 2x increase in speed with the simple GPU of my laptop. The function for conversion works well too. However in the context of using it to analyse a time-lapse, I decomposed it and create just once the numpy array and the wrapped array and then just call clijx.pullToRAI(current_image) in a loop to avoid re-creating that empty numpy array over and over.

I have a few additional questions but will open new issues to keep things clear. Just a small one that I'm adding here. Why do you need these three lines:

input8 = clijx.push(ij_img)
input = clijx.create(input8.getDimensions())
clijx.copy(input8, input)

Isn't the clijx.push function already copying the contents of ij_img into input8 ? I never use this clijx.copy() function and it works fine (see e.g. this notebook).

Thanks for your help ! Cheers, Guillaume

haesleinhuepf commented 4 years ago

Hey @guiwitz ,

Isn't the clijx.push function already copying the contents of ij_img into input8 ?

That's a good question! The create method has a second parameter where you can specify the image pixel type. If not specified, it uses Float per default. Thus, these three lines of code do a pixel type conversion - from whatever type to Float. I added comments to the notebook to make this clearer:

# push the image to the GPU
input8 = clijx.push(ij_img)

# convert it to Float
input = clijx.create(input8.getDimensions()) # generates a Float image
# actual conversion
clijx.copy(input8, input)

Thanks for the good questions! These are the ones improving documentation! :-)

Cheers, Robert