wdm0006 / pygeohash

Python module for interacting with geohashes
153 stars 25 forks source link

coding/decoding vector optimisation #8

Open IlyasMoutawwakil opened 3 years ago

IlyasMoutawwakil commented 3 years ago

is there any way to optimize the code for faster decoding in particular on multiple geohashes?

DahnJ commented 3 years ago

I was wondering about the same, so tried to quickly write encode in Numba

import numpy as np
import numba

__base32 = '0123456789bcdefghjkmnpqrstuvwxyz'

@numba.njit()
def encode_numba(latitude, longitude):
    precision = 12
    lat_interval = (-90.0, 90.0)
    lon_interval = (-180.0, 180.0)
    geohash = np.zeros(precision, dtype='<U1')
    bits = np.array([16, 8, 4, 2, 1])
    bit = 0
    ch = 0
    n = 0
    even = True
    while n < precision:
        if even:
            mid = (lon_interval[0] + lon_interval[1]) / 2
            if longitude > mid:
                ch |= bits[bit]
                lon_interval = (mid, lon_interval[1])
            else:
                lon_interval = (lon_interval[0], mid)
        else:
            mid = (lat_interval[0] + lat_interval[1]) / 2
            if latitude > mid:
                ch |= bits[bit]
                lat_interval = (mid, lat_interval[1])
            else:
                lat_interval = (lat_interval[0], mid)
        even = not even

        if bit < 4:
            bit += 1
        else: 
            geohash[n] = __base32[ch]
            bit = 0
            ch = 0
            n += 1

    return ''.join(geohash)

This already provides a significant speedup

from pygeohash import encode

%timeit encode(50, 14)
# 15.3 µs ± 422 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit encode_numba(50, 14)
# 2.81 µs ± 84.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

It would make sense to possibly also use int representation. As a quick PoC, I just changed the type to int


@numba.njit()
def encode_numba_int(latitude, longitude, precision=12):
    lat_interval = (-90.0, 90.0)
    lon_interval = (-180.0, 180.0)
    geohash = np.zeros(precision, dtype='int')  #  <-- CHANGE HERE
    bits = np.array([16, 8, 4, 2, 1])
    bit = 0
    ch = 0
    n = 0
    even = True
    while n < precision:
        if even:
            mid = (lon_interval[0] + lon_interval[1]) / 2
            if longitude > mid:
                ch |= bits[bit]
                lon_interval = (mid, lon_interval[1])
            else:
                lon_interval = (lon_interval[0], mid)
        else:
            mid = (lat_interval[0] + lat_interval[1]) / 2
            if latitude > mid:
                ch |= bits[bit]
                lat_interval = (mid, lat_interval[1])
            else:
                lat_interval = (lat_interval[0], mid)
        even = not even

        if bit < 4:
            bit += 1
        else: 
            geohash[n] = ch  #  <-- CHANGE HERE
            bit = 0
            ch = 0
            n += 1

    return geohash  #  <-- CHANGE HERE

This results in another speedup (we're down to about 1/25 of the original time)

%timeit encode_numba_int(50, 14)
# 641 ns ± 49.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

One could of course change the code so that it doesn't even attempt to create an array with base-32 representation and merely constructs a single int.

If the authors are interested in this direction, I'd be happy to start a PR (after specifying the resulting API a bit). I'd also be happy for more feedback on Numba, as I'm hardly an expert.

DahnJ commented 3 years ago

There also appears to be some recent effort to refactor and speed up PyGeohash here: https://github.com/tastatham/gsoc_dask_geopandas_2021/issues/2

IlyasMoutawwakil commented 3 years ago

I've tried doing about the same with the decode function. It's a bit complicated since it uses dictionaries. I tried numba dictionaries but for some reason they only have setters and you can't get an item from them. So finally I made use of the ord builtin function since it's implimented in numba and by doing so you don't even need access to the global variable __base32, some modifications didn't add much performance but I kept them anyway. My numba_decode function is the following:

@njit('int8(unicode_type)')
def base32_to_int(s):
    res = ord(s) - 48
    if res>9: res-=40
    if res>16: res-=1
    if res>18: res-=1
    if res>20: res-=1
    return res

@njit('UniTuple(float64, 4)(unicode_type)')
def numba_decode_exactly(geohash):
    lat_interval_neg, lat_interval_pos, lon_interval_neg, lon_interval_pos = -90, 90, -180, 180
    lat_err, lon_err = 90, 180
    is_even = True
    for c in geohash:
        cd=base32_to_int(c)
        for mask in (16, 8, 4, 2, 1):
            if is_even:  # adds longitude info
                lon_err /= 2
                if cd & mask:
                    lon_interval_neg = (lon_interval_neg + lon_interval_pos) / 2
                else:
                    lon_interval_pos = (lon_interval_neg + lon_interval_pos) / 2
            else:  # adds latitude info
                lat_err /= 2
                if cd & mask:
                    lat_interval_neg = (lat_interval_neg + lat_interval_pos) / 2
                else:
                    lat_interval_pos = (lat_interval_neg + lat_interval_pos) / 2
            is_even = not is_even
    lat = (lat_interval_neg + lat_interval_pos) / 2
    lon = (lon_interval_neg + lon_interval_pos) / 2
    return lat, lon, lat_err, lon_err

@njit('UniTuple(float64, 2)(unicode_type)')
def numba_decode(geohash):
    """
    Decode geohash, returning two float with latitude and longitude
    containing only relevant digits and with trailing zeroes removed.
    """

    lat, lon, lat_err, lon_err = numba_decode_exactly(geohash)
    # Format to the number of decimals that are known
    lat_prec = max(1, int(round(-log10(lat_err))))
    lon_prec = max(1, int(round(-log10(lon_err))))
    lat = round(lat, lat_prec)
    lon = round(lon, lon_prec)

    return lat, lon

I also modified the way precision is casted into strings then into floats. so that I could find a way to vectorize it after.

The improvement in performance is relative to the precision/size of the geohashes so I fixed it to 12:

geohash = ''.join(random.sample(__base32, 12))
%%timeit
decode(geohash)
# 19.9 µs ± 935 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%%timeit
numba_decode(geohash)
# 4.37 µs ± 110 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
IlyasMoutawwakil commented 3 years ago

update: I've made some modifications to speed processing of geohash arrays (an array with n geohashes) and pretty promissing (x10 speedup compared to a numpy vectorization). I don't know if the owners are still alive but I would love to start a PR.

wdm0006 commented 3 years ago

Would be happy to review a PR if you're still looking at this.

IlyasMoutawwakil commented 2 years ago

@wdm0006 check the code structure and performance gain on this repo , if it looks worth adding to this package I would submit a PR. Also if you have any ideas how to fully vectorize computations (with some matrix or tensor operations) I would love to implement it.

wdm0006 commented 2 years ago

I think it's worth doing as an optional path, like if installed as pygeohash[numba] or something like that. The dependencies of the existing library are very light so for many non-performance use cases would probably be preferred. Happy to review a PR with that caveat.