pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.13k stars 6.94k forks source link

[proposal] Use self.flatten instead of torch.flatten and when becomes possible derive ResNet from nn.Sequential (scripting+quantization is blocker), would simplify model surgery in the most frequent cases #3331

Open vadimkantorov opened 3 years ago

vadimkantorov commented 3 years ago

Currently In https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L243:

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

If it instead used x = self.flatten(x), then it would simplify model surgery: del model.avgpool, model.flatten, model.fc. Also in this case the class can just derive from Sequential and use OrderedDict to pass submodules (like in https://discuss.pytorch.org/t/ux-mix-of-nn-sequential-and-nn-moduledict/104724/2?u=vadimkantorov), this would preserve checkpoint compat as well. The method forward could then be removed

datumbox commented 3 years ago

Thanks for the proposal. Replacing torch.flatten with nn.Flatten seems it won't break anything but could you describe a bit more your use-case and how it will help you achieve what you want?

Concerning doing a major refactoring and inheriting from Sequential, this might have side-effects on all the other models that depend on resnet (segmentation, object detection etc), so I'm not sure if it can be done in a backward-compatible manner. Note also that replacing the forward is not something we can do due to how Quantization works: https://github.com/pytorch/vision/blob/6116812d51c6e5524ef48f8b14c0293ed674ed55/torchvision/models/quantization/resnet.py#L93-L100

@fmassa Let me know if you see any problems about replacing flatten.

fmassa commented 3 years ago

I would be fine replacing torch.flatten with nn.Flatten, although it would only simplify model surgery if we were to make further modifications to the ResNet model as you suggested (making it inherit from nn.Sequential). Without this, only del model.fc etc wouldn't be enough.

So from this perspective, there is limited value in replacing torch.flatten with nn.Flatten, it being only syntactic sugar from one another.

vadimkantorov commented 3 years ago

Oh, I didn't realize:

# Ensure scriptability 
# super(QuantizableResNet,self).forward(x) 
# is not scriptable 

If this becomes fixed, I guess inheriting from nn.Sequential should be fine.

My real usecase: I'm working with a codebase https://github.com/ajabri/videowalk/blob/master/code/resnet.py#L41 that had to reimplement resnet forward because it wants to do simple model surgery

fmassa commented 3 years ago

@vadimkantorov model surgery in PyTorch generally requires re-writing forward or other parts of the model, except for nn.Sequential, so I would say it's a valid requirement to ask the users to be a bit more verbose

vadimkantorov commented 3 years ago

model surgery in PyTorch generally requires re-writing forward or other parts of the model, except for nn.Sequential

Of course. I understand that in general it's the case, but if it can be made simpler, I think it should be made simpler even if in general case full re-writing is required. In this particular case, it seems that because of quantization scripting limitations deriving from nn.Sequential is a no-go, but when that's solved it could be nice to derive ResNet from nn.Sequential.

I'll rename thie issue accordingly

vadimkantorov commented 3 years ago

It seems that deriving from nn.Sequential would also be good from DeepSpeed-compat standpoint: https://github.com/pytorch/pytorch/issues/51574

For quantization case, couldn't it just insert quant/dequant-calling modules in the beginning and the end of sequential? then everything should work and no modification of forward would be needed

vadimkantorov commented 3 years ago

They are already modules:

        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()

One would just be required to modify the constructor and insert them before conv1 and after layer4

vadimkantorov commented 3 years ago

@fmassa Even if we just replace torch.flatten by self.flatten, it already simplifies model surgery by model.avgpool = model.flatten = model.fc = nn.Identity() (instead of del model.avgpool, model.flatten, model.fc)

fmassa commented 3 years ago

@vadimkantorov I think we should revisit how one does model surgery in PyTorch. In #3597 I have a prototype feature (which I'll be updating soon) that has as a much more robust and generic way of providing feature extraction

As it relies on FX, it and quantization is also relying on FX for extracting the graph, I think this could be a good compromise which would address all your points I believe

vadimkantorov commented 3 years ago

I think deleting some upper layers is only tangentially related to intermediate feature extraction, as with model.avgpool = model.flatten = model.fc = nn.Identity() is a very simple and understandable way removing some layers (and not involving a recent new technology).

I propose to still do self.flatten, even if you merge your new FX-based feature

fmassa commented 3 years ago

@vadimkantorov the magic with the FX-based approach is that if you specify any part of the model, you can actually delete (by removing the compute and the parameters) of the end of the model. So it is actually a generalization of what you are proposing

vadimkantorov commented 3 years ago

I understand that it also solves this case :) But self.flatten is so simple and doesn't break any back-compat and doesn't require learning the new compiling functionality just for this simple goal :)

I think both are worthwhile :) (And maybe even figuring out how to derive from nn.Sequential for DeepSpeed benefits as well, but that's separate)

And probably the FX would still forward through the "layers-to-be-removed", and the large fully-connected layer can be quite costly

fmassa commented 3 years ago

@vadimkantorov we merged https://github.com/pytorch/vision/pull/4302, which should provide a generic way of performing model surgery.

Could you give this a try and provide us feedback if it is enough for your use-cases?

vadimkantorov commented 3 years ago

Very simple: remove average pooling, remove the last fc layer. I feel that having to plug generic FX / rewriting for this is an overkill

fmassa commented 3 years ago

@vadimkantorov the thing you just mentioned can be done in one line

model_surgery = create_feature_extractor(model, ['layer4'])

or more generically

nodes, _ = get_graph_node_names(model)
model_surgery = create_feature_extractor(model, nodes[-3])
vadimkantorov commented 3 years ago

I don't even need the fc layer to run. Will it still run with it? Or will it do DCE?

I mean, I understand that FX is also a way to achieve this, but just being able to del backbone.fc or backbone.fc = nn.Identity() is so much simpler way to achieve this for simple models.

fmassa commented 3 years ago

The FC layer (and its parameters) will be DCEd from the graph and won't be executed, so this is taken care for you.

vadimkantorov commented 3 years ago

I mean it's great to have a generic solution, but having an equivalent much simpler way (when possible) that users already know how to use is also a win

Then it also probably will be less debuggable. E.g. can I put a breakpoint in transformed code?

fmassa commented 3 years ago

The transformed code can be printed etc, and is executed as standard Python code by the Python interpreter, so you can jump to every line of it in a Python debugger.

For example

m = torchvision.models.resnet18()
mm = torchvision.models.feature_extraction.create_feature_extractor(m, ['layer2'])
print(mm.code)

will give you

def forward(self, x : torch.Tensor):
    conv1 = self.conv1(x);  x = None
    bn1 = self.bn1(conv1);  conv1 = None
    relu = self.relu(bn1);  bn1 = None
    maxpool = self.maxpool(relu);  relu = None
    layer1_0_conv1 = getattr(self.layer1, "0").conv1(maxpool)
    layer1_0_bn1 = getattr(self.layer1, "0").bn1(layer1_0_conv1);  layer1_0_conv1 = None
    layer1_0_relu = getattr(self.layer1, "0").relu(layer1_0_bn1);  layer1_0_bn1 = None
    layer1_0_conv2 = getattr(self.layer1, "0").conv2(layer1_0_relu);  layer1_0_relu = None
    layer1_0_bn2 = getattr(self.layer1, "0").bn2(layer1_0_conv2);  layer1_0_conv2 = None
    add = layer1_0_bn2 + maxpool;  layer1_0_bn2 = maxpool = None
    layer1_0_relu_1 = getattr(self.layer1, "0").relu(add);  add = None
    layer1_1_conv1 = getattr(self.layer1, "1").conv1(layer1_0_relu_1)
    layer1_1_bn1 = getattr(self.layer1, "1").bn1(layer1_1_conv1);  layer1_1_conv1 = None
    layer1_1_relu = getattr(self.layer1, "1").relu(layer1_1_bn1);  layer1_1_bn1 = None
    layer1_1_conv2 = getattr(self.layer1, "1").conv2(layer1_1_relu);  layer1_1_relu = None
    layer1_1_bn2 = getattr(self.layer1, "1").bn2(layer1_1_conv2);  layer1_1_conv2 = None
    add_1 = layer1_1_bn2 + layer1_0_relu_1;  layer1_1_bn2 = layer1_0_relu_1 = None
    layer1_1_relu_1 = getattr(self.layer1, "1").relu(add_1);  add_1 = None
    layer2_0_conv1 = getattr(self.layer2, "0").conv1(layer1_1_relu_1)
    layer2_0_bn1 = getattr(self.layer2, "0").bn1(layer2_0_conv1);  layer2_0_conv1 = None
    layer2_0_relu = getattr(self.layer2, "0").relu(layer2_0_bn1);  layer2_0_bn1 = None
    layer2_0_conv2 = getattr(self.layer2, "0").conv2(layer2_0_relu);  layer2_0_relu = None
    layer2_0_bn2 = getattr(self.layer2, "0").bn2(layer2_0_conv2);  layer2_0_conv2 = None
    layer2_0_downsample_0 = getattr(getattr(self.layer2, "0").downsample, "0")(layer1_1_relu_1);  layer1_1_relu_1 = None
    layer2_0_downsample_1 = getattr(getattr(self.layer2, "0").downsample, "1")(layer2_0_downsample_0);  layer2_0_downsample_0 = None
    add_2 = layer2_0_bn2 + layer2_0_downsample_1;  layer2_0_bn2 = layer2_0_downsample_1 = None
    layer2_0_relu_1 = getattr(self.layer2, "0").relu(add_2);  add_2 = None
    layer2_1_conv1 = getattr(self.layer2, "1").conv1(layer2_0_relu_1)
    layer2_1_bn1 = getattr(self.layer2, "1").bn1(layer2_1_conv1);  layer2_1_conv1 = None
    layer2_1_relu = getattr(self.layer2, "1").relu(layer2_1_bn1);  layer2_1_bn1 = None
    layer2_1_conv2 = getattr(self.layer2, "1").conv2(layer2_1_relu);  layer2_1_relu = None
    layer2_1_bn2 = getattr(self.layer2, "1").bn2(layer2_1_conv2);  layer2_1_conv2 = None
    add_3 = layer2_1_bn2 + layer2_0_relu_1;  layer2_1_bn2 = layer2_0_relu_1 = None
    layer2_1_relu_1 = getattr(self.layer2, "1").relu(add_3);  add_3 = None
    return {'layer2': layer2_1_relu_1}

which is what is executed by Python, and what the Python interpreter will execute

vadimkantorov commented 3 years ago

Well, it's great to have, but compared to just stepping into resnet.py, this is not the same :) jk

I agree this is a great, powerful feature allowing to use a lot backbones and it's good that it unblocks this issue, but I don't see reasons for not impoving the basic ResNet by moving flatten into an attribute and maybe transforming the quantized resnet into sequential

fmassa commented 3 years ago

I think the key idea here is that we are providing a single (generic?) solution that handle many more use-cases than what was possible before. Overriding modules with nn.Identity() was a hacky solution that didn't work most of the time (only inside nn.Sequential-style models), and can lead to confusion / silent bugs in many cases, specially to newer users I think.

Having the forward implemented for the models (even if it is a straight Sequential) can still be beneficial for users as it allows for the exact same things you have been advocating for in your past messages (namely seeing the execution code, more easily putting breakpoints / prints, etc).

Still, users are free to use what they think best fit their needs.

vadimkantorov commented 3 years ago

I was of course not talking of "newer users". Of course these hacks are for someone who already understands the backbone model source code and looks to avoid boiler-plate of copy-pasting the forward. This is no more "advanced" as the recommended individual model component reinitialization (e.g. model.roi_heads.box_predictor) as found in https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

In-the-wild surgery by resetting model.fc = nn.Identity(): https://lernapparat.de/resnet-how-many-models/

vadimkantorov commented 3 years ago

One more reason for models to be nn.Sequential whenever possible: https://pytorch.org/docs/stable/checkpoint.html?highlight=checkpoint_sequential#torch.utils.checkpoint.checkpoint_sequential