Lightning-Universe / lightning-flash

Your PyTorch AI Factory - Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains
https://lightning-flash.readthedocs.io
Apache License 2.0
1.74k stars 213 forks source link

Instance Segmentation Example Broken #1505

Closed gatordevin closed 1 year ago

gatordevin commented 1 year ago

🐛 Bug

The instance Segmentation example provided is not working. I have tried using the one on the docs and grabbing the latest example from GitHub. I also noticed that the link for the dataset used was outdated and did not return valid data. So I found the link to the new data on their website and used that. Essentially it appears the model is training but returns empty results as if nothing was detected. It seems to indicate either the pre-trained weights are broken or the output of the model is broken.

https://github.com/Lightning-AI/lightning-flash/blob/cf969bcbab349c027f208168973110544c672358/flash_examples/instance_segmentation.py

To Reproduce

Here is a description of my environment. Windows 11 I followed the following install steps in a virtual environment

Should be python version 3.8.15

mamba create -n flash python==3.8.15
mamba activate flash

mamba install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge
pip install mmdet
pip install mmcv

Install right version of torch metrics and pytorch lightning and setuptools

pip install lightning-flash[image]
pip install pre-commit
pip install pysolotools

pip install icevision[all]

pip install sahi==0.10.7

pip install icedata

I noticed many newer versions of libraries appeared to have issues with either lightning flash or ice vision or both. It took me a while to find older package versions that still maintained compatibility with everything.

Code sample

from functools import partial

import flash
from flash.core.utilities.imports import example_requires
from flash.image import InstanceSegmentation, InstanceSegmentationData

example_requires("image")

import icedata  # noqa: E402

# 1. Create the DataModule
data_dir = icedata.pets.load_data()

datamodule = InstanceSegmentationData.from_icedata(
    train_folder=data_dir,
    val_split=0.1,
    transform_kwargs=dict(image_size=(256, 256)),
    parser=partial(icedata.pets.parser, mask=True),
    batch_size=4,
)

InstanceSegmentation

# 2. Build the task
model = InstanceSegmentation(
    head="mask_rcnn",
    backbone="resnet18_fpn",
    num_classes=datamodule.num_classes,
)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1, accelerator='gpu', devices=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Detect objects in a few images!
datamodule = InstanceSegmentationData.from_files(
    predict_files=[
        str(data_dir / "images/yorkshire_terrier_9.jpg"),
        str(data_dir / "images/yorkshire_terrier_12.jpg"),
        str(data_dir / "images/yorkshire_terrier_13.jpg"),
    ],
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("instance_segmentation_model.pt")

Expected behavior

I would expect this to proper output data as indicated by the tutorial in the docs. There should be detections present even after 1 epoch of transfer learning.

ahmetgunduz commented 1 year ago

Do we have any update on this?

ahmetgunduz commented 1 year ago

@gatordevin can you please share your pip freeze output here? I would like to downgrade my packages as well.

Borda commented 1 year ago

Tested with the latest https://github.com/Lightning-Universe/lightning-flash/releases/tag/0.8.2, and seems all is working/running without issues...