https://opensourceoptions.com/blog/python-geographic-object-based-image-analysis-geobia/
This notebook needs at least 16GB RAM to run
import numpy as np
from osgeo import gdal
from skimage import exposure
from skimage.segmentation import quickshift, slic
import time
import scipy
import fsspec
naip_fn = 'https://mghp.osn.xsede.org/rsignellbucket1/obia/m_4107027_se_19_060_20210904.tif'
fs_https = fsspec.filesystem('https')
fs_https.info(naip_fn)
{'name': 'https://mghp.osn.xsede.org/rsignellbucket1/obia/m_4107027_se_19_060_20210904.tif', 'size': 505021950, 'ETag': '"ba8c8e2c8e6aed3943214a5fa641552d-61"', 'type': 'file'}
%%time
fs_https.download(naip_fn, 'naip.tif')
CPU times: user 1.29 s, sys: 806 ms, total: 2.1 s Wall time: 15.1 s
[None]
naip_fn = 'naip.tif'
%%time
driverTiff = gdal.GetDriverByName('GTiff')
naip_ds = gdal.Open(naip_fn)
nbands = naip_ds.RasterCount
band_data = []
print('bands', naip_ds.RasterCount, 'rows', naip_ds.RasterYSize, 'columns',
naip_ds.RasterXSize)
for i in range(1, nbands+1):
band = naip_ds.GetRasterBand(i).ReadAsArray()
band_data.append(band)
band_data = np.dstack(band_data)
bands 4 rows 12823 columns 9844 CPU times: user 442 ms, sys: 580 ms, total: 1.02 s Wall time: 2.17 s
img = exposure.rescale_intensity(band_data)
%%time
segments = slic(img, n_segments=500000, compactness=0.1)
CPU times: user 2min 53s, sys: 3.62 s, total: 2min 57s Wall time: 2min 57s
Now we need to describe each segment based on it’s spectral properties because the spectral properties are the variables that will classify each segment as a land cover type.
First of all, write a function that, given an array of pixel values, will calculate the min, max, mean, variance, skewness, and kurtosis for each band. The code below takes all the pixels in a segment and calculates statistics for each band, saving them in the features variable, which is returned. Next, get the pixel data and save the returned features. I describe this process below.
def segment_features(segment_pixels):
features = []
npixels, nbands = segment_pixels.shape
for b in range(nbands):
stats = scipy.stats.describe(segment_pixels[:, b])
band_stats = list(stats.minmax) + list(stats)[2:]
if npixels == 1:
# in this case the variance = nan, change it 0.0
band_stats[3] = 0.0
features += band_stats
return features
In object-based image analysis each segment represents an object. Objects represent buildings, roads, trees, fields or pieces of those features, depending on how the segmentation is done.
The code below gets a list of the segment ID numbers. Then sets up a list for the statistics describing each object (i.e. segment ID) returned from the segment_features function (above). The pixels for each segment are identified and passed to segment_features, which returns the statistics describing the spectral properties of the segment/object.
Statistics are saved to the objects list and the object_id is stored in a separate list.
%%time
segment_ids = np.unique(segments)
print(len(segment_ids))
375291 CPU times: user 4.21 s, sys: 244 ms, total: 4.45 s Wall time: 4.45 s
The following cell loops over each segment. It seems to only require 2.5GB RAM
%%time
objects = []
object_ids = []
for id in segment_ids[:80]: # test with a few segments
segment_pixels = img[segments == id]
object_features = segment_features(segment_pixels)
objects.append(object_features)
object_ids.append(id)
CPU times: user 26.1 s, sys: 1.67 s, total: 27.7 s Wall time: 27.7 s
Okay, let's parallelize!
First we define a function that takes the id as input and returns the stats for that segment:
def get_features(id):
segment_pixels = img[segments == id]
return segment_features(segment_pixels)
Verify it takes the same amount of time and gives the same results:
%%time
objects2 = [get_features(id) for id in segment_ids[:80]]
CPU times: user 26 s, sys: 1.75 s, total: 27.7 s Wall time: 27.7 s
assert objects == objects2
Use Dask Bag to parallelize loop
from dask.distributed import Client
client = Client(n_workers=4, threads_per_worker=1)
coiled.analytics.computation.interval is set to '10m'. Ignoring this old default value, using '15s' instead. To override, use any value other than '10m'.
client
Client-d75f07df-b876-11ed-a03a-ba4f5d667a1f
Connection method: Cluster object | Cluster type: distributed.LocalCluster |
Dashboard: http://127.0.0.1:8787/status |
dac2423a
Dashboard: http://127.0.0.1:8787/status | Workers: 4 |
Total threads: 4 | Total memory: 30.91 GiB |
Status: running | Using processes: True |
Scheduler-c08cecfe-3739-4626-86f8-2d58d68624d2
Comm: tcp://127.0.0.1:38875 | Workers: 4 |
Dashboard: http://127.0.0.1:8787/status | Total threads: 4 |
Started: Just now | Total memory: 30.91 GiB |
Comm: tcp://127.0.0.1:37529 | Total threads: 1 |
Dashboard: http://127.0.0.1:41235/status | Memory: 7.73 GiB |
Nanny: tcp://127.0.0.1:43073 | |
Local directory: /tmp/dask-worker-space/worker-yrr4h01v |
Comm: tcp://127.0.0.1:41711 | Total threads: 1 |
Dashboard: http://127.0.0.1:41201/status | Memory: 7.73 GiB |
Nanny: tcp://127.0.0.1:45517 | |
Local directory: /tmp/dask-worker-space/worker-r3fh4puo |
Comm: tcp://127.0.0.1:43525 | Total threads: 1 |
Dashboard: http://127.0.0.1:43753/status | Memory: 7.73 GiB |
Nanny: tcp://127.0.0.1:45767 | |
Local directory: /tmp/dask-worker-space/worker-ujln9gls |
Comm: tcp://127.0.0.1:40001 | Total threads: 1 |
Dashboard: http://127.0.0.1:45131/status | Memory: 7.73 GiB |
Nanny: tcp://127.0.0.1:33087 | |
Local directory: /tmp/dask-worker-space/worker-q6q7bb69 |
client.cluster.workers
{0: <Nanny: tcp://127.0.0.1:37529, threads: 1>, 1: <Nanny: tcp://127.0.0.1:41711, threads: 1>, 2: <Nanny: tcp://127.0.0.1:43525, threads: 1>, 3: <Nanny: tcp://127.0.0.1:40001, threads: 1>}
import dask.bag as db
%%time
b = db.from_sequence(segment_ids[:80], npartitions=4)
b1 = b.map(get_features).compute()
CPU times: user 21.7 s, sys: 9.69 s, total: 31.4 s Wall time: 40.6 s
%%time
scattered_img = client.scatter(img, broadcast=True)
scattered_segments = client.scatter(segments, broadcast=True)
def get_features(id, img=scattered_img, segments=scattered_segments):
segment_pixels = img[segments == id]
return segment_features(segment_pixels)
%%time
b1 = b.map(get_features).compute()