Closed moliniao closed 1 year ago
What is the error? We have included the checkpoint, can you please try the included pretrained model and the accompanied evaluation code?
What is the error? We have included the checkpoint, can you please try the included pretrained model and the accompanied evaluation code?
Hello Tianrun,
Could you tell us how to visualize the predicted mask from the fine-tuned models? If this part of codes were already included in the GitHub, could you tell us where to find them? If not, it's also OK. I will do it myself. (I checked most codes but perhaps still missed them)
Like the figures shown in your paper and GitHub.
Thanks so much in advance! Kind regards.
You can try this at test.py
great thanks : ) I solved the problem.
Thank you ! For multi-class segmentation do we need to make any changes .
Hello, how to use it after training? My code is as follows, the mask obtained after running is the same no matter what the picture is code: modelWeight = "/Users/xxx/Desktop/sement/segment-adapter/SAM-Adapter-PyTorch/pretrained/model_epoch_last.pth" imageTransform = transforms.Compose([ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) configPath = "/Users/xxx/Desktop/sement/segment-adapter/SAM-Adapter-PyTorch/configs/cod-sam-vit-b.yaml" file1 = "/Users/xxx/Desktop/sement/segment-adapter/SAM-Adapter-PyTorch/camourflage_00103.jpg" file2 = "/Users/xxx/Desktop/sement/segment-adapter/SAM-Adapter-PyTorch/camourflage_00001.jpg" img1 = transforms.Resize((1024, 1024))(Image.open(file1).convert('RGB')) img2 = transforms.Resize((1024, 1024))(Image.open(file2).convert('RGB')) img1 = imageTransform(img1) img2 = imageTransform(img2) bs = torch.stack([img1,img2],dim=0) with open(configPath, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader) model = models.make(config['model']) sam_checkpoint = torch.load(modelWeight,map_location=torch.device('cpu')) model.load_state_dict(sam_checkpoint, strict=True) model.eval() masks = model.infer(bs) sig = torch.sigmoid(masks) single = torch.squeeze(masks, dim=0) to_pil = transforms.ToPILImage() mask_img = to_pil(single[1]) mask_img.show()