Open vadimkantorov opened 3 years ago
Thanks for the proposal. Replacing torch.flatten with nn.Flatten seems it won't break anything but could you describe a bit more your use-case and how it will help you achieve what you want?
Concerning doing a major refactoring and inheriting from Sequential, this might have side-effects on all the other models that depend on resnet (segmentation, object detection etc), so I'm not sure if it can be done in a backward-compatible manner. Note also that replacing the forward is not something we can do due to how Quantization works: https://github.com/pytorch/vision/blob/6116812d51c6e5524ef48f8b14c0293ed674ed55/torchvision/models/quantization/resnet.py#L93-L100
@fmassa Let me know if you see any problems about replacing flatten.
I would be fine replacing torch.flatten
with nn.Flatten
, although it would only simplify model surgery if we were to make further modifications to the ResNet model as you suggested (making it inherit from nn.Sequential
). Without this, only del model.fc
etc wouldn't be enough.
So from this perspective, there is limited value in replacing torch.flatten
with nn.Flatten
, it being only syntactic sugar from one another.
Oh, I didn't realize:
# Ensure scriptability
# super(QuantizableResNet,self).forward(x)
# is not scriptable
If this becomes fixed, I guess inheriting from nn.Sequential should be fine.
My real usecase: I'm working with a codebase https://github.com/ajabri/videowalk/blob/master/code/resnet.py#L41 that had to reimplement resnet forward because it wants to do simple model surgery
@vadimkantorov model surgery in PyTorch generally requires re-writing forward or other parts of the model, except for nn.Sequential
, so I would say it's a valid requirement to ask the users to be a bit more verbose
model surgery in PyTorch generally requires re-writing forward or other parts of the model, except for
nn.Sequential
Of course. I understand that in general it's the case, but if it can be made simpler, I think it should be made simpler even if in general case full re-writing is required. In this particular case, it seems that because of quantization scripting limitations deriving from nn.Sequential is a no-go, but when that's solved it could be nice to derive ResNet from nn.Sequential.
I'll rename thie issue accordingly
It seems that deriving from nn.Sequential would also be good from DeepSpeed-compat standpoint: https://github.com/pytorch/pytorch/issues/51574
For quantization case, couldn't it just insert quant/dequant-calling modules in the beginning and the end of sequential? then everything should work and no modification of forward would be needed
They are already modules:
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
One would just be required to modify the constructor and insert them before conv1 and after layer4
@fmassa Even if we just replace torch.flatten by self.flatten, it already simplifies model surgery by model.avgpool = model.flatten = model.fc = nn.Identity()
(instead of del model.avgpool, model.flatten, model.fc
)
@vadimkantorov I think we should revisit how one does model surgery in PyTorch. In #3597 I have a prototype feature (which I'll be updating soon) that has as a much more robust and generic way of providing feature extraction
As it relies on FX, it and quantization is also relying on FX for extracting the graph, I think this could be a good compromise which would address all your points I believe
I think deleting some upper layers is only tangentially related to intermediate feature extraction, as with model.avgpool = model.flatten = model.fc = nn.Identity()
is a very simple and understandable way removing some layers (and not involving a recent new technology).
I propose to still do self.flatten
, even if you merge your new FX-based feature
@vadimkantorov the magic with the FX-based approach is that if you specify any part of the model, you can actually delete (by removing the compute and the parameters) of the end of the model. So it is actually a generalization of what you are proposing
I understand that it also solves this case :) But self.flatten
is so simple and doesn't break any back-compat and doesn't require learning the new compiling functionality just for this simple goal :)
I think both are worthwhile :) (And maybe even figuring out how to derive from nn.Sequential for DeepSpeed benefits as well, but that's separate)
And probably the FX would still forward through the "layers-to-be-removed", and the large fully-connected layer can be quite costly
@vadimkantorov we merged https://github.com/pytorch/vision/pull/4302, which should provide a generic way of performing model surgery.
Could you give this a try and provide us feedback if it is enough for your use-cases?
Very simple: remove average pooling, remove the last fc layer. I feel that having to plug generic FX / rewriting for this is an overkill
@vadimkantorov the thing you just mentioned can be done in one line
model_surgery = create_feature_extractor(model, ['layer4'])
or more generically
nodes, _ = get_graph_node_names(model)
model_surgery = create_feature_extractor(model, nodes[-3])
I don't even need the fc layer to run. Will it still run with it? Or will it do DCE?
I mean, I understand that FX is also a way to achieve this, but just being able to del backbone.fc
or backbone.fc = nn.Identity()
is so much simpler way to achieve this for simple models.
The FC layer (and its parameters) will be DCEd from the graph and won't be executed, so this is taken care for you.
I mean it's great to have a generic solution, but having an equivalent much simpler way (when possible) that users already know how to use is also a win
Then it also probably will be less debuggable. E.g. can I put a breakpoint in transformed code?
The transformed code can be printed etc, and is executed as standard Python code by the Python interpreter, so you can jump to every line of it in a Python debugger.
For example
m = torchvision.models.resnet18()
mm = torchvision.models.feature_extraction.create_feature_extractor(m, ['layer2'])
print(mm.code)
will give you
def forward(self, x : torch.Tensor):
conv1 = self.conv1(x); x = None
bn1 = self.bn1(conv1); conv1 = None
relu = self.relu(bn1); bn1 = None
maxpool = self.maxpool(relu); relu = None
layer1_0_conv1 = getattr(self.layer1, "0").conv1(maxpool)
layer1_0_bn1 = getattr(self.layer1, "0").bn1(layer1_0_conv1); layer1_0_conv1 = None
layer1_0_relu = getattr(self.layer1, "0").relu(layer1_0_bn1); layer1_0_bn1 = None
layer1_0_conv2 = getattr(self.layer1, "0").conv2(layer1_0_relu); layer1_0_relu = None
layer1_0_bn2 = getattr(self.layer1, "0").bn2(layer1_0_conv2); layer1_0_conv2 = None
add = layer1_0_bn2 + maxpool; layer1_0_bn2 = maxpool = None
layer1_0_relu_1 = getattr(self.layer1, "0").relu(add); add = None
layer1_1_conv1 = getattr(self.layer1, "1").conv1(layer1_0_relu_1)
layer1_1_bn1 = getattr(self.layer1, "1").bn1(layer1_1_conv1); layer1_1_conv1 = None
layer1_1_relu = getattr(self.layer1, "1").relu(layer1_1_bn1); layer1_1_bn1 = None
layer1_1_conv2 = getattr(self.layer1, "1").conv2(layer1_1_relu); layer1_1_relu = None
layer1_1_bn2 = getattr(self.layer1, "1").bn2(layer1_1_conv2); layer1_1_conv2 = None
add_1 = layer1_1_bn2 + layer1_0_relu_1; layer1_1_bn2 = layer1_0_relu_1 = None
layer1_1_relu_1 = getattr(self.layer1, "1").relu(add_1); add_1 = None
layer2_0_conv1 = getattr(self.layer2, "0").conv1(layer1_1_relu_1)
layer2_0_bn1 = getattr(self.layer2, "0").bn1(layer2_0_conv1); layer2_0_conv1 = None
layer2_0_relu = getattr(self.layer2, "0").relu(layer2_0_bn1); layer2_0_bn1 = None
layer2_0_conv2 = getattr(self.layer2, "0").conv2(layer2_0_relu); layer2_0_relu = None
layer2_0_bn2 = getattr(self.layer2, "0").bn2(layer2_0_conv2); layer2_0_conv2 = None
layer2_0_downsample_0 = getattr(getattr(self.layer2, "0").downsample, "0")(layer1_1_relu_1); layer1_1_relu_1 = None
layer2_0_downsample_1 = getattr(getattr(self.layer2, "0").downsample, "1")(layer2_0_downsample_0); layer2_0_downsample_0 = None
add_2 = layer2_0_bn2 + layer2_0_downsample_1; layer2_0_bn2 = layer2_0_downsample_1 = None
layer2_0_relu_1 = getattr(self.layer2, "0").relu(add_2); add_2 = None
layer2_1_conv1 = getattr(self.layer2, "1").conv1(layer2_0_relu_1)
layer2_1_bn1 = getattr(self.layer2, "1").bn1(layer2_1_conv1); layer2_1_conv1 = None
layer2_1_relu = getattr(self.layer2, "1").relu(layer2_1_bn1); layer2_1_bn1 = None
layer2_1_conv2 = getattr(self.layer2, "1").conv2(layer2_1_relu); layer2_1_relu = None
layer2_1_bn2 = getattr(self.layer2, "1").bn2(layer2_1_conv2); layer2_1_conv2 = None
add_3 = layer2_1_bn2 + layer2_0_relu_1; layer2_1_bn2 = layer2_0_relu_1 = None
layer2_1_relu_1 = getattr(self.layer2, "1").relu(add_3); add_3 = None
return {'layer2': layer2_1_relu_1}
which is what is executed by Python, and what the Python interpreter will execute
Well, it's great to have, but compared to just stepping into resnet.py, this is not the same :) jk
I agree this is a great, powerful feature allowing to use a lot backbones and it's good that it unblocks this issue, but I don't see reasons for not impoving the basic ResNet by moving flatten into an attribute and maybe transforming the quantized resnet into sequential
I think the key idea here is that we are providing a single (generic?) solution that handle many more use-cases than what was possible before. Overriding modules with nn.Identity()
was a hacky solution that didn't work most of the time (only inside nn.Sequential
-style models), and can lead to confusion / silent bugs in many cases, specially to newer users I think.
Having the forward
implemented for the models (even if it is a straight Sequential) can still be beneficial for users as it allows for the exact same things you have been advocating for in your past messages (namely seeing the execution code, more easily putting breakpoints / prints, etc).
Still, users are free to use what they think best fit their needs.
I was of course not talking of "newer users". Of course these hacks are for someone who already understands the backbone model source code and looks to avoid boiler-plate of copy-pasting the forward. This is no more "advanced" as the recommended individual model component reinitialization (e.g. model.roi_heads.box_predictor
) as found in https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html
In-the-wild surgery by resetting model.fc = nn.Identity()
: https://lernapparat.de/resnet-how-many-models/
One more reason for models to be nn.Sequential whenever possible: https://pytorch.org/docs/stable/checkpoint.html?highlight=checkpoint_sequential#torch.utils.checkpoint.checkpoint_sequential
Currently In https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py#L243:
If it instead used
x = self.flatten(x)
, then it would simplify model surgery:del model.avgpool, model.flatten, model.fc
. Also in this case the class can just derive from Sequential and use OrderedDict to pass submodules (like in https://discuss.pytorch.org/t/ux-mix-of-nn-sequential-and-nn-moduledict/104724/2?u=vadimkantorov), this would preserve checkpoint compat as well. The methodforward
could then be removed