import xarray
import pandas
import mapclassify
from libpysal.weights import raster, Queen
import matplotlib.pyplot as plt
url = "https://github.com/darribas/ectqg19-workshop/raw/master/data/ghsl.tiff"
da = xarray.open_rasterio(url)
da
da.plot()
<matplotlib.collections.QuadMesh at 0x7f3352c2f950>
Now the data is highly skewed:
da.plot.hist(bins=100);
It'd be much better if we could use any choropleth scheme approach (as those in mapclassify
to better convey the distribution of values. Let's give it a try!
To make the plot, we will need to bring the data from xarray
to a data structure that PySAL understands, then use mapclassify
for the classification task, and then return the output into a xarray
data structure that we can easily and efficiently plot.
Let's build a spatial weights matrix from the array to map missing values and get the right order of observations in space:
%time w = Queen.from_xarray(da)
CPU times: user 9.27 s, sys: 159 ms, total: 9.43 s Wall time: 9.44 s
%time w = Queen.from_xarray(da, sparse=True)
CPU times: user 57.2 ms, sys: 4.21 ms, total: 61.4 ms Wall time: 60.3 ms
Now we can bring the values from xarray
to pandas
in a way that is aligned with their location:
%%time
vals = da.to_series()\
.reindex(w.index)
vals.head()
CPU times: user 197 ms, sys: 20.1 ms, total: 217 ms Wall time: 216 ms
band y x 1 5892375.0 436875.0 0.0 437125.0 0.0 437375.0 0.0 437625.0 0.0 437875.0 0.0 dtype: float32
For the sake of the illustration, we will do a Fisher-Jenks classification:
%%time
fj = mapclassify.FisherJenksSampled(vals.values, k=7)
labels = pandas.Series(fj.yb, index=vals.index)
fj
CPU times: user 312 ms, sys: 4.11 ms, total: 316 ms Wall time: 313 ms
FisherJenksSampled Interval Count ------------------------ [ 0.00, 18.36] | 65704 ( 18.36, 55.00] | 4968 ( 55.00, 98.69] | 2356 ( 98.69, 152.89] | 1359 (152.89, 219.62] | 692 (219.62, 287.44] | 262 (287.44, 588.57] | 119
xarray
¶With the labels at hand, we can then "put" them on a DataArray
again:
%time da_fj = raster.wsp2da(labels, w)
CPU times: user 11.5 ms, sys: 84 µs, total: 11.6 ms Wall time: 9.25 ms
da_fj.where(da_fj!=da_fj.attrs["nodatavals"]).plot(cmap="viridis")
<matplotlib.collections.QuadMesh at 0x7f8bec290d90>
Now let's streamline the process and compare the original xarray
classification (equal interval) with the Fisher-Jenks and an additional one using the StdMean
classifier, which bins the data by how many standard deviations they are apart from the mean.
%%time
classi = mapclassify.StdMean(vals.values)
labels_classi = pandas.Series(classi.yb, index=vals.index)
da2plot = raster.wsp2da(labels_classi, w)
classi
CPU times: user 219 ms, sys: 3.91 ms, total: 223 ms Wall time: 223 ms
StdMean Interval Count ------------------------ ( -inf, -60.28] | 0 (-60.28, -24.80] | 0 (-24.80, 46.16] | 69913 ( 46.16, 81.64] | 2396 ( 81.64, 588.57] | 3151
The classification table above essentially says that 69,913 values are within one standard deviation of the mean, no values are below that (sensible given population counts cannot be negative), and there are two groups of outliers within two (2,396) and three (3,151) standard deviations of the mean.
da2plot.where(da2plot!=da2plot.attrs["nodatavals"])\
.plot(cmap="viridis", add_colorbar=False)
<matplotlib.collections.QuadMesh at 0x7f8bf4433790>
Now we see how it works, let's take it up for larger raster. For example the NASA DEM for San Diego, which contains elevation for more than 19 million cells.
url = "https://geographicdata.science/book/data/nasadem/nasadem_sd.tif"
da = xarray.open_rasterio(url).sel(band=1)
da
<xarray.DataArray (y: 3515, x: 5510)> [19367650 values with dtype=int16] Coordinates: band int64 1 * y (y) float64 33.51 33.5 33.5 33.5 33.5 ... 32.53 32.53 32.53 32.53 * x (x) float64 -117.6 -117.6 -117.6 -117.6 ... -116.1 -116.1 -116.1 Attributes: transform: (0.0002777777777777778, 0.0, -117.61125, 0.0, -0.00027777... crs: +init=epsg:4326 res: (0.0002777777777777778, 0.0002777777777777778) is_tiled: 0 nodatavals: (-32768.0,) scales: (1.0,) offsets: (0.0,) AREA_OR_POINT: Area
[19367650 values with dtype=int16]
array(1)
array([33.505 , 33.504722, 33.504444, ..., 32.529444, 32.529167, 32.528889])
array([-117.611111, -117.610833, -117.610556, ..., -116.081389, -116.081111, -116.080833])
A straight plot will look like this (note we use imshow
for better performance):
%time da.where(da!=da.attrs["nodatavals"]).plot.imshow();
CPU times: user 847 ms, sys: 1.27 s, total: 2.11 s Wall time: 2.13 s
<matplotlib.image.AxesImage at 0x7f8bef99bbd0>
Now let's build the spatial weights matrix and classify it based on, for example, (sampled) Fisher-Jenks:
%time w = Queen.from_xarray(da, sparse=True)
%time vals = da.to_series()
%time vals = vals.reindex(w.index)
CPU times: user 17.8 s, sys: 3.1 s, total: 20.9 s Wall time: 46.2 s CPU times: user 273 ms, sys: 51.5 ms, total: 325 ms Wall time: 322 ms CPU times: user 56.1 s, sys: 5.78 s, total: 1min 1s Wall time: 1min 1s
%time fj = mapclassify.FisherJenksSampled(vals.values, k=7)
%time labels_fj = pandas.Series(fj.yb, index=vals.index)
%time da_fj = raster.wsp2da(labels_fj, w)
CPU times: user 48.2 s, sys: 1.25 s, total: 49.4 s Wall time: 49.3 s CPU times: user 396 µs, sys: 38 µs, total: 434 µs Wall time: 448 µs CPU times: user 516 ms, sys: 344 ms, total: 860 ms Wall time: 859 ms
da2plot.plot.imshow()
<matplotlib.image.AxesImage at 0x7fe9eb2eae90>
%time quantiles = mapclassify.Quantiles(vals.values, k=7)
%time labels_quantiles = pandas.Series(quantiles.yb, index=vals.index)
%time da_quantiles = raster.wsp2da(labels_quantiles, w)
CPU times: user 37.7 s, sys: 1.23 s, total: 38.9 s Wall time: 38.9 s CPU times: user 414 µs, sys: 0 ns, total: 414 µs Wall time: 428 µs CPU times: user 463 ms, sys: 367 ms, total: 830 ms Wall time: 825 ms
da_quantiles.plot.imshow()
<matplotlib.image.AxesImage at 0x7fe9eb0f6b50>
f, axs = plt.subplots(1, 3, figsize=(16, 5))
ax = axs[0]
da.where(da!=da.attrs["nodatavals"]).plot.imshow(ax=ax, add_colorbar=False)
ax.set_title("Equal Interval")
ax = axs[1]
da_fj.plot.imshow(ax=ax, add_colorbar=False)
ax.set_title("Fisher-Jenks")
ax = axs[2]
da_quantiles.plot.imshow(ax=ax, add_colorbar=False)
ax.set_title("Quantiles")
plt.show()