tand826 / wsiprocess

Whole Slide Image (WSI) Processing Library for Histopathological / Cytopathological Machine Learning Tasks
Apache License 2.0
23 stars 3 forks source link

python bioformats is fast enough? #49

Open tand826 opened 1 year ago

tand826 commented 1 year ago

bioformats-python hangs unintendedly and outputs no error message.

class DatasetWithBioformats(Dataset):

    def __init__(self, wsi, tile_size, transform):
        javabridge.start_vm(class_path=bioformats.JARS)
        self.wsi = wsi
        self.tile_size = tile_size
        self.transform = transform
        metadata = bioformats.get_omexml_metadata(wsi)
        root = ET.fromstring(metadata)
        self.width = int(root.findall('.//*[@SizeX]')[0].attrib["SizeX"])
        self.height = int(root.findall('.//*[@SizeY]')[0].attrib["SizeY"])
        x_coords = [i for i in range(0, self.width, tile_size)]
        y_coords = [i for i in range(0, self.height, tile_size)]
        self.coords = list(product(x_coords, y_coords))
        self.reader = bioformats.ImageReader(wsi)

    def __del__(self):
        self.reader.close()
        javabridge.kill_vm()

    def __len__(self):
        return len(self.coords)

    def __getitem__(self, idx):
        x, y = self.coords[idx]
        width = min(self.tile_size, self.width-x)
        height = min(self.tile_size, self.height-y)
        print(x, y, width, height)
        tile = self.reader.read(c=0, rescale=False, XYWH=(x, y, width, height))
        tensor = self.transfrom(tile)
        return tensor

aicsimageio wraps bioformats. does this work?

tand826 commented 1 year ago

aicsimageio does not provide apis to stitch mosaic.

tand826 commented 1 year ago

with jpype and loci_tools.jar, wsi data can be loaded with bioformats

import jpype
import numpy as np
from PIL import Image
from jpype.types import JByte
from copy import copy

jpype.startJVM(jpype.getDefaultJVMPath(), "-ea", "-Djava.class.path=./loci_tools.jar", "-Xmx1024m")
loci = jpype.JPackage('loci')

reader = loci.formats.ImageReader()
reader.setId("examples/CMU-1.ndpi")

h = 100
w = 120
x = 10000
y = 23000
byte = reader.openBytes(0, x, y, w, h)  # byte[] in java

# with PIL
Image.frombytes(mode="RGB", size=(w, h), data=bytes(JByte[::](byte)))

# with numpy
np.frombuffer(bytes(JByte[::](byte)), dtype=np.uint8).reshape(h, w, 3)

# with torch
torch.frombuffer(bytearray(JByte[::](byte)), dtype=torch.uint8).reshape(3, h, w,)
tand826 commented 1 year ago
def speed_bioformats():
    tile_size = 512
    reader = jpype.JPackage('loci').formats.ImageReader()
    reader.setId(wsi)
    ys = range(0, reader.getSizeY(), tile_size)
    xs = range(0, reader.getSizeX(), tile_size)
    tile_size = 512

    for x, y in tqdm(product(xs, ys), total=len(xs)*len(ys)):
        w, h = min(tile_size, reader.getSizeX() - x), min(tile_size, reader.getSizeY() - y)
        byte = reader.openBytes(0, x, y, w, h)  # byte[] in java
        tensor = torch.frombuffer(deepcopy(bytearray(JByte[::](byte))), dtype=torch.uint8).reshape(3, h, w)

-> around 36.90it/s

tand826 commented 1 year ago
class DatasetWithBioformats(Dataset):

    def __init__(self, wsi, tile_size, transform):
        jpype.startJVM(jpype.getDefaultJVMPath(), "-ea", "-Djava.class.path=./loci_tools.jar", "-Xmx1024m")

        self.wsi = wsi
        self.tile_size = tile_size
        self.transform = transform
        self.reader = jpype.JPackage('loci').formats.ImageReader()
        self.reader.setId(self.wsi)

        self.width = self.reader.getSizeX()
        self.height = self.reader.getSizeY()
        x_coords = [i for i in range(0, self.width, tile_size)]
        y_coords = [i for i in range(0, self.height, tile_size)]
        self.coords = list(product(x_coords, y_coords))

    def __del__(self):
        self.reader.close()
        jpype.shutdownJVM()

    def __len__(self):
        return len(self.coords)

    def __getitem__(self, idx):
        x, y = self.coords[idx]
        w = min(self.tile_size, self.width-x)
        h = min(self.tile_size, self.height-y)
        byte = self.reader.openBytes(0, x, y, w, h)
        tensor = torch.frombuffer(bytearray(JByte[::](byte)), dtype=torch.uint8).reshape(3, h, w)
        tensor = self.transfrom(tensor)
        return tensor

loader = DataLoader(DatasetWithBioformats(wsi, tile_size, transform=None), batch_size=batch_size, num_workers=4)
for image in tqdm(loader):
    print(image.shape)
    import pdb
    pdb.set_trace()

-> _pickle.PicklingError: Can't pickle <java class 'java.lang.NoClassDefFoundError'>: import of module 'java.lang' failed

tand826 commented 1 year ago

not with jpype, with javabridge

tile_size = 512
batch_size = 16
wsi = "examples/CMU-1.ndpi"
bioformats_path = "loci_tools.jar"
javabridge.start_vm(class_path=bioformats_path)

ImageReader = javabridge.JClassWrapper("loci.formats.ImageReader")
reader = ImageReader()
reader.setId(wsi)
width = reader.getSizeX()
height = reader.getSizeY()
x_coords = [i for i in range(0, width, tile_size)]
y_coords = [i for i in range(0, height, tile_size)]

coords = list(product(x_coords, y_coords))
for x, y in tqdm(coords, desc="Loading tiles"):
    w = min(tile_size, width - x)
    h = min(tile_size, height - y)
    w, h = min(tile_size, reader.getSizeX() - x), min(tile_size, reader.getSizeY() - y)
    byte = reader.openBytes(0, x, y, w, h)
    tensor = torch.tensor(byte, dtype=torch.uint8).reshape(h, w, 3).permute(2, 0, 1)

-> 136.19it/s

tand826 commented 1 year ago

bioformats-python hangs unintendedly and outputs no error message.

class DatasetWithBioformats(Dataset):

    def __init__(self, wsi, tile_size, transform):
        javabridge.start_vm(class_path=bioformats.JARS)
        self.wsi = wsi
        self.tile_size = tile_size
        self.transform = transform
        metadata = bioformats.get_omexml_metadata(wsi)
        root = ET.fromstring(metadata)
        self.width = int(root.findall('.//*[@SizeX]')[0].attrib["SizeX"])
        self.height = int(root.findall('.//*[@SizeY]')[0].attrib["SizeY"])
        x_coords = [i for i in range(0, self.width, tile_size)]
        y_coords = [i for i in range(0, self.height, tile_size)]
        self.coords = list(product(x_coords, y_coords))
        self.reader = bioformats.ImageReader(wsi)

    def __del__(self):
        self.reader.close()
        javabridge.kill_vm()

    def __len__(self):
        return len(self.coords)

    def __getitem__(self, idx):
        x, y = self.coords[idx]
        width = min(self.tile_size, self.width-x)
        height = min(self.tile_size, self.height-y)
        print(x, y, width, height)
        tile = self.reader.read(c=0, rescale=False, XYWH=(x, y, width, height))
        tensor = self.transfrom(tile)
        return tensor

aicsimageio wraps bioformats. does this work?

javabridge version works if start_vm is run after replicated to fork in processes.


class DatasetWithBioformats(Dataset):

    def __init__(self, wsi, tile_size, coords_len):
        self.wsi = wsi
        self.tile_size = tile_size
        self.coords_len = coords_len
        self.transform = transforms.Resize((self.tile_size, self.tile_size))

    def lazy_load(self):
        javabridge.start_vm(class_path=bioformats.JARS)
        self.reader = bioformats.ImageReader(wsi)
        self.set_params()

    def __del__(self):
        if hasattr(self, "reader"):
            self.reader.close()
        javabridge.kill_vm()

    def __len__(self):
        return self.coords_len

    def __getitem__(self, idx):
        if not hasattr(self, "reader"):
            self.lazy_load()
        x, y = self.coords[idx]
        w = min(tile_size, self.width - x)
        h = min(tile_size, self.height - y)
        w, h = min(tile_size, self.width - x), min(tile_size, self.height - y)
        tile = self.reader.read(c=0, rescale=False, XYWH=(x, y, w, h))
        tensor = torch.tensor(tile, dtype=torch.uint8).reshape(h, w, 3).permute(2, 0, 1)
        return self.transform(tensor)

    def set_params(self):
        ImageReader = javabridge.JClassWrapper("loci.formats.ImageReader")
        reader = ImageReader()
        reader.setId(self.wsi)
        self.width = reader.getSizeX()
        self.height = reader.getSizeY()
        self.coords = list(product(
            range(0, self.width, self.tile_size),
            range(0, self.height, self.tile_size)))
        reader.close()

in this script, coords_len can be read after javabridge.start_vm, but once its started something goes wrong about in the replicated processes. Here I calculated coords_len beforehand in another script, and saved it to a file.

tand826 commented 1 year ago

workaround by loading it in another process with concurrent.futures.ProcessPoolExecutor

class DatasetWithBioformats(Dataset):

    def __init__(self, wsi, tile_size):
        self.wsi = wsi
        self.tile_size = tile_size
        self.coords = self.get_coords()
        self.transform = transforms.Resize((self.tile_size, self.tile_size))

    def read_coords(self):
        javabridge.start_vm(class_path=bioformats.JARS)
        ImageReader = javabridge.JClassWrapper("loci.formats.ImageReader")
        reader = ImageReader()
        reader.setId(wsi)
        coords = list(product(
            range(0, reader.getSizeX(), tile_size),
            range(0, reader.getSizeY(), tile_size)))
        reader.close()
        javabridge.kill_vm()
        return coords

    def get_coords(self):
        with ProcessPoolExecutor() as executor:
            future = executor.submit(self.read_coords)
        return future.result()

    def lazy_load(self):
        javabridge.start_vm(class_path=bioformats.JARS)
        self.reader = bioformats.ImageReader(wsi)
        self.set_params()

    def __del__(self):
        if hasattr(self, "reader"):
            self.reader.close()
        javabridge.kill_vm()

    def __len__(self):
        return len(self.coords)

    def __getitem__(self, idx):
        if not hasattr(self, "reader"):
            self.lazy_load()
        x, y = self.coords[idx]
        w = min(tile_size, self.width - x)
        h = min(tile_size, self.height - y)
        w, h = min(tile_size, self.width - x), min(tile_size, self.height - y)
        tile = self.reader.read(c=0, rescale=False, XYWH=(x, y, w, h))
        tensor = torch.tensor(tile, dtype=torch.uint8).reshape(h, w, 3).permute(2, 0, 1)
        return self.transform(tensor)

    def set_params(self):
        ImageReader = javabridge.JClassWrapper("loci.formats.ImageReader")
        reader = ImageReader()
        reader.setId(self.wsi)
        self.width = reader.getSizeX()
        self.height = reader.getSizeY()
        self.coords = list(product(
            range(0, self.width, self.tile_size),
            range(0, self.height, self.tile_size)))
        reader.close()

dataset = DatasetWithBioformats(wsi, tile_size)
loader = DataLoader(dataset, batch_size=batch_size, num_workers=4)
with tqdm(loader) as t:
    for image in t:
        image += 1
        pass

this worked.

CDPDisk commented 4 days ago

Hi, have you tried to compare the speed between javabridge with ImageReader and Openslide when reading WSI? I try to read as big as possible images for example 20480*20480 images and compare their speed. But I found that Javabridge is not faster, which I presume is due to the large amount of debug logs being output. Do you have a way to turn off these logs please. I have mention it: https://github.com/CellProfiler/python-bioformats/issues/171#issuecomment-2506473895

tand826 commented 4 days ago

@CDPDisk Hi, thank you for your comment. It's been a while, and I don't remember how to control the log outputs, but there are a few ideas to find the answer.

I'm curious about your problem. If you find the answer, or some could be helpful for you, can you be kind enough to share the result with me?

CDPDisk commented 3 days ago

@CDPDisk Hi, thank you for your comment. It's been a while, and I don't remember how to control the log outputs, but there are a few ideas to find the answer.

  • Is the log output really the core of the problem? How about using %%time in jupyter or any measuring ways to compare the specific part of the codes?
  • There is a overhead of starting java vm with javabridge. Can you measure the time after the overhead?
  • If you want to read images in parallel, the limitation that the number of vms in 1 process is 1 might cause slower reading speed. Can you try to read only 1 20480x20480 image?

I'm curious about your problem. If you find the answer, or some could be helpful for you, can you be kind enough to share the result with me?

I calculated the time after overhead in the new code, but that didn't make an absolute difference in the results. At the moment I'm not doing parallel reads, both are using only one process as well as reading only 1 20480 x 20480 image.

I'm interested in your issue because you've mentioned that you achieved a read speed of 136.19it/s. But I can't achieve this with javabridge. I found out that I don't have this speed when reading with 512 size patches, and reading 512 size on SSD took about 60s for 1600 random patches, or 26.67it/s, which is not at all the same as your result. So I suspect the log output is seriously affecting my speed.

tand826 commented 3 days ago

@CDPDisk I tried to reproduce the speed above by supressing the log output and got 137.17it/s. Does this reproduce the speed in your environment?

def supress_log():
    # https://github.com/CellProfiler/python-bioformats/issues/137#issuecomment-802313393
    # https://github.com/pskeshu/microscoper/blob/master/microscoper/io.py#L141-L162
    rootLoggerName = javabridge.get_static_field("org/slf4j/Logger",
                                                 "ROOT_LOGGER_NAME",
                                                 "Ljava/lang/String;")

    rootLogger = javabridge.static_call("org/slf4j/LoggerFactory",
                                        "getLogger",
                                        "(Ljava/lang/String;)Lorg/slf4j/Logger;",
                                        rootLoggerName)

    logLevel = javabridge.get_static_field("ch/qos/logback/classic/Level",
                                           "WARN",
                                           "Lch/qos/logback/classic/Level;")

    javabridge.call(rootLogger,
                    "setLevel",
                    "(Lch/qos/logback/classic/Level;)V",
                    logLevel)

def speed_javabridge():
    javabridge.start_vm(class_path=bioformats.JARS)
    supress_log()
    ImageReader = javabridge.JClassWrapper("loci.formats.ImageReader")
    reader = ImageReader()
    reader.setId(wsi)
    xs = range(0, reader.getSizeX(), tile_size)
    ys = range(0, reader.getSizeY(), tile_size)
    width = reader.getSizeX()
    height = reader.getSizeY()

    for x, y in tqdm(list(product(xs, ys)), desc="javabridge"):
        w, h = min(tile_size,  width - x), min(tile_size, height - y)
        byte = reader.openBytes(0, x, y, w, h)
        tensor = torch.tensor(byte, dtype=torch.uint8).reshape(
            h, w, 3).permute(2, 0, 1)
    javabridge.kill_vm()   

For reference, I also did below, but none of them showed good result.

CDPDisk commented 6 hours ago

@CDPDisk I tried to reproduce the speed above by supressing the log output and got 137.17it/s. Does this reproduce the speed in your environment?

def supress_log():
    # https://github.com/CellProfiler/python-bioformats/issues/137#issuecomment-802313393
    # https://github.com/pskeshu/microscoper/blob/master/microscoper/io.py#L141-L162
    rootLoggerName = javabridge.get_static_field("org/slf4j/Logger",
                                                 "ROOT_LOGGER_NAME",
                                                 "Ljava/lang/String;")

    rootLogger = javabridge.static_call("org/slf4j/LoggerFactory",
                                        "getLogger",
                                        "(Ljava/lang/String;)Lorg/slf4j/Logger;",
                                        rootLoggerName)

    logLevel = javabridge.get_static_field("ch/qos/logback/classic/Level",
                                           "WARN",
                                           "Lch/qos/logback/classic/Level;")

    javabridge.call(rootLogger,
                    "setLevel",
                    "(Lch/qos/logback/classic/Level;)V",
                    logLevel)

def speed_javabridge():
    javabridge.start_vm(class_path=bioformats.JARS)
    supress_log()
    ImageReader = javabridge.JClassWrapper("loci.formats.ImageReader")
    reader = ImageReader()
    reader.setId(wsi)
    xs = range(0, reader.getSizeX(), tile_size)
    ys = range(0, reader.getSizeY(), tile_size)
    width = reader.getSizeX()
    height = reader.getSizeY()

    for x, y in tqdm(list(product(xs, ys)), desc="javabridge"):
        w, h = min(tile_size,  width - x), min(tile_size, height - y)
        byte = reader.openBytes(0, x, y, w, h)
        tensor = torch.tensor(byte, dtype=torch.uint8).reshape(
            h, w, 3).permute(2, 0, 1)
    javabridge.kill_vm()   

For reference, I also did below, but none of them showed good result.

  • download the javabridge with the latest version of Apr 2023
  • inserted some python-bioformats based codes
  • used torch.utils.data.Dataset to run in parallel ( -> 18.85it/s for batch_size=16, 4.24it/s for batch_size=64. both are too faster than 136it/s)
  • referenced a openslide based code in wiki, to reproduce almost the same speed.

Thanks for the code, it did stop outputting the log, but I found it wasn't as fast as openslide. So I tried different things mainly with openslide, including reading different files, different sizes (512, 5120 and 10240, then subdivided into 512), different types of hard disks (SSD and HDD), different ways of reading (dz in the type column is deepzoom from the wiki, openslide is openslide.read_ region), and the read speed it/s for 512 size patches. I will share the result file. record.csv

These are some conclusions with my understand but I don't if it is correct:

This is my test code:

from openslide import OpenSlide
import openslide
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from openslide.deepzoom import DeepZoomGenerator
from tqdm import tqdm
from einops import rearrange
from pathlib import Path
import time
import pandas as pd

class DatasetWithOpenSlide(Dataset):

    def __init__(self, wsi, type, tile_size):
        self.wsi_path = wsi
        self.wsi = openslide.OpenSlide(wsi)
        self.tile_size = tile_size
        self.dz = DeepZoomGenerator(
            self.wsi,
            tile_size=self.tile_size,
            overlap=0)
        self.tiles_x, self.tiles_y = self.dz.level_tiles[-1]
        self.deepest_layer = self.dz.level_count - 1
        self.type = type

    def __len__(self):
        return self.tiles_x * self.tiles_y

    def __getitem__(self, idx):
        idx = np.random.randint(0, self.tiles_x * self.tiles_y)
        x = idx % self.tiles_x
        y = idx // self.tiles_x
        if self.type == 'dz':
            try:
                img = np.array(self.dz.get_tile(self.deepest_layer, (x, y)))
            except:
                img = np.zeros((512, 512, 3))
        elif self.type == 'openslide':
            img = np.array(self.wsi.read_region((x*self.tile_size, y*self.tile_size), 0, (self.tile_size, self.tile_size)))
        else:
            raise ValueError('type should be dz or openslide')

        if self.tile_size!=512:
            try:
                img = torch.tensor(img)
                pad_h = 512 - img.shape[0]%512
                pad_w = 512 - img.shape[1]%512
                padding = (0, 0, 0, pad_w, 0, pad_h)
                img = torch.nn.functional.pad(img, padding, mode='constant', value=0)
                img = rearrange(img.unfold(0, 512, 512).unfold(1, 512, 512), 'nh nw h w c -> (nh nw) h w c')
            except:
                pass
            return img
        return torch.tensor(img).unsqueeze(0)

    def __getstate__(self):
        try:
            self.__dict__.pop('wsi')
        except:
            pass
        try:
            self.__dict__.pop('dz')
        except:
            pass
        state = self.__dict__.copy()
        return state

    def __setstate__(self, state):
        state['wsi'] = openslide.OpenSlide(state['wsi_path'])
        state['dz'] = DeepZoomGenerator(
            state['wsi'],
            tile_size=state['tile_size'],
            overlap=0)
        self.__dict__.update(state)

if __name__ == '__main__':
    tile_size = 512
    batch_size = 1
    test_disk = 'SSD'
    num_worker = 0
    type='openslide'
    record = {'file': [],
              'disk': [],
              'tile_size': [],
              'num_workers': [],
              'speed': [],
              'type':[]}
    for tile_size in [512, 512*10, 512*20]:
        for num_worker in [1, 4]:
            for test_disk in ['HDD', 'SSD']:
                if test_disk == 'HDD':
                    wsi_path_list = Path(r'E:/FUSCC2_data/read_test/').glob('*')
                    # wsi_path_list = Path(r'E:/FUSCC2_data/read_test/')
                else:
                    wsi_path_list = Path(r'C:/Users/Administrator.DESKTOP-JADQ414/Desktop/data_example/swinmil/read_test').glob('*')
                    # wsi_path_list = Path(r'C:/Users/Administrator.DESKTOP-JADQ414/Desktop/data_example/swinmil/read_test')
                    for wsi_path in wsi_path_list:
                        for type in ['dz', 'openslide']:
                            # wsi_path = r'E:\FUSCC2_data\VLM\batch1\NeoPlatform\2022-34841.svs'
                            # wsi_path = wsi_path_list/'T2018-05476.ndpi'
                            def speed_DataLoaderWithOpenSlide():
                                con = 0
                                dataset = DatasetWithOpenSlide(wsi_path, type=type, tile_size=tile_size)
                                # dataset = WSIDataset(wsi, patch_size=tile_size)
                                start_time = time.time()
                                data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_worker)
                                for img in tqdm(data_loader):
                                    con += img.shape[1]

                                print(f'file:{wsi_path.name}, disk: {test_disk}, tile_size:{tile_size}, num_workers: {num_worker}, speed: {con/(time.time()-start_time)}')
                                record['file'].append(wsi_path.name)
                                record['disk'].append(test_disk)
                                record['tile_size'].append(tile_size)
                                record['num_workers'].append(num_worker)
                                record['type'].append(type)
                                record['speed'].append(con/(time.time()-start_time))

    speed_DataLoaderWithOpenSlide()
    pd.DataFrame(record).to_csv('record.csv', index=False)