Jungsu-Yun / yolov5_deepsort_ros

Wrapped yolov5-deepport to ROS
MIT License
35 stars 15 forks source link

load_classifier #8

Open mowei951010 opened 1 year ago

mowei951010 commented 1 year ago

ImportError: cannot import name 'load_classifier' from 'yolov5.utils.torch_utils

swy767 commented 1 year ago

I think you can add the following code into it:

def load_classifier(name="resnet101", n=2):
    # Loads a pretrained model reshaped to n-class output
    model = torchvision.models.__dict__[name](pretrained=True)

    # ResNet model properties
    # input_size = [3, 224, 224]
    # input_space = 'RGB'
    # input_range = [0, 1]
    # mean = [0.485, 0.456, 0.406]
    # std = [0.229, 0.224, 0.225]

    # Reshape output to n classes
    filters = model.fc.weight.shape[1]
    model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
    model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
    model.fc.out_features = n
    return model

This problem should be caused by different versions of the library