dotnet / TorchSharpExamples

Repository for TorchSharp examples and tutorials.
MIT License
129 stars 47 forks source link

Loading Python Exported Model into TorchSharp #22

Open NiklasGustafsson opened 2 years ago

NiklasGustafsson commented 2 years ago

Originally posted in dotnet/TorchSharp by @jimquittenton:

https://github.com/dotnet/TorchSharp/issues/586

The naming scheme for layers are different in the ResNet example model found in this repo and the ResNet models found in TorchVision, which prevents a model saved from Python from being loaded in TorchSharp using this example code.

Original post:


Hi, I'm new to TorchSharp and am having trouble loading a python trained ResNet18 model. I've been following this article: https://github.com/dotnet/TorchSharp/blob/main/docfx/articles/saveload.md and have exported my python model using the 'save_state_dict' function in this script: https://github.com/dotnet/TorchSharp/blob/main/src/Python/exportsd.py .

In TorchSharp I have copied the ResNet model from https://github.com/dotnet/TorchSharpExamples/blob/main/src/CSharp/Models/ResNet.cs and then call the following:

int numClasses = 3; ResNet myModel = ResNet.ResNet18(numClasses); myModel.to(DeviceType.CPU); myModel.load(mPath); The load() line throws an exception with message Mismatched module state names: the target modules does not have a submodule or buffer named 'conv1.weight'.

If I examine the state_dict from 'myModel' prior to load(), it contains entries like:

{[layers.conv2d-first.weight, {TorchSharp.Modules.Parameter}]} {[layers.bnrm2d-first.weight, {TorchSharp.Modules.Parameter}]} {[layers.bnrm2d-first.bias, {TorchSharp.Modules.Parameter}]} {[layers.bnrm2d-first.running_mean, {TorchSharp.torch.Tensor}]} {[layers.bnrm2d-first.running_var, {TorchSharp.torch.Tensor}]} {[layers.bnrm2d-first.num_batches_tracked, {TorchSharp.torch.Tensor}]} {[layers.blck-64-0.layers.blck-64-0-conv2d-1.weight, {TorchSharp.Modules.Parameter}]} {[layers.blck-64-0.layers.blck-64-0-bnrm2d-1.weight, {TorchSharp.Modules.Parameter}]} {[layers.blck-64-0.layers.blck-64-0-bnrm2d-1.bias, {TorchSharp.Modules.Parameter}]} whereas the corresponding entries prior to saving from python are:

conv1.weight torch.Size([64, 3, 7, 7]) bn1.weight torch.Size([64]) bn1.bias torch.Size([64]) bn1.running_mean torch.Size([64]) bn1.running_var torch.Size([64]) bn1.num_batches_tracked torch.Size([]) layer1.0.conv1.weight torch.Size([64, 64, 3, 3]) layer1.0.bn1.weight torch.Size([64]) layer1.0.bn1.bias torch.Size([64]) I tried amending the ResNet.cs code to reflect the python names, but could not get them to exactly match.

I also tried calling load() with strict=false myModel.load(mPath, false);. This seemed to get past the Mismatched names exception, but throws another exception with message Too many bytes in what should have been a 7 bit encoded Int32.

I've been struggling with this for a couple of days now so would really appreciate any help you guys could offer.

Thanks Jim

jimquittenton commented 2 years ago

Hi Niklas Firstly I've been away until today so apologies for the late reply. I saw the pull request you made is now merged in so I tried cloning the main repo and building TorchSharp locally. I got it building to an extent but couldn't build the Test projects or run the examples due to cuda dependencies (I'm on a CPU only machine so only used the cpu version of libtorch so far). I'm assuming this change fixes loading of the weights and skipping layers as per your message, but so far have not able to test it as yet. Do you have an idea as to when this would make it onto a nuget package (as that would be way easier for me)?

As for the changing of the ResNet layer names, I had a go at changing them and started by creating a spreadsheet matching up all the torchvision model names to those in the TorchSharp ResNet18 State_Dict (I've attached this so you can see the differences), but while I can see the lists of (name, module) being built up in say BasicBlock, I can't see a way to prevent the name of the root module 'layers' becoming part of the final name. Also the AddLayer() method adds the block name into each (string, module) tuple as well as passing it into the delegate for BasicBlock so this means it ends up being in the final name twice. If you could start me off in resolving these issues then I'd be happy to have a go at the rest of the renaming.

Thanks Jim

torchsharp-resnet18-names.xlsx

NiklasGustafsson commented 2 years ago

Hi!

The PR contains a ResNet with the right names, but the loading of weights will still be a manual thing, in a way -- you can pass the path to the weights file into the ResNet factory method, and it gives you the option to skip the last Linear layer so that you can repurpose it for your own classification:

        public static Modules.ResNet resnet18(int num_classes = 1000,
                string weights_file = null,
                bool skipfc = true,
                Device device = null)

This will be in the next release, which will probably be next week some time. If you don't pass in a path, then it will create the model with random weights.

Note that we will keep the CIFAR-related versions in the Examples -- they can handle the 30x30 images that CIFAR provides and is about 10x faster than the torchvision.models version, which requires you to resized to 224x224.

NiklasGustafsson commented 2 years ago

@jimquittenton -- can I close this now?