aras-p / UnityGaussianSplatting

Toy Gaussian Splatting visualization in Unity
MIT License
1.94k stars 217 forks source link

Encode splats in a compute shader #83

Open cecarlsen opened 6 months ago

cecarlsen commented 6 months ago

There is a LoadAndDecodeVector method in GaussianSplattering.hlsl. But no EncodeAndStoreVector. I want to manipulate the splats for various experiments directly in GPU memory, and for that I need to encode the vectors back again.

So I need the inverse of this:

float3 LoadAndDecodeVector(SplatBufferDataType dataBuffer, uint addrU, uint fmt)
{
    uint addrA = addrU & ~0x3;

    uint val0 = dataBuffer.Load(addrA);

    float3 res = 0;
    if (fmt == VECTOR_FMT_32F)
    {
        uint val1 = dataBuffer.Load(addrA + 4);
        uint val2 = dataBuffer.Load(addrA + 8);
        if (addrU != addrA)
        {
            uint val3 = dataBuffer.Load(addrA + 12);
            val0 = (val0 >> 16) | ((val1 & 0xFFFF) << 16);
            val1 = (val1 >> 16) | ((val2 & 0xFFFF) << 16);
            val2 = (val2 >> 16) | ((val3 & 0xFFFF) << 16);
        }
        res = float3(asfloat(val0), asfloat(val1), asfloat(val2));
    }
    else if (fmt == VECTOR_FMT_16)
    {
        uint val1 = dataBuffer.Load(addrA + 4);
        if (addrU != addrA)
        {
            val0 = (val0 >> 16) | ((val1 & 0xFFFF) << 16);
            val1 >>= 16;
        }
        res = DecodePacked_16_16_16(uint2(val0, val1));
    }
    else if (fmt == VECTOR_FMT_11)
    {
        uint val1 = dataBuffer.Load(addrA + 4);
        if (addrU != addrA)
        {
            val0 = (val0 >> 16) | ((val1 & 0xFFFF) << 16);
        }
        res = DecodePacked_11_10_11(val0);
    }
    else if (fmt == VECTOR_FMT_6)
    {
        if (addrU != addrA)
            val0 >>= 16;
        res = DecodePacked_6_5_5(val0);
    }
    return res;
}

The inverse exists in GaussianSplatAssetCreator.cs

static ulong EncodeFloat3ToNorm16(float3 v) // 48 bits: 16.16.16
{
    return (ulong) (v.x * 65535.5f) | ((ulong) (v.y * 65535.5f) << 16) | ((ulong) (v.z * 65535.5f) << 32);
}
static uint EncodeFloat3ToNorm11(float3 v) // 32 bits: 11.10.11
{
    return (uint) (v.x * 2047.5f) | ((uint) (v.y * 1023.5f) << 11) | ((uint) (v.z * 2047.5f) << 21);
}
static ushort EncodeFloat3ToNorm655(float3 v) // 16 bits: 6.5.5
{
    return (ushort) ((uint) (v.x * 63.5f) | ((uint) (v.y * 31.5f) << 6) | ((uint) (v.z * 31.5f) << 11));
}
static ushort EncodeFloat3ToNorm565(float3 v) // 16 bits: 5.6.5
{
    return (ushort) ((uint) (v.x * 31.5f) | ((uint) (v.y * 63.5f) << 5) | ((uint) (v.z * 31.5f) << 11));
}

static uint EncodeQuatToNorm10(float4 v) // 32 bits: 10.10.10.2
{
    return (uint) (v.x * 1023.5f) | ((uint) (v.y * 1023.5f) << 10) | ((uint) (v.z * 1023.5f) << 20) | ((uint) (v.w * 3.5f) << 30);
}

static unsafe void EmitEncodedVector(float3 v, byte* outputPtr, GaussianSplatAsset.VectorFormat format)
{
    switch (format)
    {
        case GaussianSplatAsset.VectorFormat.Float32:
        {
            *(float*) outputPtr = v.x;
            *(float*) (outputPtr + 4) = v.y;
            *(float*) (outputPtr + 8) = v.z;
        }
            break;
        case GaussianSplatAsset.VectorFormat.Norm16:
        {
            ulong enc = EncodeFloat3ToNorm16(math.saturate(v));
            *(uint*) outputPtr = (uint) enc;
            *(ushort*) (outputPtr + 4) = (ushort) (enc >> 32);
        }
            break;
        case GaussianSplatAsset.VectorFormat.Norm11:
        {
            uint enc = EncodeFloat3ToNorm11(math.saturate(v));
            *(uint*) outputPtr = enc;
        }
            break;
        case GaussianSplatAsset.VectorFormat.Norm6:
        {
            ushort enc = EncodeFloat3ToNorm655(math.saturate(v));
            *(ushort*) outputPtr = enc;
        }
            break;
    }
}

But I have difficulties translating that into HLSL. Did someone already do this?