weecology / DeepForest

Python Package for Airborne RGB machine learning
https://deepforest.readthedocs.io/
MIT License
479 stars 172 forks source link

Allow users to modify input size #347

Open bw4sz opened 2 years ago

bw4sz commented 2 years ago
def create_model(num_classes, nms_thresh, score_thresh, backbone = None):
    """Create a retinanet model
    Args:
        num_classes (int): number of classes in the model
        nms_thresh (float): non-max suppression threshold for intersection-over-union [0,1]
        score_thresh (float): minimum prediction score to keep during prediction  [0,1]
    Returns:
        model: a pytorch nn module
    """
    if not backbone:
        resnet = load_backbone()
        backbone = resnet.backbone

    model = RetinaNet(backbone=backbone, num_classes=num_classes)
    model.nms_thresh = nms_thresh
    model.score_thresh = score_thresh

    # Optionally allow anchor generator parameters to be created here
    # https://pytorch.org/vision/stable/_modules/torchvision/models/detection/retinanet.html

    return model

This is poorly documented in torchvision,

https://github.com/pytorch/vision/blob/97e0ea9c6ebc454538b3fa505e1d199547b0feed/torchvision/models/detection/transform.py#L43

Om-Doiphode commented 1 year ago

@bw4sz would like to work on this issue, can you please guide me on how to get started with it?

bw4sz commented 1 year ago

Thanks, but I didn't put good first issue label for a reason, messing around with the backbone model probably isn't the best way to start with the codebase.