SHI-Labs / OneFormer

OneFormer: One Transformer to Rule Universal Image Segmentation, arxiv 2022 / CVPR 2023
https://praeclarumjj3.github.io/oneformer
MIT License
1.41k stars 128 forks source link

Correct dataset format to fine-tune with Hugging Face? #53

Open nikolaydyankov opened 1 year ago

nikolaydyankov commented 1 year ago

Hi, first of all thank you for sharing your awesome work.

I'm trying to fine-tune the model for instance segmentation with a custom dataset that I have locally in COCO format. The issue that I'm having is that I don't know how exactly to convert the segmentation polygon masks to pixel_values and task_inputs that the model's forward function expects.

This is my data loader script:

import datasets
import os
from pycocotools.coco import COCO
from pathlib import Path

class COCODataset(datasets.GeneratorBasedBuilder):
    def _info(self):
        return datasets.DatasetInfo(
            description="COCO dataset",
            features=datasets.Features({
                # "pixel_values": ...
                # "task_inputs": ...
                "image": datasets.Image(),
                "annotations": datasets.Sequence({
                    "id": datasets.Value("int32"),
                    "image_id": datasets.Value("int32"),
                    "category_id": datasets.Value("int32"),
                    "area": datasets.Value("int32"),
                    "iscrowd": datasets.Value("int32"),
                    "bbox": datasets.Sequence(datasets.Value("float32")),
                    "attributes": {
                        "occluded": datasets.Value("bool"),
                    },
                    "segmentation": datasets.Sequence(datasets.Sequence(datasets.Value("float32"))),
                })
            }),       
        )
    def _split_generators(self, dl_manager):
        instances_train_path = dl_manager.download(os.path.join(self.config.data_dir, "annotations/instances_train.json"))
        instances_val_path = dl_manager.download(os.path.join(self.config.data_dir, "annotations/instances_val.json"))

        return [
            datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"images": instances_train_path}),
            datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"images": instances_val_path}),
        ]
    def _generate_examples(self, images):
        coco = COCO(images)

        for image_id in coco.imgs:
            image = coco.loadImgs(image_id)[0]
            annotations = coco.loadAnns(coco.getAnnIds(image_id))

            # Load the image content as bytes
            image_path = os.path.join(self.config.data_dir, "images", image["file_name"])
            image_content = Path(image_path).read_bytes()

            yield image_id, {
                "image": image_content,
                "annotations": annotations,
                # "pixel_values": ...,
                # "task_inputs": ...
            }

I know that I'm supposed to use OneFormerProcessor, but the examples provided are only for inference and don't specify how to process input masks. What exactly am I supposed to do in the _generate_examples method? Any tips are greatly appreciated!

Just for reference, here is my train script as well:

import numpy as np
import evaluate
from transformers import OneFormerForUniversalSegmentation, TrainingArguments, Trainer
import datasets
import os

script_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.join(script_dir, "..", "data/datasets/archviz-600-v2-coco")

ds = datasets.load_dataset(os.path.join(script_dir, "dataset_loader.py"), data_dir=data_dir)

print("Length of train dataset:", len(ds['train']))
print("Length of validation dataset:", len(ds['validation']))

model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_cityscapes_swin_large")
training_args = TrainingArguments(output_dir=os.path.join(script_dir, 'output'), evaluation_strategy="epoch")
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds['train'],
    eval_dataset=ds['validation'],
    compute_metrics=compute_metrics,
)

trainer.train()

And this is the output:

Length of train dataset: 472
Length of validation dataset: 118

/usr/local/lib/python3.8/dist-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
  0%|                                                                                                                                                              

| 0/177 [00:00<?, ?it/s]Traceback (most recent call last):
  File "oneformer-hugging/train.py", line 32, in <module>
    trainer.train()
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 1662, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.8/dist-packages/transformers/trainer.py", line 1899, in _inner_training_loop
    for step, inputs in enumerate(epoch_iterator):
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 635, in __next__
    data = self._next_data()
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/dataloader.py", line 679, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/usr/local/lib/python3.8/dist-packages/torch/utils/data/_utils/fetch.py", line 56, in fetch
    data = self.dataset.__getitems__(possibly_batched_index)
  File "/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py", line 2782, in __getitems__
    batch = self.__getitem__(keys)
  File "/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py", line 2778, in __getitem__
    return self._getitem(key)
  File "/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py", line 2762, in _getitem
    pa_subtable = query_table(self._data, key, indices=self._indices if self._indices is not None else None)
  File "/usr/local/lib/python3.8/dist-packages/datasets/formatting/formatting.py", line 578, in query_table
    _check_valid_index_key(key, size)
  File "/usr/local/lib/python3.8/dist-packages/datasets/formatting/formatting.py", line 531, in _check_valid_index_key
    _check_valid_index_key(int(max(key)), size=size)
  File "/usr/local/lib/python3.8/dist-packages/datasets/formatting/formatting.py", line 521, in _check_valid_index_key
    raise IndexError(f"Invalid key: {key} is out of bounds for size {size}")
IndexError: Invalid key: 375 is out of bounds for size 0
  0%|