import functools
import xarray
import numpy as np
import pandas as pd
# import xhistogram for comparison
from xhistogram.xarray import histogram as xhistogram
Conceptually, da.groupby_bins(bins).apply(count)
is essentially a pure-xarray histogram function. Can we do something similar to write a high-level histogram funciton in xarray without messing with low-level numpy code?
We want to first replicate np.histogram
:
data = xarray.DataArray(np.random.randn(20), dims='x', name='a')
bins = np.linspace(-4, 4, 10)
expected, _ = np.histogram(data, bins=bins)
expected
array([0, 1, 1, 5, 6, 5, 2, 0, 0])
result = data.groupby_bins(data, bins).apply(xarray.DataArray.count).fillna(0.0)
result
<xarray.DataArray 'a' (a_bins: 9)> array([0., 1., 1., 5., 6., 5., 2., 0., 0.]) Coordinates: * a_bins (a_bins) object (-4.0, -3.111] (-3.111, -2.222] ... (3.111, 4.0]
array([0., 1., 1., 5., 6., 5., 2., 0., 0.])
array([Interval(-4.0, -3.111, closed='right'), Interval(-3.111, -2.222, closed='right'), Interval(-2.222, -1.333, closed='right'), Interval(-1.333, -0.444, closed='right'), Interval(-0.444, 0.444, closed='right'), Interval(0.444, 1.333, closed='right'), Interval(1.333, 2.222, closed='right'), Interval(2.222, 3.111, closed='right'), Interval(3.111, 4.0, closed='right')], dtype=object)
np.allclose(result, expected)
True
That actually works!
(Maybe we could get rid of the floatyness later by specifying dtypes...?)
But how fast is it? Let's try it on a larger array (stealing Stephan's test example dataarray):
def make_test_data(t, x, y, seed=0):
signal = xarray.DataArray(
np.random.RandomState(seed).rand(t, x, y),
dims=['time', 'y', 'x'],
coords={
'time': np.arange(t),
'y': np.arange(x),
'x': np.arange(y),
},
name='signal')
return signal
signal = make_test_data(t=20, x=500, y=500)
bins = np.linspace(-4, 4, 50)
signal.nbytes / 1e6
40.0
%time _ = signal.groupby_bins(signal, bins=bins).apply(xarray.DataArray.count).fillna(0.0)
CPU times: user 23.1 s, sys: 2.96 s, total: 26.1 s Wall time: 26.5 s
It's very very slow.
numpy_groupies
¶But Stephan has a trick using numpy_groupies
...
from numpy_groupies import aggregate_np
def _binned_agg(
array: np.ndarray,
indices: np.ndarray,
num_bins: int,
*,
func,
fill_value,
dtype,
npg_aggregate=aggregate_np,
) -> np.ndarray:
"""NumPy helper function for aggregating over bins."""
mask = np.logical_not(np.isnan(indices))
int_indices = indices[mask].astype(int)
result = npg_aggregate(
int_indices, array[..., mask],
func=func,
size=num_bins,
fill_value=fill_value,
dtype=dtype,
axis=-1,
)
return result
def groupby_bins_count(
array: xarray.DataArray,
bins,
func=np.count_nonzero,
fill_value=0,
dtype=None,
**cut_kwargs,
) -> xarray.DataArray:
"""Faster equivalent of Xarray's groupby_bins(...).map()."""
# TODO: implement this upstream in xarray:
# https://github.com/pydata/xarray/issues/4473
binned = pd.cut(np.ravel(array), bins, **cut_kwargs)
new_dim_name = array.name + "_bins"
indices = array.copy(data=binned.codes.reshape(array.shape))
result = xarray.apply_ufunc(
_binned_agg, array, indices,
input_core_dims=[indices.dims, indices.dims],
output_core_dims=[[new_dim_name]],
output_dtypes=[array.dtype],
dask_gufunc_kwargs=dict(
output_sizes={new_dim_name: binned.categories.size},
allow_rechunk=True,
),
kwargs={
'num_bins': binned.categories.size,
'func': func,
'fill_value': fill_value,
'dtype': dtype,
},
dask='parallelized',
)
result.coords[new_dim_name] = binned.categories
return result
groupby_bins_count(data, bins)
/home/tegn500/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/numpy/core/_asarray.py:171: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray. return array(a, dtype, copy=False, order=order, subok=True)
<xarray.DataArray 'a' (a_bins: 49)> array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 2., 1., 0., 1., 1., 1., 3., 1., 0., 3., 0., 1., 1., 0., 0., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]) Coordinates: * a_bins (a_bins) object (-4.0, -3.837] (-3.837, -3.673] ... (3.837, 4.0]
array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 2., 1., 0., 1., 1., 1., 3., 1., 0., 3., 0., 1., 1., 0., 0., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
array([Interval(-4.0, -3.837, closed='right'), Interval(-3.837, -3.673, closed='right'), Interval(-3.673, -3.51, closed='right'), Interval(-3.51, -3.347, closed='right'), Interval(-3.347, -3.184, closed='right'), Interval(-3.184, -3.02, closed='right'), Interval(-3.02, -2.857, closed='right'), Interval(-2.857, -2.694, closed='right'), Interval(-2.694, -2.531, closed='right'), Interval(-2.531, -2.367, closed='right'), Interval(-2.367, -2.204, closed='right'), Interval(-2.204, -2.041, closed='right'), Interval(-2.041, -1.878, closed='right'), Interval(-1.878, -1.714, closed='right'), Interval(-1.714, -1.551, closed='right'), Interval(-1.551, -1.388, closed='right'), Interval(-1.388, -1.224, closed='right'), Interval(-1.224, -1.061, closed='right'), Interval(-1.061, -0.898, closed='right'), Interval(-0.898, -0.735, closed='right'), Interval(-0.735, -0.571, closed='right'), Interval(-0.571, -0.408, closed='right'), Interval(-0.408, -0.245, closed='right'), Interval(-0.245, -0.0816, closed='right'), Interval(-0.0816, 0.0816, closed='right'), Interval(0.0816, 0.245, closed='right'), Interval(0.245, 0.408, closed='right'), Interval(0.408, 0.571, closed='right'), Interval(0.571, 0.735, closed='right'), Interval(0.735, 0.898, closed='right'), Interval(0.898, 1.061, closed='right'), Interval(1.061, 1.224, closed='right'), Interval(1.224, 1.388, closed='right'), Interval(1.388, 1.551, closed='right'), Interval(1.551, 1.714, closed='right'), Interval(1.714, 1.878, closed='right'), Interval(1.878, 2.041, closed='right'), Interval(2.041, 2.204, closed='right'), Interval(2.204, 2.367, closed='right'), Interval(2.367, 2.531, closed='right'), Interval(2.531, 2.694, closed='right'), Interval(2.694, 2.857, closed='right'), Interval(2.857, 3.02, closed='right'), Interval(3.02, 3.184, closed='right'), Interval(3.184, 3.347, closed='right'), Interval(3.347, 3.51, closed='right'), Interval(3.51, 3.673, closed='right'), Interval(3.673, 3.837, closed='right'), Interval(3.837, 4.0, closed='right')], dtype=object)
(Why does this warning happen though? it even happens if I pass keepdims=True
. Do I need to explicitly tell it that the input_core_dims
will be reduced to length zero?)
Is this faster now?
%time gbc_result = groupby_bins_count(signal, bins=bins)
CPU times: user 988 ms, sys: 80 ms, total: 1.07 s Wall time: 1.08 s
It's way faster!
Let's compare to xhistogram's current implementation...
%time xhist_result = xhistogram(signal, bins=bins)
CPU times: user 180 ms, sys: 16 ms, total: 196 ms Wall time: 206 ms
Ah. So we could simplify the xhistogram code significantly, but it would be nearly 4x slower.
# double check
np.allclose(xhist_result, gbc_result)
True
But numpy_groupies can also use numba, so let's try that
from numpy_groupies import aggregate_nb
_binned_agg = functools.partial(_binned_agg, npg_aggregate=aggregate_nb)
groupby_bins_count(signal, bins=bins)
--------------------------------------------------------------------------- TypingError Traceback (most recent call last) <ipython-input-21-2c58b33edcfc> in <module> ----> 1 groupby_bins_count(signal, bins=bins) <ipython-input-14-3d6cd8c5c89e> in groupby_bins_count(array, bins, func, fill_value, dtype, **cut_kwargs) 15 indices = array.copy(data=binned.codes.reshape(array.shape)) 16 ---> 17 result = xarray.apply_ufunc( 18 _binned_agg, array, indices, 19 input_core_dims=[indices.dims, indices.dims], ~/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/xarray/core/computation.py in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args) 1169 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc 1170 elif any(isinstance(a, DataArray) for a in args): -> 1171 return apply_dataarray_vfunc( 1172 variables_vfunc, 1173 *args, ~/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/xarray/core/computation.py in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args) 286 287 data_vars = [getattr(a, "variable", a) for a in args] --> 288 result_var = func(*data_vars) 289 290 if signature.num_outputs > 1: ~/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/xarray/core/computation.py in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args) 737 ) 738 --> 739 result_data = func(*input_data) 740 741 if signature.num_outputs == 1: <ipython-input-13-ea0b63a650e7> in _binned_agg(array, indices, num_bins, func, fill_value, dtype, npg_aggregate) 12 mask = np.logical_not(np.isnan(indices)) 13 int_indices = indices[mask].astype(int) ---> 14 result = npg_aggregate( 15 int_indices, array[..., mask], 16 func=func, ~/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/numpy_groupies/aggregate_numba.py in aggregate(group_idx, a, func, size, fill_value, order, dtype, axis, cache, **kwargs) 436 cache = _default_cache 437 aggregate_op = cache.setdefault(func, AggregateGeneric(func)) --> 438 return aggregate_op(group_idx, a, size, fill_value, order, dtype, axis, **kwargs) 439 else: 440 func = _impl_dict[func] ~/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/numpy_groupies/aggregate_numba.py in __call__(self, group_idx, a, size, fill_value, order, dtype, axis, ddof) 202 203 sortidx = np.argsort(group_idx, kind='mergesort') --> 204 self._jitfunc(sortidx, group_idx, a, ret) 205 206 # Deal with ndimensional indexing ~/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/numba/core/dispatcher.py in _compile_for_args(self, *args, **kws) 418 e.patch_message(msg) 419 --> 420 error_rewrite(e, 'typing') 421 except errors.UnsupportedError as e: 422 # Something unsupported is present in the user code, add help info ~/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/numba/core/dispatcher.py in error_rewrite(e, issue_type) 359 raise e 360 else: --> 361 raise e.with_traceback(None) 362 363 argtypes = [] TypingError: Failed in nopython mode pipeline (step: nopython frontend) Internal error at <numba.core.typeinfer.CallConstraint object at 0x7fcb47a89ac0>. tuple index out of range During: resolving callee type: type(CPUDispatcher(<function count_nonzero at 0x7fcb9003c9d0>)) During: typing of call at /home/tegn500/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/numpy_groupies/aggregate_numba.py (228) Enable logging at debug level for details. File "../../../../miniconda3/envs/py38-mamba/lib/python3.8/site-packages/numpy_groupies/aggregate_numba.py", line 228: def _loop(sortidx, group_idx, a, ret): <source elided> raise ValueError("one or more indices in group_idx are too large") ret[ri] = jitfunc(a_srt[start_idx:stop_idx]) ^
Oh dear. Not sure if that's related to the warning I'm getting above?
# Set numpy_groupies.aggregate back to numpy implementation
_binned_agg = functools.partial(_binned_agg, npg_aggregate=aggregate_np)
groupby_bins_count
does the binning and then the counting separately - which is the bottleneck?
def groupby_bins(
array: xarray.DataArray,
bins,
**cut_kwargs,
):
binned = pd.cut(np.ravel(array), bins, **cut_kwargs)
indices = array.copy(data=binned.codes.reshape(array.shape))
return binned, indices
def count_groups(array: xarray.DataArray=None,
func=np.count_nonzero,
binned=None,
indices=None,
fill_value=0,
dtype=None,
) -> xarray.DataArray:
new_dim_name = array.name + "_bins"
result = xarray.apply_ufunc(
_binned_agg, array, indices,
input_core_dims=[indices.dims, indices.dims],
output_core_dims=[[new_dim_name]],
output_dtypes=[array.dtype],
dask_gufunc_kwargs=dict(
output_sizes={new_dim_name: binned.categories.size},
allow_rechunk=True,
),
kwargs={
'num_bins': binned.categories.size,
'func': func,
'fill_value': fill_value,
'dtype': dtype,
},
dask='parallelized',
)
result.coords[new_dim_name] = binned.categories
return result
%time binned, indices = groupby_bins(signal, bins=bins)
CPU times: user 248 ms, sys: 136 ms, total: 384 ms Wall time: 399 ms
%time result = count_groups(signal, binned=binned, indices=indices)
CPU times: user 584 ms, sys: 392 ms, total: 976 ms Wall time: 989 ms
/home/tegn500/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/numpy/core/_asarray.py:171: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray. return array(a, dtype, copy=False, order=order, subok=True)
Either part alone is slower than xhistogram's whole call...
Parallelize over a single core dimension time
import dask
dask_signal = signal.chunk({'time': 1})
dask.config.set(num_workers=4)
dask_signal
<xarray.DataArray 'signal' (time: 20, y: 500, x: 500)> dask.array<xarray-<this-array>, shape=(20, 500, 500), dtype=float64, chunksize=(1, 500, 500), chunktype=numpy.ndarray> Coordinates: * time (time) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 * y (y) int64 0 1 2 3 4 5 6 7 8 ... 491 492 493 494 495 496 497 498 499 * x (x) int64 0 1 2 3 4 5 6 7 8 ... 491 492 493 494 495 496 497 498 499
|
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
array([ 0, 1, 2, ..., 497, 498, 499])
array([ 0, 1, 2, ..., 497, 498, 499])
%time result = dask_signal.groupby_bins(signal, bins).apply(xarray.DataArray.count).fillna(0.0)
%time result.compute()
result
CPU times: user 19.4 s, sys: 872 ms, total: 20.2 s Wall time: 20.3 s CPU times: user 128 ms, sys: 12 ms, total: 140 ms Wall time: 102 ms
<xarray.DataArray 'signal' (signal_bins: 49)> dask.array<where, shape=(49,), dtype=float64, chunksize=(25,), chunktype=numpy.ndarray> Coordinates: * signal_bins (signal_bins) object (-4.0, -3.837] ... (3.837, 4.0]
|
array([Interval(-4.0, -3.837, closed='right'), Interval(-3.837, -3.673, closed='right'), Interval(-3.673, -3.51, closed='right'), Interval(-3.51, -3.347, closed='right'), Interval(-3.347, -3.184, closed='right'), Interval(-3.184, -3.02, closed='right'), Interval(-3.02, -2.857, closed='right'), Interval(-2.857, -2.694, closed='right'), Interval(-2.694, -2.531, closed='right'), Interval(-2.531, -2.367, closed='right'), Interval(-2.367, -2.204, closed='right'), Interval(-2.204, -2.041, closed='right'), Interval(-2.041, -1.878, closed='right'), Interval(-1.878, -1.714, closed='right'), Interval(-1.714, -1.551, closed='right'), Interval(-1.551, -1.388, closed='right'), Interval(-1.388, -1.224, closed='right'), Interval(-1.224, -1.061, closed='right'), Interval(-1.061, -0.898, closed='right'), Interval(-0.898, -0.735, closed='right'), Interval(-0.735, -0.571, closed='right'), Interval(-0.571, -0.408, closed='right'), Interval(-0.408, -0.245, closed='right'), Interval(-0.245, -0.0816, closed='right'), Interval(-0.0816, 0.0816, closed='right'), Interval(0.0816, 0.245, closed='right'), Interval(0.245, 0.408, closed='right'), Interval(0.408, 0.571, closed='right'), Interval(0.571, 0.735, closed='right'), Interval(0.735, 0.898, closed='right'), Interval(0.898, 1.061, closed='right'), Interval(1.061, 1.224, closed='right'), Interval(1.224, 1.388, closed='right'), Interval(1.388, 1.551, closed='right'), Interval(1.551, 1.714, closed='right'), Interval(1.714, 1.878, closed='right'), Interval(1.878, 2.041, closed='right'), Interval(2.041, 2.204, closed='right'), Interval(2.204, 2.367, closed='right'), Interval(2.367, 2.531, closed='right'), Interval(2.531, 2.694, closed='right'), Interval(2.694, 2.857, closed='right'), Interval(2.857, 3.02, closed='right'), Interval(3.02, 3.184, closed='right'), Interval(3.184, 3.347, closed='right'), Interval(3.347, 3.51, closed='right'), Interval(3.51, 3.673, closed='right'), Interval(3.673, 3.837, closed='right'), Interval(3.837, 4.0, closed='right')], dtype=object)
def _dask_count_nonzero(arr, axis=None, keepdims=False):
"""The keepdims argument is not implemented in the dask version of count_nonzero"""
counts = dask.array.count_nonzero(arr, axis=axis)
if keepdims:
if axis is None:
axis = tuple(range(arr.ndim))
return np.expand_dims(counts, axis)
else:
return counts
%time result = groupby_bins_count(dask_signal, func=_dask_count_nonzero, bins=bins)
display(result.data.visualize())
%time result.compute()
result
CPU times: user 304 ms, sys: 704 ms, total: 1.01 s Wall time: 1.33 s
/home/tegn500/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/numpy/core/_asarray.py:171: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray. return array(a, dtype, copy=False, order=order, subok=True)
CPU times: user 1.07 s, sys: 2.27 s, total: 3.34 s Wall time: 4.34 s
<xarray.DataArray 'signal' (signal_bins: 49)> dask.array<transpose, shape=(49,), dtype=float64, chunksize=(49,), chunktype=numpy.ndarray> Coordinates: * signal_bins (signal_bins) object (-4.0, -3.837] ... (3.837, 4.0]
|
array([Interval(-4.0, -3.837, closed='right'), Interval(-3.837, -3.673, closed='right'), Interval(-3.673, -3.51, closed='right'), Interval(-3.51, -3.347, closed='right'), Interval(-3.347, -3.184, closed='right'), Interval(-3.184, -3.02, closed='right'), Interval(-3.02, -2.857, closed='right'), Interval(-2.857, -2.694, closed='right'), Interval(-2.694, -2.531, closed='right'), Interval(-2.531, -2.367, closed='right'), Interval(-2.367, -2.204, closed='right'), Interval(-2.204, -2.041, closed='right'), Interval(-2.041, -1.878, closed='right'), Interval(-1.878, -1.714, closed='right'), Interval(-1.714, -1.551, closed='right'), Interval(-1.551, -1.388, closed='right'), Interval(-1.388, -1.224, closed='right'), Interval(-1.224, -1.061, closed='right'), Interval(-1.061, -0.898, closed='right'), Interval(-0.898, -0.735, closed='right'), Interval(-0.735, -0.571, closed='right'), Interval(-0.571, -0.408, closed='right'), Interval(-0.408, -0.245, closed='right'), Interval(-0.245, -0.0816, closed='right'), Interval(-0.0816, 0.0816, closed='right'), Interval(0.0816, 0.245, closed='right'), Interval(0.245, 0.408, closed='right'), Interval(0.408, 0.571, closed='right'), Interval(0.571, 0.735, closed='right'), Interval(0.735, 0.898, closed='right'), Interval(0.898, 1.061, closed='right'), Interval(1.061, 1.224, closed='right'), Interval(1.224, 1.388, closed='right'), Interval(1.388, 1.551, closed='right'), Interval(1.551, 1.714, closed='right'), Interval(1.714, 1.878, closed='right'), Interval(1.878, 2.041, closed='right'), Interval(2.041, 2.204, closed='right'), Interval(2.204, 2.367, closed='right'), Interval(2.367, 2.531, closed='right'), Interval(2.531, 2.694, closed='right'), Interval(2.694, 2.857, closed='right'), Interval(2.857, 3.02, closed='right'), Interval(3.02, 3.184, closed='right'), Interval(3.184, 3.347, closed='right'), Interval(3.347, 3.51, closed='right'), Interval(3.51, 3.673, closed='right'), Interval(3.673, 3.837, closed='right'), Interval(3.837, 4.0, closed='right')], dtype=object)
Performance seems as expected, but it is cool that we can achieve parallelization over all input dimensions with so little code!
Let's compare to xhistogram with it's new dask.array.blockwise
implementation:
%time result = xhistogram(dask_signal, bins=bins)
display(result.data.visualize())
%time result.compute()
result
CPU times: user 8 ms, sys: 4 ms, total: 12 ms Wall time: 19.1 ms
CPU times: user 220 ms, sys: 16 ms, total: 236 ms Wall time: 293 ms
<xarray.DataArray 'histogram_signal' (signal_bin: 49)> dask.array<sum-aggregate, shape=(49,), dtype=int64, chunksize=(49,), chunktype=numpy.ndarray> Coordinates: * signal_bin (signal_bin) float64 -3.918 -3.755 -3.592 ... 3.592 3.755 3.918
|
array([-3.918367, -3.755102, -3.591837, -3.428571, -3.265306, -3.102041, -2.938776, -2.77551 , -2.612245, -2.44898 , -2.285714, -2.122449, -1.959184, -1.795918, -1.632653, -1.469388, -1.306122, -1.142857, -0.979592, -0.816327, -0.653061, -0.489796, -0.326531, -0.163265, 0. , 0.163265, 0.326531, 0.489796, 0.653061, 0.816327, 0.979592, 1.142857, 1.306122, 1.469388, 1.632653, 1.795918, 1.959184, 2.122449, 2.285714, 2.44898 , 2.612245, 2.77551 , 2.938776, 3.102041, 3.265306, 3.428571, 3.591837, 3.755102, 3.918367])
xhistogram performs way better, and has a much more linear task graph.
For this implementation to replace the current xhistogram implementaion we would also need to work out how to generalise it in various ways:
Try counting over time
, but not over broadcast dimensions x
and y
dask_signal
<xarray.DataArray 'signal' (time: 20, y: 500, x: 500)> dask.array<xarray-<this-array>, shape=(20, 500, 500), dtype=float64, chunksize=(1, 500, 500), chunktype=numpy.ndarray> Coordinates: * time (time) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 * y (y) int64 0 1 2 3 4 5 6 7 8 ... 491 492 493 494 495 496 497 498 499 * x (x) int64 0 1 2 3 4 5 6 7 8 ... 491 492 493 494 495 496 497 498 499
|
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
array([ 0, 1, 2, ..., 497, 498, 499])
array([ 0, 1, 2, ..., 497, 498, 499])
Need to generalise our function to accept count_dims
to count over:
def _binned_agg(
array: np.ndarray,
indices: np.ndarray,
num_bins: int,
axis=-1,
*,
func,
fill_value,
dtype,
) -> np.ndarray:
"""NumPy helper function for aggregating over bins."""
mask = np.logical_not(np.isnan(indices))
int_indices = indices[mask].astype(int)
shape = array.shape[:-indices.ndim] + (num_bins,)
result = aggregate_np(
int_indices, array[..., mask],
func=func,
size=num_bins,
fill_value=fill_value,
dtype=dtype,
axis=axis,
)
return result
def groupby_bins_count(
array: xarray.DataArray,
count_dims=None,
bins=None,
func=np.count_nonzero,
fill_value=0,
dtype=None,
**cut_kwargs,
) -> xarray.DataArray:
"""Faster equivalent of Xarray's groupby_bins(...).map()."""
# TODO: implement this upstream in xarray:
# https://github.com/pydata/xarray/issues/4473
binned = pd.cut(np.ravel(array), bins, **cut_kwargs)
new_dim_name = array.name + "_bins"
indices = array.copy(data=binned.codes.reshape(array.shape))
if count_dims is None:
# Count over flattened array by default
count_dims = indices.dims
if any(d not in indices.dims for d in count_dims):
raise ValueError
result = xarray.apply_ufunc(
_binned_agg, array, indices,
input_core_dims=[count_dims, count_dims],
output_core_dims=[[new_dim_name]],
output_dtypes=[array.dtype],
dask_gufunc_kwargs=dict(
output_sizes={new_dim_name: binned.categories.size+1},
allow_rechunk=True,
),
kwargs={
'num_bins': binned.categories.size,
'func': func,
'fill_value': fill_value,
'dtype': dtype,
'axis': tuple(range(-len(count_dims), 0))
},
dask='parallelized',
)
result.coords[new_dim_name] = binned.categories
return result
groupby_bins_count(signal, count_dims=('x','y'), bins=bins, func=_dask_count_nonzero)
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-38-2e42c3ccdebe> in <module> ----> 1 groupby_bins_count(signal, count_dims=('x','y'), bins=bins, func=_dask_count_nonzero) <ipython-input-37-3679abaeabf3> in groupby_bins_count(array, count_dims, bins, func, fill_value, dtype, **cut_kwargs) 22 raise ValueError 23 ---> 24 result = xarray.apply_ufunc( 25 _binned_agg, array, indices, 26 input_core_dims=[count_dims, count_dims], ~/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/xarray/core/computation.py in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args) 1169 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc 1170 elif any(isinstance(a, DataArray) for a in args): -> 1171 return apply_dataarray_vfunc( 1172 variables_vfunc, 1173 *args, ~/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/xarray/core/computation.py in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args) 286 287 data_vars = [getattr(a, "variable", a) for a in args] --> 288 result_var = func(*data_vars) 289 290 if signature.num_outputs > 1: ~/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/xarray/core/computation.py in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args) 737 ) 738 --> 739 result_data = func(*input_data) 740 741 if signature.num_outputs == 1: <ipython-input-36-af667a9143ce> in _binned_agg(array, indices, num_bins, axis, func, fill_value, dtype) 14 int_indices = indices[mask].astype(int) 15 shape = array.shape[:-indices.ndim] + (num_bins,) ---> 16 result = aggregate_np( 17 int_indices, array[..., mask], 18 func=func, ~/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/numpy_groupies/aggregate_numpy.py in aggregate(group_idx, a, func, size, fill_value, order, dtype, axis, **kwargs) 289 def aggregate(group_idx, a, func='sum', size=None, fill_value=0, order='C', 290 dtype=None, axis=None, **kwargs): --> 291 return _aggregate_base(group_idx, a, size=size, fill_value=fill_value, 292 order=order, dtype=dtype, func=func, axis=axis, 293 _impl_dict=_impl_dict, _nansqueeze=True, **kwargs) ~/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/numpy_groupies/aggregate_numpy.py in _aggregate_base(group_idx, a, func, size, fill_value, order, dtype, axis, _impl_dict, _nansqueeze, cache, **kwargs) 254 order='C', dtype=None, axis=None, _impl_dict=_impl_dict, 255 _nansqueeze=False, cache=None, **kwargs): --> 256 group_idx, a, flat_size, ndim_idx, size = input_validation(group_idx, a, 257 size=size, order=order, axis=axis) 258 if group_idx.dtype == np.dtype("uint64"): ~/miniconda3/envs/py38-mamba/lib/python3.8/site-packages/numpy_groupies/utils_numpy.py in input_validation(group_idx, a, size, order, axis, ravel_group_idx, check_bounds) 216 raise ValueError("a must be scalar or 1 dimensional, use .ravel to" 217 " flatten. Alternatively specify axis.") --> 218 elif axis >= ndim_a or axis < -ndim_a: 219 raise ValueError("axis arg too large for np.ndim(a)") 220 else: TypeError: '>=' not supported between instances of 'tuple' and 'int'
It seems like numpy_groupies.aggregate
might only accept one axis?
We could try reshaping so all dim to be reduced are along axis=-1 ?
To get an ND histogram (i.e. binning along multiple dimensions) would we need to calculate one set of indices
for each input array
? And then perform the aggregation in a loop: once for each dimension to bin?
To apply weights we would need a vectorized (and dask-parallelizable) function which instead of
def count_nonzero(arr):
count = 0
for v in arr:
if v != 0:
count += 1
return count
instead worked like
def sum_weights(arr, weights):
total = 0
for i, v in enumarate(arr):
if v != 0:
total += weights[i]
return total
we would also want it to have an axis argument I think.
This is really like a cumulative sum of weights, but masked by the non-zero values of arr. So we could just do:
mask = np.logical_not(np.isnan(arr))
total = weights[mask].sum()
This should be possible with dask.array
s?
(Also perhaps using bottleneck.nansum
might be faster somehow?)
Would need a version of this logic which could accept an N-dimensional bins argument
binned = pd.cut(np.ravel(array), bins, **cut_kwargs)
indices = array.copy(data=binned.codes.reshape(array.shape))
I have no idea how to do that in a vectorized manner.