lix19937 / tensorrt-insight

Deep insight tensorrt, including but not limited to qat, ptq, plugin, triton_inference, cuda
12 stars 0 forks source link

how to inverse a permutation #55

Open lix19937 opened 1 week ago

lix19937 commented 1 week ago

import torch

# https://stackoverflow.com/questions/66832716/how-to-quickly-inverse-a-permutation-by-using-pytorch
# https://discuss.pytorch.org/t/how-to-quickly-inverse-a-permutation-by-using-pytorch/116205/2

def inverse_permutation(perm):
    inv = torch.empty_like(perm)
    inv[perm] = torch.arange(perm.size(0), device=perm.device)
    return inv

bev_feat_base = torch.randn(4, 2, 1, 3, 4)

bev_feat = bev_feat_base
bev_feat = bev_feat.squeeze(2)
print(bev_feat.shape)

bev_feat = bev_feat.permute(0, 2, 3, 1)
print(bev_feat.shape)

bev_feat_rp = inverse_permutation(torch.tensor([0, 2, 3, 1]))
print(bev_feat_rp)

bev_feat = bev_feat.permute(0, 3, 1, 2).contiguous()
bev_feat = bev_feat.unsqueeze(2)
print(bev_feat.shape)

print(torch.equal(bev_feat, bev_feat_base))