Open Depersonalizc opened 3 years ago
I did do
ray_directions = torch.einsum("mn,whn->whm", tform_cam2world[:3, :3], directions)
As the proposed version with tensordot resulted in a UserWarning with PyTorch 1.13
UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matrices or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541702/work/aten/src/ATen/native/TensorShape.cpp:3277.)
This also gave another speedup:
start = time.time()
for i in range(1000):
torch.einsum("mn,whn->whm", tform_cam2world[:3, :3], directions)
stop = time.time()
print(f"time: {(stop-start)} [einsum]")
start = time.time()
for i in range(1000):
torch.sum(directions[..., None, :] * tform_cam2world[:3, :3], dim=-1)
stop = time.time()
print(f"time: {(stop-start)} [sum]")
start = time.time()
for i in range(1000):
torch.tensordot(tform_cam2world[:3, :3], directions.T, dims=([1], [0])).T
stop = time.time()
print(f"time: {(stop-start)} [tensordot]")
time: 0.06764698028564453 [einsum]
time: 0.29248738288879395 [sum]
time: 0.10972046852111816 [tensordot]
A tiny tweak of using torch.tensordot() to compute c2w ray transform, up to ~6x acceleration.