Lightning-Universe / lightning-flash

Your PyTorch AI Factory - Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains
https://lightning-flash.readthedocs.io
Apache License 2.0
1.74k stars 213 forks source link

Semantic Segmentation target masks broken >0.7.5 #1489

Closed newzealandpaul closed 1 year ago

newzealandpaul commented 1 year ago

🐛 Bug

The switch to albumentation in newer releases of lightning-flash seem to have broken transformation of segmentation targets.

This is what I expect masks to look like (screenshot showing below code sample running on 0.7.5):

2022-11-25 25-10-10-49--259_chrome

This is what it looks like on the latest release (0.8.1):

2022-11-25 25-13-13-32--468_chrome

To Reproduce

Run the below sample with lightning-flash=0.7.5 and lightning-flash=0.8.1 and compare behavior.

Code sample

import torch

import flash
from flash.core.data.utils import download_data
from flash.image import SemanticSegmentation, SemanticSegmentationData

import matplotlib.pyplot as plt
import numpy as np

# 1. Create the DataModule
# The data was generated with the  CARLA self-driving simulator as part of the Kaggle Lyft Udacity Challenge.
# More info here: https://www.kaggle.com/kumaresanmanickavelu/lyft-udacity-challenge
# download_data(
#     "https://github.com/ongchinkiat/LyftPerceptionChallenge/releases/download/v0.1/carla-capture-20180513A.zip",
#     "./data",
# )

datamodule = SemanticSegmentationData.from_folders(
    train_folder="data/CameraRGB",
    train_target_folder="data/CameraSeg",
    val_split=0.1,
    transform_kwargs=dict(image_size=(256, 256)),
    num_classes=21,
    batch_size=4,
)

# 2. Build the task
model = SemanticSegmentation(
    backbone="mobilenetv3_large_100",
    head="fpn",
    num_classes=datamodule.num_classes,
)

n = 3
fig, axarr = plt.subplots(ncols=2, nrows=n, figsize=(8, 4*n))

for batch in datamodule.train_dataloader():
    print(batch.keys())
    for i in range(n):
        segm = batch['target'][i]
        print(segm.shape)
        img = np.rollaxis(batch['input'][i].numpy(), 0, 3)
        axarr[i, 0].imshow(img)
        axarr[i, 1].imshow(segm)
    break

Environment

noname202 commented 1 year ago

I can confirm this issue. Just spent a significant amount of time trying to figure out if there is anything wrong with my code. Any estimations when this is going to be fixed?

Borda commented 1 year ago

Hi, @newzealandpaul @noname202 we are sorry for this bug which seems to be very critical for segmentations... :( Would you be interested in trying to debug it and I belive that @ethanwharris could eventually give a hand... :rabbit: