xuebinqin / U-2-Net

The code for our newly accepted paper in Pattern Recognition 2020: "U^2-Net: Going Deeper with Nested U-Structure for Salient Object Detection."
Apache License 2.0
8.61k stars 1.48k forks source link

Using with torch.jit.trace and C++? #29

Closed DBraun closed 4 years ago

DBraun commented 4 years ago

I started a discussion here https://discuss.pytorch.org/t/debugging-runtime-error-module-forward-inputs-libtorch-1-4/82415

I modified u2net_test.py and used torch.jit.trace to save a module

traced_script_module = torch.jit.trace(net, inputs_test)
traced_script_module.save("traced_model.pt")
print(inputs_test.size()) # shows (1, 3, 320, 320)

Then in c++

auto module = torch::jit::load("traced_model.pt");
torchinputs.clear();
torchinputs.push_back(torch::ones({1, 3, 320, 320 }, torch::kCUDA).to(at::kFloat)); // because python was torch.FloatTensor
module.forward(torchinputs); // error

The error:

 Unhandled exception at 0x00007FFFD8FFA799 in TouchDesigner.exe: Microsoft C++ exception: std::runtime_error at memory location 0x000000EA677F1B30. occurred

stacktrace

The error is at https://github.com/pytorch/pytorch/blob/4c0bf93a0e61c32fd0432d8e9b6deb302ca90f1e/torch/csrc/jit/api/module.h#L112 It says inputs has size 0. I don't know if that's the cause of the exception or a result.

Do you have advice about running U-2-Net in C++? Thank you.

wpmed92 commented 4 years ago

Hi! I use U-2-Net on Android, which internally uses C++ and I managed to make it work. With trace I couldn't make it work, instead I used the following mode of converting to TorchScript:

net.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu'))) ` if torch.cuda.is_available(): net.cuda() scripted = torch.jit.script(net) torch.jit.save(scripted, "fod.p")`

Hope this helps.

DBraun commented 4 years ago

@wpmed92 Thanks for your suggestion. I did what you suggested, and torch.cuda.is_available() was still true. However, in C++ the model seems to stall forever on the ->forward call. No error, just taking forever.

wpmed92 commented 4 years ago

Oh, I see. Not completely related, but we found something really interesting. On Android the jit.script() saved model inference runs in about a sec, on iOS it was 30 sec, so on iOS we use the traced model, which turns out to be running at about the same speed as on Android. Internally both platforms use the same C++ library, so I'm wondering what may cause such difference.

DBraun commented 4 years ago

Glad to have finally resolved this. I'm using TouchDesigner as my environment that loads and executes from my DLL. It turns out I needed to paste ALL the libtorch DLLs in a location that TouchDesigner itself would load them C:/Program Files/Derivative/TouchDesigner/bin, not just when my custom DLL is loaded from Documents/Derivative/Plugins