import sys
import numpy as np
from scipy import sparse
from matplotlib import pyplot as plt
import pygsp as pg
sys.path.insert(0, '..')
from modules import remap
# Run cdo from conda environment from a jupyter that is not in the environment.
# import os, subprocess
# cdo = os.path.join(sys.exec_prefix, 'bin/cdo')
# p = subprocess.run([cdo, '-V'], stderr=subprocess.PIPE)
# print(p.stderr)
# Graphs to remap to (pool and unpool to).
graph1 = pg.graphs.SphereHealpix(subdivisions=2, nest=True, k=4, kernel_width=None)
graph2 = pg.graphs.SphereHealpix(subdivisions=1, nest=True, k=4, kernel_width=None)
Update SphericalVoronoiMesh_from_pygsp
(in remap.py
) with the below.
graph = graph2
radius = 1
def xyz2lonlat(x,y,z, radius=6371.0e6):
"""From cartesian geocentric coordinates to 2D geographic coordinates."""
latitude = np.arcsin(z / radius)/np.pi*180
longitude = np.arctan2(y, x)/np.pi*180
return longitude, latitude
# Hack to get HEALPix true vertices (quadrilateral polygons).
import healpy as hp
npix = graph.n_vertices
nside = np.sqrt(npix/12)
step = 8 # number of vertices per edge (edges are not geodesics)
vertices = hp.boundaries(nside, range(npix), nest=graph.nest, step=step)
assert vertices.shape == (npix, 3, 4*step)
list_polygons_lonlat = []
for tmp_xyz in vertices:
tmp_lon, tmp_lat = xyz2lonlat(tmp_xyz[0],tmp_xyz[1],tmp_xyz[2], radius=radius)
list_polygons_lonlat.append(np.column_stack((tmp_lon, tmp_lat)))
# HEALPix vertices are ordered counter-clockwise.
vertex = vertices[7]
lat, lon = pg.utils.xyz2latlon(vertex[0], vertex[1], vertex[2])
plt.scatter(lon, lat)
plt.xlim(0, 2*np.pi)
plt.ylim(-np.pi/2, np.pi/2)
for i, (lon_i, lat_i) in enumerate(zip(lon, lat)):
plt.text(lon_i, lat_i, i)
remap.get_available_interp_methods()
['nearest_neighbors', 'idw', 'bilinear', 'bicubic', 'conservative', 'conservative_SCRIP', 'conservative2', 'largest_area_fraction']
remap.compute_interpolation_weights(graph1, graph2, method='conservative', normalization='fracarea') # destarea’
<xarray.Dataset> Dimensions: (dst_grid_corners: 6, dst_grid_rank: 1, dst_grid_size: 12, num_links: 80, num_wgts: 1, src_grid_corners: 7, src_grid_rank: 1, src_grid_size: 48) Dimensions without coordinates: dst_grid_corners, dst_grid_rank, dst_grid_size, num_links, num_wgts, src_grid_corners, src_grid_rank, src_grid_size Data variables: src_grid_dims (src_grid_rank) int32 48 dst_grid_dims (dst_grid_rank) int32 12 src_grid_center_lat (src_grid_size) float64 0.3398 0.7297 ... -0.3398 dst_grid_center_lat (dst_grid_size) float64 0.7297 0.7297 ... -0.7297 src_grid_center_lon (src_grid_size) float64 0.7854 1.178 ... 5.105 5.498 dst_grid_center_lon (dst_grid_size) float64 0.7854 2.356 ... 3.927 5.498 src_grid_corner_lat (src_grid_size, src_grid_corners) float64 0.0 ... -0... src_grid_corner_lon (src_grid_size, src_grid_corners) float64 0.8348 ...... dst_grid_corner_lat (dst_grid_size, dst_grid_corners) float64 -1.732e-16... dst_grid_corner_lon (dst_grid_size, dst_grid_corners) float64 0.8394 ...... src_grid_imask (src_grid_size) int32 1 1 1 1 1 1 1 1 ... 1 1 1 1 1 1 1 dst_grid_imask (dst_grid_size) int32 1 1 1 1 1 1 1 1 1 1 1 1 src_grid_area (src_grid_size) float64 0.2592 0.2582 ... 0.2582 0.2592 dst_grid_area (dst_grid_size) float64 1.085 1.085 ... 1.085 1.085 src_grid_frac (src_grid_size) float64 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 dst_grid_frac (dst_grid_size) float64 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 src_address (num_links) int32 1 2 3 4 18 20 ... 29 30 45 46 47 48 dst_address (num_links) int32 1 1 1 1 1 1 1 ... 12 12 12 12 12 12 remap_matrix (num_links, num_wgts) float64 0.2389 0.238 ... 0.2389 Attributes: title: SCRIP remapping with CDO normalization: fracarea map_method: Conservative remapping using clipping on sphere conventions: SCRIP source_grid: unstructured dest_grid: unstructured history: 15 Mar 2021 : cdo -b 64 genycon,/tmp/dst_CDO_grid_ypnubqk... CDO: Climate Data Operators version 1.9.9 (https://mpimet.mpg....
array([48], dtype=int32)
array([12], dtype=int32)
array([ 0.339837, 0.729728, 0.729728, 1.159658, 0.339837, 0.729728, 0.729728, 1.159658, 0.339837, 0.729728, 0.729728, 1.159658, 0.339837, 0.729728, 0.729728, 1.159658, -0.339837, 0. , 0. , 0.339837, -0.339837, 0. , 0. , 0.339837, -0.339837, 0. , 0. , 0.339837, -0.339837, 0. , 0. , 0.339837, -1.159658, -0.729728, -0.729728, -0.339837, -1.159658, -0.729728, -0.729728, -0.339837, -1.159658, -0.729728, -0.729728, -0.339837, -1.159658, -0.729728, -0.729728, -0.339837])
array([ 0.729728, 0.729728, 0.729728, 0.729728, 0. , 0. , 0. , 0. , -0.729728, -0.729728, -0.729728, -0.729728])
array([0.785398, 1.178097, 0.392699, 0.785398, 2.356194, 2.748894, 1.963495, 2.356194, 3.926991, 4.31969 , 3.534292, 3.926991, 5.497787, 5.890486, 5.105088, 5.497787, 0. , 0.392699, 5.890486, 0. , 1.570796, 1.963495, 1.178097, 1.570796, 3.141593, 3.534292, 2.748894, 3.141593, 4.712389, 5.105088, 4.31969 , 4.712389, 0.785398, 1.178097, 0.392699, 0.785398, 2.356194, 2.748894, 1.963495, 2.356194, 3.926991, 4.31969 , 3.534292, 3.926991, 5.497787, 5.890486, 5.105088, 5.497787])
array([0.785398, 2.356194, 3.926991, 5.497787, 0. , 1.570796, 3.141593, 4.712389, 0.785398, 2.356194, 3.926991, 5.497787])
array([[ 0. , 0.36486 , 0.651497, ..., 0. , 0. , 0. ], [ 0.36486 , 0.36486 , 0.651497, ..., 0.857571, 0.651497, 0.651497], [ 0.651497, 1.018891, 1.018891, ..., 0.651497, 0.36486 , 0.36486 ], ..., [-0.651497, -1.018891, -0.857571, ..., -0.36486 , -0.36486 , -0.36486 ], [-0.651497, -0.36486 , -0.36486 , ..., -1.018891, -0.857571, -0.857571], [-0.36486 , 0. , 0. , ..., -0.651497, -0.651497, -0.651497]])
array([[0.834823, 1.173563, 0.785398, ..., 0.735973, 0.735973, 0.735973], [1.173563, 1.182632, 1.570796, ..., 0.785398, 0.785398, 0.785398], [6.283185, 6.283185, 6.283185, ..., 0.785398, 0.397234, 0.388165], ..., [6.283185, 6.283185, 5.497787, ..., 5.885952, 5.895021, 5.895021], [5.497787, 5.109622, 5.100554, ..., 4.712389, 5.497787, 5.497787], [5.885952, 5.547212, 5.448362, ..., 5.497787, 5.497787, 5.497787]])
array([[-1.731973e-16, 6.170272e-01, 1.570796e+00, 6.170272e-01, -1.795578e-16, -1.795578e-16], [ 6.170272e-01, 1.570796e+00, 1.570796e+00, 6.170272e-01, -1.731973e-16, -1.795578e-16], [-1.731973e-16, -1.795578e-16, 6.170272e-01, 1.570796e+00, 6.170272e-01, 6.170272e-01], [ 6.170272e-01, 1.570796e+00, 1.570796e+00, 6.170272e-01, -1.175850e-16, -1.175850e-16], [-1.795578e-16, 6.170272e-01, -1.175850e-16, -6.170272e-01, -6.170272e-01, -6.170272e-01], [-1.731973e-16, 6.170272e-01, -1.795578e-16, -6.170272e-01, -6.170272e-01, -6.170272e-01], [-6.170272e-01, -1.731973e-16, 6.170272e-01, -1.731973e-16, -1.731973e-16, -1.731973e-16], [-6.170272e-01, -1.175850e-16, 6.170272e-01, -1.795578e-16, -1.795578e-16, -1.795578e-16], [-1.731973e-16, -1.795578e-16, -6.170272e-01, -1.570796e+00, -6.170272e-01, -6.170272e-01], [-6.170272e-01, -1.795578e-16, -1.731973e-16, -6.170272e-01, -1.570796e+00, -1.570796e+00], [-6.170272e-01, -1.570796e+00, -6.170272e-01, -1.795578e-16, -1.731973e-16, -1.731973e-16], [-6.170272e-01, -1.570796e+00, -1.570796e+00, -6.170272e-01, -1.175850e-16, -1.175850e-16]])
array([[0.839438, 1.570796, 0. , 6.283185, 0.731358, 0.731358], [1.570796, 0. , 0. , 3.141593, 2.410234, 2.302155], [3.872951, 3.981031, 4.712389, 0. , 3.141593, 3.141593], [6.283185, 0. , 0. , 4.712389, 5.443747, 5.551827], [0.731358, 6.283185, 5.551827, 6.283185, 6.283185, 6.283185], [0.839438, 1.570796, 2.302155, 1.570796, 1.570796, 1.570796], [3.141593, 2.410234, 3.141593, 3.872951, 3.872951, 3.872951], [4.712389, 5.443747, 4.712389, 3.981031, 3.981031, 3.981031], [0.839438, 0.731358, 6.283185, 0. , 1.570796, 1.570796], [1.570796, 2.302155, 2.410234, 3.141593, 0. , 0. ], [3.141593, 0. , 4.712389, 3.981031, 3.872951, 3.872951], [6.283185, 0. , 0. , 4.712389, 5.443747, 5.551827]])
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)
array([0.259184, 0.258215, 0.258215, 0.277227, 0.259184, 0.258215, 0.258215, 0.277227, 0.259184, 0.258215, 0.258215, 0.277227, 0.259184, 0.258215, 0.258215, 0.277227, 0.259184, 0.258771, 0.258771, 0.259184, 0.259184, 0.258771, 0.258771, 0.259184, 0.259184, 0.258771, 0.258771, 0.259184, 0.259184, 0.258771, 0.258771, 0.259184, 0.277227, 0.258215, 0.258215, 0.259184, 0.277227, 0.258215, 0.258215, 0.259184, 0.277227, 0.258215, 0.258215, 0.259184, 0.277227, 0.258215, 0.258215, 0.259184])
array([1.085092, 1.085092, 1.085092, 1.085092, 0.971408, 0.971408, 0.971408, 0.971408, 1.085092, 1.085092, 1.085092, 1.085092])
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
array([ 1, 2, 3, 4, 18, 20, 23, 24, 5, 6, 7, 8, 22, 24, 27, 28, 9, 10, 11, 12, 26, 28, 31, 32, 13, 14, 15, 16, 19, 20, 30, 32, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 17, 18, 21, 23, 33, 34, 35, 36, 21, 22, 25, 27, 37, 38, 39, 40, 25, 26, 29, 31, 41, 42, 43, 44, 17, 19, 29, 30, 45, 46, 47, 48], dtype=int32)
array([ 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12], dtype=int32)
array([[0.238859], [0.237966], [0.237966], [0.255487], [0.005867], [0.008994], [0.005867], [0.008994], [0.238859], [0.237966], [0.237966], [0.255487], [0.005867], [0.008994], [0.005867], [0.008994], [0.238859], [0.237966], [0.237966], [0.255487], [0.005867], [0.008994], [0.005867], [0.008994], [0.238859], [0.237966], [0.237966], [0.255487], [0.005867], [0.008994], [0.005867], [0.008994], [0.24672 ], [0.25328 ], [0.25328 ], [0.24672 ], [0.24672 ], [0.25328 ], [0.25328 ], [0.24672 ], [0.24672 ], [0.25328 ], [0.25328 ], [0.24672 ], [0.24672 ], [0.25328 ], [0.25328 ], [0.24672 ], [0.008994], [0.005867], [0.008994], [0.005867], [0.255487], [0.237966], [0.237966], [0.238859], [0.008994], [0.005867], [0.008994], [0.005867], [0.255487], [0.237966], [0.237966], [0.238859], [0.008994], [0.005867], [0.008994], [0.005867], [0.255487], [0.237966], [0.237966], [0.238859], [0.008994], [0.005867], [0.008994], [0.005867], [0.255487], [0.237966], [0.237966], [0.238859]])
def build_interpolation_matrix(src_graph, dst_graph):
"""Return the sparse matrix that interpolates between two spherical samplings."""
ds = remap.compute_interpolation_weights(src_graph, dst_graph, method='conservative', normalization='fracarea') # destarea’
# Sanity checks.
np.testing.assert_allclose(ds.src_grid_center_lat, src_graph.signals['lat'])
np.testing.assert_allclose(ds.src_grid_center_lon, src_graph.signals['lon'])
np.testing.assert_allclose(ds.dst_grid_center_lat, dst_graph.signals['lat'])
np.testing.assert_allclose(ds.dst_grid_center_lon, dst_graph.signals['lon'])
np.testing.assert_allclose(ds.src_grid_frac, 1)
np.testing.assert_allclose(ds.dst_grid_frac, 1)
np.testing.assert_allclose(ds.src_grid_imask, 1)
np.testing.assert_allclose(ds.dst_grid_imask, 1)
col = ds.src_address
row = ds.dst_address
dat = ds.remap_matrix.squeeze()
# CDO indexing starts at 1
row = np.array(row) - 1
col = np.array(col) - 1
weights = sparse.csr_matrix((dat, (row, col)))
assert weights.shape == (dst_graph.n_vertices, src_graph.n_vertices)
# Destination pixels are normalized to 1 (row-sum = 1).
# Weights represent the fractions of area attributed to source pixels.
np.testing.assert_allclose(weights.sum(axis=1), 1)
# Interpolation is conservative: it preserves area.
np.testing.assert_allclose(weights.T @ ds.dst_grid_area, ds.src_grid_area)
# Unnormalize.
weights = weights.multiply(ds.dst_grid_area.values[:, np.newaxis])
# Another way to assert that the interpolation is conservative.
np.testing.assert_allclose(np.asarray(weights.sum(1)).squeeze(), ds.dst_grid_area)
np.testing.assert_allclose(np.asarray(weights.sum(0)).squeeze(), ds.src_grid_area)
return weights
def build_pooling_matrices(weights):
"""Normalize interpolation matrix for pooling and unpooling."""
pool = weights.multiply(1/weights.sum(1))
unpool = weights.multiply(1/weights.sum(0)).T
return pool, unpool
weights = build_interpolation_matrix(graph1, graph2)
pool, unpool = build_pooling_matrices(weights)
# Check normalization.
np.testing.assert_allclose(pool.sum(1), 1)
np.testing.assert_allclose(unpool.sum(1), 1)
def plot_interpolation_matrix(weights):
fig, axes = plt.subplots(1, 4, figsize=(24, 4))
axes[0].hist(weights.data, bins=100)
axes[0].set_title('histogram of overlaping areas')
im = axes[1].imshow(weights.toarray())
fig.colorbar(im, ax=axes[1])
axes[1].set_title('non-normalized interpolation matrix')
def plot_area(area, name, ax):
ax.plot(area, '.')
assert np.allclose(area.mean(), 4*np.pi / len(area))
ax.axhline(area.mean(), ls='--', c='grey')
ax.text(0, area.mean(), 'mean area', c='grey', va='top')
ax.set_title(f'{name} pixel areas')
area_src = weights.sum(0)
area_dst = weights.sum(1)
plot_area(area_src.T, 'source', axes[2])
plot_area(area_dst, 'destination', axes[3])
plot_interpolation_matrix(weights)
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
im = axes[0].imshow(pool.toarray())
fig.colorbar(im, ax=axes[0])
im = axes[1].imshow(unpool.toarray())
fig.colorbar(im, ax=axes[1])
<matplotlib.colorbar.Colorbar at 0x7fc96d902a90>
pool @ unpool = I
when the bipartite interpolation graph is disconnected, i.e., parent vertices have disjoint supports.pool
is the Moore–Penrose inverse of unpool
.unpool @ pool
should be a block-diagonal (averaging over pooled vertices) matrix if all pixels are included in a single parent (and properly ordered).unpool
be the Moore–Penrose inverse of pool
?The two above work for true HEALPix pixels (not the Voronoi cells), with pooling [0.25, 0.25, 0.25, 0.25]
and unpooling [1, 1, 1, 1]
, because that sampling scheme is exactly hierarchical.
Can we use this to evaluate the quality of a coarsening/interpolation or of a (hierarchical) sampling?
1 - np.diag(pool @ unpool)
(as the row-sum is one, that is also the sum of off-diagonal elements).np.sum(1 - np.diag(pool @ unpool)) / npix
is the fraction of averaged/mixed pixel values.def example(weights):
unpool = (weights / weights.sum(0)).T
pool = weights / weights.sum(1)[:, np.newaxis]
print(unpool)
print(pool)
print(pool @ unpool)
print(unpool @ pool)
print('Is invertible:')
example(np.array([
[1, 3, 0, 0],
[0, 0, 3, 1],
]))
print('Is not invertible:')
example(np.array([
[1, 3, 0, 0],
[0, 1, 3, 1],
]))
Is invertible: [[1. 0.] [1. 0.] [0. 1.] [0. 1.]] [[0.25 0.75 0. 0. ] [0. 0. 0.75 0.25]] [[1. 0.] [0. 1.]] [[0.25 0.75 0. 0. ] [0.25 0.75 0. 0. ] [0. 0. 0.75 0.25] [0. 0. 0.75 0.25]] Is not invertible: [[1. 0. ] [0.75 0.25] [0. 1. ] [0. 1. ]] [[0.25 0.75 0. 0. ] [0. 0.2 0.6 0.2 ]] [[0.8125 0.1875] [0.15 0.85 ]] [[0.25 0.75 0. 0. ] [0.1875 0.6125 0.15 0.05 ] [0. 0.2 0.6 0.2 ] [0. 0.2 0.6 0.2 ]]
def plot_matrices(mat1, mat2, axes=None):
if sparse.issparse(mat1):
mat1 = mat1.toarray()
if sparse.issparse(mat2):
mat2 = mat2.toarray()
if axes is None:
fig, axes = plt.subplots(1, 2, figsize=(15, 4))
im = axes[0].imshow(mat1)
axes[0].figure.colorbar(im, ax=axes[0])
im = axes[1].imshow(mat2)
axes[1].figure.colorbar(im, ax=axes[1])
p = pool @ unpool
# Only if non-destructive.
# assert np.allclose(p, np.identity(graph2.N), atol=1e-10)
err = np.identity(graph2.N) - p
plot_matrices(p.toarray(), err)
# Another way to see the error.
# pool_pinv = np.linalg.pinv(unpool.toarray())
# assert np.allclose(pool_pinv @ unpool, np.identity(graph2.n_vertices), atol=1e-10)
# err = pool.toarray() - pool_pinv
# plot_matrices(pool_pinv, err)
def plot_inversion_error(pool, unpool, ax=None):
if ax is None:
_, ax = plt.subplots()
# diag = np.diag((pool @ unpool).toarray()
diag = pool.multiply(unpool.T).sum(1)
err = 1 - diag
ax.plot(err, '.')
err = np.sum(err) / len(err)
ax.set_title(f'averaging error per pixel ({err:.1%} overall error)')
plot_inversion_error(pool, unpool)
p = unpool @ pool
def block_diag(blocksize, nblock):
block = np.ones((int(blocksize), int(blocksize))) / blocksize
return sparse.block_diag([block]*nblock)
# Only a true error for the original HEALPix pixels. Not the Voronoi ones (which may overlap).
err = block_diag(int(graph1.n_vertices / graph2.n_vertices), graph2.n_vertices) - p
plot_matrices(p.toarray(), err.toarray())
# Another way to see the error.
# unpool_pinv = np.linalg.pinv(pool.toarray())
# err = unpool.toarray() - unpool_pinv
# plot_matrices(unpool_pinv, err)
# graph1 = pg.graphs.SphereHealpix(subdivisions=8, nest=False, k=20, kernel_width=None)
# graph2 = pg.graphs.SphereHealpix(subdivisions=2, nest=False, k=20, kernel_width=None)
# weights = build_interpolation_matrix(graph1, graph2)
# pool, unpool = build_pooling_matrices(weights)
def plot_laplacians(L, graph):
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
err = L - graph.L
plot_matrices(L, err, axes)
graph.compute_fourier_basis()
e, U = np.linalg.eigh(L.toarray())
axes[2].plot(graph.e, '.-', label='original')
axes[2].plot(e, '.-', label='reconstructed')
axes[2].legend()
plot_matrices(graph1.L, graph2.L)
# Processing on a graph of lower or higher resolution.
# TODO: a scaling factore is missing.
plot_laplacians(pool @ graph1.L @ unpool, graph2)
plot_laplacians(unpool @ graph2.L @ pool, graph1)
# Graph compression.
plot_laplacians(pool @ unpool @ graph2.L @ pool @ unpool, graph2)
plot_laplacians(unpool @ pool @ graph1.L @ unpool @ pool, graph1)
graphs = [
pg.graphs.SphereHealpix(4, k=8),
pg.graphs.SphereHealpix(2, k=8),
pg.graphs.SphereHealpix(1, k=8),
]
weights1 = build_interpolation_matrix(graphs[0], graphs[1])
weights2 = build_interpolation_matrix(graphs[1], graphs[2])
weights3 = build_interpolation_matrix(graphs[0], graphs[2])
# Toy example illustrating the mixing.
# weights2 = sparse.csr_matrix(np.array([
# [1, 1],
# [0.5, 0],
# ]))
# weights1 = sparse.csr_matrix(np.array([
# [0.5, 1, 0, 0, 0],
# [0, 0.1, 0.6, 0.1, 0.2],
# ]))
# weights3 = sparse.csr_matrix(np.array([
# [0.2, 0.9, 0.6, 0.1, 0.2],
# [0.3, 0.2, 0, 0, 0],
# ]))
# Same areas.
np.testing.assert_allclose(weights1.sum(1), weights2.sum(0).T)
np.testing.assert_allclose(weights1.sum(0), weights3.sum(0))
np.testing.assert_allclose(weights2.sum(1), weights3.sum(1))
pool1 = weights1.multiply(1/weights1.sum(1))
pool2 = weights2.multiply(1/weights2.sum(1))
pool3 = weights3.multiply(1/weights3.sum(1))
unpool1 = weights1.multiply(1/weights1.sum(0)).T
unpool2 = weights2.multiply(1/weights2.sum(0)).T
unpool3 = weights3.multiply(1/weights3.sum(0)).T
pool = pool2 @ pool1
np.testing.assert_allclose(pool.sum(1), 1)
np.testing.assert_allclose(pool3.sum(1), 1)
unpool = unpool1 @ unpool2
np.testing.assert_allclose(unpool.sum(1), 1)
np.testing.assert_allclose(unpool3.sum(1), 1)
# Encoder-decoder on multi-scale sampling.
unpool1.shape, unpool2.shape, pool2.shape, pool1.shape
((192, 48), (48, 12), (12, 48), (48, 192))
# Chaining is conservative by distributing area back.
areas = weights2.sum(1)
np.testing.assert_allclose(pool2.T @ areas, weights1.sum(1))
np.testing.assert_allclose(pool.T @ areas, weights1.sum(0).T)
np.testing.assert_allclose(pool3.T @ areas, weights1.sum(0).T)
areas = weights1.sum(0)
np.testing.assert_allclose(unpool1.T @ areas.T, weights2.sum(0).T)
np.testing.assert_allclose(unpool.T @ areas.T, weights2.sum(1))
np.testing.assert_allclose(unpool3.T @ areas.T, weights2.sum(1))
# Mixing / averaging through intermediary pixels.
assert not np.allclose(pool.toarray(), pool3.toarray())
assert not np.allclose(unpool.toarray(), unpool3.toarray())
samplings = {
'healpix': [
pg.graphs.SphereHealpix(16),
pg.graphs.SphereHealpix(8),
pg.graphs.SphereHealpix(4),
],
'icosahedral': [
pg.graphs.SphereIcosahedral(16),
pg.graphs.SphereIcosahedral(8),
pg.graphs.SphereIcosahedral(4),
],
'cubed': [
pg.graphs.SphereCubed(22),
pg.graphs.SphereCubed(11),
pg.graphs.SphereCubed(5),
],
'gausslegendre': [
pg.graphs.SphereGaussLegendre(45, nlon='ecmwf-octahedral'),
pg.graphs.SphereGaussLegendre(22, nlon='ecmwf-octahedral'),
pg.graphs.SphereGaussLegendre(11, nlon='ecmwf-octahedral'),
],
'equiangular': [
pg.graphs.SphereEquiangular(38, 76),
pg.graphs.SphereEquiangular(19, 38),
pg.graphs.SphereEquiangular(10, 20),
],
'random': [
pg.graphs.SphereRandom(2800, seed=1),
pg.graphs.SphereRandom(700, seed=1),
pg.graphs.SphereRandom(175, seed=1),
]
}
for sampling in samplings.values():
weights = build_interpolation_matrix(sampling[0], sampling[1])
plot_interpolation_matrix(weights)
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
plot_inversion_error(*build_pooling_matrices(weights), axes[0])
axes[1].hist((weights > 0).sum(1));
print('averaging over {:.1f} pixels, ({} non-zeros, {:.2%} sparsity)'.format(weights.nnz / weights.shape[0], weights.nnz, weights.nnz / np.prod(weights.shape)))
averaging over 8.2 pixels, (6272 non-zeros, 0.27% sparsity) averaging over 8.9 pixels, (5742 non-zeros, 0.35% sparsity) averaging over 8.8 pixels, (6384 non-zeros, 0.30% sparsity) averaging over 7.5 pixels, (6640 non-zeros, 0.27% sparsity) averaging over 7.9 pixels, (5700 non-zeros, 0.27% sparsity) averaging over 9.0 pixels, (6283 non-zeros, 0.32% sparsity)