SegFormer Segmentation #148

Closed HurairaCodes closed 1 year ago

HurairaCodes commented 1 year ago

So, I have trained custom semantic segmentation model following this tutorial (SegFormer Semantic Segmentation Model) but the model generated is in .ckpt format and I have downloaded the model but I am having issues when I try to do interference using images.

Can you please guide me on how I can convert this file to a Tensorflow-saved model file or how I can perform interference after downloading the checkpoint file.

Your guidance will be highly appreciated, Thank you and Best Regards


SkalskiP commented 1 year ago

Hi @HurairaCODE 👋🏻 ! What is the exact error message you get?

SkalskiP commented 1 year ago

I'm converting this issue into a discussion and moving it to our QA section.

HurairaCodes commented 1 year ago

Thanks for your reply. So, I am having issues with the model checkpoint file meaning I have downloaded the model checkpoint file and now if I want to do an inference by giving an image as an input, it's not working.

I tried several ways, latest code I have is this

import torch from torchvision.transforms import functional as F from PIL import Image from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation

Load the saved model checkpoint

checkpoint_path = "/content/lightning_logs/version_0/checkpoints/epoch=3-step=40.ckpt" checkpoint = torch.load(checkpoint_path)

Extract the model state_dict

state_dict = checkpoint['state_dict']

Create the model

model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")

Map the state_dict keys to match the current model's state_dict key names

new_state_dict = {} for key in state_dict: if key.startswith("model."): new_key = key.replace("model.", "") new_state_dict[new_key] = state_dict[key] else: new_state_dict[key] = state_dict[key]

Load the updated state_dict


Create the feature extractor

feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") feature_extractor.do_reduce_labels = False feature_extractor.size = 128

Load the test image

test_image_path = "/content/Snow-1/test/frame_000724_jpg.rf.3d7e01b9e220be5e596a863cfd8e5401.jpg" test_image =

Preprocess the test image

encoded_inputs = feature_extractor(test_image, return_tensors="pt") input_image = encoded_inputs['pixel_values']

Perform inference

with torch.no_grad(): outputs = model(input_image) logits = outputs.logits

Post-process the predictions

upsampled_logits = torch.nn.functional.interpolate( logits, size=test_image.size[::-1], # Upsample to the original image size mode="bilinear", align_corners=False ) probabilities = torch.softmax(upsampled_logits, dim=1) predicted_labels = torch.argmax(probabilities, dim=1)

Convert the predicted labels to an RGB image for visualization

color_map = { 0: (0, 0, 0), # Background 1: (255, 0, 0), # Class 1 }

def convert_labels_to_image(labels): vis_shape = labels.shape + (3,) vis = torch.zeros(vis_shape, dtype=torch.uint8) for label, color in color_map.items(): vis[labels == label] = color return vis.permute(1, 2, 0)

predicted_image = convert_labels_to_image(predicted_labels[0])

Overlay the predicted mask on the original image

overlay_img = Image.blend(test_image.convert("RGBA"), predicted_image, alpha=0.5)

Display the result

I am getting the following error now RuntimeError Traceback (most recent call last) in <cell line: 26>() 24 25 # Load the updated state_dict ---> 26 model.load_state_dict(new_state_dict) 27 28 # Create the feature extractor

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/ in load_state_dict(self, state_dict, strict) 2039 2040 if len(error_msgs) > 0: -> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2042, "\n\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for SegformerForSemanticSegmentation: size mismatch for decode_head.classifier.weight: copying a param with shape torch.Size([2, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([150, 256, 1, 1]). size mismatch for decode_head.classifier.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([150]).

nataliameira commented 11 months ago

Hello @HurairaCODE,

Were you able to perform the image inference correctly? Can you help me with this?

HurairaCodes commented 11 months ago

Hello @nataliameira

No, I couldn't find the solution to this problem, also I connected to a team member of RoboFlow through Email and they were unable to help me.

I shifted to YoloV8 segmentation, it is much easier to train and use for projects.