Closed potipot closed 3 years ago
The EfficientDet workflow seems not to be yet ready.
Outdated comment :sweat_smile: It's going to be removed in upcoming commits
Actually, when you call efficientdet.model
it already loads pretrained weights from COCO. This is what we grab by default, you can take a look there for all the sizes with pretrained weights.
If you want to load your own weights and change the head for fine-tuning, take a look here for insights on how to do it.
The key concept there is model.reset_head
Thanks for the update!
For now I wrote this workaround function to load pretrained weights with matching parameter names and shapes:
@patch
def load_matching(self:Learner, model_name:str):
this_model = self.model.state_dict()
trained_model = torch.load(f'models/{model_name}.pth')['model']
for (this_module, this_param), (loaded_module, loaded_param) in zip(this_model.items(), trained_model.items()):
assert(this_module==loaded_module), f'Models differ: {this_module}, {loaded_module}'
if this_param.shape==loaded_param.shape:
this_model[this_module]=loaded_param
else:
print(f'Weights not loaded: {this_module}: {this_param.shape=}, {loaded_param.shape=}')
return self.model.load_state_dict(this_model)
The EfficientDet workflow seems not to be yet ready.
Outdated comment sweat_smile It's going to be removed in upcoming commits
Actually, when you call
efficientdet.model
it already loads pretrained weights from COCO. This is what we grab by default, you can take a look there for all the sizes with pretrained weights.
any ETA on those commits? I was trying to reproduce the mAP results from the source you provided, but the problem is that apart from loading weights, the head gets modified and parameter values are not verbatim
class_net.predict.conv_pw.weight model_loaded.size=torch.Size([810, 64, 1, 1]) model_current.size=torch.Size([819, 64, 1, 1])
class_net.predict.conv_pw.bias model_loaded.size=torch.Size([810]) model_current.size=torch.Size([819])
I guess this is because we add background
to the ClassMap.
Following your advice I forced the number of classes to be 90 and I can observe an improvement in COCOMetric (was nearly all zeros before): model: tf_efficientdet_lite0
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.122
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.227
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.113
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.015
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.203
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.279
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.177
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.258
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.275
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.114
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.338
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.409
This however is still not close to the reported value of 33.6 (AP) @[ IoU=0.50:0.95 | area= all
I've made sure that I'm using the same size and config parameters as in the original (padding color, normalization).
One thing that I'm still trying to figure out is the bbox configuration. In the tfms.A.Adapter
uses hardcoded pascal voc bbox orientation of xyxy while comment in this thread suggest the models were pretrained using yxyx.
I will try to change that parameter and re-run the validations tomorrow. (I'm doing both fastai and pl validations, same results)
Or perhaps you have some other suggestion on how to obtain the same mAP in validation?
this thread suggest the models were pretrained using yxyx.
Model specific formatting happens inside the dataloader here (this is how we can be agnostic to any model implementation)
I've made sure that I'm using the same size and config parameters as in the original (padding color, normalization).
This is all being done on COCO right? can you share the code? I can take a look and we can figure out what is lacking to achieve the reported results
Actually, let us make this clear: Are you training from scratch on coco or loading the model with the pre-trained weights and checking the mAP?
any ETA on those commits?
I meant a commit only to remove that comment :sweat_smile: , did you have something else in mind?
Thanks for pointing to the place in dataloader. I'm trying to reproduce the results of a trained model to make sure that I'm using the transfer learning correctly. All done on COCO Dataset downloaded from the official link, 90 classes, coco.parser
One thing that surprised me: I changed the order in the build_train_batch
function (also used by build_valid_batch
) from yxyx to xyxy and re-run the evaluation metrics and results are the same! The only difference is the loss function increase in the changed example.
#default - yxyx
learn.validate()
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.122
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.227
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.113
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.015
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.203
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.279
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.177
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.258
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.275
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.114
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.338
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.409
(#2) [0.6418642997741699,0.12209737428903714] # ValLoss, COCOMetric
# modified - xyxy
learn.validate()
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.122
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.227
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.113
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.015
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.203
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.279
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.177
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.258
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.275
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.114
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.338
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.409
(#2) [2.570807933807373,0.12209737428903714] # ValLoss, COCOMetric
Maybe its a problem with the COCOMetric again? Another hypothesis is difference in the ordering of classes. For example passing class_map to
batch, samples = first(train_dl)
show_samples(
samples[:6], class_map=class_map, ncols=3, denormalize_fn=denormalize_imagenet
)
throws and error IndexError: list index out of range
.
NOTE: Coco uses 1..90 numbering while in class map the default is 0..89?
My search continues..
from icevision.all import *
import icedata
from fastai.vision.all import *
from fastai.callback.wandb import *
import wandb
from imports import *
path: Param("Training dataset path", str) = Path.home()/'Datasets/image/coco/'
bs: Param("Batch size", int) = 8
log: Param("Log to wandb", bool) = False
num_workers: Param("Number of workers to use", int) = 4
resume: Param("Link to pretrained model", str) = None
name: Param('experiment name', str) = 'coco'
class_map = icedata.coco.class_map(background=None)
path = Path(path)
coco_train = icedata.coco.parser(
img_dir=path / 'train2017',
annotations_file=path/'annotations/instances_train2017.json',
mask=False)
coco_valid = icedata.coco.parser(
img_dir=path / 'val2017',
annotations_file=path/'annotations/instances_val2017.json',
mask=False)
train_records, *_ = coco_train.parse(data_splitter=SingleSplitSplitter(), cache_filepath=path/'train_cache')
valid_records, *_ = coco_valid.parse(data_splitter=SingleSplitSplitter(), cache_filepath=path/'valid_cache')
show_record(train_records[1], display_label=True)
size = 512
aug_tfms = tfms.A.aug_tfms(
size=size,
shift_scale_rotate=tfms.A.ShiftScaleRotate(rotate_limit=(-15, 15)),
pad=partial(tfms.A.PadIfNeeded, border_mode=0)
)
aug_tfms.append(tfms.A.Normalize())
train_tfms = tfms.A.Adapter(aug_tfms)
valid_tfms = tfms.A.Adapter([*tfms.A.resize_and_pad(size), tfms.A.Normalize()])
train_ds = Dataset(train_records, train_tfms)
valid_ds = Dataset(valid_records, valid_tfms)
metrics = [COCOMetric(print_summary=True)]
train_dl = efficientdet.train_dl(train_ds, batch_size=bs, num_workers=num_workers, shuffle=True)
valid_dl = efficientdet.valid_dl(valid_ds, batch_size=bs, num_workers=num_workers, shuffle=False)
model = efficientdet.model(model_name="tf_efficientdet_lite0", num_classes=len(class_map), img_size=size)
learn = efficientdet.fastai.learner(dls=[train_dl, valid_dl], model=model, metrics=metrics)
#default - yxyx
learn.validate()
I'm thinking this is something specific to the Effdet + COCO dataset, cause I was able to train the Effdet model using icevision workflow and obtain good mAP results.
I calculated mAP from preds and records accumulated in the metric (before conversion to COCO API) using another library and got nearly the same results, here is how I did it:
def raw_to_odm():
groundtruth_bbs = []
detected_bbs = []
for pred, record in zip(preds, records):
image_name = record['filepath'].name
for bbox, label in zip(record['bboxes'], record['labels']):
class_id=label
coordinates = bbox.xywh
bb = BoundingBox(image_name=image_name, class_id=class_id, coordinates=coordinates)
groundtruth_bbs.append(bb)
for score, bbox, label in zip(pred['scores'], pred['bboxes'], pred['labels']):
class_id=label
coordinates = bbox.xywh
bb = BoundingBox(image_name=image_name, class_id=class_id, coordinates=coordinates,
bb_type=BBType.DETECTED, confidence=score)
detected_bbs.append(bb)
return groundtruth_bbs, detected_bbs
get_coco_summary(groundtruth_bbs, detected_bbs)
{'AP': 0.11017058248779392,
'AP50': 0.2117803965667843,
'AP75': 0.09722278750266831,
'APsmall': 0.005122517036826111,
'APmedium': 0.1955153436184948,
'APlarge': 0.26638203179285247,
'AR1': 0.15844023950095418,
'AR10': 0.22851812631946697,
'AR100': 0.2428790814116496,
'ARsmall': 0.07777300144455561,
'ARmedium': 0.33974585757257,
'ARlarge': 0.37168434508668985}
I'm still investigating the Effdet repo to see what kind of tweaking they do to achieve these mAP results. Any help is welcome!
Bingo! I figured it out.
When passing target images to the EfficientDet model, icevision preserves the original image sizes of targets and those values are processed by the effdet itself (no padding remembered). I forced those values to be 512 (as in the padding I use) and the metric show correct results now:
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.311
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.489
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.328
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.095
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.356
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.503
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.270
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.418
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.442
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.177
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.523
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.659
This also explains why I was able to train on my own Dataset with aspect_ratio 1 and no resize or padding used in the tfms.
def forward(self, *args, **kwargs):
args[1]['img_size'] = torch.full_like(args[1]['img_size'], 512.0)
return self.model(*args, **kwargs)
I will try to figure out a PR on how to fix this the right way.
Amazing work @potipot, I was going to investigate this right now but you already solved it! hahahah
So, if I understood correctly, the error is on this line? Here we are setting the image size without padding, but instead we need to set the image size with padding? If so we just need to take the image size from the images
tensor images.shape[2:]
Another question, how was your experience using this other library for metrics? Is it a good replacement for pycocotools
?
So, if I understood correctly, the error is on this line? Here we are setting the image size without padding, but instead we need to set the image size with padding? If so we just need to take the image size from the images tensor images.shape[2:]
I think this would be the place to insert it and rely directly on the input tensor; however, what they do at EffDet is they use some box scaling and resizing, to scale it up to the match with the original target image shape and only then, calculate the metric.
Note the difference in parameters passed to _batch_detection
in train and validation
I'm not sure how this improves the results they get on COCO. I guess it can have impact on the box size and whether it is classified as small or large but otherwise?
Another question, how was your experience using this other library for metrics? Is it a good replacement for
pycocotools
?
It was way easier to implement but I think slightly slower. The API is self explanatory with single call to
get_coco_summary(groundtruth_bbs: list[BoundingBox], detected_bbs: list[BoundingBox])
I could take a look later on how to speed it up. I expect there is room for improvement.
I could take a look later on how to speed it up. I expect there is room for improvement.
Do you think there is value in us trying to implement these metrics ourselves?
Note the difference in parameters passed to _batch_detection in train and validation
I still have to look deeper into this, but here is a conversation I had with ross that can be helpful.
img_scale is used to move coorodinates between what I think of as the 'model canvas', the img_size * img_size input image size of the model. Umages are scaled down maintaining aspect to fit in that square, located at the origin, upper left corner, the rest is padded if the original image aspect is not a square. The img_scale stores ratio needed to move the output coordinates of the model back to the original image coordinate space for coco evaluator. you can just set img_scale to 1 to not use it. and the image size values used to crop bbox to (img_size, img_size) if you want to handle of the image sizing, scaling, evaluation yourself.
Btw @potipot, have you joined our forum? If not, consider joining, we have lots of interesting discussions happening there =)
Note the difference in parameters passed to _batch_detection in train and validation
So, we want to handle re-scaling ourselves, we have the option to set img_size and img_scale as None, from what I'm seeing this will then not rescale the bboxes [1] [2]
Before we continue, can you try that and see if you still get the correct results?
I'm not sure how this improves the results they get on COCO. I guess it can have impact on the box size and whether it is classified as small or large but otherwise?
Yeaaah, I'm still not quite sure as well. What we have to be careful about is that when we call CocoMetric.accumulate
records and preds should have the same img_sizes and bboxes scaled accordingly. When effdet
is scaling the predictions internally it might be messing this up
Do you think there is value in us trying to implement these metrics ourselves?
I think yes, cause current conversion to COCO api is quite robust, I had trouble understanding what was going on. Maybe we could try more OO approach? The code would be much cleaner if we used the library I linked before.
Maybe we could try more OO approach?
For sure! The code for the conversion was implemented in a hush, a better way to implement this functionality is to use the visitor pattern in the RecordMixin
s.
I opened this issue to keep track of that
Btw, the same strategy can be used for the transforms, currently tfms.A.Adapter
is also quite a mess
The code would be much cleaner if we used the library I linked before.
The only disadvantage is that it's a bit slower right? Probably because it's fully implement in python while pycocotools uses C.
COCOMetric gives correct results for validation with this change:
def build_valid_batch(records, batch_tfms=None):
(images, targets), records = build_train_batch(
records=records, batch_tfms=batch_tfms
)
- img_sizes = [(r["height"], r["width"]) for r in records]
- targets["img_size"] = tensor(img_sizes, dtype=torch.float)
+ # passing the size of transformed image to efficientdet, necessary for its own scaling and resizing, see
+ # https://github.com/rwightman/efficientdet-pytorch/blob/645d84a6f0cd837703f98f48179a06c354902515/effdet/bench.py#L100
+ targets["img_size"] = tensor([image.shape[-2:] for image in images], dtype=torch.float)
targets["img_scale"] = tensor([1] * len(records), dtype=torch.float)
return (images, targets), records
Effdet inference will be resolved by https://github.com/airctic/icevision/pull/630
š New <Tutorial/Example>
Request for an example
What is the task? Object detection using transfer learning for the whole architecture. Are there some defined methods to load fastai model and change its head to a different number of classes, similar to this?
I was able to run the Faster-RCNN example using this example trained on COCO dataset and evaluate its mAP.
The EfficientDet workflow seems not to be yet ready. Has there been some update on that?
I was able to create EfficientDet with pretrained encoder and train it myself on COCO. I'm now trying to do transfer learning for a different number of classes. Loading model through fastai, expectedly, throws an error:
Is this example for a specific model? EfficientDet
Is this example for a specific dataset? COCO transfer learning
Don't remove Main issue for examples: #39