NVIDIA / warp

A Python framework for high performance GPU simulation and graphics
https://nvidia.github.io/warp/
Other
4.27k stars 243 forks source link

[BUG] Periodic boundary conditions in HashGrid are not correctly traversed #283

Open fritzio opened 3 months ago

fritzio commented 3 months ago

Bug Description

The HashGrid does not correctly traverse cells through the periodic boundary condition. It does for particles close to the right edge - but not for particles in the vicinity of the left edge though.

Below you find a working example, which allocates particles in the center of the cells and counts the number of neighboring cells using the HashGrid. It should produce a neighbor_count of 27 for every particle. However, particles close to the origin only have 8 neighboring particles. It seems that cells are only visited through the upper boundary - but not through the lower boundary.

import warp as wp

wp.init()

if __name__ == "__main__":

    r = wp.array(shape=1000, dtype=wp.vec3)

    # initialize the gas grid unsing 1000 cells 
    grid = wp.HashGrid(dim_x=10, dim_y=10, dim_z=10)

    # grid size
    dx = float(0.1)

    # set the points into the resepective cell centers
    @wp.kernel
    def initialize_particles(r: wp.array(dtype=wp.vec3), dx: float):

        tid = wp.tid()

        x = float(tid % 10) + 0.5
        y = float(tid / 10 % 10) + 0.5
        z = float(tid / (10 * 10)) + 0.5

        r[tid] = dx * wp.vec3(x, y, z)

    wp.launch(kernel=initialize_particles, dim=r.shape, inputs=[r, dx])

    # update the nearest neighbor list
    grid.build(points=r, radius=dx)

    neighbor_count = wp.array(shape=1000, dtype=wp.int32)

    # count the number of neighboring particles including particles through the periodic boundary condition, which
    # should be exclusively 27 for all particles
    @wp.kernel
    def count_neighboring_particles(grid: wp.uint64,
                                    r: wp.array(dtype=wp.vec3),
                                    neighbor_count: wp.array(dtype=wp.int32),
                                    dx: float):

        tid = wp.tid()

        i = wp.hash_grid_point_id(grid, tid)

        tmp = wp.int32(0)

        # loop over neighboring particles
        for _ in wp.hash_grid_query(grid, r[i], dx):
            tmp += wp.int32(1)

        neighbor_count[i] = tmp

    wp.launch(kernel=count_neighboring_particles, dim=r.shape, inputs=[grid.id, r, neighbor_count, dx])

    print(r)
    print(neighbor_count)

If something is not correct from our side in the code sample, please let us know.

System Information

Warp version: 1.3.0 CUDA version: 12.5 OS: Ubuntu 22.04 Python version: 3.12.4

fritzio commented 3 months ago

Hi @mmacklin

In the hash_grid_query function, the start coordinates for neighboring cells are not correctly converted for cells adjacent to the origin, i.e., if the own cell index is zero, then pos[0] < radius (and equivalent for pos[1] and pos[2]) and the x_start coordinates are not correctly converted. For example, query.x_start = int((pos[0]-radius)*query.grid.cell_width_inv); gives for pos[0] = 0.5 and radius = 1.0 a value of 0, which is the cell itself and not -1.

I propose in hashgrid.h to change

query.x_start = int((pos[0] - radius) * query.grid.cell_width_inv)

to

query.x_start = int((pos[0] - radius) * query.grid.cell_width_inv + query.grid.dim_x) - query.grid.dim_x

(and equivalent for the other start coordinates), which should produce the correct cell values. I think subtracting the grid dimension again is technically not needed, though.

sudo-panda commented 3 months ago

@fritzio I believe flooring the double values before converting them to integers is the way to go. I have tested it out, and it gives me the correct number of neighbours:

    query.x_start = int(floor((pos[0]-radius)*query.grid.cell_width_inv));
    query.y_start = int(floor((pos[1]-radius)*query.grid.cell_width_inv));
    query.z_start = int(floor((pos[2]-radius)*query.grid.cell_width_inv));

@mmacklin I just have one doubt regarding this, can x_start have negative integers?