roboflow / notebooks

Examples and tutorials on using SOTA computer vision models and techniques. Learn everything from old-school ResNet, through YOLO and object-detection transformers like DETR, to the latest models like Grounding DINO and SAM.
https://roboflow.com/models
4.89k stars 759 forks source link

SegFormer Segmentation #148

Closed HurairaCodes closed 1 year ago

HurairaCodes commented 1 year ago

Search before asking

Notebook name

https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/train-segformer-segmentation-on-custom-data.ipynb

Bug

So, I have trained custom semantic segmentation model following this tutorial (SegFormer Semantic Segmentation Model) but the model generated is in .ckpt format and I have downloaded the model but I am having issues when I try to do interference using images.

Can you please guide me on how I can convert this file to a Tensorflow-saved model file or how I can perform interference after downloading the checkpoint file.

Your guidance will be highly appreciated, Thank you and Best Regards

Environment

Google Colab

Minimal Reproducible Example

No response

Additional

No response

Are you willing to submit a PR?

github-actions[bot] commented 1 year ago

👋 Hello @HurairaCODE, thank you for leaving an issue on Roboflow Notebooks.

🐞 Bug reports

If you are filing a bug report, please be as detailed as possible. This will help us more easily diagnose and resolve the problem you are facing. To learn more about contributing, check out our Contributing Guidelines.

If you require support with custom code that is not part of Roboflow Notebooks, please reach out on the Roboflow Forum or on the GitHub Discussions page associated with this repository.

💬 Get in touch

Do you have more questions about Roboflow that we haven't responded to yet? Feel free to ask them on the Roboflow Discuss forum. Our developer advocates and community team actively respond to questions there.

To ask questions about Notebooks, head over to the GitHub Discussions section of this repository.

SkalskiP commented 1 year ago

Hi @HurairaCODE 👋🏻 ! What is the exact error message you get?

SkalskiP commented 1 year ago

I'm converting this issue into a discussion and moving it to our QA section.

HurairaCodes commented 1 year ago

Thanks for your reply. So, I am having issues with the model checkpoint file meaning I have downloaded the model checkpoint file and now if I want to do an inference by giving an image as an input, it's not working.

I tried several ways, latest code I have is this

import torch from torchvision.transforms import functional as F from PIL import Image from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation

Load the saved model checkpoint

checkpoint_path = "/content/lightning_logs/version_0/checkpoints/epoch=3-step=40.ckpt" checkpoint = torch.load(checkpoint_path)

Extract the model state_dict

state_dict = checkpoint['state_dict']

Create the model

model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")

Map the state_dict keys to match the current model's state_dict key names

new_state_dict = {} for key in state_dict: if key.startswith("model."): new_key = key.replace("model.", "") new_state_dict[new_key] = state_dict[key] else: new_state_dict[key] = state_dict[key]

Load the updated state_dict

model.load_state_dict(new_state_dict)

Create the feature extractor

feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512") feature_extractor.do_reduce_labels = False feature_extractor.size = 128

Load the test image

test_image_path = "/content/Snow-1/test/frame_000724_jpg.rf.3d7e01b9e220be5e596a863cfd8e5401.jpg" test_image = Image.open(test_image_path)

Preprocess the test image

encoded_inputs = feature_extractor(test_image, return_tensors="pt") input_image = encoded_inputs['pixel_values']

Perform inference

with torch.no_grad(): outputs = model(input_image) logits = outputs.logits

Post-process the predictions

upsampled_logits = torch.nn.functional.interpolate( logits, size=test_image.size[::-1], # Upsample to the original image size mode="bilinear", align_corners=False ) probabilities = torch.softmax(upsampled_logits, dim=1) predicted_labels = torch.argmax(probabilities, dim=1)

Convert the predicted labels to an RGB image for visualization

color_map = { 0: (0, 0, 0), # Background 1: (255, 0, 0), # Class 1 }

def convert_labels_to_image(labels): vis_shape = labels.shape + (3,) vis = torch.zeros(vis_shape, dtype=torch.uint8) for label, color in color_map.items(): vis[labels == label] = color return vis.permute(1, 2, 0)

predicted_image = convert_labels_to_image(predicted_labels[0])

Overlay the predicted mask on the original image

overlay_img = Image.blend(test_image.convert("RGBA"), predicted_image, alpha=0.5)

Display the result

overlay_img.show()

I am getting the following error now RuntimeError Traceback (most recent call last) in <cell line: 26>() 24 25 # Load the updated state_dict ---> 26 model.load_state_dict(new_state_dict) 27 28 # Create the feature extractor

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict) 2039 2040 if len(error_msgs) > 0: -> 2041 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 2042 self.class.name, "\n\t".join(error_msgs))) 2043 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for SegformerForSemanticSegmentation: size mismatch for decode_head.classifier.weight: copying a param with shape torch.Size([2, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([150, 256, 1, 1]). size mismatch for decode_head.classifier.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([150]).

nataliameira commented 11 months ago

Hello @HurairaCODE,

Were you able to perform the image inference correctly? Can you help me with this?

HurairaCodes commented 11 months ago

Hello @nataliameira

No, I couldn't find the solution to this problem, also I connected to a team member of RoboFlow through Email and they were unable to help me.

I shifted to YoloV8 segmentation, it is much easier to train and use for projects.