roboflow / notebooks

Examples and tutorials on using SOTA computer vision models and techniques. Learn everything from old-school ResNet, through YOLO and object-detection transformers like DETR, to the latest models like Grounding DINO and SAM.
https://roboflow.com/models
4.89k stars 759 forks source link

issues with Segformer loading the trained model checkpoint file. #143

Closed yashmewada9618 closed 1 year ago

yashmewada9618 commented 1 year ago

Search before asking

Notebook name

train-segformer-segmentation-on-custom-data.ipynb

Bug

Hii, I am trying to replicate the output of the segformer on my machine and I rearranged the code for my dataset. I input this code using the RUGD dataset. The training was done without errors, and the ckpt file was also saved.

I am trying to load the checkpoint file and I am unsuccessful in that.

Also below is my edited __init__() function of the SegformerFinetuner class. To the best knowledge, the error seems to be in the inheritance implementation (correct me if I am wrong).

After these alterations in the code, I got the below error.

Traceback (most recent call last):
  File "inf_roboflow.py", line 349, in <module>
    segformer_finetuner = SegformerFinetuner(id2label=id2label,lable2id=label2id,color_map=color_map).load_from_checkpoint(path)
  File "/home/yash/.local/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 139, in load_from_checkpoint
    return _load_from_checkpoint(  # type: ignore[return-value]
  File "/home/yash/.local/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 188, in _load_from_checkpoint
    return _load_state(cls, checkpoint, strict=strict, **kwargs)
  File "/home/yash/.local/lib/python3.8/site-packages/pytorch_lightning/core/saving.py", line 234, in _load_state
    obj = cls(**_cls_kwargs)
TypeError: __init__() missing 2 required positional arguments: 'id2label' and 'lable2id'

Also, the implementation of the below method doesn't work as it looks for the model_state_dict key in the ckpt file and PL saves it as state_dict and changing the key name doesn't seem to help.

model = SegformerForSemanticSegmentation.from_pretrained("nvidia/mit-b0",
                                                            num_labels=24, 
                                                            id2label=id2label, 
                                                            label2id=label2id,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load("/media/yash/T7/Fine_Tune_Segformer/checkpoints/b0/epoch_15.pth")
    print("[+] checkpoint: ",list(checkpoint.keys()))
    exit()
    # print(checkpoint['model_state_dict'])
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)

    model.eval()
Thank you for your help.

Environment

Minimal Reproducible Example

segformer_finetuner = SegformerFinetuner(id2label,label2id,color_map=color_map).load_from_checkpoint(path)

SegformerFunetuner class init function.

def __init__(self, id2label, lable2id, train_dataloader=None, val_dataloader=None, test_dataloader=None, metrics_interval=100,color_map=None):

        super(SegformerFinetuner, self).__init__()
        self.id2label = id2label
        self.metrics_interval = metrics_interval
        self.train_dl = train_dataloader
        self.val_dl = val_dataloader
        self.test_dl = test_dataloader
        self.color_seg = color_map
        self.label2id = lable2id

        self.num_classes = len(id2label.keys())

        print("-----------------------------------------------------------------")
        print("[+] Number of Classes: ", self.num_classes)
        print("[+] id2labels: ",self.id2label)
        print("[+] lables2id: ",self.label2id)
        print("-----------------------------------------------------------------")

        self.model = SegformerForSemanticSegmentation.from_pretrained(
                                                                    "nvidia/segformer-b0-finetuned-ade-512-512", 
                                                                    return_dict=False, 
                                                                    num_labels=self.num_classes,
                                                                    id2label=self.id2label,
                                                                    label2id=self.label2id,
                                                                    ignore_mismatched_sizes=True,
                                                                    )
        self.train_mean_iou = load_metric("mean_iou")
        self.val_mean_iou = load_metric("mean_iou")
        self.test_mean_iou = load_metric("mean_iou")
        self.save_hyperparameters()

Main loop

if name == "main": color_map = pd.read_csv('/media/yash/T7/Fine_Tune_Segformer/RUGD_sample-data/RUGD_annotation-colormap.txt', sep=" ", header=None) color_map.columns = ["label_idx", "label", "R", "G", "B"] color_map.head()

label2id = {label: id for id, label in enumerate(color_map.label)}
id2label = {id: label for id, label in enumerate(color_map.label)}
print("---------------------------------------------------------")
print("[+] id2label: ",id2label)
print("[+] length of id2 labels: ", len(id2label))
id2color = {id: [R,G,B] for id, (R,G,B) in enumerate(zip(color_map.R, color_map.G, color_map.B))}
print("[+] id2 RGB color: ", id2color)
print("[+] Color of id2 tree label: ", id2color[label2id["tree"]])
print("---------------------------------------------------------")

del id2color[0]

id2color = {id-1: color for id, color in id2color.items()}
print("[+] id2color: ",id2color)

del id2label[0]

label2id = {label: id-1 for id, label in id2label.items()}
id2label = {id-1: label for id, label in id2label.items()}
print("[+] id2label: ",id2label)

root_dir_imgs = "/media/yash/T7/Fine_Tune_Segformer/RUGD_sample-data/split_imgs" 
mask_dir = "/media/yash/T7/Fine_Tune_Segformer/RUGD_sample-data/split_anns"

feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
feature_extractor.do_reduce_labels = False
feature_extractor.size = 128

train_dataset = SemanticSegmentationDataset(f"{root_dir_imgs}/train/",f"{mask_dir}/train/", feature_extractor)
val_dataset = SemanticSegmentationDataset(f"{root_dir_imgs}/val/",f"{mask_dir}/val/", feature_extractor)
test_dataset = SemanticSegmentationDataset(f"{root_dir_imgs}/test/",f"{mask_dir}/test/", feature_extractor)

print("-----------------------------------------------------------------")
encoded_inputs = train_dataset[0]
print("[+] Input encoder pixel size: ",encoded_inputs["pixel_values"].shape)
print("[+] Input encoder lables size: ",encoded_inputs["labels"].shape)
print("[+] Unique encoder lables: ",encoded_inputs["labels"].squeeze().unique())
print("-----------------------------------------------------------------")

batch_size = 128
num_workers = 10
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=num_workers,pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers,pin_memory=True)

path = "/media/yash/T7/Fine_Tune_Segformer/Scripts/lightning_logs/version_35/checkpoints/epoch=40-step=1927.ckpt"
segformer_finetuner = SegformerFinetuner(id2label,label2id,color_map=color_map).load_from_checkpoint(path)

Additional

No response

Are you willing to submit a PR?

github-actions[bot] commented 1 year ago

πŸ‘‹ Hello @yashmewada9618, thank you for leaving an issue on Roboflow Notebooks.

🐞 Bug reports

If you are filing a bug report, please be as detailed as possible. This will help us more easily diagnose and resolve the problem you are facing. To learn more about contributing, check out our Contributing Guidelines.

If you require support with custom code that is not part of Roboflow Notebooks, please reach out on the Roboflow Forum or on the GitHub Discussions page associated with this repository.

πŸ’¬ Get in touch

Do you have more questions about Roboflow that we haven't responded to yet? Feel free to ask them on the Roboflow Discuss forum. Our developer advocates and community team actively respond to questions there.

To ask questions about Notebooks, head over to the GitHub Discussions section of this repository.

yashmewada9618 commented 1 year ago

So, it turns out I was incorrectly implementing the loaf_from_checkpoint(). The correct way to do that is segformer_finetuner = SegformerFinetuner(id2label=id2label,lable2id=label2id,color_map=color_map).load_from_checkpoint(path,id2label=id2label,lable2id=label2id). And I was able to successfully load and predict from the saved checkpoint. image

SkalskiP commented 1 year ago

Hi @yashmewada9618 πŸ‘‹πŸ»! It is awesome to hear that you were able to solve your issue. πŸ’œ