Open bobby-chiu opened 1 year ago
Hi, any idea to fix it? @vinthony Until now, I forced to bypass the grid sample processing for generator model. Although this will lead to no morphable face generated from audio, the end-to-end inference can be finished successfully. And I got 4~5 times speedup (with batch size=1, precision=fp16) on TeslaT4 GPU for face render model.
Below is test code snippet for model conversion and inference by torch->onnx->engine, which you can use it and test in your sadtalker pipeline: test_torch_onnx_tensorrt.zip
Hi, I am so glad that you are trying to make it happen.
However, it is not the good first issue we are going to fix. We are currently working on debugging and better user-friendly settings. The speed or the deployment things is not the first concern.
we are so glad if you can make it happen and share it with the community.
Hi, @bobby-chiu Really thank you for your work. I'm also trying to convert sadtalker model as tensorRT. Unfortunately, I faced a lot of error during converting. so i could not finish it. Can you share your modified sadtalker code? maybe you have changed lots of things.
@BbChip0103 , do you use torch-tensorrt or tensorrt from nvidia. torch-tensorrt may be not good enough to compatible with python-like torch code, I got lots of erros with toch-tensorrt, and finally the compilation was stuck with unknown reason. So I gave up on it. With tensorrt you can use the testcode as I released before. If any error happened, it may be easy to fix to follow the error message. But util now, 5d grid_sample is also failed to convert.
BTW,
the onnx/tensorrt seems not support with the model forward with List input/output, you need to change generator.forward() input/output to Tensor
nn.InstanceNorm2d seems to be failed also for tensorrt, you can use below alternative:
class InstanceNormAlternative(nn.InstanceNorm2d):
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
desc = 1 / (input.var(axis=[2, 3], keepdim=True, unbiased=False) + self.eps) ** 0.5
retval = (input - input.mean(axis=[2, 3], keepdim=True)) * desc
return retval
@BbChip0103 , do you use torch-tensorrt or tensorrt from nvidia. torch-tensorrt may be not good enough to compatible with python-like torch code, I got lots of erros with toch-tensorrt, and finally the compilation was stuck with unknown reason. So I gave up on it. With tensorrt you can use the testcode as I released before. If any error happened, it may be easy to fix to follow the error message. But util now, 5d grid_sample is also failed to convert.
BTW,
- the onnx/tensorrt seems not support with the model forward with List input/output, you need to change generator.forward() input/output to Tensor
- nn.InstanceNorm2d seems to be failed also for tensorrt, you can use below alternative:
class InstanceNormAlternative(nn.InstanceNorm2d):
def forward(self, input: Tensor) -> Tensor: self._check_input_dim(input) desc = 1 / (input.var(axis=[2, 3], keepdim=True, unbiased=False) + self.eps) ** 0.5 retval = (input - input.mean(axis=[2, 3], keepdim=True)) * desc return retval
It seams not only 5D grid_sampler but other 2 ops not supported by pytorch, got this msgs when trying to use torch.complie: Using FallbackKernel: torch.ops.aten.grid_sampler_3d.default Using FallbackKernel: aten.avg_pool3d Using FallbackKernel: aten.upsample_trilinear3d
@BbChip0103 , do you use torch-tensorrt or tensorrt from nvidia. torch-tensorrt may be not good enough to compatible with python-like torch code, I got lots of erros with toch-tensorrt, and finally the compilation was stuck with unknown reason. So I gave up on it. With tensorrt you can use the testcode as I released before. If any error happened, it may be easy to fix to follow the error message. But util now, 5d grid_sample is also failed to convert.
BTW,
- the onnx/tensorrt seems not support with the model forward with List input/output, you need to change generator.forward() input/output to Tensor
- nn.InstanceNorm2d seems to be failed also for tensorrt, you can use below alternative:
class InstanceNormAlternative(nn.InstanceNorm2d):
def forward(self, input: Tensor) -> Tensor: self._check_input_dim(input) desc = 1 / (input.var(axis=[2, 3], keepdim=True, unbiased=False) + self.eps) ** 0.5 retval = (input - input.mean(axis=[2, 3], keepdim=True)) * desc return retval
@bobby-chiu
First of all, thank you for your proposed solution, but I still have mistakes in the practice process: torch.onnx.errors.OnnxExporterError: Unsupported: ONNX export of operator GridSample with 5D volumetric input
Can you provide more code implementation details? Or open a new branch.
My code is as follows: src/facerender/modules/util.py
class InstanceNormAlternative(nn.InstanceNorm2d):
def __init__(self, num_features, affine):
super(InstanceNormAlternative, self).__init__(num_features=num_features, affine=affine)
def forward(self, input: Tensor) -> Tensor:
self._check_input_dim(input)
desc = 1 / (input.var(axis=[2, 3], keepdim=True, unbiased=False) + self.eps) ** 0.5
retval = (input - input.mean(axis=[2, 3], keepdim=True)) * desc
return retval
class SPADE(nn.Module):
def __init__(self, norm_nc, label_nc):
super().__init__()
#self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
self.param_free_norm = InstanceNormAlternative(num_features=norm_nc, affine=False)
nhidden = 128
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
nn.ReLU())
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
Hello everyone, I focus the same problems on Sadtalker optimisation. If you solve these problems, could u share solutions.
Hello, Any updates?
nn.spectral_norm is not supported for FP16. Can you give me some suggestions?
Can you use the code for converting SadTalker_V0.0.2_256.safetensors into trt fp16 format?
Using pull request of @liqunfu, I have built pytorch from source. using this torch i was able to successfully convert my model to onnx. GridSample with 5D convolutions works with opset 20.
try with onnxopset 20 using the torch(no gpu) from this docker image. docker.io/saikiran321/pytorch_opset_20_
Here is an article about using Tensorrt to accelerate SADTALKER: https://zhuanlan.zhihu.com/p/675551997
I am trying to do inference with tensorrt for possible speedup, but got error when exporting the pretrained model of generator to onnx (with torch2.0 raise error when exporting to onnx, with torch1.12 raise error when loading by tensorrt).
torch.onnx.errors.OnnxExporterError: Unsupported: ONNX export of operator GridSample with 5D volumetric input. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues
It seems 5D grid_sample is still unsupported by onnx (https://github.com/pytorch/pytorch/issues/92209). So is it possible to replace the grid_sample of torch implementations with other alternative?