krrish94 / nerf-pytorch

A PyTorch re-implementation of Neural Radiance Fields
Other
887 stars 122 forks source link

Accelerate get_ray_bundle() #29

Open Depersonalizc opened 3 years ago

Depersonalizc commented 3 years ago

A tiny tweak of using torch.tensordot() to compute c2w ray transform, up to ~6x acceleration.

BjoernHaefner commented 1 year ago

I did do

ray_directions = torch.einsum("mn,whn->whm", tform_cam2world[:3, :3], directions)

As the proposed version with tensordot resulted in a UserWarning with PyTorch 1.13

UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matrices or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at /opt/conda/conda-bld/pytorch_1670525541702/work/aten/src/ATen/native/TensorShape.cpp:3277.)

This also gave another speedup:

start = time.time()
for i in range(1000):
    torch.einsum("mn,whn->whm", tform_cam2world[:3, :3], directions)
stop = time.time()
print(f"time: {(stop-start)} [einsum]")

start = time.time()
for i in range(1000):
    torch.sum(directions[..., None, :] * tform_cam2world[:3, :3], dim=-1)
stop = time.time()
print(f"time: {(stop-start)} [sum]")

start = time.time()
for i in range(1000):
    torch.tensordot(tform_cam2world[:3, :3], directions.T, dims=([1], [0])).T
stop = time.time()
print(f"time: {(stop-start)} [tensordot]")

time: 0.06764698028564453 [einsum]
time: 0.29248738288879395 [sum]
time: 0.10972046852111816 [tensordot]