tand826 / wsiprocess

Whole Slide Image (WSI) Processing Library for Histopathological / Cytopathological Machine Learning Tasks
Apache License 2.0
21 stars 3 forks source link

python bioformats is fast enough? #49

Open tand826 opened 1 year ago

tand826 commented 1 year ago

bioformats-python hangs unintendedly and outputs no error message.

class DatasetWithBioformats(Dataset):

    def __init__(self, wsi, tile_size, transform):
        javabridge.start_vm(class_path=bioformats.JARS)
        self.wsi = wsi
        self.tile_size = tile_size
        self.transform = transform
        metadata = bioformats.get_omexml_metadata(wsi)
        root = ET.fromstring(metadata)
        self.width = int(root.findall('.//*[@SizeX]')[0].attrib["SizeX"])
        self.height = int(root.findall('.//*[@SizeY]')[0].attrib["SizeY"])
        x_coords = [i for i in range(0, self.width, tile_size)]
        y_coords = [i for i in range(0, self.height, tile_size)]
        self.coords = list(product(x_coords, y_coords))
        self.reader = bioformats.ImageReader(wsi)

    def __del__(self):
        self.reader.close()
        javabridge.kill_vm()

    def __len__(self):
        return len(self.coords)

    def __getitem__(self, idx):
        x, y = self.coords[idx]
        width = min(self.tile_size, self.width-x)
        height = min(self.tile_size, self.height-y)
        print(x, y, width, height)
        tile = self.reader.read(c=0, rescale=False, XYWH=(x, y, width, height))
        tensor = self.transfrom(tile)
        return tensor

aicsimageio wraps bioformats. does this work?

tand826 commented 1 year ago

aicsimageio does not provide apis to stitch mosaic.

tand826 commented 1 year ago

with jpype and loci_tools.jar, wsi data can be loaded with bioformats

import jpype
import numpy as np
from PIL import Image
from jpype.types import JByte
from copy import copy

jpype.startJVM(jpype.getDefaultJVMPath(), "-ea", "-Djava.class.path=./loci_tools.jar", "-Xmx1024m")
loci = jpype.JPackage('loci')

reader = loci.formats.ImageReader()
reader.setId("examples/CMU-1.ndpi")

h = 100
w = 120
x = 10000
y = 23000
byte = reader.openBytes(0, x, y, w, h)  # byte[] in java

# with PIL
Image.frombytes(mode="RGB", size=(w, h), data=bytes(JByte[::](byte)))

# with numpy
np.frombuffer(bytes(JByte[::](byte)), dtype=np.uint8).reshape(h, w, 3)

# with torch
torch.frombuffer(bytearray(JByte[::](byte)), dtype=torch.uint8).reshape(3, h, w,)
tand826 commented 1 year ago
def speed_bioformats():
    tile_size = 512
    reader = jpype.JPackage('loci').formats.ImageReader()
    reader.setId(wsi)
    ys = range(0, reader.getSizeY(), tile_size)
    xs = range(0, reader.getSizeX(), tile_size)
    tile_size = 512

    for x, y in tqdm(product(xs, ys), total=len(xs)*len(ys)):
        w, h = min(tile_size, reader.getSizeX() - x), min(tile_size, reader.getSizeY() - y)
        byte = reader.openBytes(0, x, y, w, h)  # byte[] in java
        tensor = torch.frombuffer(deepcopy(bytearray(JByte[::](byte))), dtype=torch.uint8).reshape(3, h, w)

-> around 36.90it/s

tand826 commented 1 year ago
class DatasetWithBioformats(Dataset):

    def __init__(self, wsi, tile_size, transform):
        jpype.startJVM(jpype.getDefaultJVMPath(), "-ea", "-Djava.class.path=./loci_tools.jar", "-Xmx1024m")

        self.wsi = wsi
        self.tile_size = tile_size
        self.transform = transform
        self.reader = jpype.JPackage('loci').formats.ImageReader()
        self.reader.setId(self.wsi)

        self.width = self.reader.getSizeX()
        self.height = self.reader.getSizeY()
        x_coords = [i for i in range(0, self.width, tile_size)]
        y_coords = [i for i in range(0, self.height, tile_size)]
        self.coords = list(product(x_coords, y_coords))

    def __del__(self):
        self.reader.close()
        jpype.shutdownJVM()

    def __len__(self):
        return len(self.coords)

    def __getitem__(self, idx):
        x, y = self.coords[idx]
        w = min(self.tile_size, self.width-x)
        h = min(self.tile_size, self.height-y)
        byte = self.reader.openBytes(0, x, y, w, h)
        tensor = torch.frombuffer(bytearray(JByte[::](byte)), dtype=torch.uint8).reshape(3, h, w)
        tensor = self.transfrom(tensor)
        return tensor

loader = DataLoader(DatasetWithBioformats(wsi, tile_size, transform=None), batch_size=batch_size, num_workers=4)
for image in tqdm(loader):
    print(image.shape)
    import pdb
    pdb.set_trace()

-> _pickle.PicklingError: Can't pickle <java class 'java.lang.NoClassDefFoundError'>: import of module 'java.lang' failed

tand826 commented 1 year ago

not with jpype, with javabridge

tile_size = 512
batch_size = 16
wsi = "examples/CMU-1.ndpi"
bioformats_path = "loci_tools.jar"
javabridge.start_vm(class_path=bioformats_path)

ImageReader = javabridge.JClassWrapper("loci.formats.ImageReader")
reader = ImageReader()
reader.setId(wsi)
width = reader.getSizeX()
height = reader.getSizeY()
x_coords = [i for i in range(0, width, tile_size)]
y_coords = [i for i in range(0, height, tile_size)]

coords = list(product(x_coords, y_coords))
for x, y in tqdm(coords, desc="Loading tiles"):
    w = min(tile_size, width - x)
    h = min(tile_size, height - y)
    w, h = min(tile_size, reader.getSizeX() - x), min(tile_size, reader.getSizeY() - y)
    byte = reader.openBytes(0, x, y, w, h)
    tensor = torch.tensor(byte, dtype=torch.uint8).reshape(h, w, 3).permute(2, 0, 1)

-> 136.19it/s

tand826 commented 1 year ago

bioformats-python hangs unintendedly and outputs no error message.

class DatasetWithBioformats(Dataset):

    def __init__(self, wsi, tile_size, transform):
        javabridge.start_vm(class_path=bioformats.JARS)
        self.wsi = wsi
        self.tile_size = tile_size
        self.transform = transform
        metadata = bioformats.get_omexml_metadata(wsi)
        root = ET.fromstring(metadata)
        self.width = int(root.findall('.//*[@SizeX]')[0].attrib["SizeX"])
        self.height = int(root.findall('.//*[@SizeY]')[0].attrib["SizeY"])
        x_coords = [i for i in range(0, self.width, tile_size)]
        y_coords = [i for i in range(0, self.height, tile_size)]
        self.coords = list(product(x_coords, y_coords))
        self.reader = bioformats.ImageReader(wsi)

    def __del__(self):
        self.reader.close()
        javabridge.kill_vm()

    def __len__(self):
        return len(self.coords)

    def __getitem__(self, idx):
        x, y = self.coords[idx]
        width = min(self.tile_size, self.width-x)
        height = min(self.tile_size, self.height-y)
        print(x, y, width, height)
        tile = self.reader.read(c=0, rescale=False, XYWH=(x, y, width, height))
        tensor = self.transfrom(tile)
        return tensor

aicsimageio wraps bioformats. does this work?

javabridge version works if start_vm is run after replicated to fork in processes.


class DatasetWithBioformats(Dataset):

    def __init__(self, wsi, tile_size, coords_len):
        self.wsi = wsi
        self.tile_size = tile_size
        self.coords_len = coords_len
        self.transform = transforms.Resize((self.tile_size, self.tile_size))

    def lazy_load(self):
        javabridge.start_vm(class_path=bioformats.JARS)
        self.reader = bioformats.ImageReader(wsi)
        self.set_params()

    def __del__(self):
        if hasattr(self, "reader"):
            self.reader.close()
        javabridge.kill_vm()

    def __len__(self):
        return self.coords_len

    def __getitem__(self, idx):
        if not hasattr(self, "reader"):
            self.lazy_load()
        x, y = self.coords[idx]
        w = min(tile_size, self.width - x)
        h = min(tile_size, self.height - y)
        w, h = min(tile_size, self.width - x), min(tile_size, self.height - y)
        tile = self.reader.read(c=0, rescale=False, XYWH=(x, y, w, h))
        tensor = torch.tensor(tile, dtype=torch.uint8).reshape(h, w, 3).permute(2, 0, 1)
        return self.transform(tensor)

    def set_params(self):
        ImageReader = javabridge.JClassWrapper("loci.formats.ImageReader")
        reader = ImageReader()
        reader.setId(self.wsi)
        self.width = reader.getSizeX()
        self.height = reader.getSizeY()
        self.coords = list(product(
            range(0, self.width, self.tile_size),
            range(0, self.height, self.tile_size)))
        reader.close()

in this script, coords_len can be read after javabridge.start_vm, but once its started something goes wrong about in the replicated processes. Here I calculated coords_len beforehand in another script, and saved it to a file.

tand826 commented 1 year ago

workaround by loading it in another process with concurrent.futures.ProcessPoolExecutor

class DatasetWithBioformats(Dataset):

    def __init__(self, wsi, tile_size):
        self.wsi = wsi
        self.tile_size = tile_size
        self.coords = self.get_coords()
        self.transform = transforms.Resize((self.tile_size, self.tile_size))

    def read_coords(self):
        javabridge.start_vm(class_path=bioformats.JARS)
        ImageReader = javabridge.JClassWrapper("loci.formats.ImageReader")
        reader = ImageReader()
        reader.setId(wsi)
        coords = list(product(
            range(0, reader.getSizeX(), tile_size),
            range(0, reader.getSizeY(), tile_size)))
        reader.close()
        javabridge.kill_vm()
        return coords

    def get_coords(self):
        with ProcessPoolExecutor() as executor:
            future = executor.submit(self.read_coords)
        return future.result()

    def lazy_load(self):
        javabridge.start_vm(class_path=bioformats.JARS)
        self.reader = bioformats.ImageReader(wsi)
        self.set_params()

    def __del__(self):
        if hasattr(self, "reader"):
            self.reader.close()
        javabridge.kill_vm()

    def __len__(self):
        return len(self.coords)

    def __getitem__(self, idx):
        if not hasattr(self, "reader"):
            self.lazy_load()
        x, y = self.coords[idx]
        w = min(tile_size, self.width - x)
        h = min(tile_size, self.height - y)
        w, h = min(tile_size, self.width - x), min(tile_size, self.height - y)
        tile = self.reader.read(c=0, rescale=False, XYWH=(x, y, w, h))
        tensor = torch.tensor(tile, dtype=torch.uint8).reshape(h, w, 3).permute(2, 0, 1)
        return self.transform(tensor)

    def set_params(self):
        ImageReader = javabridge.JClassWrapper("loci.formats.ImageReader")
        reader = ImageReader()
        reader.setId(self.wsi)
        self.width = reader.getSizeX()
        self.height = reader.getSizeY()
        self.coords = list(product(
            range(0, self.width, self.tile_size),
            range(0, self.height, self.tile_size)))
        reader.close()

dataset = DatasetWithBioformats(wsi, tile_size)
loader = DataLoader(dataset, batch_size=batch_size, num_workers=4)
with tqdm(loader) as t:
    for image in t:
        image += 1
        pass

this worked.