graphdeco-inria / gaussian-splatting

Original reference implementation of "3D Gaussian Splatting for Real-Time Radiance Field Rendering"
https://repo-sam.inria.fr/fungraph/3d-gaussian-splatting/
Other
13.91k stars 1.8k forks source link

External rendering : is the opacity taken into account during rasterization #830

Open EarltShirt opened 4 months ago

EarltShirt commented 4 months ago

Thanks a lot for your excellent work ! I'm trying to add Meta's Segment Anything Model to the repo, but some anomalies appear during my personally added rendering process which is handled via cv2 (Python OpenCV). When looking at the generated image using the interactive viewer, the image is well represented (gaussians with right colors and opacities), but when loading the image on cpu, converting it to numpy format, converting from float32 to uint8 and finally rendering the view using cv2 and PIL with the following code, some unwanted gaussians are being rendered (specifically the very bright and colorful ones. Here is my added code :

if iteration == 10000 :
      # Convert image to numpy array
      # print(f"dir(image) : {dir(image)}")
      seg_viewpoint_stack = scene.getTrainCameras().copy()
      seg_viewpoint_cam = seg_viewpoint_stack.pop(20)
      seg_render_pkg = render(seg_viewpoint_cam, gaussians, pipe, bg)
      seg_image = seg_render_pkg["render"]
      seg_image = seg_image.detach().cpu().numpy()
      segment_anything(seg_image)

def segment_anything(image): # image is given in float32, which isn't supported by PIL
    print("Reading image...")
    image = np.transpose(image, (1, 2, 0))
    print(f"Value example in image : {image[500,500,:]}")
    image = (image * 255).astype(np.uint8)
    # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    print("Loading model...")
    sam_checkpoint = "./sam_vit_h_4b8939.pth"
    model_type = "vit_h"
    device = "cuda"
    print("Loading model on device...")
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    predictor = SamPredictor(sam)
    predictor.set_image(image)
    interactive_segmenter = InteractiveSegmenter(image, predictor)
    interactive_segmenter.ask_for_number_of_points()
    interactive_segmenter.retrieve_all()

As you can see, some gaussians at the edge of the object (I'm using a white background during training in order to only generate gaussians on the given object), are being plotted with extremely bright colors, which aren't visible when looking at the object using the same view in the interactive viewer. Is it possible that the opacities aren't taken into account during the rendering using the GaussianRasterizer ? I suspect this to be the problem because this only occures at the edge of the object, where the opoacities are smaller than inside the object. Or could it be that some anomalies occure during the conversion from float32 to uint8 ?

I hope I was clear enough, here are the images highlighting the error (first image is the one rendered using cv2, the second is a screenshot from the SIBR_Viewer).

Thanks a lot for your time.

problematic image rendered

MACILLAS commented 4 months ago

Hey I don't know if you've solved this but I had similar problem. The Gaussian Rasterizer that render.py uses does use the opacity... My problem was I doing torchvision.transforms.ToPILImage()(output) however if you do torchvision.utils.save_image(output, "test.png") the artifacts you showed should disappear.

Update:

The cause is that the output (float32) has values greater than 1... So you will have to do a normalization first.

result = torchvision.transforms.ToPILImage()((result*255/result.max()).type(torch.uint8))

EarltShirt commented 2 months ago

@MACILLAS The artifacts disappear, but the whole image gets dimmed and the background is gray, do you know what happens in my case ? I saw you changed the background to grey in your code to overcome this problem, but this isn't really viable in my case. Thanks a lot !

EarltShirt commented 2 months ago

I finally used, which works perfectly for me

image = Image.open(img_pth)
im_data = np.array(image.convert("RGBA"))
norm_data = im_data / 255.0
arr = norm_data[:,:,:3] * norm_data[:,:,3:4] + self.bg * (1 - norm_data[:,:,3:4])
image = PILtoToch(Image.fromarray(np.array(arr*255.0,dtype=np.byte),"RGB"), image.size)[:3, ...]