grimoire / torch2trt_dynamic

A pytorch to tensorrt convert with dynamic shape support
MIT License
254 stars 34 forks source link

No warnings, but [TensorRT] ERROR: INVALID_ARGUMENT: Cannot find binding of given name: input_0 #19

Closed iAlexKai closed 3 years ago

iAlexKai commented 3 years ago

I run into [TensorRT] ERROR: INVALID_ARGUMENT: Cannot find binding of given name: input_0 error a few times before, but usually there are some warnings indicate that some methods are not supported. But this time is pretty different.

The TRTModule is successfully built without any warning, but the engine cannot find tensor by engine.get_binding_index(input_name) torch2trt_dynamic.py line 450.

What's more, when I use pdb to debug line by line, I find sometimes the TRTModule could be built successfully, but sometimes errors like "'Tensor' object has no attribute '_trt'" occurs.

To be honest, this drives me crazy now, I'd rather see some warning and write the corresponding unsupported methods by myself...

iAlexKai commented 3 years ago

I believe this is still due to an unsuccessful conversion because when I simplify my model, all the errors go away... I just want the warnings back and find out what causes the error

grimoire commented 3 years ago

Hi, Would you mind provide more information such as models and scripts?

iAlexKai commented 3 years ago

Hi, Would you mind provide more information such as models and scripts?

Thank you so much for your quick reply. My model is an encoder-decoder model. The forward() is like this:

    def forward(self, title):
        self.eval()

        title_last_hidden, _ = self.seq_encoder(title)
        context_last_hidden, _ = self.seq_encoder(title)

        cond = torch.cat((title_last_hidden, context_last_hidden), 1)  # (batch, 2 * hidden_size * 2)

        z, _, _ = self.prior_net(cond)  # e: (batch, z_size)
        z = self.prior_generator(z)  # z: (batch, z_size)

        input_to_init_decoder_hidden = torch.cat((z, cond), 1)

        decoder_init = self.init_decoder_hidden(input_to_init_decoder_hidden)
        output = self.decoder(decoder_init, maxlen=self.maxlen, go_id=self.go_id)

        flattened_output = output.view(-1, self.vocab_size)
        return flattened_output

the self.seq_encoder and self.decoder are both normal GRU:

class Encoder(nn.Module):
    def __init__(self, embedder, input_size, hidden_size, bidirectional, n_layers, noise_radius=0.2):
        super(Encoder, self).__init__()

        self.hidden_size = hidden_size
        self.noise_radius = noise_radius
        self.n_layers = n_layers
        self.bidirectional = bidirectional
        self.embedding = embedder
        self.rnn = nn.GRU(input_size, hidden_size, n_layers, batch_first=True, bidirectional=bidirectional)

    def forward(self, inputs):

        # if self.embedding is not None:
        inputs = self.embedding(inputs)  # 过embedding

        batch_size, seq_len, emb_size = inputs.size()  # (batch, len, emb_size) len是12,即标题的最大长度

        init_hidden = to_tensor(torch.zeros(self.n_layers * (1 + self.bidirectional), batch_size, self.hidden_size))
        hids, h_n = self.rnn(inputs, init_hidden)

        h_n = h_n.view(self.n_layers, (1 + self.bidirectional), batch_size, self.hidden_size)
        try:
            h_n = h_n[-1]  
        except Exception:
            import time
            time.sleep(0.5)

        enc = h_n.transpose(1, 0).contiguous().view(batch_size, -1)

        return enc,  hids

The prior_net is variation of the CVAE module:

class Variation(nn.Module):
    def __init__(self, input_size, z_size, dropout_rate, init_weight):
        super(Variation, self).__init__()
        self.input_size = input_size
        self.z_size=z_size
        self.init_w = init_weight
        self.fc = nn.Sequential(
            nn.Linear(input_size, 1200),
            nn.BatchNorm1d(1200, eps=1e-05, momentum=0.1),
            nn.LeakyReLU(0.1),
            nn.Dropout(dropout_rate),
            nn.Linear(1200, z_size),
            nn.BatchNorm1d(z_size, eps=1e-05, momentum=0.1),
            nn.LeakyReLU(0.1),
            # nn.Dropout(dropout_rate),
        )
        self.context_to_mu = nn.Linear(z_size, z_size)  # activation???
        self.context_to_logsigma = nn.Linear(z_size, z_size)

        self.fc.apply(self.init_weights)
        self.init_weights(self.context_to_mu)
        self.init_weights(self.context_to_logsigma)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):        
            m.weight.data.uniform_(-self.init_w, self.init_w)
            m.bias.data.fill_(0)

    # def forward(self, context, epsilon):
    def forward(self, context):
        batch_size, _ = context.size()  # prior: (batch, 4 * hidden)
        context = self.fc(context)
        mu = self.context_to_mu(context)
        logsigma = self.context_to_logsigma(context) 
        std = torch.exp(0.5 * logsigma)

        # epsilon = to_tensor(torch.randn([batch_size, self.z_size]))

        epsilon = to_tensor(torch.ones([batch_size, self.z_size]))
        z = epsilon * std + mu
        return z, mu, logsigma 

Thanks so much!

grimoire commented 3 years ago

Ok, I will notice you once I found something. This might take some time.

iAlexKai commented 3 years ago

Thanks! I find that I have to write a GRU converter first. This is the first place that causes the error when I simplify my model. I found a reference here https://github.com/NVIDIA-AI-IOT/torch2trt/issues/144#issuecomment-553321172 But this is the lstm converter, I have to change it to gru first.

iAlexKai commented 3 years ago

I have fixed most of the bugs. The reason for this phenomenon is still the unsupported methods. I added the GRU converter and the code run well until the point where one additional input is needed:

    def forward(self, title, decoder_input):
        self.eval()
        title_last_hidden = self.seq_encoder(title)
        context_last_hidden = self.seq_encoder(title)
        cond = torch.cat((title_last_hidden, context_last_hidden), 1)  
        z, _, _ = self.prior_net(cond)  
        z = self.prior_generator(z)  
        input_to_init_decoder_hidden = torch.cat((z, cond), 1)
        decoder_init = self.init_decoder_hidden(input_to_init_decoder_hidden)

        output = self.decoder(init_hidden=decoder_init, maxlen=self.maxlen, decoder_input=decoder_input)  # Here, decoder_input is the second input tensor
        flattened_output = output.view(-1, self.vocab_size)

        return flattened_output

I input two tensors into the torch2trt_dynamic function like this:

model = torch2trt_dynamic(model, [title_tensor, decoder_input], max_workspace_size=1 << 28)

The [TensorRT] ERROR: INVALID_ARGUMENT: Cannot find binding of given name: input_1 (not input_0) error occurred again, which means something is wrong with the second tensor.

grimoire commented 3 years ago

That was so cool! I alway want to add better rnn support to this project. The unsupported layer output would be treat as constant value. That would disable the input which might involved. I guess that is what cause the error of 'cannot find input'. Seems I still have a lot to do ...

iAlexKai commented 3 years ago

Thanks for your reply! I've fixed all the bugs in my code. I'd like to summarize two points:

  1. The 'cannot find input' bug is always caused by some unsupported layer in your model. You have to find it and write the corresponding op if necessary.
  2. One problem is that if you use the from torch2trt_dynamic import torch2trt_dynamic in your python code, you will not see the Traceback, you will only see one line: segmentation fault (core dumped). This bothered me so much during the debug process. I will give my GRU converter here:
    
    import tensorrt as trt
    import torch
    from torch import nn
    from torch2trt_dynamic.torch2trt_dynamic import *

@tensorrt_converter('torch.nn.GRU.forward') def convert_gru(ctx):

module = ctx.method_args[0]
input_tensor = ctx.method_args[1]
if len(ctx.method_args) == 3:
    init_state_tensor = ctx.method_args[2]
output_0, output_1 = ctx.method_return[0], ctx.method_return[1]

layer_count = module.num_layers
hidden_size = module.hidden_size
batch_first = module.batch_first
max_seq_length = input_tensor.shape[1] if batch_first else input_tensor.shape[0]
op = trt.RNNOperation.GRU
layer = ctx.network.add_rnn_v2(input_tensor._trt, layer_count, hidden_size, max_seq_length, op)
if len(ctx.method_args) == 3:
    layer.hidden_state = init_state_tensor._trt

if module.bidirectional is True:
    layer.direction = trt.RNNDirection.BIDIRECTION
for i in range(layer_count):
    iw = getattr(module, "weight_ih_l%s" % i).detach().cpu().numpy()
    hw = getattr(module, "weight_hh_l%s" % i).detach().cpu().numpy()

    rela_index = 2 * i if module.bidirectional is True else i

    layer.set_weights_for_gate(rela_index, trt.RNNGateType.UPDATE, True, iw[:hidden_size, :].copy())
    layer.set_weights_for_gate(rela_index, trt.RNNGateType.RESET, True, iw[hidden_size: hidden_size * 2, :].copy())
    layer.set_weights_for_gate(rela_index, trt.RNNGateType.HIDDEN, True, iw[hidden_size * 2: hidden_size * 3, :].copy())

    layer.set_weights_for_gate(rela_index, trt.RNNGateType.UPDATE, False, hw[:hidden_size, :].copy())
    layer.set_weights_for_gate(rela_index, trt.RNNGateType.RESET, False, hw[hidden_size: hidden_size * 2, :].copy())
    layer.set_weights_for_gate(rela_index, trt.RNNGateType.HIDDEN, False, hw[hidden_size * 2: hidden_size * 3, :].copy())

    ib = getattr(module, "bias_ih_l%s" % i).detach().cpu().numpy()
    hb = getattr(module, "bias_hh_l%s" % i).detach().cpu().numpy()

    layer.set_bias_for_gate(rela_index, trt.RNNGateType.UPDATE, True, ib[:hidden_size].copy())
    layer.set_bias_for_gate(rela_index, trt.RNNGateType.RESET, True, ib[hidden_size:hidden_size * 2].copy())
    layer.set_bias_for_gate(rela_index, trt.RNNGateType.HIDDEN, True, ib[hidden_size * 2: hidden_size * 3].copy())

    layer.set_bias_for_gate(rela_index, trt.RNNGateType.UPDATE, False, hb[:hidden_size].copy())
    layer.set_bias_for_gate(rela_index, trt.RNNGateType.RESET, False, hb[hidden_size:hidden_size * 2].copy())
    layer.set_bias_for_gate(rela_index, trt.RNNGateType.HIDDEN, False, hb[hidden_size * 2: hidden_size * 3].copy())

    if module.bidirectional is True:
        # ================reverse=====================
        iw_r = getattr(module, "weight_ih_l%s_reverse" % i).detach().cpu().numpy()
        hw_r = getattr(module, "weight_hh_l%s_reverse" % i).detach().cpu().numpy()

        layer.set_weights_for_gate(2 * i + 1, trt.RNNGateType.UPDATE, True, iw_r[:hidden_size, :].copy())
        layer.set_weights_for_gate(2 * i + 1, trt.RNNGateType.RESET, True, iw_r[hidden_size:hidden_size * 2, :].copy())
        layer.set_weights_for_gate(2 * i + 1, trt.RNNGateType.HIDDEN, True, iw_r[hidden_size * 2: hidden_size * 3, :].copy())

        layer.set_weights_for_gate(2 * i + 1, trt.RNNGateType.UPDATE, False, hw_r[:hidden_size, :].copy())
        layer.set_weights_for_gate(2 * i + 1, trt.RNNGateType.RESET, False, hw_r[hidden_size:hidden_size * 2, :].copy())
        layer.set_weights_for_gate(2 * i + 1, trt.RNNGateType.HIDDEN, False, hw_r[hidden_size * 2: hidden_size * 3, :].copy())

        ib_r = getattr(module, "bias_ih_l%s_reverse" % i).detach().cpu().numpy()
        hb_r = getattr(module, "bias_hh_l%s_reverse" % i).detach().cpu().numpy()

        layer.set_bias_for_gate(2 * i + 1, trt.RNNGateType.UPDATE, True, ib_r[:hidden_size].copy())
        layer.set_bias_for_gate(2 * i + 1, trt.RNNGateType.RESET, True, ib_r[hidden_size:hidden_size * 2].copy())
        layer.set_bias_for_gate(2 * i + 1, trt.RNNGateType.HIDDEN, True, ib_r[hidden_size * 2: hidden_size * 3].copy())

        layer.set_bias_for_gate(2 * i + 1, trt.RNNGateType.UPDATE, False, hb_r[:hidden_size].copy())
        layer.set_bias_for_gate(2 * i + 1, trt.RNNGateType.RESET, False, hb_r[hidden_size:hidden_size * 2].copy())
        layer.set_bias_for_gate(2 * i + 1, trt.RNNGateType.HIDDEN, False, hb_r[hidden_size * 2: hidden_size * 3].copy())

gru_output_0 = layer.get_output(0)
gru_output_1 = layer.get_output(1)
output_0._trt = gru_output_0
output_1._trt = gru_output_1

def main(): class TestNet(torch.nn.Module): def init(self): super(TestNet, self).init() self.gru = nn.GRU(10, 20, 1, batch_first=True, bidirectional=False)

    def forward(self, x, init):
        out, out1 = self.gru(x, init)
        return out, out1

net = TestNet()
net = net.cuda()

x = torch.randn(1, 5, 10).cuda()
init = torch.randn(1, 1, 20).cuda()
output_0, hn_0 = net(x, init)

trt_net = torch2trt_dynamic(net, [x, init],  max_workspace_size=1 << 28)

output_1, hn_1 = trt_net(x, init)

print(torch.max(torch.abs(output_0 - output_1)))
print(torch.max(torch.abs(hn_0 - hn_1)))

if name == "main": main()


Jetson Nano, Pytorch 1.8.0, TensorRT 7.1.3.0
Thanks to https://github.com/NVIDIA-AI-IOT/torch2trt/issues/144#issuecomment-553321172
grimoire commented 3 years ago

Amazing! Would you mind if I add the convertor to this repo? Or if you'd like to PR it by you self?

iAlexKai commented 3 years ago

Haha, my pleasure. Just take it and add it~

iAlexKai commented 3 years ago

By the way, is there any way I can convert torch.Tensor.__bool__, without which I cannot use if xxxx == xxx: in my code. I've tried for a long time, but no results...

grimoire commented 3 years ago

Actually, Nope, like most tracing based deploy tools, this project does not support python control flow statements such as if or while etc. That means even you can convert torch.Tensor.__bool__, the if statement still won't work. Try:

output = input * alpha + other * ( 1-alpha)

to avoid control flow.

iAlexKai commented 3 years ago

Alright! I'll try, thanks!

iAlexKai commented 3 years ago

Actually, Nope, like most tracing based deploy tools, this project does not support python control flow statements such as if or while etc. That means even you can convert torch.Tensor.__bool__, the if statement still won't work. Try:

output = input * alpha + other * ( 1-alpha)

to avoid control flow.

I take your advice and fixed it all, thanks again! I'll close this issue.

yanghgai commented 3 years ago

Thanks for your reply! I've fixed all the bugs in my code. I'd like to summarize two points:

  1. The 'cannot find input' bug is always caused by some unsupported layer in your model. You have to find it and write the corresponding op if necessary.
  2. One problem is that if you use the from torch2trt_dynamic import torch2trt_dynamic in your python code, you will not see the Traceback, you will only see one line: segmentation fault (core dumped). This bothered me so much during the debug process. I will give my GRU converter here:
import tensorrt as trt
import torch
from torch import nn
from torch2trt_dynamic.torch2trt_dynamic import *

@tensorrt_converter('torch.nn.GRU.forward')
def convert_gru(ctx):

    module = ctx.method_args[0]
    input_tensor = ctx.method_args[1]
    if len(ctx.method_args) == 3:
        init_state_tensor = ctx.method_args[2]
    output_0, output_1 = ctx.method_return[0], ctx.method_return[1]

    layer_count = module.num_layers
    hidden_size = module.hidden_size
    batch_first = module.batch_first
    max_seq_length = input_tensor.shape[1] if batch_first else input_tensor.shape[0]
    op = trt.RNNOperation.GRU
    layer = ctx.network.add_rnn_v2(input_tensor._trt, layer_count, hidden_size, max_seq_length, op)
    if len(ctx.method_args) == 3:
        layer.hidden_state = init_state_tensor._trt

    if module.bidirectional is True:
        layer.direction = trt.RNNDirection.BIDIRECTION
    for i in range(layer_count):
        iw = getattr(module, "weight_ih_l%s" % i).detach().cpu().numpy()
        hw = getattr(module, "weight_hh_l%s" % i).detach().cpu().numpy()

        rela_index = 2 * i if module.bidirectional is True else i

        layer.set_weights_for_gate(rela_index, trt.RNNGateType.UPDATE, True, iw[:hidden_size, :].copy())
        layer.set_weights_for_gate(rela_index, trt.RNNGateType.RESET, True, iw[hidden_size: hidden_size * 2, :].copy())
        layer.set_weights_for_gate(rela_index, trt.RNNGateType.HIDDEN, True, iw[hidden_size * 2: hidden_size * 3, :].copy())

        layer.set_weights_for_gate(rela_index, trt.RNNGateType.UPDATE, False, hw[:hidden_size, :].copy())
        layer.set_weights_for_gate(rela_index, trt.RNNGateType.RESET, False, hw[hidden_size: hidden_size * 2, :].copy())
        layer.set_weights_for_gate(rela_index, trt.RNNGateType.HIDDEN, False, hw[hidden_size * 2: hidden_size * 3, :].copy())

        ib = getattr(module, "bias_ih_l%s" % i).detach().cpu().numpy()
        hb = getattr(module, "bias_hh_l%s" % i).detach().cpu().numpy()

        layer.set_bias_for_gate(rela_index, trt.RNNGateType.UPDATE, True, ib[:hidden_size].copy())
        layer.set_bias_for_gate(rela_index, trt.RNNGateType.RESET, True, ib[hidden_size:hidden_size * 2].copy())
        layer.set_bias_for_gate(rela_index, trt.RNNGateType.HIDDEN, True, ib[hidden_size * 2: hidden_size * 3].copy())

        layer.set_bias_for_gate(rela_index, trt.RNNGateType.UPDATE, False, hb[:hidden_size].copy())
        layer.set_bias_for_gate(rela_index, trt.RNNGateType.RESET, False, hb[hidden_size:hidden_size * 2].copy())
        layer.set_bias_for_gate(rela_index, trt.RNNGateType.HIDDEN, False, hb[hidden_size * 2: hidden_size * 3].copy())

        if module.bidirectional is True:
            # ================reverse=====================
            iw_r = getattr(module, "weight_ih_l%s_reverse" % i).detach().cpu().numpy()
            hw_r = getattr(module, "weight_hh_l%s_reverse" % i).detach().cpu().numpy()

            layer.set_weights_for_gate(2 * i + 1, trt.RNNGateType.UPDATE, True, iw_r[:hidden_size, :].copy())
            layer.set_weights_for_gate(2 * i + 1, trt.RNNGateType.RESET, True, iw_r[hidden_size:hidden_size * 2, :].copy())
            layer.set_weights_for_gate(2 * i + 1, trt.RNNGateType.HIDDEN, True, iw_r[hidden_size * 2: hidden_size * 3, :].copy())

            layer.set_weights_for_gate(2 * i + 1, trt.RNNGateType.UPDATE, False, hw_r[:hidden_size, :].copy())
            layer.set_weights_for_gate(2 * i + 1, trt.RNNGateType.RESET, False, hw_r[hidden_size:hidden_size * 2, :].copy())
            layer.set_weights_for_gate(2 * i + 1, trt.RNNGateType.HIDDEN, False, hw_r[hidden_size * 2: hidden_size * 3, :].copy())

            ib_r = getattr(module, "bias_ih_l%s_reverse" % i).detach().cpu().numpy()
            hb_r = getattr(module, "bias_hh_l%s_reverse" % i).detach().cpu().numpy()

            layer.set_bias_for_gate(2 * i + 1, trt.RNNGateType.UPDATE, True, ib_r[:hidden_size].copy())
            layer.set_bias_for_gate(2 * i + 1, trt.RNNGateType.RESET, True, ib_r[hidden_size:hidden_size * 2].copy())
            layer.set_bias_for_gate(2 * i + 1, trt.RNNGateType.HIDDEN, True, ib_r[hidden_size * 2: hidden_size * 3].copy())

            layer.set_bias_for_gate(2 * i + 1, trt.RNNGateType.UPDATE, False, hb_r[:hidden_size].copy())
            layer.set_bias_for_gate(2 * i + 1, trt.RNNGateType.RESET, False, hb_r[hidden_size:hidden_size * 2].copy())
            layer.set_bias_for_gate(2 * i + 1, trt.RNNGateType.HIDDEN, False, hb_r[hidden_size * 2: hidden_size * 3].copy())

    gru_output_0 = layer.get_output(0)
    gru_output_1 = layer.get_output(1)
    output_0._trt = gru_output_0
    output_1._trt = gru_output_1

def main():
    class TestNet(torch.nn.Module):
        def __init__(self):
            super(TestNet, self).__init__()
            self.gru = nn.GRU(10, 20, 1, batch_first=True, bidirectional=False)

        def forward(self, x, init):
            out, out1 = self.gru(x, init)
            return out, out1

    net = TestNet()
    net = net.cuda()

    x = torch.randn(1, 5, 10).cuda()
    init = torch.randn(1, 1, 20).cuda()
    output_0, hn_0 = net(x, init)

    trt_net = torch2trt_dynamic(net, [x, init],  max_workspace_size=1 << 28)

    output_1, hn_1 = trt_net(x, init)

    print(torch.max(torch.abs(output_0 - output_1)))
    print(torch.max(torch.abs(hn_0 - hn_1)))

if __name__ == "__main__":
    main()

Jetson Nano, Pytorch 1.8.0, TensorRT 7.1.3.0 Thanks to NVIDIA-AI-IOT/torch2trt#144 (comment)

Hi, How can i see the Traceback during degug? i convert swin-transformer to trt, but warning:Encountered known unsupported method torch.Tensor.__matmul__ Segmentation fault (core dumped) occurs

iAlexKai commented 3 years ago

It's impossible to see the Traceback for this project, as far as I know. You can try to print some logs to locate the bug.