Lightning-Universe / lightning-flash

Your PyTorch AI Factory - Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains
https://lightning-flash.readthedocs.io
Apache License 2.0
1.74k stars 212 forks source link

Object detection example broken #1529

Closed lorinczszabolcs closed 1 year ago

lorinczszabolcs commented 1 year ago

🐛 Bug - Object detection example broken

The object detection example here is broken. I think the actual problem lies in package version incompatibilities. I installed the latest versions, so I assumed it would work out of the box.

To Reproduce

Used the original script that is shared here, also tried with own custom inputs, but the same error pops up:

/lib/python3.8/site-packages/effdet/anchors.py", line 404, in batch_label_anchors
    box_targets[count:count + steps].view([feat_size[0], feat_size[1], -1]))
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Code sample

Just for the sake of having everythin here, copying the original code:

import flash
from flash.core.data.utils import download_data
from flash.image import ObjectDetectionData, ObjectDetector

# 1. Create the DataModule
# Dataset Credit: https://www.kaggle.com/ultralytics/coco128
download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/")

datamodule = ObjectDetectionData.from_coco(
    train_folder="data/coco128/images/train2017/",
    train_ann_file="data/coco128/annotations/instances_train2017.json",
    val_split=0.1,
    transform_kwargs={"image_size": 512},
    batch_size=4,
)

# 2. Build the task
model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=512)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Detect objects in a few images!
datamodule = ObjectDetectionData.from_files(
    predict_files=[
        "data/coco128/images/train2017/000000000625.jpg",
        "data/coco128/images/train2017/000000000626.jpg",
        "data/coco128/images/train2017/000000000629.jpg",
    ],
    transform_kwargs={"image_size": 512},
    batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)

# 5. Save the model!
trainer.save_checkpoint("object_detection_model.pt")

Expected behavior

The training should not fail with the mentioned error.

Environment

Borda commented 1 year ago

@lorinczszabolcs would you be interested in debugging it further and eventually sending a fix? :flamingo:

lorinczszabolcs commented 1 year ago

I retried it now, and now a different error is given:

---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
[<ipython-input-2-430f4efffa8e>](https://localhost:8080/#) in <cell line: 1>()
----> 1 import flash
      2 from flash.core.data.utils import download_data
      3 from flash.image import ObjectDetectionData, ObjectDetector
      4 
      5 # 1. Create the DataModule

2 frames
[/usr/local/lib/python3.10/dist-packages/flash/core/data/utils.py](https://localhost:8080/#) in <module>
     20 import requests
     21 import urllib3
---> 22 from pytorch_lightning.utilities.apply_func import apply_to_collection
     23 from torch import nn
     24 from tqdm.auto import tqdm as tq

ModuleNotFoundError: No module named 'pytorch_lightning.utilities.apply_func'

Unfortunately I won't have the time to look into it in detail, but it still seems like it is some package version related issue (possibly caused by pytorch-lightning 2.0.0 (lightning 2.0.0) release.

Borda commented 1 year ago

possibly caused by pytorch-lightning 2.0.0 (lightning 2.0.0) release

Flash has pin dependency bellow 2.0

lorinczszabolcs commented 1 year ago

Upon installing with pip install lightning-flash in a clean environment, it shows me that pytorch-lightning==2.0.2 gets installed. The following message is shown meanwhile: Collecting pytorch-lightning>=1.3.6 (from lightning-flash), indicating that the pinned dependency is >=1.3.6. even though I can see that is not the case here: https://github.com/Lightning-Universe/lightning-flash/blob/14c27555e9cba706a2f20caeaf787c0e116ef1f4/requirements.txt#L8

Any idea why that's happening?

Borda commented 1 year ago

Upon installing with pip install lightning-flash in a clean environment, it shows me that pytorch-lightning==2.0.2

most likely this pin adjustment was not yet released, so pls install ut from source for now: pip install https://github.com/Lightning-Universe/lightning-flash/archive/refs/heads/master.zip

Borda commented 1 year ago

shall be fixed in https://github.com/Lightning-Universe/lightning-flash/releases/tag/0.8.2