LaurentMazare / tch-rs

Rust bindings for the C++ api of PyTorch.
Apache License 2.0
4.19k stars 328 forks source link

TorchScript Model Weight Conversion #732

Closed bokenator closed 1 year ago

bokenator commented 1 year ago

I am following the instruction to convert a pytorch model weight into a torchscript weight, but when I am trying to load the weights, I'm getting the following error:

Error: cannot find the tensor named layer4.0.downsample.1.running_mean in ./weights/resnet18.ot

Error messages may show different layer each time it's run, but they seem to always be associated with a batch norm layer. For example, I also get these error messages:

Error: cannot find the tensor named layer2.0.downsample.1.running_var in ./weights/resnet18.ot

or

Error: cannot find the tensor named layer3.0.bn2.running_var in ./weights/resnet18.ot

The model definition itself seems to be correct because I'm able to load the resnet18.ot file in the release files properly and do inference. However, the weights that I tried to convert myself are not working.

Here's python function I wrote to do the conversion, I have attempted both the trace method and the annotation method:

def convert_model(Block: Union[Type[BasicBlock], Type[Bottleneck]], layer_count: List[int], url: str, name: str):
    model_path = './weights'
    input_name = f'{name}.pth'
    output_name = f'{name}.ot'

    # Download model
    if not os.path.exists(f'./weights'):
        os.makedirs('./weights')
    if not os.path.exists(f'./weights/{input_name}'):
        urlretrieve(url, f'./weights/{input_name}')

    # resnet = ResNet(Block, layer_count)
    # resnet.load_state_dict(torch.load(f'{model_path}/{input_name}')) # type: ignore
    # script_module = torch.jit.script(resnet)   # type: ignore
    # script_module.save(f'{model_path}/{output_name}')   # type: ignore

    resnet = ResNet(Block, layer_count)
    resnet.load_state_dict(torch.load(f'{model_path}/{input_name}')) # type: ignore
    resnet.eval()
    example = torch.rand(1, 3, 224, 224)
    script_module = torch.jit.trace(resnet, example)   # type: ignore
    script_module.save(f'{model_path}/{output_name}')   # type: ignore

if __name__ == '__main__':
    for weight_url, name, Block, layer_count in [
        ('https://download.pytorch.org/models/resnet18-f37072fd.pth', 'resnet18', BasicBlock, [2, 2, 2, 2]),
        ('https://download.pytorch.org/models/resnet34-b627a593.pth', 'resnet34', BasicBlock, [3, 4, 6, 3]),
        ('https://download.pytorch.org/models/resnet50-11ad3fa6.pth', 'resnet50', Bottleneck, [3, 4, 6, 3]),
        ('https://download.pytorch.org/models/resnet101-cd907fc2.pth', 'resnet101', Bottleneck, [3, 4, 23, 3]),
        ('https://download.pytorch.org/models/resnet152-f82ba261.pth', 'resnet152', Bottleneck, [3, 8, 36, 3]),
    ]:
        convert_model(Block, layer_count, weight_url, name)

For reference, you can download the problematic weights at https://ml47.s3.amazonaws.com/resnet18.ot.

Any suggestion would be greatly appreciated.

LaurentMazare commented 1 year ago

I am following the instruction to convert a pytorch model weight into a torchscript weight

First note that we would now recommend using safetensors to export weights from the Python side as it's much simpler, see this export script.

One thing that could help debugging this kind of issue is inspecting the content of the weight file via cargo run --example tensor-tools ls resnet18.ot. Looking at your file it indeed seems that the batch norm running mean and var are not included. Not sure where your resnet python definition is coming from, is it possible that it's set with track_running_stats=False or something like this (python doc).

bokenator commented 1 year ago

I played around with it some more. The running_mean and running_val is definitely loaded by the model because I can print their values out. It seems that torch.jit.script is either not converting them into the torchscript representation or torch_script_module.save is not saving it into the .ot file.

The resnet definition is straight from torchvision.models. Whether loading the weights manually with load_state_dict or instantiating the model directly with models.resnet18(pretrained=True) yield the same result.

I also tried out export using safetensors, but I'm seeing the following errors when running tensor-tools against the output files:

Error: Internal torch error: PytorchStreamReader failed reading zip archive: failed finding central directory
Exception raised from valid at ../caffe2/serialize/inline_container.cc:184 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x6b (0x7efc4ec5a6bb in /opt/libtorch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xbf (0x7efc4ec555ef in /opt/libtorch/lib/libc10.so)
frame #2: caffe2::serialize::PyTorchStreamReader::valid(char const*, char const*) + 0x3ca (0x7efc529571fa in /opt/libtorch/lib/libtorch_cpu.so)
frame #3: caffe2::serialize::PyTorchStreamReader::init() + 0xad (0x7efc52957a0d in /opt/libtorch/lib/libtorch_cpu.so)
frame #4: caffe2::serialize::PyTorchStreamReader::PyTorchStreamReader(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x13a (0x7efc5295b1ea in /opt/libtorch/lib/libtorch_cpu.so)
frame #5: torch::jit::import_ir_module(std::shared_ptr<torch::jit::CompilationUnit>, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, c10::optional<c10::Device>, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > > >&, bool, bool) + 0x28d (0x7efc53b1955d in /opt/libtorch/lib/libtorch_cpu.so)
frame #6: torch::jit::import_ir_module(std::shared_ptr<torch::jit::CompilationUnit>, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, c10::optional<c10::Device>, bool) + 0x92 (0x7efc53b199f2 in /opt/libtorch/lib/libtorch_cpu.so)
frame #7: torch::jit::load(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, c10::optional<c10::Device>, bool) + 0xd1 (0x7efc53b19b21 in /opt/libtorch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0xddd59 (0x559de2652d59 in target/debug/examples/tensor-tools)
frame #9: <unknown function> + 0x6917f (0x559de25de17f in target/debug/examples/tensor-tools)
frame #10: <unknown function> + 0x64cd1 (0x559de25d9cd1 in target/debug/examples/tensor-tools)
frame #11: <unknown function> + 0x596cb (0x559de25ce6cb in target/debug/examples/tensor-tools)
frame #12: <unknown function> + 0x5cd6e (0x559de25d1d6e in target/debug/examples/tensor-tools)
frame #13: <unknown function> + 0x5cd31 (0x559de25d1d31 in target/debug/examples/tensor-tools)
frame #14: <unknown function> + 0x37643c (0x559de28eb43c in target/debug/examples/tensor-tools)
frame #15: <unknown function> + 0x5cd0a (0x559de25d1d0a in target/debug/examples/tensor-tools)
frame #16: <unknown function> + 0x666be (0x559de25db6be in target/debug/examples/tensor-tools)
frame #17: <unknown function> + 0x29d90 (0x7efc4ea01d90 in /lib/x86_64-linux-gnu/libc.so.6)
frame #18: __libc_start_main + 0x80 (0x7efc4ea01e40 in /lib/x86_64-linux-gnu/libc.so.6)
frame #19: <unknown function> + 0x2f545 (0x559de25a4545 in target/debug/examples/tensor-tools)
LaurentMazare commented 1 year ago

Did you ensure that the safetensors file is named with a .safetensors suffix? (that's used to trigger using safetensors for decoding, a bit sad that we have this kind of implicit magic but anyway)

bokenator commented 1 year ago

That's the magic we needed!

I will open a PR over the weekend to update the tutorial for weight conversion.

Thanks so much for your help!