import numpy as np
import matplotlib.pyplot as plt
import pys2index
from sklearn.neighbors import BallTree
from pyinterp import RTree
def generate_points(shape):
lat = np.random.uniform(-90, 90, size=shape)
lon = np.random.uniform(-180, 180, size=shape)
return np.stack([np.ravel(arr) for arr in [lat, lon]]).T
idx_points = generate_points(10000)
s2index = pys2index.S2PointIndex(idx_points)
btindex = BallTree(np.deg2rad(idx_points), metric='haversine')
# use point index as values
rtindex = RTree()
rtindex.insert(coordinates=np.flip(idx_points, axis=-1),
values=np.arange(10000))
query_points = generate_points(100)
res_s2 = s2index.query(query_points)
res_bt = btindex.query(np.deg2rad(query_points), return_distance=False)
_, values = rtindex.query(coordinates=np.flip(query_points, axis=-1), k=1, within=False)
res_rt = values[:, 0].astype('int')
Test pys2index.S2PointIndex
query results against sklearn.neighbors.BallTree
query results and against pyinterp.RTree
query results
np.all((res_s2 == res_bt.flatten()) & (res_s2 == res_rt))
True
Show query results
color = idx_points[:, 0] * idx_points[:, 1]
fig, ax = plt.subplots(figsize=(8, 8))
ax.scatter(idx_points[:, 1], idx_points[:, 0], c=color, cmap=plt.cm.RdBu, alpha=0.5)
ax.scatter(query_points[:, 1], query_points[:, 0], c=color[res_s2], cmap=plt.cm.RdBu, alpha=0.5);
def bm_build_index(npoints):
idx_points_bm = generate_points(npoints)
idx_points_bm_rad = np.deg2rad(idx_points_bm)
idx_points_bm_flip = np.flip(idx_points_bm, axis=-1)
values = np.arange(npoints)
res_s2 = %timeit -o pys2index.S2PointIndex(idx_points_bm)
res_bt = %timeit -o BallTree(idx_points_bm_rad, metric='haversine')
rtindex = RTree()
res_rt = %timeit -o rtindex.insert(coordinates=idx_points_bm_flip, values=values)
return res_s2.best, res_bt.best, res_rt.best
idx_npoints_range = [1_000, 10_000, 100_000, 1_000_000, 10_000_000]
res = [bm_build_index(npoints) for npoints in idx_npoints_range]
505 µs ± 7.19 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 422 µs ± 7.46 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 2.68 ms ± 91.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each) 5.38 ms ± 70.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 6.23 ms ± 73.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 30.1 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) 61.1 ms ± 890 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 103 ms ± 810 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 231 ms ± 30.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 924 ms ± 15.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 2.08 s ± 83.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 3.3 s ± 182 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 13.4 s ± 857 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 31.1 s ± 473 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 41.5 s ± 879 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
res_s2, res_bt, res_rt = zip(*res)
fig, ax = plt.subplots(figsize=(6, 6))
ax.plot(idx_npoints_range, res_s2, label='pys2index.S2PointIndex', marker='+')
ax.plot(idx_npoints_range, res_bt, label='sklearn.neighbors.BallTree', marker='+')
ax.plot(idx_npoints_range, res_rt, label='pyinterp.RTree', marker='+')
ax.set_xlabel("nb. of lat/lon points")
ax.set_ylabel("time (s)")
ax.semilogx()
ax.legend();
fig, ax = plt.subplots(figsize=(6, 6))
ax.plot(idx_npoints_range, res_s2, label='pys2index.S2PointIndex', marker='+')
ax.plot(idx_npoints_range, res_bt, label='sklearn.neighbors.BallTree', marker='+')
ax.plot(idx_npoints_range, res_rt, label='pyinterp.RTree', marker='+')
ax.set_xlabel("nb. of lat/lon points")
ax.set_ylabel("time (s)")
ax.loglog()
ax.legend();
idx_points_bm = generate_points(1_000_000)
idx_points_bm_rad = np.deg2rad(idx_points_bm)
s2index = pys2index.S2PointIndex(idx_points_bm)
btindex = BallTree(idx_points_bm_rad, metric='haversine')
rtindex = RTree()
rtindex.insert(coordinates=np.flip(idx_points_bm, axis=-1),
values=np.arange(1_000_000))
def bm_query_index(npoints):
query_points_bm = generate_points(npoints)
query_points_bm_rad = np.deg2rad(query_points_bm)
query_points_bm_flip = np.flip(query_points_bm, axis=-1)
res_s2 = %timeit -o s2index.query(query_points_bm)
res_bt = %timeit -o btindex.query(query_points_bm_rad, return_distance=False)
res_rt = %timeit -o rtindex.query(coordinates=query_points_bm_flip, k=1, within=False, num_threads=1)
return res_s2.best, res_bt.best, res_rt.best
query_npoints_range = [1_000, 10_000, 100_000]
res_query = [bm_query_index(npoints) for npoints in query_npoints_range]
9.14 ms ± 394 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 377 ms ± 36.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 7.07 ms ± 204 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) 99.7 ms ± 1.69 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 3.41 s ± 490 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 70.9 ms ± 646 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 963 ms ± 5.67 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 41.1 s ± 635 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 777 ms ± 20.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
res_s2, res_bt, res_rt = zip(*res_query)
fig, ax = plt.subplots(figsize=(6, 6))
ax.plot(query_npoints_range, res_s2, label='pys2index.S2PointIndex', marker='+')
ax.plot(query_npoints_range, res_bt, label='sklearn.neighbors.BallTree', marker='+')
ax.plot(query_npoints_range, res_rt, label='pyinterp.RTree', marker='+')
ax.set_xlabel("nb. of lat/lon query points")
ax.set_ylabel("time (s)")
ax.semilogx()
ax.legend();
fig, ax = plt.subplots(figsize=(6, 6))
ax.plot(query_npoints_range, res_s2, label='pys2index.S2PointIndex', marker='+')
ax.plot(query_npoints_range, res_bt, label='sklearn.neighbors.BallTree', marker='+')
ax.plot(query_npoints_range, res_rt, label='pyinterp.RTree', marker='+')
ax.set_xlabel("nb. of lat/lon query points")
ax.set_ylabel("time (s)")
ax.loglog()
ax.legend();