ChristianMarzahl / ObjectDetection

Some experiments with object detection in PyTorch
136 stars 43 forks source link

How to change number of classes for already trained retinanet model #28

Open Monk5088 opened 2 years ago

Monk5088 commented 2 years ago

I have trained my retinanet mode from object detection library, now i want to change the number of classes for the next dataset. I have found a way in pytorch to change the classification head of retinanet to do the same, can anyone help me on how can i perform the same for retinanet.py from object-detection-fastai.

from torchvision.models.detection import retinanet_resnet50_fpn_v2, RetinaNet_ResNet50_FPN_V2_Weights
        from torchvision.models.detection.retinanet import RetinaNetHead, RetinaNetClassificationHead
        weights = RetinaNet_ResNet50_FPN_V2_Weights.DEFAULT
        model = retinanet_resnet50_fpn_v2(weights=weights, box_score_thresh=0.7)

        # replace classification layer
        out_channels = model.head.classification_head.conv[0].out_channels
        num_anchors = model.head.classification_head.num_anchors
        model.head.classification_head.num_classes = num_classes

        cls_logits = torch.nn.Conv2d(out_channels, num_anchors * num_classes, kernel_size=3, stride=1, padding=1)
        torch.nn.init.normal_(cls_logits.weight, std=0.01)  # as per pytorch code
        torch.nn.init.constant_(cls_logits.bias, -math.log((1 - 0.01) / 0.01))  # as per pytorcch code
        # assign cls head to model
        model.head.classification_head.cls_logits = cls_logits
ChristianMarzahl commented 2 years ago

Dear @Monk5088,

If your dataset and with it the number of classes, the code should adapt to the new number of classes.

In the following example, the parameter n_classes defines the number of classes.

RetinaNet(encoder, n_classes=data.train_ds.c, n_anchors=3, sizes=[32], chs=8, final_bias=-4., n_conv=3)
Monk5088 commented 2 years ago

Dear @ChristianMarzahl , I have tried changing the number of classes while initialising retinanet, but when i do learner.load(), it throws me weight mismatch error. CODE:

batch_size = 64

do_flip = True
flip_vert = True 
max_rotate = 90 
max_zoom = 1.1 
max_lighting = 0.2
max_warp = 0.2
p_affine = 0.75 
p_lighting = 0.75 

tfms = get_transforms(do_flip=do_flip,
                      flip_vert=flip_vert,
                      max_rotate=max_rotate,
                      max_zoom=max_zoom,
                      max_lighting=max_lighting,
                      max_warp=max_warp,
                      p_affine=p_affine,
                      p_lighting=p_lighting)
train, valid = ObjectItemListSlide(train_images) ,ObjectItemListSlide(valid_images)
item_list = ItemLists(".", train, valid)
lls = item_list.label_from_func(lambda x: x.y, label_cls=SlideObjectCategoryList)
lls = lls.transform(tfms, tfm_y=True, size=patch_size)
data = lls.databunch(bs=batch_size, collate_fn=bb_pad_collate,num_workers=0).normalize() 

Here train dataset is as follows: SlideLabelList (100 items) x: ObjectItemListSlide Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256),Image (3, 256, 256) y: SlideObjectCategoryList ImageBBox (256, 256),ImageBBox (256, 256),ImageBBox (256, 256),ImageBBox (256, 256),ImageBBox (256, 256) Path: . And the train_images is a list of object_detection_fastai.helper.wsi_loader.SlideContainer objects that are created using the following function:

def create_wsi_container(annotations_df: pd.DataFrame):

    container = []

    for image_name in tqdm(annotations_df["file_name"].unique()):

        image_annos = annotations_df[annotations_df["file_name"] == image_name]

        bboxes = [box   for box   in image_annos["box"]]
        labels = [label for label in image_annos["cat"]]

        container.append(SlideContainer(image_folder/image_name, y=[bboxes, labels], level=res_level,width=patch_size, height=patch_size, sample_func=sample_function))

    return container

CODE FOR LEARNER:

backbone = "ResNet34" #["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet150"]

backbone_model = models.resnet18
if backbone == "ResNet34":
    backbone_model = models.resnet34
if backbone == "ResNet50":
    backbone_model = models.resnet50
if backbone == "ResNet101":
    backbone_model = models.resnet101
if backbone == "ResNet150":
    backbone_model = models.resnet150

pre_trained_on_imagenet = False
encoder = create_body(backbone_model, pre_trained_on_imagenet, -2)

loss_function = "FocalLoss" 

if loss_function == "FocalLoss":
    crit = RetinaNetFocalLoss(anchors)

channels = 128 

final_bias = -4 

n_conv = 3 
model = RetinaNet(encoder, n_classes=3, 
                  n_anchors=len(scales) * len(ratios), 
                  sizes=[size[0] for size in sizes], 
                  chs=channels, # number of hidden layers for the classification head
                  final_bias=final_bias,
                  n_conv=n_conv # Number of hidden layers
                  )
voc = PascalVOCMetric(anchors, patch_size, [str(i) for i in data.train_ds.y.classes[1:]])
learn = Learner(data, model, loss_func=crit, 
                callback_fns=[BBMetrics,ShowGraph,CSVLogger,partial(GradientClipping, clip=2.0)],metrics=[voc])
learn.load("PATH/to/.pth",strict=False) 

ERROR:

/usr/local/lib/python3.7/dist-packages/fastai/basic_train.py in load(self, file, device, strict, with_opt, purge, remove_module)
    271             model_state = state['model']
    272             if remove_module: model_state = remove_module_load(model_state)
--> 273             get_model(self.model).load_state_dict(model_state, strict=strict)
    274             if ifnone(with_opt,True):
    275                 if not hasattr(self, 'opt'): self.create_opt(defaults.lr, self.wd)

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1496         if len(error_msgs) > 0:
   1497             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1498                                self.class.name, "\n\t".join(error_msgs)))
   1499         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1500 

RuntimeError: Error(s) in loading state_dict for RetinaNet:
    size mismatch for classifier.3.weight: copying a param with shape torch.Size([2, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([3, 128, 3, 3]).
    size mismatch for classifier.3.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([3]).
Monk5088 commented 2 years ago

I have trained it on 2 class dataset, but i want it to just predict on the 3 class dataset, and both dataset share the same first and last class in databunch ,i.e., the first databnch i trained my retinanet on has following classes: ['background', 'mitosis'] While the new dataset on which i need prediction contains the following classes: ['background', 'hard negative', 'mitosis'] So is it possible for my model to only predict the mitosis for new dataset. Thanks