Jumpat / SegAnyGAussians

The official implementation of SAGA (Segment Any 3D GAussians)
Apache License 2.0
466 stars 36 forks source link

Can't get the rendered masks in v2 branch #86

Closed shenggedeqiang closed 2 days ago

shenggedeqiang commented 3 days ago

Thanks for your great job!In v2 branch,I used the “python render.py -m <path to the pre-trained 3DGS model> --precomputed_mask --target seg” to get the 2D rendered masks,but it needs "seg_cfg_args".

So I modified the "--target seg" to "target scene --segment" and the render.py to use render_mask function like this:

image

But I got RuntimeError:

Traceback (most recent call last): File "render.py", line 137, in render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.segment, args.target, args.idx, args.precomputed_mask) File "render.py", line 106, in render_sets render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background, target, precomputed_mask=precomputed_mask) File "render.py", line 48, in render_set mask_res = render_mask(view, gaussians, pipeline, background, precomputed_mask=precomputed_mask) File "/mnt/disk2t/lrs/3DGaussian/SegAnyGAussians-2/gaussian_renderer/init.py", line 173, in render_mask rendered_mask, radii = rasterizer( File "/root/anaconda3/envs/saga/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/root/anaconda3/envs/saga/lib/python3.8/site-packages/diff_gaussian_rasterization/init.py", line 210, in forward return rasterize_gaussians( File "/root/anaconda3/envs/saga/lib/python3.8/site-packages/diff_gaussian_rasterization/init.py", line 32, in rasterize_gaussians return _RasterizeGaussians.apply( File "/root/anaconda3/envs/saga/lib/python3.8/site-packages/diff_gaussian_rasterization/init.py", line 92, in forward num_rendered, color, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(args) RuntimeError: expected scalar type Float but found Bool

Could you please tell me how to get the gt_mask and the rendered mask?

Jumpat commented 3 days ago

Hello. You need to add .float() to precomputed_mask for data type conversion so that the rasterizer can handle it.

shenggedeqiang commented 2 days ago

Hello. You need to add .float() to precomputed_mask for data type conversion so that the rasterizer can handle it.

It works,thanks a lot !