juglab / labkit-ui

Advanced Tool for Labeling And Segmentation
BSD 2-Clause "Simplified" License
33 stars 20 forks source link

SNT Integration: Pre-load LabkitProjectFrame with pre-defined labels #112

Open tferr opened 9 months ago

tferr commented 9 months ago

@maarzt ,

We've been hacking ways to integrate both Labkit and TWS into SNT. Right now we can import and train models from paths, but a more useful command would to to "Send neurites" to Labkit, so that the traced paths of selected neurites could be pre-loaded into a Labkit instance as labels.

This would allow for pixel-perfect labels along thin neurites, which - in our experience - significantly improves the training relatively to freehand annotations.

In TWS, we convert Paths into ROIs and feed those to its GUI at startup. It is a bit clunky because we need to use reflection, but it works quite well (will probably open a PR there to formalize this feature). Labkit can hold much larger datasets and would be a pity to be restricted to TWS.

Do you have pointers on how to attempt this in Labkit? I've looked at LabelBrushController briefly, but did not find it useful. Did I miss something? I am convinced that a method that would allow to 'paint' programmatically would work. A method that converts poly/freehand-lines into labels would be best.

NB: We can convert selected paths into a digitized skeleton mask and train a model with it, but that is not interactive, and defeats a bit what we are trying to achieve.

maarzt commented 8 months ago

Ok, interesting. In Labkit I call a collection of ROIs a "labeling". So if I understand you correctly you would like to start Labkit with a pre-initialized "labeling". That is definitely possible. Here is an example for you:

import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;
import java.util.Arrays;

import javax.swing.JFrame;
import javax.swing.WindowConstants;

import ij.ImagePlus;
import net.imagej.ImgPlus;
import net.imglib2.FinalInterval;
import net.imglib2.RandomAccess;
import net.imglib2.img.VirtualStackAdapter;
import net.imglib2.img.display.imagej.ImageJFunctions;
import net.imglib2.roi.labeling.LabelingType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.FloatType;
import org.scijava.Context;
import sc.fiji.labkit.ui.SegmentationComponent;
import sc.fiji.labkit.ui.inputimage.DatasetInputImage;
import sc.fiji.labkit.ui.labeling.Label;
import sc.fiji.labkit.ui.labeling.Labeling;
import sc.fiji.labkit.ui.models.DefaultSegmentationModel;
import sc.fiji.labkit.ui.segmentation.SegmentationTool;
import sc.fiji.labkit.ui.segmentation.Segmenter;

/**
 * An example how one can start a "custom" Labkit with pre-initialized labeling.
 */
public class LabkitExample {

    public static void main(String... args) {
        // open an image and create a segmentation model
        Context context = new Context();
        ImagePlus imagePlus = new ImagePlus("https://imagej.nih.gov/ij/images/t1-head.zip");
        ImgPlus<?> image = VirtualStackAdapter.wrap(imagePlus);
        image.setChannelMinimum(0, 0 );     // set brightness and contrast for displaying the given image
        image.setChannelMaximum(0, 1000 );
        DefaultSegmentationModel segmentationModel =
            new DefaultSegmentationModel(context, new DatasetInputImage(image));

        // Initialize an empty labeling of correct size
        FinalInterval imageSize = new FinalInterval(image);
        Labeling labeling = Labeling.createEmpty(Arrays.asList("S1", "S2", "background"), imageSize);

        // add one straight line
        RandomAccess<LabelingType<Label>> randomAccess = labeling.randomAccess();
        LabelingType<Label> color = randomAccess.get().createVariable();
        color.add(labeling.getLabel("S1"));
        for (int i = 0; i < 200; i++) {
            randomAccess.setPosition(i, 0); // set X
            randomAccess.setPosition(i, 1); // set Y
            randomAccess.setPosition(i, 2); // set Z
            randomAccess.get().set( color );
        }

        // Create the gui components
        segmentationModel.imageLabelingModel().labeling().set(labeling);
        JFrame frame = new JFrame();
        SegmentationComponent segmentationComponent =
            new SegmentationComponent(frame, segmentationModel, false);
        frame.add(segmentationComponent);
        frame.setSize(1000, 600);
        frame.setDefaultCloseOperation( WindowConstants.DISPOSE_ON_CLOSE );
        frame.addWindowListener(new WindowAdapter() {

            @Override
            public void windowClosed(WindowEvent e) {
                showResults( segmentationModel, image );
                segmentationComponent.close(); // don't forget to close the segmentation component if you are done.
            }
        });
        frame.setVisible(true);
    }

    private static void showResults(DefaultSegmentationModel segmentationModel, ImgPlus<?> image) {
        // Compute and show the complete segmentation and probability map.
        Segmenter segmenter = segmentationModel.segmenterList().segmenters().get().get(0);
        SegmentationTool segmentationTool = new SegmentationTool(segmenter);
        ImgPlus<UnsignedByteType> segmentation = segmentationTool.segment(image);
        ImageJFunctions.show(segmentation).setDisplayRange(0, 3);
        ImgPlus<FloatType> probabilityMap = segmentationTool.probabilityMap(image);
        ImageJFunctions.show(probabilityMap).setDisplayRange(0, 1);
    }

}

The example start's Labkit on a given image and also shows a labeling that is pre-initialized with a straight line.

Additionally if you train a pixel classifier and close the window. The results will be shown in two new windows.

@tferr I hope this example helps you!

tferr commented 8 months ago

Thanks @maarzt , this is really helpful. Thanks for the detailed example. It should be straightforward to adapt this.

Should I assume that the ZCT positions are indexed to the ImgPlus<T> Axes?, i.e.:

final long zLen = image.dimension(image.dimensionIndex(Axes.Z));
final long cLen = image.dimension(image.dimensionIndex(Axes.CHANNEL));
final long tLen = image.dimension(image.dimensionIndex(Axes.TIME));
for (int i = 0; i < 200; i++)  {
    randomAccess.setPosition(i, image.dimensionIndex(Axes.X)); // set X
    randomAccess.setPosition(i, image.dimensionIndex(Axes.Y)); // set Y
    if (zLen > 1)
        randomAccess.setPosition(i, image.dimensionIndex(Axes.Z)); // set Z
    if (cLen > 1)
        randomAccess.setPosition(i, image.dimensionIndex(Axes.CHANNEL)); // set C
    if (tLen > 1)
        randomAccess.setPosition(i, image.dimensionIndex(Axes.TIME)); // set T
    randomAccess.get().set(color);
}

Or is there some other convention in place in which the axis index is always known, eg. X=0;Y=1;Z=2;C=3, etc.?

maarzt commented 8 months ago

Labkit can deal with varying Axes order. But yes you need to set the axes types correctly in the ImgPlus<T>.

For the Labeling it's a different story. I think there the axes order is fixed to XYZT. The labeling has no channel axis. Every channel has the same "label".

tferr commented 8 months ago

The labeling has no channel axis. Every channel has the same "label".

Why is that? Are multichannel images not formally supported? If you have an image of two fluorophores (A, and B), and want to train the model with 3 classes ("foreground a", "foreground b", and "background"), how would one go about it? I see indeed that I trigger this error, when I use a 2D multichannel image: The error says that the image has 2 dimensions(!) (and my labeling 3), so I guess the channel dimension as been simplified internally by LabKit !? Something akin to an RGB conversion?

Apart from that, I think I got everything to work well with 'normal' grayscale images 2D, 3D, with or without a time axis (will link the code here as soon as I have time to check that everything is indeed working)

tferr commented 8 months ago

Multichannel-support aside: the code is here and I added a dedicated page to the documentation. Hopefully I got the main differences right between Labkit/TWS. Do let me know if otherwise.

maarzt commented 8 months ago

Your code looks good. Does it work as intended?

Regarding your earlier question about multichannel images. Yes the channel axis is treated specially. But the details are hard to explain. The "LabelingType" class that is used for the labeling does a lot of tricks. It is weakly similar to an RGB type but rather than representing color, it can represent a set of labels. So technically a pixel can be annotated as "foreground", "background", {"foreground", "background"} or an empty set.

I modified the example to show how to use Labkit on a two channel image:

import java.awt.event.WindowAdapter;
import java.awt.event.WindowEvent;
import java.util.Arrays;

import javax.swing.JFrame;
import javax.swing.WindowConstants;

import ij.ImagePlus;
import net.imagej.ImgPlus;
import net.imagej.axis.Axes;
import net.imglib2.Interval;
import net.imglib2.RandomAccess;
import net.imglib2.img.VirtualStackAdapter;
import net.imglib2.img.display.imagej.ImageJFunctions;
import net.imglib2.roi.labeling.LabelingType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Intervals;
import net.imglib2.view.Views;
import org.scijava.Context;
import sc.fiji.labkit.ui.SegmentationComponent;
import sc.fiji.labkit.ui.inputimage.DatasetInputImage;
import sc.fiji.labkit.ui.labeling.Label;
import sc.fiji.labkit.ui.labeling.Labeling;
import sc.fiji.labkit.ui.models.DefaultSegmentationModel;
import sc.fiji.labkit.ui.segmentation.SegmentationTool;
import sc.fiji.labkit.ui.segmentation.Segmenter;
import sc.fiji.labkit.ui.utils.DimensionUtils;

/**
 * An example how one can start a "custom" Labkit with pre-initialized labeling.
 */
public class LabkitExample {

    public static void main(String... args) {
        // open an image and create a segmentation model
        Context context = new Context();
        ImagePlus imagePlus = new ImagePlus("https://imagej.net/ij/images/confocal-series.zip");
        ImgPlus<?> image = VirtualStackAdapter.wrap(imagePlus);
        DefaultSegmentationModel segmentationModel =
            new DefaultSegmentationModel(context, new DatasetInputImage(image));

        // Initialize an empty labeling of correct size
        int channelDimension = image.dimensionIndex(Axes.CHANNEL);
        Interval imageSize = DimensionUtils.intervalRemoveDimension(image, channelDimension);
        Labeling labeling = Labeling.createEmpty(Arrays.asList("S1", "S2", "background"), imageSize);

        // add some labels
        RandomAccess<LabelingType<Label>> randomAccess = labeling.randomAccess();
        LabelingType<Label> color = randomAccess.get().createVariable();
        color.clear();
        color.add(labeling.getLabel("S1"));
        Views.interval(labeling, Intervals.createMinSize(120,205,11, 30, 3, 5)).forEach(pixel -> pixel.set(color));
        color.clear();
        color.add(labeling.getLabel("S2"));
        Views.interval(labeling, Intervals.createMinSize(250,139,11, 3, 30, 5)).forEach(pixel -> pixel.set(color));
        color.clear();
        color.add(labeling.getLabel("background"));
        Views.interval(labeling, Intervals.createMinSize(34,28,11, 20, 4, 5)).forEach(pixel -> pixel.set(color));

        // Create the gui components
        segmentationModel.imageLabelingModel().labeling().set(labeling);
        JFrame frame = new JFrame();
        SegmentationComponent segmentationComponent =
            new SegmentationComponent(frame, segmentationModel, false);
        frame.add(segmentationComponent);
        frame.setSize(1000, 600);
        frame.setDefaultCloseOperation( WindowConstants.DISPOSE_ON_CLOSE );
        frame.addWindowListener(new WindowAdapter() {

            @Override
            public void windowClosed(WindowEvent e) {
                showResults( segmentationModel, image );
                segmentationComponent.close(); // don't forget to close the segmentation component if you are done.
            }
        });
        frame.setVisible(true);
    }

    private static void showResults(DefaultSegmentationModel segmentationModel, ImgPlus<?> image) {
        Segmenter segmenter = segmentationModel.segmenterList().segmenters().get().get(0);
        SegmentationTool segmentationTool = new SegmentationTool(segmenter);
        ImgPlus<UnsignedByteType> segmentation = segmentationTool.segment(image);
        ImageJFunctions.show(segmentation).setDisplayRange(0, 3);
        ImgPlus<FloatType> probabilityMap = segmentationTool.probabilityMap(image);
        ImageJFunctions.show(probabilityMap).setDisplayRange(0, 1);
    }

}

The biggest difference is in these two lines:

        int channelDimension = image.dimensionIndex(Axes.CHANNEL);
        Interval imageSize = DimensionUtils.intervalRemoveDimension(image, channelDimension);
tferr commented 8 months ago

@maarzt , thanks! I've incorporated this, and tweaked the documentation. Everything seems to be working well, on a couple of tests with 2D and 3D multichannel images. I consider the SNT <> Labkit bridge finalized, so feel free to close this. Thanks a lot for the thoughtful examples