NielsRogge / Transformers-Tutorials

This repository contains demos I made with the Transformers library by HuggingFace.
MIT License
8.45k stars 1.32k forks source link

AttributeError: 'function' object has no attribute 'items' while finetuning segment anything model #368

Closed Snimm closed 7 months ago

Snimm commented 7 months ago

Description:

I am encountering an issue while fine-tuning Segment Anything model on a fabric defect dataset. I am closely following this tutorial and have shared my full code on Kaggle.


Code Snippet:

I am using a custom dataset class SAMDataset to load images and masks. Here's a snippet of the code:

class SAMDataset(Dataset):
    def __init__(self, image_paths, target_paths, processor):
        self.image_paths = image_paths
        self.target_paths = target_paths
        self.processor = processor

    def __getitem__(self, index):
        image = Image.open(self.image_paths[index])
        mask = Image.open(self.target_paths[index])
        mask = np.array(mask)
        prompt = get_bounding_box(mask)
        inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")
        # remove batch dimension which the processor adds by default
        inputs = {k:v.squeeze(0) for k,v in inputs.items()}
        # add ground truth segmentation
        inputs["ground_truth_mask"] = mask
        return input

    def __len__(self):
        return len(self.image_paths)

from transformers import SamProcessor
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

train_dataset = SAMDataset(image_path_list, mask_path_list, processor)
example = train_dataset[0]
for k, v in example.items():
    print(k, v.shape)

Error:

I am encountering an AttributeError when trying to print the keys and shapes of tensors in the example variable:

AttributeError                            Traceback (most recent call last)
Cell In[36], line 2
      1 example = train_dataset[0]
----> 2 for k,v in example.items():
      3   print(k,v.shape)

AttributeError: 'function' object has no attribute 'items'

Observation:

In the tutorial, the example variable is a dictionary of tensors, but in my code, it seems to be a <bound method Kernel.raw_input> function object. I suspect the issue might be related to using two arrays instead of a dataset as input to the SAMDataset class.


Question:

Is using two arrays instead of a dataset as input to the SAMDataset class causing this issue? I have checked all variables inside the class, and they are all of the correct type. Why am I encountering this error?


Additional Information:

Any insights or suggestions would be greatly appreciated. Thank you!