eigenvivek / DiffDRR

Auto-differentiable digitally reconstructed radiographs in PyTorch
https://vivekg.dev/DiffDRR
MIT License
131 stars 18 forks source link

Regarding the DRR parameter issue #324

Open linquanxu opened 1 month ago

linquanxu commented 1 month ago

@eigenvivek hello, I segmented and cropped an L4 spine based on the original DICOM data, and the results of the DRR rendering are as follows: image drr_image_plot

My original parameters are as follows:

source_dicom
PixelSpacing: [0.40625,0.40625]
SliceThickness: 1.0
ImagePositionPatient: [-101.796875, -188.296875, -409.2]
ImageOrientationPatient: [1, 0, 0, 0, 1, 0]
Rows: 512
Columns: 512
DistanceSourceToDetector: 1085.6
DistanceSourceToPatient: 595.0

crop vertebrae_L4
shape: (228, 199, 49)
affine: array([[-0.40625,  0.     ,  0.     , -0.     ],
       [ 0.     , -0.40625,  0.     , -0.     ],
       [ 0.     ,  0.     ,  1.     ,  0.     ],
       [ 0.     ,  0.     ,  0.     ,  1.     ]])

The configuration parameters for the DRR are as follows:

    SDD = 1085.6
    HEIGHT = 200
    DELX = 0.40625
    # Initialize the DRR module for generating synthetic X-rays
    drr = DRR(subject, sdd=SDD, height=HEIGHT, delx=DELX).to(device)

    # Make a posea
    rot = torch.tensor([[0.0, 0.0, 0.0]], device=device) / 180 * torch.pi
    xyz = torch.tensor([[0.0, 595.0, 0.0]], device=device)
    pose = convert(rot, xyz, parameterization="euler_angles", convention="ZXY")

My DRR parameter configuration is inconsistent with the configuration you provided for cxr.nii.gz, mainly regarding delx and by. I analyzed the cxr.nii.gz file and found that the space for cxr.nii is (0.703125, 0.703125, 2.5). Why is it set to 2.0 in your code instead of 0.703125? Regarding by, your code sets it to 800; how was that determined? What should I set it to in my case?

cxr.nii.gz
shape: (512, 512, 133)
spacing: (0.703125, 0.703125, 2.5)
affine: array([[  -0.703125  ,   -0.        ,    0.        ,  166.        ],
       [   0.        ,    0.703125  ,    0.        , -187.59687805],
       [   0.        ,    0.        ,    2.5       , -340.        ],
       [   0.        ,    0.        ,    0.        ,    1.        ]])

your code:

drr = DRR(
    subject,     # An object storing the CT volume, origin, and voxel spacing
    sdd=1020.0,  # Source-to-detector distance (i.e., focal length)
    height=200,  # Image height (if width is not provided, the generated DRR is square)
    delx=2.0,    # Pixel spacing (in mm)
).to(device)

# Set the camera pose with rotations (yaw, pitch, roll) and translations (x, y, z)
rotations = torch.tensor([[0.0, 0.0, 0.0]], device=device)
translations = torch.tensor([[0.0, 850.0, 0.0]], device=device)
eigenvivek commented 1 month ago

Looks good!

I analyzed the cxr.nii.gz file and found that the space for cxr.nii is (0.703125, 0.703125, 2.5). Why is it set to 2.0 in your code instead of 0.703125?

You're confusing the pixel spacing in CT vs the pixel spacing in X-ray. Those two are independent. (0.703125, 0.703125, 2.5) is the XYZ spacing of the CT. 2.0 is the spacing I set for the X-ray detector.

Choosing that, along with by=800, were just arbitrary choices to demonstrate the renderer. It's not based on a real X-ray system, but rather just an example to show how the renderer works.

eigenvivek commented 1 month ago

For rendering X-rays based on a real physical imaging system, see DiffPose: https://github.com/eigenvivek/DiffPose

linquanxu commented 1 month ago

@eigenvivek , Thank you for your answer. What you mean is that the parameters sdd, height, and delx are all set based on the real X-ray images equipment. As for by, it is just a suitable initial value.

1 For my real X-ray, delx = 1.0. In the code above, the drr height is set to 200. To achieve registration, I use the real X-ray image. I cropped the spine in the real X-ray image, resulting in a 300x300 image. To maintain consistency with the final imaging of the drr, I need to resize the cropped image to get a 200x200 intraoperative real X-ray image. So, is delx = 1.5 in this case?

seg_ct_path = 'data/L4.nii.gz'
    subject = read(seg_ct_path)
    # subject = load_example_ct()

    # Make a mesh from the CT volume
    ct = drr_to_mesh(subject, "surface_nets", threshold=225, verbose=True)

    SDD = 1085.6
    HEIGHT = 200
    # DELX = 0.40625
    DELX = 1
    # Initialize the DRR module for generating synthetic X-rays
    drr = DRR(subject, sdd=SDD, height=HEIGHT, delx=DELX).to(device)

    # Make a posea
    rot = torch.tensor([[0.0, 0.0, 0.0]], device=device) / 180 * torch.pi
    xyz = torch.tensor([[0.0, 595.0, 0.0]], device=device)
    pose = convert(rot, xyz, parameterization="euler_angles", convention="ZXY")

    points = torch.tensor([[[0., 0., 0.,],
                        [-100., 0., 0.]]], device=device)
    points_det = drr.perspective_projection(pose, points)
    print(points_det)

    points_3d_det = drr.inverse_projection(pose, points_det,595)
    print(points_3d_det)
tensor([[[100.0000, 100.0000],
         [282.4538, 100.0000]]], device='cuda:0')
tensor([[[   0.0000,    0.0000,    0.0000],
         [-100.0000,    0.0000,    0.0000]]], device='cuda:0')

2 Additionally, when calculating the perspective projection coordinates, I set the initial two points to (0, 0, 0) and (-100, 0, 0). The projected 2D coordinates are (100.0000, 100.0000) and (282.4538, 1000). Clearly, the second point is incorrect because the drr image is 200x200, and 282.4538 exceeds 200. What could be the reason for this?

The DRR results are as follows: image

image

Lastly, I modified the inverse_projection code by adding a bias, set to 595, which successfully restored the image.

@patch
def inverse_projection(
    self: DRR,
    pose: RigidTransform,
    pts: torch.Tensor,
    bias: torch.float32
):
    """Backproject points in pixel plane (2D) onto the image plane in world coordinates (3D)."""
    pts = pts.flip(-1)
    if self.detector.reverse_x_axis:
        pts[..., 1] = self.detector.width - pts[..., 1]
    # x = self.detector.sdd * torch.einsum(
    #     "ij, bnj -> bni",
    #     self.detector.intrinsic.inverse(),
    #     pad(pts, (0, 1), value=1),  # Convert to homogenous coordinates
    # )

    x = bias * torch.einsum(
        "ij, bnj -> bni",
        self.detector.intrinsic.inverse(),
        pad(pts, (0, 1), value=1),  # Convert to homogenous coordinates
    )
    extrinsic = self.detector.reorient.compose(pose)
    return extrinsic(x)