constantinpape / torch-em

Deep-learning based semantic and instance segmentation for 3D Electron Microscopy and other bioimage analysis problems based on pytorch.
MIT License
69 stars 17 forks source link

DOC Build Status DOI Anaconda-Server Badge

torch-em

Deep-learning based semantic and instance segmentation for 3D Electron Microscopy and other bioimage analysis problems based on PyTorch. Any feedback is highly appreciated, just open an issue!

Highlights:

Design:

# train a 2d U-Net for foreground and boundary segmentation of nuclei
# using data from https://github.com/mpicbg-csbd/stardist/releases/download/0.1.0/dsb2018.zip

import torch
import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets import get_dsb_loader

model = UNet2d(in_channels=1, out_channels=2)

# transform to go from instance segmentation labels
# to foreground/background and boundary channel
label_transform = torch_em.transform.BoundaryTransform(
    add_binary_target=True, ndim=2
)

# training and validation data loader
data_path = "./dsb"  # the training data will be downloaded and saved here
train_loader = get_dsb_loader(
    data_path, 
    patch_shape=(1, 256, 256),
    batch_size=8,
    split="train",
    download=True,
    label_transform=label_transform
)
val_loader = get_dsb_loader(
    data_path, 
    patch_shape=(1, 256, 256),
    batch_size=8,
    split="test",
    label_transform=label_transform
)

# the trainer object that handles the training details
# the model checkpoints will be saved in "checkpoints/dsb-boundary-model"
# the tensorboard logs will be saved in "logs/dsb-boundary-model"
trainer = torch_em.default_segmentation_trainer(
    name="dsb-boundary-model",
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    learning_rate=1e-4,
    device=torch.device("cuda")
)
trainer.fit(iterations=5000)

# export bioimage.io model format
from glob import glob
import imageio
from torch_em.util import export_bioimageio_model

# load one of the images to use as reference image image
# and crop it to a shape that is guaranteed to fit the network
test_im = imageio.imread(glob(f"{data_path}/test/images/*.tif")[0])[:256, :256]

export_bioimageio_model("./checkpoints/dsb-boundary-model", "./bioimageio-model", test_im)

For a more in-depth example, check out one of the example notebooks:

Installation

From mamba

mamba is a drop-in replacement for conda, but much faster. While the steps below may also work with conda, it's highly recommended using mamba. You can follow the instructions here to install mamba.

You can install torch_em from conda-forge:

mamba install -c conda-forge torch_em

Please check out pytorch.org for more information on how to install a PyTorch version compatible with your system.

From source

It's recommmended to set up a conda environment for using torch_em. Two conda environment files are provided: environment_cpu.yaml for a pure CPU set-up and environment_gpu.yaml for a GPU set-up. If you want to use the GPU version, make sure to set the correct CUDA version for your system in the environment file, by modifiying this-line.

You can set up a conda environment using one of these files like this:

mamba env create -f <ENV>.yaml -n <ENV_NAME>
mamba activate <ENV_NAME>
pip install -e .

where <ENV>.yaml is either environment_cpu.yaml or environment_gpu.yaml.

Features

Command Line Scripts

A command line interface for training, prediction and conversion to the bioimage.io modelzoo format wll be installed with torch_em:

For more details run <COMMAND> -h for any of these commands. The folder scripts/cli contains some examples for how to use the CLI.

Note: this functionality was recently added and is not fully tested.

Research Projects using torch-em