pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
15.99k stars 6.92k forks source link

ResNeXt model type(s) #853

Closed barrh closed 5 years ago

barrh commented 5 years ago

ResNeXt has been recently introduced in the code. Current implementation takes advantage of the similarities with Resnet, by enhancing the Resnet class to support ResNeXt topology as well. While the similarities between the models exist, they are not the same. The result of current implementation is: type(resnext50_32x4d()) == type(resnet50()), which is misleading imo.

  1. ResNeXt models should have unique type. It can inherit from other topologies (e.g. resnet), but must not be the same.

  2. This is more general, and applies to more than just ResNeXt: type(resnext50_32x4d()) == type(resnext101_32x8d()), shouldn't those be of different types?

fmassa commented 5 years ago

More than just the type being the same, there was actually a bug in the resnext implementation that has been fixed in https://github.com/pytorch/vision/pull/852

About your points: 1 - I agree about ResNeXt having its own class. I'm actually thinking about another refactor of it 2 - I'm less sure about the need of this. Should all possible variations of a model, which can be arbitrarily parametrized, have its own class?

barrh commented 5 years ago

2 - In principle - yes. type(A) == type(B) only when it makes sense to call the "copy constructor" (i.e. load_state_dict()) of A on B (AND, B on A).

rwightman commented 5 years ago

So, not going to step into the debate how things 'should' be, but this issue is not specific to ResNet vs ResNext or even PyTorch is it? For models with numerous variants that are parameterized at construction, this is pretty much the norm across the standard model repos associated with most frameworks, with the exception of maybe Keras?

Is a ResNeXt-50 not closer to a ResNet-50 than a ResNet-50 is to a ResNet-34? ;)

fmassa commented 5 years ago

@barrh one can rephrase your point by saying that we would need a ResNet50_1000classes and a ResNet50_21classes, because loading a state_dict for a different number of classes will fail.

In general, I agree with the core idea. But instead of parametrizing on the instance, we should instead parametrize on the class. So if anyone wants to save their model, in order to retrieve the model we need the class name, the constructor arguments and the state dict

checkpoint = torch.load(...)

model_factory = models_catalog[checkpoint['class_name']]
# model_factory is ResNet, or VGG, or Inception, etc, a
# class that inherits from nn.Module
model = model_factory(**checkpoint['arguments'])
# arguments is a dict with the constructor args
model.load_state_dict(checkpoint['state_dict'])
# as before
fmassa commented 5 years ago

While this is a valid point, I don't quite agree that each class should provide a unique type of state_dict.

I'm closing this issue then, but feel free to reopen it if you disagree.