dotnet / TorchSharp

A .NET library that provides access to the library that powers PyTorch.
MIT License
1.39k stars 181 forks source link

Classes of torchvision\ops #790

Open xhuan8 opened 2 years ago

xhuan8 commented 2 years ago

Is there any plan to add classes of torchvision/ops?

roi_align and boxes are required, they are not implemented in TorchSharp

return torch.ops.torchvision.roi_align(
        input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned
    )
GeorgeS2019 commented 2 years ago

TORCHVISION.OPS

suggestion: Need to approach this ALSO from the perspective of Onnx contributed Ops e.g. com.microsoft.vision

flowchart TD
    image-->|com.microsoft.vision\nops| TorchSharp\nmodel-->|com.microsoft.vision\nops| output\nimage
GeorgeS2019 commented 2 years ago

@kaiidams

Why there are no Torchaudio.ops?

Are they in here? https://github.com/pytorch/audio/tree/main/torchaudio/csrc

e.g. com.microsoft.audio

flowchart TD
    audio-->|com.microsoft.audio\nops| TorchSharp\nmodel-->|com.microsoft.audio\nops| output\naudio
GeorgeS2019 commented 2 years ago

Most of the Torchtext.ops are here https://github.com/pytorch/text/tree/main/torchtext/csrc

com.microsoft.nlp

flowchart TD
    text-->|com.microsoft.nlp\nops| TorchSharp\nmodel-->|com.microsoft.nlp\nops| output\ntext
kaiidams commented 2 years ago

@GeorgeS2019

Probably ONNX runtime approach is not related to this.

Why there are no Torchaudio.ops?

torchaudio has C code that uses Kaldi and Sox, FFMpeg, which is not implemented in TorchSharp. IMHO, Modern models don't depend on these except I/O, but still useful in some cases. torchaudio doesn't call them ops.

@NiklasGustafsson I don't know about what happened to torchvision.ops, if they are missing features, do you have plan to add them?

NiklasGustafsson commented 2 years ago

There's a small number of ops in torchvision.ops. Feel free to contribute more, if you have time. It would be good to have an issue to track them. I'll open something up.

NiklasGustafsson commented 2 years ago

BTW, it seems like FFMPEG is disabled by default in PyTorch, now:

https://pytorch.org/vision/stable/#torchvision.set_video_backend

You have to build from source to enable it, apparently.

xhuan8 commented 2 years ago

seems there is no prebuild package for torchvision, what is the best way to add it after build from source?

NiklasGustafsson commented 2 years ago

seems there is no prebuild package for torchvision, what is the best way to add it after build from source?

TorchVision is now available on NuGet: https://www.nuget.org/packages/TorchVision, starting with version 0.98.1, when it was split out from TorchSharp.

xhuan8 commented 2 years ago

I try to call the native method from torchvision, torch.ops.torchvision.nms https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/nms.cpp https://github.com/pytorch/vision/blob/main/torchvision/ops/_register_onnx_ops.py

how to do it in TorchSharp?

NiklasGustafsson commented 2 years ago

I believe someone already implemented nms in C#.

https://github.com/dotnet/TorchSharp/blob/2f89e580090624ab7329391841e0884c506a0813/src/TorchVision/Ops.cs#L48

NiklasGustafsson commented 2 years ago

To answer your question, though -- I haven't understood where the native code binaries and header files for TV are to be found. If you figure that out, we should bundle it in with the other native code we already have.

xhuan8 commented 2 years ago

They are under https://github.com/pytorch/vision/tree/main/torchvision/csrc/ops, contains cpu and cuda implementations. We don't need the IO operations, so FFMPEG and JPEG are not needed. Currently I'll try build from souce and use it on Windows.

NiklasGustafsson commented 2 years ago

Have you found where pre-built binaries are available for download?

xhuan8 commented 2 years ago

no, only build from source.

NiklasGustafsson commented 2 years ago

Note that the current libtorch binaries are 1.11, with CUDA 11.3 -- that will matter for building compatible TV binaries, I believe.

GeorgeS2019 commented 2 years ago

I try to call the native method from torchvision, torch.ops.torchvision.nms https://github.com/pytorch/vision/blob/main/torchvision/csrc/ops/nms.cpp https://github.com/pytorch/vision/blob/main/torchvision/ops/_register_onnx_ops.py

For Onnxruntime, here is some preliminary discussion how to register Onnx Ops WIP [Documentation]: C# Workflow for consuming "Augmented" Onnx model with Custom Operators

xhuan8 commented 2 years ago

Fortunately I have done the build, there are 3 files, torchvision.dll, torchvision.exp, torchvision.lib, and quite small, only 2 MB

NiklasGustafsson commented 2 years ago

@xhuan8 -- this is really cool! There are a couple of things to think about next:

  1. A different output names (more on that later), since 'torchvision.dll' will conflict with the DLL name for the .NET project.

  2. How we distribute the native library. Currently, we have a package (libtorch-*) containing the CPU backends for all three platforms, and one for CUDA on Windows, and one for CUDA on Linux. We needed this because the CUDA backends are gigantic.

So, we have to decide where these native libraries go, whether in the same package as the managed TorchVision binaries, or in a separate one. It will depend on the overall size (separate Windows + MacOS + Linux binaries), as well as whether a CUDA backend can be loaded on a machine without a CUDA-capable GPU. This will impact the library names we pick, too.

NiklasGustafsson commented 2 years ago

Also, how to automate the build -- do we do this manually and add the binaries to the build, or do we somehow integrate the pytorch/vision repo into the TorchSharp build process? We should consider pytorch/text and pytorch/audio at the same time.

Thanks for doing this work, it's going to make a huge difference!

NiklasGustafsson commented 2 years ago

TORCHVISION.OPS

suggestion: Need to approach this ALSO from the perspective of Onnx contributed Ops e.g. com.microsoft.vision

@GeorgeS2019 -- I love your enthusiasm for ONNX runtime ops. However, I want to say that TorchSharp is about providing a .NET layer on top of the native library (libraries) underlying PyTorch. Nothing less, nothing more. It's simplistic, perhaps, but the main point is that it simple. Other projects can and should go beyond TorchSharp and provide all kinds of additional features, including interactions with ONNX.

Thus, ONNX runtime ops (contributed or not) lie outside the scope of what TorchSharp will be designed for.

xhuan8 commented 2 years ago

try to load torchvision native dll with LoadLibrary, it fails with error 1114, @NiklasGustafsson @kaiidams do you have any idea? the dll upload to: https://drive.google.com/file/d/1LHd1jwuFlFT87-vT09hzsOWpUUeD-TKF/view?usp=sharing

GeorgeS2019 commented 2 years ago

Instead of loading a native dll, the latest way is to embedded a static library .lib through .csproj and compile that into a normal c# dll. More and more of Microsoft products which are previously c++ are now in c#

Godot4 .NET6 NativeAOT is an excellent recent example.

image

GeorgeS2019 commented 2 years ago

The Godot c# library attaches a c++ Shared Library _internal ( or _internal.lib ) image

GeorgeS2019 commented 2 years ago

@xhuan8 An interesting MIT License Visual Programming approach with pre/post vision processing node(Ops) around ONNX

FYI: could be interesting and the example involves TorchScript, however not related to TorchSharp, more investigation needed.

image

kaiidams commented 2 years ago

@xhuan8 If this is the build from torchvision, it is a C++ torchvision library. You'll need to make a C wrapper so that C# can use it with P/Invoke.

(py310) C:\local>dumpbin /dependents "libtorchvision.dll"
Microsoft (R) COFF/PE Dumper Version 14.33.31630.0
Copyright (C) Microsoft Corporation.  All rights reserved.

Dump of file libtorchvision.dll

File Type: DLL

  Image has the following dependencies:

    c10.dll
    c10_cuda.dll
    torch_cuda_cu.dll
    torch_cuda_cpp.dll
    torch_cpu.dll
    KERNEL32.dll
    MSVCP140.dll
    VCRUNTIME140.dll
    VCRUNTIME140_1.dll
    api-ms-win-crt-runtime-l1-1-0.dll
    api-ms-win-crt-math-l1-1-0.dll
    api-ms-win-crt-heap-l1-1-0.dll
    api-ms-win-crt-stdio-l1-1-0.dll
    api-ms-win-crt-filesystem-l1-1-0.dll
    api-ms-win-crt-string-l1-1-0.dll
    api-ms-win-crt-time-l1-1-0.dll

  Summary

        D000 .data
        1000 .nvFatBi
      149000 .nv_fatb
        7000 .pdata
       4E000 .rdata
        2000 .reloc
        1000 .rsrc
       F0000 .text

If Torch C++ provides torch::fft::hfft2() then we have THSTensor_hfft2 for C# to call. https://github.com/dotnet/TorchSharp/blob/07047718d3f9a7c9f946223ea8646b2988cdc65d/src/Native/LibTorchSharp/THSFFT.cpp#L64

kaiidams commented 2 years ago

@NiklasGustafsson To build torchvision.dll you'll need Python (and zlib, libpng, CUDA, etc) I think it should be built outside TorchSharp.

NiklasGustafsson commented 2 years ago

@NiklasGustafsson To build torchvision.dll you'll need Python (and zlib, libpng, CUDA, etc) I think it should be built outside TorchSharp.

Yeah, you're right... Ideally, we would re-implement the C++ bodies of TorchVision in C# -- that would solve a lot of problems.

NiklasGustafsson commented 2 years ago

@xhuan8 -- would that be feasible, do you think? To just re-implement the ops in C# using Torch operators, like how nms() was implemented?

xhuan8 commented 2 years ago

@NiklasGustafsson It takes time to implement with C# and verify the correctness, and also not sure how is the performace without cuda. Currently I'll use the c++ wrapper, the method name from torchvision is little bit strange, like ?nms@ops@vision@@YA?AVTensor@at@@AEBV34@0N@Z

NiklasGustafsson commented 2 years ago

If the ops are implemented in terms of TorchSharp ops, CUDA should come for free.

That name is a C++ mangled name, which means that it's missing the 'extern "C"' declaration.

NiklasGustafsson commented 1 year ago

Same as I asked about Faster RCNN -- you've been working on this, it appears. Anything ready for a PR?

GeorgeS2019 commented 1 year ago

@NiklasGustafsson Now we have a user routinely compiling NLP ops in TorchSharp. Perhaps a wiki on this with links to user projects?