mosure / bevy_gaussian_splatting

bevy gaussian splatting render pipeline plugin
https://mosure.github.io/bevy_gaussian_splatting?gaussian_count=1000
MIT License
140 stars 9 forks source link

allow setting shader defines via API #15

Open github-actions[bot] opened 11 months ago

github-actions[bot] commented 11 months ago

https://api.github.com/mosure/bevy_gaussian_splatting/blob/eedd27e0f32bdf33f324238284c915dc6e1574d2/src/render/mod.rs#L509


            ],
        });

        let sorting_buffer_entry = BindGroupLayoutEntry {
            binding: 1,
            visibility: ShaderStages::COMPUTE,
            ty: BindingType::Buffer {
                ty: BufferBindingType::Storage { read_only: false },
                has_dynamic_offset: false,
                min_binding_size: BufferSize::new(ShaderDefines::default().sorting_buffer_size as u64),
            },
            count: None,
        };

        let draw_indirect_buffer_entry = BindGroupLayoutEntry {
            binding: 2,
            visibility: ShaderStages::COMPUTE,
            ty: BindingType::Buffer {
                ty: BufferBindingType::Storage { read_only: false },
                has_dynamic_offset: false,
                min_binding_size: BufferSize::new(std::mem::size_of::<wgpu::util::DrawIndirect>() as u64),
            },
            count: None,
        };

        let radix_sort_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
            label: Some("radix_sort_layout"),
            entries: &[
                BindGroupLayoutEntry {
                    binding: 0,
                    visibility: ShaderStages::COMPUTE,
                    ty: BindingType::Buffer {
                        ty: BufferBindingType::Uniform,
                        has_dynamic_offset: false,
                        min_binding_size: BufferSize::new(std::mem::size_of::<u32>() as u64),
                    },
                    count: None,
                },
                sorting_buffer_entry,
                draw_indirect_buffer_entry,
                BindGroupLayoutEntry {
                    binding: 3,
                    visibility: ShaderStages::COMPUTE,
                    ty: BindingType::Buffer {
                        ty: BufferBindingType::Storage { read_only: false },
                        has_dynamic_offset: false,
                        min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64),
                    },
                    count: None,
                },
                BindGroupLayoutEntry {
                    binding: 4,
                    visibility: ShaderStages::COMPUTE,
                    ty: BindingType::Buffer {
                        ty: BufferBindingType::Storage { read_only: false },
                        has_dynamic_offset: false,
                        min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64),
                    },
                    count: None,
                },
            ],
        });

        let sorted_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor {
            label: Some("sorted_layout"),
            entries: &vec![
                BindGroupLayoutEntry {
                    binding: 5,
                    visibility: ShaderStages::VERTEX_FRAGMENT,
                    ty: BindingType::Buffer {
                        ty: BufferBindingType::Storage { read_only: true },
                        has_dynamic_offset: false,
                        min_binding_size: BufferSize::new(std::mem::size_of::<(u32, u32)>() as u64),
                    },
                    count: None,
                },
            ],
        });

        let compute_layout = vec![
            view_layout.clone(),
            gaussian_uniform_layout.clone(),
            gaussian_cloud_layout.clone(),
            radix_sort_layout.clone(),
        ];
        let shader = GAUSSIAN_SHADER_HANDLE.typed();
        let shader_defs = shader_defs(false, false);

        let pipeline_cache = render_world.resource::<PipelineCache>();
        let radix_sort_a = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
            label: Some("radix_sort_a".into()),
            layout: compute_layout.clone(),
            push_constant_ranges: vec![],
            shader: shader.clone(),
            shader_defs: shader_defs.clone(),
            entry_point: "radix_sort_a".into(),
        });

        let radix_sort_b = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
            label: Some("radix_sort_b".into()),
            layout: compute_layout.clone(),
            push_constant_ranges: vec![],
            shader: shader.clone(),
            shader_defs: shader_defs.clone(),
            entry_point: "radix_sort_b".into(),
        });

        let radix_sort_c = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
            label: Some("radix_sort_c".into()),
            layout: compute_layout.clone(),
            push_constant_ranges: vec![],
            shader: shader.clone(),
            shader_defs: shader_defs.clone(),
            entry_point: "radix_sort_c".into(),
        });

        let temporal_sort_flip = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
            label: Some("temporal_sort_flip".into()),
            layout: compute_layout.clone(),
            push_constant_ranges: vec![],
            shader: shader.clone(),
            shader_defs: shader_defs.clone(),
            entry_point: "temporal_sort_flip".into(),
        });

        let temporal_sort_flop = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
            label: Some("temporal_sort_flop".into()),
            layout: compute_layout.clone(),
            push_constant_ranges: vec![],
            shader: shader.clone(),
            shader_defs: shader_defs.clone(),
            entry_point: "temporal_sort_flop".into(),
        });

        GaussianCloudPipeline {
            gaussian_cloud_layout,
            gaussian_uniform_layout,
            view_layout,
            shader: shader.clone(),
            radix_sort_layout,
            radix_sort_pipelines: [
                radix_sort_a,
                radix_sort_b,
                radix_sort_c,
            ],
            temporal_sort_pipelines: [
                temporal_sort_flip,
                temporal_sort_flop,
            ],
            sorted_layout,
        }
    }
}

// TODO: allow setting shader defines via API
struct ShaderDefines {
    radix_bits_per_digit: u32,
    radix_digit_places: u32,
    radix_base: u32,
    entries_per_invocation_a: u32,
    entries_per_invocation_c: u32,
    workgroup_invocations_a: u32,
    workgroup_invocations_c: u32,
    workgroup_entries_a: u32,
    workgroup_entries_c: u32,
    max_tile_count_c: u32,
    sorting_buffer_size: usize,

    temporal_sort_window_size: u32,
}

impl Default for ShaderDefines {
    fn default() -> Self {
        let radix_bits_per_digit = 8;
        let radix_digit_places = 32 / radix_bits_per_digit;
        let radix_base = 1 << radix_bits_per_digit;
        let entries_per_invocation_a = 4;
        let entries_per_invocation_c = 4;
        let workgroup_invocations_a = radix_base * radix_digit_places;
        let workgroup_invocations_c = radix_base;
        let workgroup_entries_a = workgroup_invocations_a * entries_per_invocation_a;
        let workgroup_entries_c = workgroup_invocations_c * entries_per_invocation_c;
        let max_tile_count_c = (10000000 + workgroup_entries_c - 1) / workgroup_entries_c;
        let sorting_buffer_size = (
            radix_base as usize *
            (radix_digit_places as usize + max_tile_count_c as usize) *
            std::mem::size_of::<u32>()
        ) + std::mem::size_of::<u32>() * 5;

        Self {
            radix_bits_per_digit,
            radix_digit_places,
            radix_base,
            entries_per_invocation_a,
            entries_per_invocation_c,
            workgroup_invocations_a,
            workgroup_invocations_c,
            workgroup_entries_a,
            workgroup_entries_c,
            max_tile_count_c,
            sorting_buffer_size,

            temporal_sort_window_size: 16,
        }
    }
}

fn shader_defs(
    aabb: bool,
    visualize_bounding_box: bool,
) -> Vec<ShaderDefVal> {
    let defines = ShaderDefines::default();
    let mut shader_defs = vec![
        ShaderDefVal::UInt("MAX_SH_COEFF_COUNT".into(), MAX_SH_COEFF_COUNT as u32),
        ShaderDefVal::UInt("RADIX_BASE".into(), defines.radix_base),
        ShaderDefVal::UInt("RADIX_BITS_PER_DIGIT".into(), defines.radix_bits_per_digit),
        ShaderDefVal::UInt("RADIX_DIGIT_PLACES".into(), defines.radix_digit_places),
        ShaderDefVal::UInt("ENTRIES_PER_INVOCATION_A".into(), defines.entries_per_invocation_a),
        ShaderDefVal::UInt("ENTRIES_PER_INVOCATION_C".into(), defines.entries_per_invocation_c),
        ShaderDefVal::UInt("WORKGROUP_INVOCATIONS_A".into(), defines.workgroup_invocations_a),
        ShaderDefVal::UInt("WORKGROUP_INVOCATIONS_C".into(), defines.workgroup_invocations_c),
        ShaderDefVal::UInt("WORKGROUP_ENTRIES_C".into(), defines.workgroup_entries_c),
        ShaderDefVal::UInt("MAX_TILE_COUNT_C".into(), defines.max_tile_count_c),

        ShaderDefVal::UInt("TEMPORAL_SORT_WINDOW_SIZE".into(), defines.temporal_sort_window_size),
    ];

    if aabb {
        shader_defs.push("USE_AABB".into());
    }

    if !aabb {
        shader_defs.push("USE_OBB".into());
    }

    if visualize_bounding_box {
        shader_defs.push("VISUALIZE_BOUNDING_BOX".into());
    }

    shader_defs
}

#[derive(PartialEq, Eq, Hash, Clone, Copy)]
pub struct GaussianCloudPipelineKey {
    pub aabb: bool,