MarilynKeller / HIT

Official repo of the 2024 paper: HIT: Estimating Internal Human Implicit Tissues from the Body Surface
Other
59 stars 5 forks source link

Generate OCC for entire mesh #1

Open personal-coding opened 1 week ago

personal-coding commented 1 week ago

I want to start by saying this is amazing research. Is there a way to generate the occupancy for the entire mesh, instead of just a slice?

MarilynKeller commented 1 week ago

Hi, thanks :)

What do you mean by "for the entire mesh" ?

You can evaluate the occupancy in any 3D point, inside or outside the body mesh. So you can evaluate it for 3D points on a plane like in the slice example, or on 3D points on a 3D grid to evaluate a volume.

The meshes of the different tissues are generated using marching cube and sampling the occupancy function for a specific tissue.

Does this clarify?

personal-coding commented 1 week ago

@MarilynKeller I'm trying to calculate each tissues' percentage of the entire mesh similar to here. I was able to generate meshes for each tissue. But since each mesh is not enclosed, I'm not able to accurately calculate the volume of each. I was thinking of taking the occupancy prediction but for all points within the SMPL mesh, instead of a slice, to accomplish this. I tried updating the evaluate_slice function to predict points within the mesh, but the number of points consumed too much memory. There is likely an easier way to approach this, but it wasn't obvious to me.

Separately, the paper only predicts long bones. Did you consider predicting spine, ribs, skull, etc? Why did you choose long bones only, instead of including other skeletal items?

MarilynKeller commented 1 week ago

Ok I see, indeed this might be too many points to evaluate. The key is to only sample and evaluate points inside the body.

So, create a 3D array with coordinates u,v,w, decide on a resolution to convert these indices to 3D points x,y,z. For each of these points, test if they are inside the body mesh or not, this gives you a mask B (of shape U,V,W) And then only evaluate the occupancy of points B==True

You could have u in [0, 1000] (index) x in [-0.5, 0.5] meters

This could still give you 3e6 points to evaluate so one way is to evaluate the occupancy by batches or 3e5 or smaller, depending on you manage to load in memory.

I'd recommend starting with a low resolution, like points spaced by 5cm to test your implementation, and then lower it to see at which resolution you converge towards a fixed set of tissues percentage.

As for the bones, we only predict the long bones because the other ones were too small or thin to be segmented on MRI. With adequate dataset, the approach could be extended to all bones. It all depends on the training dataset.

personal-coding commented 5 days ago

@MarilynKeller This is really helpful. Filtering points to only those within the mesh would greatly reduce the amount of points to search. As a test, I tried to filter out points outside the mesh from your slice example. Specifically, I filtered all points outside the mesh and arbitrarily changed them to [-0.9000, -1.2000, 0.0000]. This was to ensure all points outside the mesh were moved to a point known to return an empty prediction. However, the resulting slice below looks different than the generated example slice. The adipose tissue is skinnier. Does this imply the model is predicting tissues outside the mesh, or did I make a mistake? The first image is my test, and the second image is the example slice.

    def evaluate_slice(self, batch, smpl_output, z0, axis='z', values=["occ", "sw", "beta", "fwd_beta"], res=0.01):
        """ 
        Infer the different values on a slice of the 3D space
        Args:
            batch : dict containing the input smpl parameters
            smpl_output : the output of the SMPL model
            z0 : the z coordinate of the slice
            axis : the axis of the slice
            values : the values to infer on the slice, can contain :
                "occ" (occupancy) 
                "sw" (skinning weights)
            res : the size of a slice pixel (in meters) 
        Returns:
            out_images : list of pillow images of the different values
        """

        sl = SliceLevelSet(nbins=10, xbounds=[-0.2,0.2], ybounds=[-0.2,0.2], z_bounds=[-0.2, 0.2], res=res)

        xc = sl.gen_slice_points(z0=z0, axis=axis)
        xc_batch = torch.FloatTensor(xc).to(batch['betas'].device).unsqueeze(0).expand(batch['betas'].shape[0], -1, -1)

        points = xc_batch.squeeze(0).cpu().numpy()

        #Generate x-pose
        smpl = self.smpl
        betas = smpl_output.betas
        smpl_output_xpose = smpl.forward(betas=betas, body_pose=smpl.x_cano().to(betas.device), global_orient=None,
                                         transl=None)

        #Extract vertices
        if isinstance(smpl_output_xpose.vertices, torch.Tensor):
            verts_numpy = smpl_output_xpose.vertices.squeeze().cpu().numpy()
        else:
            verts_numpy = smpl_output_xpose.vertices

        #Check which vertices are inside the x-pose mesh
        mesh = trimesh.Trimesh(verts_numpy, smpl_output_xpose.faces, process=False)
        from leap.tools.libmesh import check_mesh_contains
        is_inside = check_mesh_contains(mesh, points).astype(float)
        is_outside = torch.FloatTensor(is_inside).to(smpl_output_xpose.vertices.device) != 1
        is_outside = is_outside[None, :]

        #Adjust all points outside the mesh to an arbitrary point
        xc_batch[is_outside] = torch.FloatTensor([-0.9000, -1.2000,  0.0000]).to(smpl_output.vertices.device)

occ_y=0 0

occ_y=0 0