Open IlyasMoutawwakil opened 3 years ago
I was wondering about the same, so tried to quickly write encode
in Numba
import numpy as np
import numba
__base32 = '0123456789bcdefghjkmnpqrstuvwxyz'
@numba.njit()
def encode_numba(latitude, longitude):
precision = 12
lat_interval = (-90.0, 90.0)
lon_interval = (-180.0, 180.0)
geohash = np.zeros(precision, dtype='<U1')
bits = np.array([16, 8, 4, 2, 1])
bit = 0
ch = 0
n = 0
even = True
while n < precision:
if even:
mid = (lon_interval[0] + lon_interval[1]) / 2
if longitude > mid:
ch |= bits[bit]
lon_interval = (mid, lon_interval[1])
else:
lon_interval = (lon_interval[0], mid)
else:
mid = (lat_interval[0] + lat_interval[1]) / 2
if latitude > mid:
ch |= bits[bit]
lat_interval = (mid, lat_interval[1])
else:
lat_interval = (lat_interval[0], mid)
even = not even
if bit < 4:
bit += 1
else:
geohash[n] = __base32[ch]
bit = 0
ch = 0
n += 1
return ''.join(geohash)
This already provides a significant speedup
from pygeohash import encode
%timeit encode(50, 14)
# 15.3 µs ± 422 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit encode_numba(50, 14)
# 2.81 µs ± 84.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
It would make sense to possibly also use int
representation. As a quick PoC, I just changed the type to int
@numba.njit()
def encode_numba_int(latitude, longitude, precision=12):
lat_interval = (-90.0, 90.0)
lon_interval = (-180.0, 180.0)
geohash = np.zeros(precision, dtype='int') # <-- CHANGE HERE
bits = np.array([16, 8, 4, 2, 1])
bit = 0
ch = 0
n = 0
even = True
while n < precision:
if even:
mid = (lon_interval[0] + lon_interval[1]) / 2
if longitude > mid:
ch |= bits[bit]
lon_interval = (mid, lon_interval[1])
else:
lon_interval = (lon_interval[0], mid)
else:
mid = (lat_interval[0] + lat_interval[1]) / 2
if latitude > mid:
ch |= bits[bit]
lat_interval = (mid, lat_interval[1])
else:
lat_interval = (lat_interval[0], mid)
even = not even
if bit < 4:
bit += 1
else:
geohash[n] = ch # <-- CHANGE HERE
bit = 0
ch = 0
n += 1
return geohash # <-- CHANGE HERE
This results in another speedup (we're down to about 1/25 of the original time)
%timeit encode_numba_int(50, 14)
# 641 ns ± 49.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
One could of course change the code so that it doesn't even attempt to create an array with base-32 representation and merely constructs a single int.
If the authors are interested in this direction, I'd be happy to start a PR (after specifying the resulting API a bit). I'd also be happy for more feedback on Numba, as I'm hardly an expert.
There also appears to be some recent effort to refactor and speed up PyGeohash here: https://github.com/tastatham/gsoc_dask_geopandas_2021/issues/2
I've tried doing about the same with the decode function. It's a bit complicated since it uses dictionaries. I tried numba dictionaries but for some reason they only have setters and you can't get an item from them. So finally I made use of the ord
builtin function since it's implimented in numba and by doing so you don't even need access to the global variable __base32
, some modifications didn't add much performance but I kept them anyway. My numba_decode
function is the following:
@njit('int8(unicode_type)')
def base32_to_int(s):
res = ord(s) - 48
if res>9: res-=40
if res>16: res-=1
if res>18: res-=1
if res>20: res-=1
return res
@njit('UniTuple(float64, 4)(unicode_type)')
def numba_decode_exactly(geohash):
lat_interval_neg, lat_interval_pos, lon_interval_neg, lon_interval_pos = -90, 90, -180, 180
lat_err, lon_err = 90, 180
is_even = True
for c in geohash:
cd=base32_to_int(c)
for mask in (16, 8, 4, 2, 1):
if is_even: # adds longitude info
lon_err /= 2
if cd & mask:
lon_interval_neg = (lon_interval_neg + lon_interval_pos) / 2
else:
lon_interval_pos = (lon_interval_neg + lon_interval_pos) / 2
else: # adds latitude info
lat_err /= 2
if cd & mask:
lat_interval_neg = (lat_interval_neg + lat_interval_pos) / 2
else:
lat_interval_pos = (lat_interval_neg + lat_interval_pos) / 2
is_even = not is_even
lat = (lat_interval_neg + lat_interval_pos) / 2
lon = (lon_interval_neg + lon_interval_pos) / 2
return lat, lon, lat_err, lon_err
@njit('UniTuple(float64, 2)(unicode_type)')
def numba_decode(geohash):
"""
Decode geohash, returning two float with latitude and longitude
containing only relevant digits and with trailing zeroes removed.
"""
lat, lon, lat_err, lon_err = numba_decode_exactly(geohash)
# Format to the number of decimals that are known
lat_prec = max(1, int(round(-log10(lat_err))))
lon_prec = max(1, int(round(-log10(lon_err))))
lat = round(lat, lat_prec)
lon = round(lon, lon_prec)
return lat, lon
I also modified the way precision is casted into strings then into floats. so that I could find a way to vectorize it after.
The improvement in performance is relative to the precision/size of the geohashes so I fixed it to 12:
geohash = ''.join(random.sample(__base32, 12))
%%timeit
decode(geohash)
# 19.9 µs ± 935 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%%timeit
numba_decode(geohash)
# 4.37 µs ± 110 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
update: I've made some modifications to speed processing of geohash arrays (an array with n geohashes) and pretty promissing (x10 speedup compared to a numpy vectorization). I don't know if the owners are still alive but I would love to start a PR.
Would be happy to review a PR if you're still looking at this.
@wdm0006 check the code structure and performance gain on this repo , if it looks worth adding to this package I would submit a PR. Also if you have any ideas how to fully vectorize computations (with some matrix or tensor operations) I would love to implement it.
I think it's worth doing as an optional path, like if installed as pygeohash[numba]
or something like that. The dependencies of the existing library are very light so for many non-performance use cases would probably be preferred. Happy to review a PR with that caveat.
is there any way to optimize the code for faster decoding in particular on multiple geohashes?