rwightman / gen-efficientnet-pytorch

Pretrained EfficientNet, EfficientNet-Lite, MixNet, MobileNetV3 / V2, MNASNet A1 and B1, FBNet, Single-Path NAS
Apache License 2.0
1.57k stars 214 forks source link

[bug] cannot set num_features #21

Closed mattans closed 4 years ago

mattans commented 4 years ago

When I try: model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', in_chans=1, num_features=16)

or

model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', in_chans=1, num_features=16, pretrained=False) I get this error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\t-maserr\AppData\Local\Continuum\anaconda3\envs\python37\lib\site-packages\torch\hub.py", line 359, in load
    model = entry(*args, **kwargs)
  File "C:\Users\t-maserr/.cache\torch\hub\rwightman_gen-efficientnet-pytorch_master\geffnet\gen_efficientnet.py", line 636, in efficientnet_b0
    'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
  File "C:\Users\t-maserr/.cache\torch\hub\rwightman_gen-efficientnet-pytorch_master\geffnet\gen_efficientnet.py", line 430, in _gen_efficientnet
    **kwargs,
TypeError: type object got multiple values for keyword argument 'num_features
rwightman commented 4 years ago

@mattans Changing num_features from the top level instatiation isn't currently intended. Did you mean to change the number of classes for the classifer (num_classes)?

If you want to change num_features, I'd recommend creating new a new, or editing existing _gen_XXX() and model entrypoint function and add the necessary overrides for the num_features calculations. Note it'll currently render the pretrained weights invalid without adding addition support for ignoring the weights for the impacted convolution.