mosure / bevy_gaussian_splatting

bevy gaussian splatting render pipeline plugin
https://mosure.github.io/bevy_gaussian_splatting?gaussian_count=1000
MIT License
134 stars 9 forks source link

fix transforms - dynamic offset using DynamicUniformIndex #18

Closed github-actions[bot] closed 9 months ago

github-actions[bot] commented 10 months ago

https://api.github.com/mosure/bevy_gaussian_splatting/blob/eedd27e0f32bdf33f324238284c915dc6e1574d2/src/render/mod.rs#L1147


            None => return RenderCommandResult::Failure,
        };

        pass.set_bind_group(2, &bind_groups.cloud_bind_group, &[]);
        pass.set_bind_group(3, &bind_groups.sorted_bind_group, &[]);

        pass.draw_indirect(&gpu_gaussian_cloud.draw_indirect_buffer, 0);

        RenderCommandResult::Success
    }
}

struct RadixSortNode {
    gaussian_clouds: QueryState<(
        &'static Handle<GaussianCloud>,
        &'static GaussianCloudBindGroup
    )>,
    initialized: bool,
    pipeline_idx: Option<u32>,
    view_bind_group: QueryState<(
        &'static GaussianViewBindGroup,
        &'static ViewUniformOffset,
    )>,
}

impl FromWorld for RadixSortNode {
    fn from_world(world: &mut World) -> Self {
        Self {
            gaussian_clouds: world.query(),
            initialized: false,
            pipeline_idx: None,
            view_bind_group: world.query(),
        }
    }
}

impl render_graph::Node for RadixSortNode {
    fn update(&mut self, world: &mut World) {
        let pipeline = world.resource::<GaussianCloudPipeline>();
        let pipeline_cache = world.resource::<PipelineCache>();

        if !self.initialized {
            let mut pipelines_loaded = true;
            for sort_pipeline in pipeline.radix_sort_pipelines.iter() {
                if let CachedPipelineState::Ok(_) =
                        pipeline_cache.get_compute_pipeline_state(*sort_pipeline)
                {
                    continue;
                }

                pipelines_loaded = false;
            }

            self.initialized = pipelines_loaded;

            if !self.initialized {
                return;
            }
        }

        if self.pipeline_idx.is_none() {
            self.pipeline_idx = Some(0);
        } else {
            self.pipeline_idx = Some((self.pipeline_idx.unwrap() + 1) % pipeline.radix_sort_pipelines.len() as u32);
        }

        self.gaussian_clouds.update_archetypes(world);
        self.view_bind_group.update_archetypes(world);
    }

    fn run(
        &self,
        _graph: &mut render_graph::RenderGraphContext,
        render_context: &mut RenderContext,
        world: &World,
    ) -> Result<(), render_graph::NodeRunError> {
        if !self.initialized || self.pipeline_idx.is_none() {
            return Ok(());
        }

        let _idx = self.pipeline_idx.unwrap() as usize; // TODO: temporal sort

        let pipeline_cache = world.resource::<PipelineCache>();
        let pipeline = world.resource::<GaussianCloudPipeline>();
        let gaussian_uniforms = world.resource::<GaussianUniformBindGroups>();

        let command_encoder = render_context.command_encoder();

        for (
            view_bind_group,
            view_uniform_offset,
        ) in self.view_bind_group.iter_manual(world) {
            for (
                cloud_handle,
                cloud_bind_group
            ) in self.gaussian_clouds.iter_manual(world) {
                let cloud = world.get_resource::<RenderAssets<GaussianCloud>>().unwrap().get(cloud_handle).unwrap();

                let radix_digit_places = ShaderDefines::default().radix_digit_places;

                command_encoder.clear_buffer(
                    &cloud.sorting_global_buffer,
                    0,
                    None,
                );

                {
                    let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());

                    // TODO: view/global
                    pass.set_bind_group(
                        0,
                        &view_bind_group.value,
                        &[view_uniform_offset.offset],
                    );
                    pass.set_bind_group(
                        1,
                        gaussian_uniforms.base_bind_group.as_ref().unwrap(),
                        &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
                    );
                    pass.set_bind_group(
                        2,
                        &cloud_bind_group.cloud_bind_group,
                        &[]
                    );
                    pass.set_bind_group(
                        3,
                        &cloud_bind_group.radix_sort_bind_groups[1],
                        &[],
                    );

                    let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap();
                    pass.set_pipeline(radix_sort_a);

                    let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a;
                    pass.dispatch_workgroups((cloud.count + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1);

                    let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap();
                    pass.set_pipeline(radix_sort_b);

                    pass.dispatch_workgroups(1, radix_digit_places, 1);
                }

                for pass_idx in 0..radix_digit_places {
                    if pass_idx > 0 {
                        let size = ShaderDefines::default().radix_base * ShaderDefines::default().max_tile_count_c * std::mem::size_of::<u32>() as u32;
                        command_encoder.clear_buffer(
                            &cloud.sorting_global_buffer,
                            0,
                            std::num::NonZeroU64::new(size as u64).unwrap().into()
                        );
                    }

                    let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default());

                    let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap();
                    pass.set_pipeline(&radix_sort_c);

                    pass.set_bind_group(
                        0,
                        &view_bind_group.value,
                        &[view_uniform_offset.offset],
                    );
                    pass.set_bind_group(
                        1,
                        gaussian_uniforms.base_bind_group.as_ref().unwrap(),
                        &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex
                    );
                    pass.set_bind_group(
                        2,
                        &cloud_bind_group.cloud_bind_group,
                        &[]
                    );
                    pass.set_bind_group(
                        3,
                        &cloud_bind_group.radix_sort_bind_groups[pass_idx as usize],
                        &[],
                    );

                    let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c;
                    pass.dispatch_workgroups(1, (cloud.count + workgroup_entries_c - 1) / workgroup_entries_c, 1);
                }
            }
        }

        Ok(())
    }
}
github-actions[bot] commented 9 months ago

Closed in 4356f87a7a5353e33e997297ed96714d11cdc4be

github-actions[bot] commented 9 months ago

Closed in 4356f87a7a5353e33e997297ed96714d11cdc4be