Lightning-Universe / lightning-flash

Your PyTorch AI Factory - Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains
https://lightning-flash.readthedocs.io
Apache License 2.0
1.74k stars 212 forks source link

ImageEmbedder example not working #1260

Closed sirtris closed 2 years ago

sirtris commented 2 years ago

🐛 ImageEmbedder example not working

I tried to run the flash_examples/image_embedder.py script but it crashes. It gives me the error message: AssertionError: Incorrect embedding shape: torch.Size([16, 8192]) but expected Nx128

To Reproduce

run the example script

Code sample

import torch
from torchvision.datasets import CIFAR10

import flash
from flash.core.data.utils import download_data
from flash.image import ImageClassificationData, ImageEmbedder

# 1. Download the data and prepare the datamodule
datamodule = ImageClassificationData.from_datasets(
    train_dataset=CIFAR10(".", download=True),
    batch_size=16,
)

# 2. Build the task
embedder = ImageEmbedder(
    backbone="resnet",
    training_strategy="barlow_twins",
    head="barlow_twins_head",
    pretraining_transform="barlow_twins_transform",
    training_strategy_kwargs={"latent_embedding_dim": 128},
    pretraining_transform_kwargs={"size_crops": [196]},
)

# 3. Create the trainer and pre-train the encoder
# use accelerator='ddp' when using GPU(s),
# i.e. flash.Trainer(max_epochs=3, gpus=1, accelerator='ddp')
trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count())
trainer.fit(embedder, datamodule=datamodule)

# 4. Save the model!
trainer.save_checkpoint("image_embedder_model.pt")

# 5. Download the downstream prediction dataset and generate embeddings
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

datamodule = ImageClassificationData.from_files(
    predict_files=[
        "data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg",
        "data/hymenoptera_data/predict/2039585088_c6f47c592e.jpg",
    ]
)
embeddings = trainer.predict(embedder, datamodule=datamodule)

# list of embeddings for images sent to the predict function
print(embeddings)

Expected behavior

Run without error.

Environment

THX

danacity commented 2 years ago

I was trying to run it on colab and I got a different error 13 frames

/usr/local/lib/python3.7/dist-packages/flash/image/embedding/vissl/hooks.py in on_start(self, task) 50 # get around vissl distributed training by setting MockTask flags 51 num_nodes = lightning_module.trainer.num_nodes ---> 52 accelerators_ids = accelerator_connector(lightning_module.trainer).parallel_device_ids 53 accelerator_per_node = len(accelerators_ids) if accelerators_ids is not None else 1 54 task.world_size = num_nodes * accelerator_per_node

AttributeError: 'AcceleratorConnector' object has no attribute 'parallel_device_ids'

ethanwharris commented 2 years ago

Hi @sirtris @Daniel-R-Armstrong Thanks for reporting these! They should both be fixed on latest master. We will have a patch release later today that includes the fixes. I'll report back here once the release is out for you to try :smiley:

Thanks!

ethanwharris commented 2 years ago

Our release is out! @sirtris @Daniel-R-Armstrong This should now be working for you if you install flash with:

pip install 'lightning-flash[image]==0.7.2'

Also note that we've updated our docs to warn that multi-gpu ssl training is not currently supported: https://lightning-flash.readthedocs.io/en/stable/reference/image_embedder.html

Hope that helps :smiley:

sirtris commented 2 years ago

Awesome
Thanks I gave it a quick run with Trainer(..., max_steps=16) and it seems to work.