jakubcerveny / gilbert

Space-filling curve for rectangular domains or arbitrary size.
BSD 2-Clause "Simplified" License
121 stars 11 forks source link

Use this as an alternative for the `flatten()` function in Pytorch? #16

Open drupol opened 1 week ago

drupol commented 1 week ago

Hello,

I’ve recently started new AI classes and I learned about PyTorch’s flatten() function, exploring its pros and cons. One of the notable issues is that the original function doesn't preserve spatial locality, which can be important in this contexts of AI and image processing.

This reminded me of the incredible @3blue1brown video on Hilbert curves (https://www.youtube.com/watch?v=3s7h2MHQtxc). I began wondering if there was an algorithm or method to flatten an image of any size with a Hilbert curve, preserving locality. And that’s what led me here.

Are you aware of an existing implementation for this and Pytorch, do you think this be a worthwhile feature to explore? The idea would be to use this algorithm to reshape a matrix into a simple vector.

Thank you!

Interesting reads:

drupol commented 1 week ago

I quickly drafted something, making minor changes here and there, and it seems to work.

Among the changes I made in the original algorithm, the biggest change was to remove the recursivity to gain performance by using a reference variable result which is passed through each inner call.

import torch
import numpy as np
import math

import torch
import numpy as np
import math

# Function to compute sign of a value
def sgn(x):
    return -1 if x < 0 else (1 if x > 0 else 0)

# Optimized Recursive generate2d function for Gilbert curve
def generate2d(x: int, y: int, ax: int, ay: int, bx: int, by: int, result):
    """Recursive generation of 2D coordinates using the Gilbert space-filling curve."""

    # Width and height of the grid to fill
    w = abs(ax + ay)
    h = abs(bx + by)

    # Direction vectors (calculated once and reused)
    dax, day = sgn(ax), sgn(ay)  # Major direction
    dbx, dby = sgn(bx), sgn(by)  # Orthogonal direction

    # Handle trivial row or column fills
    if h == 1 or w == 1:
        if h == 1:
            for _ in range(w):
                result.append((x, y))
                x, y = x + dax, y + day  # Inlining move_point
        elif w == 1:
            for _ in range(h):
                result.append((x, y))
                x, y = x + dbx, y + dby  # Inlining move_point
        return

    # Halve the movement vectors
    ax2, ay2 = ax // 2, ay // 2
    bx2, by2 = bx // 2, by // 2

    w2 = abs(ax2 + ay2)
    h2 = abs(bx2 + by2)

    if 2 * w > 3 * h:
        if w2 % 2 and w > 2:
            ax2, ay2 = ax2 + dax, ay2 + day

        generate2d(x, y, ax2, ay2, bx, by, result)
        generate2d(x + ax2, y + ay2, ax - ax2, ay - ay2, bx, by, result)

    else:
        if h2 % 2 and h > 2:
            bx2, by2 = bx2 + dbx, by2 + dby

        generate2d(x, y, bx2, by2, ax2, ay2, result)
        generate2d(x + bx2, y + by2, ax, ay, bx - bx2, by - by2, result)
        generate2d(x + (ax - dax) + (bx2 - dbx),
                   y + (ay - day) + (by2 - dby),
                   -bx2, -by2, -(ax - ax2), -(ay - ay2), result)

# Top-level gilbert2d function
def gilbert2d(width, height):
    result = []
    if width >= height:
        generate2d(0, 0, width, 0, 0, height, result)
    else:
        generate2d(0, 0, 0, height, width, 0, result)
    return result

# Optimized reshape function using batch updates
def reshape_via_gilbert(tensor, width=None, height=None, path=None):
    flattened_tensor = tensor.flatten()
    num_elements = flattened_tensor.numel()

    if width is None or height is None:
        if width is None and height is not None:
            # Automatically calculate width
            width = (num_elements + height - 1) // height
        if height is None and width is not None:
            # Automatically calculate height
            height = (num_elements + width - 1) // width
        if height is None and width is None:
            # Automatically calculate width and height
            height = height or math.isqrt(num_elements)
            width = width or (num_elements + height - 1) // height

    # Create an empty tensor to store the reshaped values
    reshaped_tensor = torch.zeros((height, width), dtype=tensor.dtype, device=tensor.device)

    if path is None:
        # Get the Gilbert curve path
        path = gilbert2d(width, height)

    # Convert path to list of index tensors (for batch update)
    idx_list = torch.tensor(path[:num_elements], dtype=torch.long, device=tensor.device)
    reshaped_tensor[idx_list[:, 1], idx_list[:, 0]] = flattened_tensor[:num_elements]

    return reshaped_tensor

# Example usage
tensor = torch.tensor([
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 10, 11, 12],
    [13, 14, 15, 16],
    [17, 18, 19, 20],
    [21, 22, 23, 24],
    [25, 26, 27, 28],
    [29, 30, 31, 32],
    [33, 34, 35, 36],
])

reshaped_tensor = reshape_via_gilbert(tensor)
print(reshaped_tensor)

reshaped_tensor = reshape_via_gilbert(tensor, width=5)
print(reshaped_tensor)

reshaped_tensor = reshape_via_gilbert(tensor, height=4)
print(reshaped_tensor)

reshaped_tensor = reshape_via_gilbert(tensor, width=8, height=8)
print(reshaped_tensor)

And the resulting tensor:

tensor([[ 1,  4,  5, 32, 33, 36],
        [ 2,  3,  6, 31, 34, 35],
        [11, 10,  7, 30, 27, 26],
        [12,  9,  8, 29, 28, 25],
        [13, 16, 17, 20, 21, 24],
        [14, 15, 18, 19, 22, 23]])
tensor([[ 1,  4,  5,  8,  9],
        [ 2,  3,  6,  7, 10],
        [19, 18, 15, 14, 11],
        [20, 17, 16, 13, 12],
        [21, 24, 25, 28, 29],
        [22, 23, 26, 27, 30],
        [ 0,  0, 35, 34, 31],
        [ 0,  0, 36, 33, 32]])
tensor([[ 1,  2, 15, 16, 17, 18, 34, 35, 36],
        [ 4,  3, 14, 13, 20, 19, 33, 32, 31],
        [ 5,  8,  9, 12, 21, 24, 25, 30, 29],
        [ 6,  7, 10, 11, 22, 23, 26, 27, 28]])
tensor([[ 1,  4,  5,  6,  0,  0,  0,  0],
        [ 2,  3,  8,  7,  0,  0,  0,  0],
        [15, 14,  9, 10,  0,  0,  0,  0],
        [16, 13, 12, 11,  0,  0,  0,  0],
        [17, 18, 31, 32, 33, 34,  0,  0],
        [20, 19, 30, 29, 36, 35,  0,  0],
        [21, 24, 25, 28,  0,  0,  0,  0],
        [22, 23, 26, 27,  0,  0,  0,  0]])
jakubcerveny commented 1 week ago

Hi, I agree that a better ordered flatten could be useful for AI, image compression, and maybe more.

In the code however, why generate the Gilbert curve over a shape that is close to a square? I think it would be better to use the original tensor shape directly, if the tensor is 2D or 3D.

Thanks for the links, the 3Blue1Brown video got me thinking if the Gilbert curve is actually stable in the limit like the Hilbert curve, which would be interesting to prove (or to fix the algorithm to make it true). A simple test would be to look at the curve over a [kn, km] grid, where k grows slowly (say by 0.1) and observe any abrupt changes or discontinuities, what do you think @abetusk?

I took a quick look at the linked paper. If point cloud processing is important to you, I would suggest using a Hilbert spatial sort (see e.g. https://doc.cgal.org/latest/Spatial_sorting/index.html) to "flatten" the point cloud directly, skipping voxelization altogether.

drupol commented 1 week ago

Thanks for your answer!

In the code however, why generate the Gilbert curve over a shape that is close to a square? I think it would be better to use the original tensor shape directly, if the tensor is 2D or 3D.

I'm still learning and trying to figure out the best way to do this. Don't expect production ready in the code I posted here, I'm just doing experimentations. That said, I've updated the code, we can now optionally add the width and height of the shape.

I took a quick look at the linked paper. If point cloud processing is important to you, I would suggest using a Hilbert spatial sort (see e.g. https://doc.cgal.org/latest/Spatial_sorting/index.html) to "flatten" the point cloud directly, skipping voxelization altogether.

Thanks for this going to check it out.

In the meantime, I posted a message on PyTorch's forums to have some insights on how I could optimize the algorithm using GPUs.

Link: https://discuss.pytorch.org/t/custom-flatten-function-using-gpu-acceleration/211830