Open sean-xiang-applovin opened 1 month ago
Since only part of the ops support dynamic shapes, and some are not. What's the criteria to decide if an op supports dynamic shape or not?
supports_dynamic_shape
specifies if a converter's implementation handles dynamic shape properlyFor some existing ops, which are not marked as supports_dynamic_shapes=True, can I write a converter that wraps the existing converter, and mark my own converter with high priority? Is this the recommended way?
supports_dynamic_shapes
then either its not going to work because we know the converter doesn't support dynamic shape or supports_dynamic_shapes
for the converter is stale and a PR flipping it would be welcome. You can totally write your own converter that supports dynamic shape, mark it as high priority and Torch-TensorRT will use it or should I just turn on assume_dynamic_shape_support, which seems to be a flag globally for all converters ?
- You can do this, this setting existed mostly for when we had not yet verified that all converters support dynamic shape but we were pretty sure most did. At this point we expect it to be near equivalent to having it off
If you are finding Core ATen ops which we convert that don't support dynamic shape, please file issues, my impression is that we should cover nearly all of them at this point. cc @apbose @chohk88
@narendasan thank you for your explanation, and I think your suggestion totally makes sense to me.
BTW, originally I make this question because I am seeing torch.ops.aten._embedding_bag_forward_only.default
in my graph, after decomposition. And I see we support torch.ops.aten.embedding_bag.default
, for static shape.
I haven't tried, but I plan to convert torch.ops.aten._embedding_bag_forward_only.default
with the existing converter. That's why I am asking this question.
is this torch.ops.aten._embedding_bag_forward_only.default
some ops that's missing to be covered, or it is meant to be not covered?
Seems like embedding bag forward only is a new op in core aten. @zewenli98 any thoughts about supporting embedding bag forward only?
To my knowledge, torch.ops.aten._embedding_bag_forward_only.default
is the same as torch.ops.aten._embedding_bag.default
. @sean-xiang-applovin I think the most convenient way for you to try it out is to add the decorator below to the converter.
@dynamo_tensorrt_converter(
torch.ops.aten._embedding_bag_forward_only.default,
capability_validator=embedding_bag_validator,
supports_dynamic_shapes=True,
)
Thanks for your suggestion @zewenli98 , I remember I have tried with this, and the code failed on some strange shape assertion in impl.embedding.embedding_bag
. Sorry I cannot paste the error here, since it is some days ago, but I remember I check the debugger, the two shapes are the same, but seem to be of different type, so assertion failed.
I have to comment out the assertion, and let it compile. However, the compilation failed due to some other reasons. So I am switching to the traditional onnx way now.
@sean-xiang-applovin Thanks for letting us know. It looks like the assertion you pointed out only checks their shapes, not types.
If you have runnable code at hand, could you try passing in None to per_sample_weights
? It would bypass the if branch and see if it is the root cause.
Besides, I'm wondering if you passed in 1-dim indices
? If yes, can you provide more details about the compilation failed due to some other reasons
? I'm willing to debug for you if you can extract a minimal model if possible.
thanks @zewenli98 , I will try to give a mini repo as soon as possible
Hi @zewenli98 , it took me some time to set up everything and reproduce the error. And since pytorch 2.5 is released recently, and also there is a version bump of torch-tensorrt. I basically re-setup everything from my end. And this time, there is some new error/issues comes up.
I have created 3 notebooks to explain what I did and what I found, and what's the issue/bug, in this zip embedding_issue.zip
In embedding_bag_forward_only, I described that, only compiling on loaded exported program, will decomposition generate torch.ops.aten._embedding_bag_forward_only.default
. It is not a bug, but some background I want to bring up.
In embedding_bag_compile_slow, I described what I find, that, compiling a simple embedding bag layer takes long time to finish. There is a lot of net layers generated. which looks strange to me. This compilation time bothers me a lot, since my model has a lot of embedding layers.
In embedding_bag_compile_result_mismatch, I describe a real bug or issue. When I compile with an embedding bag layer, the compiled result is very different from the original results. In this notebook, I compiled based on a loaded exported program model.ep
, which I have also included in the zip. with this loaded exported program, torch.ops.aten._embedding_bag_forward_only.default
will be generated, and you have to do something special to make the notebook work.
Please let me know if you need more information. and I really appreciate your help, thank you.
another embedding bag bug, issue @zewenli98 can you please help take a look, thank you
@sean-xiang-applovin Thanks for the details. I'll take a look and get back to you soon.
Hi @sean-xiang-applovin ,
The first issue is that torch.ops.aten.embedding_bag.padding_idx
is mapped to torch.ops.aten._embedding_bag.default
or torch.ops.aten._embedding_bag_forward_only.default
, as you observed in embedding_bag_forward_only. I can repro the results on the old versions and the main branch. Can you try git checkout v2.5.0
? It works for compiling both direct_ep
and loaded_ep
on my side, i.e., still keeping torch.ops.aten.embedding_bag.padding_idx
:
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
%p_fn_weight : [num_users=1] = placeholder[target=p_fn_weight]
%input : [num_users=1] = placeholder[target=input]
%arange : [num_users=1] = call_function[target=torch.ops.aten.arange.start_step](args = (0, 30, 30), kwargs = {dtype: torch.int64, device: cpu, pin_memory: False})
%view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%input, [-1]), kwargs = {})
%embedding_bag : [num_users=1] = call_function[target=torch.ops.aten.embedding_bag.padding_idx](args = (%p_fn_weight, %view, %arange, False, 0, False, None, False, 0), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%embedding_bag, 0), kwargs = {})
return (getitem,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
%weight : [num_users=1] = get_attr[target=weight]
%input_1 : [num_users=1] = placeholder[target=input]
%arange : [num_users=1] = call_function[target=torch.ops.aten.arange.start_step](args = (0, 30, 30), kwargs = {dtype: torch.int64, device: cpu, pin_memory: False})
%view : [num_users=1] = call_function[target=torch.ops.aten.view.default](args = (%input_1, [-1]), kwargs = {})
%embedding_bag : [num_users=1] = call_function[target=torch.ops.aten.embedding_bag.padding_idx](args = (%weight, %view, %arange, False, 0, False, None, False, 0), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%embedding_bag, 0), kwargs = {})
return (getitem,)
I checked out v2.5.0 and added the following decorator in front of the converter, and then ran your embedding_bag_compile_slow, it took 41s on RTX 4080. Is this time reasonable to you?
@dynamo_tensorrt_converter(
torch.ops.aten.embedding_bag.padding_idx,
capability_validator=embedding_bag_validator,
supports_dynamic_shapes=True,
)
output:
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.080986
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:41.396221
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 11947508 bytes of Memory
DEBUG: [Torch-TensorRT - Debug Build] - Deserializing Device Info: 0%8%9%0%NVIDIA GeForce RTX 4080
Besides, you can also try pass in `torch_executed_ops={"torch.ops.aten.embedding_bag.padding_idx"}` to `dynamo_compile()` to compare time cost.
3. For **embedding_bag_compile_result_mismatch**, after changes above, we can keep `torch.ops.aten.embedding_bag.padding_idx` but the results are still different. I was wondering whether you moved the model to cuda before export and save?
I tried replacing the ep you saved (loaded_ep = torch.export.load("model.ep")) with a new embedding_bag model like this:
model = torch.nn.EmbeddingBag(30000, 32, mode='sum', padding_idx=0).to("cuda") loaded_ep = torch.export.export(model, args=model_input)
The outputs are the same. So, I'm thinking maybe the issue is not related to dynamo_compile but your ep?
@zewenli98 thanks for debugging.
you mentioned v2.5.0
, I am a little confused. I am install torch-tensorrt by pip, and I am on 2.5.0 already, which is shown in the first cell of the notebook. In the 2.5.0
version that I pip installed, I could only see torch.ops.aten.embedding_bag.padding_idx
gets decomposed to torch.ops.aten._embedding_bag.default
or torch.ops.aten._embedding_bag_forward_only.default
I'd say 41 seconds are still slow to me. I just feel weird that compiling a single embedding bag module would take that long time, and also the compiling time seems to be proportional with the batch size
I didn't move the original model to cuda before exporting. But even if I move, the compiled result is still quite different from the original result. FYI, I could get same results as your example shows, but couldn't get the same results with my model.I am not sure if there is something wrong with the weights of the embedding bag. Is there anyway we can collaborate more efficiently on this result difference issue?
- you mentioned v2.5.0, I am a little confused. I am install torch-tensorrt by pip, and I am on 2.5.0 already, which is shown in the first cell of the notebook. In the 2.5.0 version that I pip installed, I could only see torch.ops.aten.embedding_bag.padding_idx gets decomposed to torch.ops.aten._embedding_bag.default or torch.ops.aten._embedding_bag_forward_only.default
@sean-xiang-applovin Sorry for the unclearness. I found the reason is that I was using nightly pytorch 2.6.0, in which torch.ops.aten.embedding_bag.padding_idx
is not decomposed to others.
For embedding bag converter, I just noticed that you're using 2d input which is not supported yet for some reasons. Besides, compilation time for embedding bag is kind of slow may be because we are using TensorRT's ILoopLayer which causes additional overhead for data-dependent issue.
In order to totally solve the issue, we have to do much additional work on it. So if you are really relying on the 2d-input embedding bag, I would recommend seeking other paths or forcing this op to fallback to pytorch by something like torch_executed_ops={"torch.ops.aten._embedding_bag.default"}
Hi @zewenli98
thanks for the response, I can try with nightly pytorch.
For embedding bag converter, I just noticed that you're using 2d input which is not supported yet for some reasons.
You mentioned 2d input is not supported. I am kinda confused. My input shape is (x, y), where x is the batch size. I would be very surprised that embedding bad module cannot support this.
@sean-xiang-applovin Can you try something like torch.ops.aten._embedding_bag.default
instead of torch.nn.EmbeddingBag()
? because internally pytorch transforms those ops into aten ops. Here're some unit tests for your reference.
❓ Question
Since only part of the ops support dynamic shapes, and some are not. What's the criteria to decide if an op supports dynamic shape or not?
For some existing ops, which are not marked as
supports_dynamic_shapes=True
, can I write a converter that wraps the existing converter, and mark my own converter with high priority? Is this the recommended way?or should I just turn on
assume_dynamic_shape_support
, which seems to be a flag globally for all converters ?What you have already tried
Environment
conda
,pip
,libtorch
, source): pipAdditional context