f-dangel / unfoldNd

(N=1,2,3)-dimensional unfold (im2col) and fold (col2im) in PyTorch
MIT License
82 stars 6 forks source link

ONNX to TensorRT fails because of overloaded scatter_add_ #35

Open avdhoeke opened 7 months ago

avdhoeke commented 7 months ago

Context

Under torch 2.1.1+cu121, torch.onnx.export supports opset versions up to 17 (included). This means that the col2im operation is not supported since it depends on opset version 18. This operation is necessary when converting torch.nn.functional.fold, which is why I'm using this repo's FoldNd implementation.

Problem

The issue is that the scatter_add_ operation is overloaded:

https://github.com/f-dangel/unfoldNd/blob/cbe9e07ca6a5d008e226852177140f8dc6a544de/unfoldNd/fold.py#L93

According to PyTorch's documentation, scatter_add_ should be such that:

self, index and src should have same number of dimensions. It is also required that index.size(d) <= src.size(d) for all dimensions d, and that index.size(d) <= self.size(d) for all dimensions d != dim.

Using the simple example provided in your README and including some padding

# random output of an im2col operation
inputs = torch.randn(64, 3 * 2 * 2, 5 * 9)
output_size = (4, 8)

# other module hyperparameters
kernel_size = 2
dilation = 1
padding = 1
stride = 1

fold = FoldNd(
    output_size, kernel_size, dilation=dilation, padding=padding, stride=stride
)

, you'll see that the shapes are the following:

This does not follow the requirements of scatter_add_ and hence trtexec raises the following error during conversion:

trtexec --onnx=fold.onnx --saveEngine=fold.trt
Assertion failed: indicesDims.d[i] <= dataDims.d[i] && "Indices dimensions must be less than data dimensions!

Solution

Instead of scatter_add_ing along the flattened kernel_size_numel * output_size_numel, I suggest to reshape index and src to (batch_size, n_channels, kernel_size_numel, -1) and applyscatteradd` in a loop by slicing trough the 3rd dimension first:

# Replicate indices over batch and channels, then scatter the patch values
# back to the padded image
input = input.reshape(batch_size, in_channels, kernel_size_numel, -1)
idx = idx.reshape(1, 1, -1).long().expand(batch_size, in_channels, -1)
idx = idx.reshape(batch_size, in_channels, kernel_size_numel, -1)

output = torch.zeros(
    batch_size,
    in_channels,
    padded_output_size_numel,
    device=device,
    dtype=input.dtype,
)

for k in range(kernel_size_numel):
    output.scatter_add_(2, idx[:, :, k, :], input[:, :, k, :])

output = output.reshape(batch_size, in_channels, *padded_output_size)

I checked that this solution produces correct results for various configurations. Additionally, trtexec successfully converts FoldNd to TensorRT.

Environment

f-dangel commented 7 months ago

Interesting. I'd be happy to merge this if we can add an onnx test. FYI, I believe one can get rid of some idx.reshapes. If you decide to set up a PR, I can provide feedback.

avdhoeke commented 7 months ago

Interesting. I'd be happy to merge this if we can add an onnx test.

Well torch.onnx.export will run, even using the current scatter_add_ operation. What will fail however is the trtexec operation, which would imply installing TensorRT on whatever machine you decide to run this test. I could help with this if you think there is value.

If you decide to set up a PR, I can provide feedback.

Sure I can submit a PR. Am I allowed to push to a new remote branch?

git push origin HEAD:onnx_to_trt
ERROR: Permission to f-dangel/unfoldNd.git denied to avdhoeke.
fatal: Could not read from remote repository.

Please make sure you have the correct access rights
and the repository exists.
f-dangel commented 7 months ago

I believe you'll have to fork the repo and then set up a PR from your fork to this repo.

Regarding a test: It should be possible to install TensorRT in the Github action which runs the tests for this lib. I'm not familiar with this tool. Can you execute trtexec on CPU-only machines (the Github runners don't have GPUs)?

If not, it would be great do explain in a comment in the code why we're chunking the scatter_add (basically explaining that there are constraints for idx that are documented in the PyTorch docs, but not enforced).

avdhoeke commented 7 months ago

Pretty busy atm. I'll follow up on this asap once I figured out a way to comprehensively adapt the code for conversion purposes.

f-dangel commented 7 months ago

Sounds good.

I am also thinking about a fold implementation that exclusively relies on einsum (different project). It would fix the TensorRT problem, but requires benchmarking to make sure there is no disadvantageous performance difference.

Liupei1101 commented 5 months ago

i onnx convert to tensorrt success,but the rusult is wrong,i find in output.scatteradd(2, idx[:, :, k, :], input[:, :, k, :]) scatteradd not work in tensorRT

Liupei1101 commented 5 months ago

Context

Under torch 2.1.1+cu121, torch.onnx.export supports opset versions up to 17 (included). This means that the col2im operation is not supported since it depends on opset version 18. This operation is necessary when converting torch.nn.functional.fold, which is why I'm using this repo's FoldNd implementation.

Problem

The issue is that the scatter_add_ operation is overloaded:

https://github.com/f-dangel/unfoldNd/blob/cbe9e07ca6a5d008e226852177140f8dc6a544de/unfoldNd/fold.py#L93

According to PyTorch's documentation, scatter_add_ should be such that:

self, index and src should have same number of dimensions. It is also required that index.size(d) <= src.size(d) for all dimensions d, and that index.size(d) <= self.size(d) for all dimensions d != dim.

Using the simple example provided in your README and including some padding

# random output of an im2col operation
inputs = torch.randn(64, 3 * 2 * 2, 5 * 9)
output_size = (4, 8)

# other module hyperparameters
kernel_size = 2
dilation = 1
padding = 1
stride = 1

fold = FoldNd(
    output_size, kernel_size, dilation=dilation, padding=padding, stride=stride
)

, you'll see that the shapes are the following:

  • output: torch.Size([64, 3, 60])
  • index: torch.Size([64, 3, 180])
  • src: torch.Size([64, 3, 180])

This does not follow the requirements of scatter_add_ and hence trtexec raises the following error during conversion:

trtexec --onnx=fold.onnx --saveEngine=fold.trt
Assertion failed: indicesDims.d[i] <= dataDims.d[i] && "Indices dimensions must be less than data dimensions!

Solution

Instead of scatter_add_ing along the flattened kernel_size_numel * output_size_numel, I suggest to reshape index and src to (batch_size, n_channels, kernel_size_numel, -1) and applyscatteradd` in a loop by slicing trough the 3rd dimension first:

# Replicate indices over batch and channels, then scatter the patch values
# back to the padded image
input = input.reshape(batch_size, in_channels, kernel_size_numel, -1)
idx = idx.reshape(1, 1, -1).long().expand(batch_size, in_channels, -1)
idx = idx.reshape(batch_size, in_channels, kernel_size_numel, -1)

output = torch.zeros(
    batch_size,
    in_channels,
    padded_output_size_numel,
    device=device,
    dtype=input.dtype,
)

for k in range(kernel_size_numel):
    output.scatter_add_(2, idx[:, :, k, :], input[:, :, k, :])

output = output.reshape(batch_size, in_channels, *padded_output_size)

I checked that this solution produces correct results for various configurations. Additionally, trtexec successfully converts FoldNd to TensorRT.

Environment

  • PyTorch 2.1.1+cu121
  • tensorrt 8.6.1.6-1+cuda12.0

ionnx convert to tensorrt success,but the rusult is wrong , i find in output.scatteradd(2, idx[:, :, k, :], input[:, :, k, :]) "scatteradd" not work in tensorRT , no error ,do you know why?

f-dangel commented 5 months ago

Hi, I did not fully through your last message. Did you apply the suggested fix but it did not lead to the correct result? Could you clarify?

Liupei1101 commented 5 months ago

scatteradd

I have investigated that tensorRT does not support scatteradd, so after converting the model, scatteradd is executed as the scatter assignment operation scatteradd does not work.