eigenvivek / DiffDRR

Auto-differentiable digitally reconstructed radiographs in PyTorch
https://vivekg.dev/DiffDRR
MIT License
131 stars 17 forks source link

Interpolate voxel values instead of nearest neighbor in siddon #186

Closed gmedan closed 7 months ago

gmedan commented 7 months ago

The nearest-neighbor approach taken in the implementation of Siddon's method (_get_voxel: torch.take(volume, idxs)) produces visible artifacts in DRRs, especially when derived from coarse volumes. Another approach would be to interpolate the density value at the mid-alpha position and use that value instead of the nearest voxel. Although there is a performance penalty, the results are smoother and more aesthetic:

image

The implementation is based on https://github.com/sbarratt/torch_interpolations (slightly modified), and refactors siddon_raycast into a torch.nn.Module.

I haven't examined the performance cost or how it behaves under autodiff and pose regression. Relates to #164 #172

eigenvivek commented 7 months ago

Wow that result is awesome. I knew the one-ray-per-pixel approach behind Siddon's method was aliased, but I didn't realize how much the aliasing could be abated with trilinear interpolation.

A few thoughts w.r.t. performance and implementation:

A potential improvement:

A previous approach:

Screenshot 2024-02-08 at 08 15 49
eigenvivek commented 7 months ago

I think this should definitely be added to to DiffDRR as a rendering backend.

Also, thanks for teaching me about the @property decorator! Love learning new python tricks.

eigenvivek commented 7 months ago

One clarification is that Siddon's method is the exact computation of the line integral over a voxelgrid (see the figure below). That is, the nearest-neighbor approach is what makes it Siddon's method. Trilinear interpolation-based rendering would be a distinct algorithm.

Screenshot 2024-02-08 at 10 05 23

gmedan commented 7 months ago

Thanks!

  • Perhaps the interpolation could be performed with torch.nn.functional.interpolate instead of Shane Barratt's package? Should be faster since it has dedicated C++ kernels.

From my understanding torch.nn.functional.interpolate is a scaling transform which means it resamples on a regular grid, making it not suitable for raycasting

  • If we're using trilinear, there's no need to calculate the alphamid tensor. Instead, you could just define the rays with source and target, and uniformly sample along them. Skipping the intermediate computation will probably speed it up.

Interesting idea and also easy to implement, I tried and the results appear visually similar but no significant speedup. Updated the code gist with the new option

I think this should definitely be added to to DiffDRR as a rendering backend.

I don't have the time to make a proper PR, but if you wish to incorporate it go ahead!

  • Instead of doing trilinear interpolation along a ray, it could be cool to do trilinear interpolation along a cone cast from the source to the target (like the approach behind mip-NeRF). I think casting cones is the ultimate approach to minimizing aliasing.

Yes, this approach should be better equipped to handle low resolution DRRs

Also, thanks for teaching me about the @property decorator! Love learning new python tricks.

Super useful, also note the usage of einops.rearrange instead of tensor.view/unsqueeze which makes the code more readable

One clarification is that Siddon's method is the exact computation of the line integral over a voxelgrid (see the figure below). That is, the nearest-neighbor approach is what makes it Siddon's method. Trilinear interpolation-based rendering would be a distinct algorithm.

Right, but I think that the hidden assumption is that the density is voxel-wise constant, whereas in reality it is just the average density in the space of the voxel. Interpolating indeed eliminates the need for finding the grid intersection points.

gmedan commented 7 months ago

Also the performance penalty for interpolating seems to be roughly 2x rendering time.

gmedan commented 7 months ago

Actually, it's possible to gain some performance back even when interpolating, by sampling fewer uniform distances - and without a significant impact on rendering quality if subsampling up to ~3x. Higher than that and it starts looking like this: image

This downsampling could be used for training the pose regressor, with slower full ray sampling used for fine tuning the pose.

eigenvivek commented 7 months ago

What's your base sampling rate? ~300 samples per ray?

gmedan commented 7 months ago

Took the max dimension of the volume as base

eigenvivek commented 7 months ago

I was looking for torch.nn.functional.grid_sample, not torch.nn.functional.interpolate.

Kernelized trilinear interpolation is muuuuch faster than the pure Pytorch code.

Renderings below compare Siddon's method against trilinear interpolation with different numbers of points per ray. At ~200 points per ray, trilinear is indistinguishable from Siddon's method (to me) while being ~6X faster. You need to sample ~2000 points for trilinear to be as slow as Siddon's method.

You can expect a trilinear backend in DiffDRR soon.

This might also help us break the 1 sec threshold for registration in DiffPose!

Screenshot 2024-02-08 at 21 17 29
eigenvivek commented 7 months ago

Actually the more I look at it, the more I think trilinear interpolation @ 200 points / ray looks even better than the Siddon's method rendering.

gmedan commented 7 months ago

Glad this helped. I'll try the new released version soon

Actually the more I look at it, the more I think trilinear interpolation @ 200 points / ray looks even better than the Siddon's method rendering.

image