Closed gcuendet closed 1 year ago
I've reproduced the error on main, and it is occurring on this line, where the operation torch::jit::EliminateExceptions
does not complete. This is the source code for torch::jit::EliminateExceptions
.
https://github.com/pytorch/TensorRT/blob/a245b861d75fe0cb007eca5d23b3a992113b268b/core/lowering/lowering.cpp#L106
The graph at that point is shown below, and has a prim::RaiseException
type within a prim::If
@bowang007 - this seems related to your work with exceptions and control flow, do you have any suggestions on this?
Thanks @gs-olive for having taken a look at this issue so quickly! Nice that you could reproduce it! Following up on that, here is another simple Network (though slightly less trivial than the one above).
Interestingly, when converting a torchscript generated from that network using torch.jit.script
on linux, the behaviour is the same as with the trivial network previously shared: using torch-tensorrt recent commit from main, it hangs (i.e. the operation torch::jit::EliminateExceptions
does not complete).
Nevertheless, when converting the torchscript on windows using the same torch-tensorrt recent commit from main, it works! Note that it works both when the torchscript is generated on windows or on linux.
Both graphs, as printed by Torch-TensorRT when calling compile
are included in archive.zip, as well as the mini_net.py network definition/torchscript generation file. In summary, here is a diff between the graphs, as printed by Torch-TensorRT:
Please let me know if I can provide more details to help solve this issue!
Thank you for the additional details - this is very helpful!
Hi @gs-olive
I have been working on the idea you described in #1842 . See this commit. Of course, this initial implementation is overly specific and only solves the case of upsample_bilinear2d
but together with your proposed fix of torch::jit::EliminateExceptions
and an additional small modification of torch-tensorRT custom exceptions elimination pass, is seems that I am able to successfully convert the two networks linked in this issue (the simple upsample, as well as the slightly less trivial one).
I would appreciate some feedback on this. Is my approach going in the right direction, w.r.t. what you had in mind when describing the #1842 issue?
Also, what do you (or @bowang007 maybe?) think of the changes I made to TensorRT/core/lowering/passes/exception_elimination.cpp
? The rationale is that instead of catching only something like:
= prim::If(%5958)
block0():
= prim::RaiseException(%45)
-> ()
block1():
-> ()
or
= prim::If(%5958)
block0():
-> ()
block1():
= prim::RaiseException(%45)
-> ()
you could also catch more complex blocks, given that:
prim::If
, containing two blocksprim::RaiseException
(instead of the first one currently)What I've typically been observing in the scope of the Upsample investigation is something like:
= prim::If(%5958)
block0():
-> ()
block1():
%191 : str = aten::format(%108, %171, %117)
= prim::RaiseException(%191, %100)
-> ()
Hi @gcuendet - thanks for the update! The commit linked here is definitely along the lines of what was intended for #1842. One thing I was wondering about for that commit - on line 83 - was there an issue with calling if_node->destroy()
?
Regarding the changes made to TensorRT/core/lowering/passes/exception_elimination.cpp
, I think this rationale/edit idea is a good one. I have a few comments:
auto arm1_last = arm1->nodes().rbegin()
is always a prim::Return
node, since every block in a prim::If
must have a return to be valid. In that case, we could potentially require something like bool arm1_ends_with_exception = (*(++arm1_last))->kind() == prim::RaiseException
.prim::RaiseException
is the last node in the block, I am unsure if we can be certain that the block returns nothing. For example, consider the outermost prim::If
denoted by %out2.1
in the graph here: https://github.com/pytorch/TensorRT/issues/1823#issuecomment-1509005086. This graph has a prim::RaiseException
as the last node of a prim::If
, yet still returns something. I think the current code logic would result in an elimination in this case.Also @narendasan for any comments on the proposed edits to TensorRT/core/lowering/passes/exception_elimination.cpp
.
Thanks for the quick feedback! I am still fiddling with these changes and trying to make them work in more generic cases than the two overly simplified networks shared above, but regarding your questions:
One thing I was wondering about for that commit - on line 83 - was there an issue with calling if_node->destroy()?
At some point I had the impression that destroying the node was unnecessary, because some dead code removal pass would do it for you (I did observe that in some cases, but that might not be true in all cases). Maybe it could still be good to try to do it at that point, that way if the outputs are not properly replaced, this would fail. I'll check that.
I wonder if auto arm1_last = arm1->nodes().rbegin() is always a prim::Return node, since every block in a prim::If must have a return to be valid.
I don't think auto arm1_last = arm1->nodes().rbegin()
is always a prim::Return
node, no. It's true that every block in a prim::If
must have a return to be valid, but I think that the iterator returned by generic_graph_node_list<Node>::rbegin()
(which is what you get when you call Block::nodes().rbegin()
) starts on the last node just before the return node.
Additionally, in the case where the
prim::RaiseException
is the last node in the block, I am unsure if we can be certain that the block returns nothing.
You are right, but we also check earlier that
if (arm1->outputs().size() != 0 || arm2->outputs().size() != 0) {
// Make sure that the node doesn't actually produce any Value that are
// used by other nodes
return false;
}
So the hypothesis at that point is that none of the arms actually return a Value. In the example you point to (the prim::If
denoted by %out2.1
), both block0
and block1
have 1 output. And we should not even check that the last node is a prim::RaiseException
but will return false
earlier.
Hi! I have a small update regarding this work.
New commit implementing the custom and specific upsample_bilinear2d
exceptions removal is here.
A small note on that commit: even though the isUpsampleBlock
method is very specific, I think that the copyAllNodes
one could potentially be reused in the scope of #1842 . What I mean by that is that, once the prim::If
node has been identified as well as which block of that node is doing the computation of interest, copying that whole block is implemented by copyAllNodes
, including renaming the inputs of nodes that reuse outputs of previous nodes in that same block as well as replacing the usage of the prim::If
outputs in the rest of the graph by the corresponding outputs of the nodes inside that block.
Regarding your comment 1. above,
One thing I was wondering about for that commit - on line 83 - was there an issue with calling if_node->destroy()?
I think my previous answer was not completely accurate:
TensorRT/core/lowering/passes/exception_elimination.cpp
(with my changes, see next point) (and not by some dead code removal pass, as I was assuming before).New commit implementing changes to TensorRT/core/lowering/passes/exception_elimination.cpp
is here.
I changed just slightly the implementation, most importantly to verify that the block of the prim::If
node which is not raising an exception is also not computing anything. I initially thought that checking that the prim::If
node had zero outputs (or similarly that each of the two blocks of the prim::If
node had zero outputs) would be enough, but I saw cases where apparently that was not the case. (I am basically saying that the Values are not scoped in blocks, not sure if that's plausible or completely crazy).
Let me know if that's of interest to you, I'll be happy to open a PR!
Just for completeness: using Upsample
directly in the network definition still does not seem to work, even with the changes proposed above and the fix in torch::jit::EliminateException
.
I am now observing the following error, in shape_analysis.cpp
with the Network described in this comment:
GRAPH: [Torch-TensorRT] - Running shape analysis on block Segment Block @1:
Target: Torch
Graph: graph(%1 : bool,
%4 : bool?):
%self.block1.norm.training.15 : bool = prim::Constant[value=0]()
%3 : str = prim::Constant[value="builtins.ValueError"]()
%2 : str = prim::Constant[value="align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"]()
%align_corners0.1 : bool? = prim::If(%1)
block0():
= prim::RaiseException(%2, %3)
-> (%4)
block1():
-> (%self.block1.norm.training.15)
return (%align_corners0.1)
terminate called after throwing an instance of 'torch_tensorrt::Error'
what(): [Error thrown at core/partitioning/shape_analysis.cpp:187] Expected ivalues_maps.count(input) to be true but got false
Could not find torch::jit::Value* 474 produced from %474 : bool? = prim::Uninitialized() # :0:0 in lowering graph for mini graph input.
This seems linked to the check on the validity of the align_corners
option (checking that the interpolating mode is one of linear, bilinear, bicubic or trilinear) and the fact that, when it's invalid, a bool? = prim::Uninitialized()
is returned, which seems to cause the shape analysis to fail..? This is not simplified by the exception elimination pass, since the block actually returns something.
A workaround for this is to not call Upsample
, but rather directly upsample_bilinear2d
(or one of the other functions corresponding to the correct interpolating mode and dimensions).
I also include the lowered graph below (lowered with all the changes described above):
Hi @gcuendet - thank you very much for all of the work and detailed answers on this topic. I made a few comments on the upsample_bilinear2d
exceptions removal and the new changes to TensorRT/core/lowering/passes/exception_elimination.cpp
. I think both of these updates look good and would be welcome additions via a PR, though for the upsample_bilinear2d
exceptions removal, we would also need to add some testing since it is a new feature.
@narendasan - do you have any input on the proposed changes to TensorRT/core/lowering/passes/exception_elimination.cpp
and the custom upsample_bilinear2d
exceptions removal pass?
Regarding the Could not find torch::jit::Value*
issue, we are tracking + investigating this issue with @bowang007 across multiple reports, including #1834 and #1815. A common thread among all of these seems to be the presence of nested "If" blocks elsewhere in the code, though it's not yet clear if this is the root cause of the issue.
Hey @gcuendet can you try this PR: https://github.com/pytorch/TensorRT/pull/1933 I reproduced your bug for uninitialized error and I think this PR might help with it.
Thanks @bowang007 . Did you test on the small network described above? Is it working for you? The conversion is not working for me. I get the following error:
GRAPH: [Torch-TensorRT] - Running shape analysis on block Segment Block @3:
Target: Torch
Graph: graph(%out0.2 : Tensor,
%align_corners0.1 : bool?):
%self.block1.conv.bias.15 : NoneType = prim::Constant()
%10 : int = prim::Constant[value=4]()
%8 : float = prim::Constant[value=2.]()
%5 : bool = prim::Constant[value=1]()
%3 : int = prim::Constant[value=2]()
%0 : int = aten::dim(%out0.2)
%dim.2 : int = aten::sub(%0, %3)
%scale_factors2.2 : float[] = prim::ListConstruct()
= prim::Loop(%dim.2, %5)
block0(%6 : int):
%7 : float[] = aten::append(%scale_factors2.2, %8)
-> (%5)
%9 : bool = aten::eq(%0, %10)
%11 : bool = aten::__isnot__(%align_corners0.1, %self.block1.conv.bias.15)
return (%scale_factors2.2, %0, %9, %11)
terminate called after throwing an instance of 'torch_tensorrt::Error'
what(): [Error thrown at core/partitioning/shape_analysis.cpp:212] Expected to find type bool? for value align_corners0.1 but get nothing.
For reference the lowered graph is the following.
Hey @gcuendet, I was using the model provided:
import torch
import torch.nn as nn
import torch.nn.functional as F
import argparse
import torch_tensorrt
class Block(nn.Module):
def __init__(self, in_channel, out_channel):
super(Block, self).__init__()
self.conv = nn.Conv2d(
in_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False
)
self.norm = nn.BatchNorm2d(out_channel)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(
kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False
)
def forward(self, x):
out = self.conv(x)
out = self.norm(out)
out = self.relu(out)
out = self.maxpool(out)
return out
class Network(torch.nn.Module):
def __init__(self, num_classes=2):
super(Network, self).__init__()
self.num_classes = num_classes
self.block1 = Block(3, 32)
self.block2 = Block(32, 64)
self.upsample1 = nn.Upsample(
scale_factor=2, mode="bilinear", align_corners=False
)
self.upsample2 = nn.Upsample(
scale_factor=2, mode="bilinear", align_corners=False
)
self.conv = nn.Conv2d(64, num_classes, 1, bias=True)
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
gclayer1 = self.upsample1(out)
gclayer2 = self.upsample2(gclayer1)
out = self.conv(gclayer2)
return out
input = torch.randn([3, 3,224, 224]).cuda()
model = Network()
model = model.eval().cuda()
model = torch.jit.script(model)
torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Graph)
compile_settings = {
"inputs": [
torch_tensorrt.Input([2, 2], dtype=torch.int32),
],
"min_block_size": 1,
# "truncate_long_and_double": True,
# "enabled_precisions": {torch.int64},
# "torch_executed_ops": ['aten::conv2d']
}
trt_mod = torch_tensorrt.ts.compile(model, **compile_settings)
output = trt_mod(*input)
Any details that I might miss?
This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days
Bug Description
Scripting a simple "network" containing two
torch.nn.Upsample
modules and trying to convert the resulting torchscript does not work.To Reproduce
Steps to reproduce the behavior:
torch.jit.script
.Expected behavior
The conversions succeeds and a new valid torchscript is obtained.
Environment
I managed to reproduce the bug both when using pytorch 1.11 and torch-tensorRT 1.1.0 and using pytorch 1.13.1 and torch-tensorRT main.
Torch-TensorRT 1.1.0
conda
,pip
,libtorch
, source): pip for the python package used to generate the torchscript, source for the C++ dependency linked to Torch-TensorRTWhen using torch-tensorRT 1.1.0, I get the following error:
That looked kind of similar to this issue and patching Torch-TensorRT with this PR makes the behavior exactly the same as in the second case (i.e. when using pytorch 1.13.1 and torch-tensorRT main).
Torch-TensorRT main (commit 861edd03a510c600146575836b02c993ac386b00)
conda
,pip
,libtorch
, source): pip for the python package used to generate the torchscript, source for the C++ dependency linked to Torch-TensorRTWhen using torch-TensorRT main, the conversion just hangs for ever after
Additional context
Interestingly, when using the tracing mechanism of pytorch to generate the torchscript, everything seems fine (I didn't check the results, but the conversion finishes properly). Also, when scripting with pytorch 1.9, everything works fine 🤯
The thing I noticed is that pytorch changed slightly the
torch.nn.interpolate
API and I am wondering if that could explain (at least partially) the problem:torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None)
torch.nn.functional.interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None, antialias=False)
See the attached .zip file containing a python file to generate the torchscript. upsample.zip
Let me know if you need more details to reproduce the problem. Thanks!