from collections import OrderedDict
# parameters
# Some advice
#
# - on a single big node, favor in-memory comms (i.e., many
# threads per few workers)
#
# - memory_limit is per worker and hence must be able to cope with
# threads x chunksize x "a handful" (<-- depends on complexity of calculation)
#
# - also keep the Dask graph in mind: Rechunking adds lots of info to the
# dependency graph and scheduling can quickly become the bottleneck.
# (Play optimized chunk size and watch length of the __dask_graph__.)
#
# - Rechunking is hard. As Dask tries very hard to not repeat itself,
# rechunking will only realease an initial chunk from memroy after all
# final chunks that depend on it are assembled. Hence, chunking from fully
# row-oriented to fully column-oriented will result in all data in memory at
# the same time.
#
# - To scale out horizontally, consider Dask jobqueue, but keep the
# communication in mind. At the same time, if IO bandwidth is an issue,
#
#
# make cluster setup easy to see and change here
dask_local_cluster_args = {
"n_workers": 1,
"threads_per_worker": 8,
"memory_limit": 60e9
}
# make it easy to limit to fewer data
time_steps = int(3 * 24) # 3 days for now
# total data shape
data_shape = OrderedDict(time=time_steps, depth=2, Y=2000, X=4000)
# with x.open_mfdataset, chunks cannot cross file boundaries:
data_chunk_shape_initial = OrderedDict(time=1, depth=1, Y=1000, X=2000)
# chunking that is optimized for time-filtering
# too aggressive: data_chunk_shape_for_time_ops = OrderedDict(time=None, depth=1, Y=1, X=1)
data_chunk_shape_for_time_ops = OrderedDict(time=time_steps, depth=1, Y=500, X=500)
from operator import mul
from functools import reduce
print(f"data size per var = {reduce(mul, data_shape.values()) * 8 / 1e9} GB")
print(f"initial chunk size = {reduce(mul, data_chunk_shape_initial.values()) * 8 / 1e9} GB")
print(f"optimized chunk size = {reduce(mul, data_chunk_shape_for_time_ops.values()) * 8 / 1e9} GB")
data size per var = 9.216 GB initial chunk size = 0.016 GB optimized chunk size = 0.144 GB
from dask import array as darr
from dask.distributed import LocalCluster, Client, wait
import numpy as np
import xarray as xr
# set up Dask cluster
cluster = LocalCluster(**dask_local_cluster_args, ip="0.0.0.0")
client = Client(cluster)
client
Client
|
Cluster
|
tuple(data_chunk_shape_initial.values())
(1, 1, 1000, 2000)
# create dataset with two vars
# no coordinate variables for now
initial_dataset = xr.Dataset(
{
"u": xr.DataArray(
darr.random.normal(
size=tuple(data_shape.values()),
chunks=tuple(data_chunk_shape_initial.values())
),
dims=tuple(data_shape.keys()),
coords={k: np.arange(v) for k, v in data_shape.items()}
),
"v": xr.DataArray(
darr.random.normal(
size=tuple(data_shape.values()),
chunks=tuple(data_chunk_shape_initial.values())
),
dims=tuple(data_shape.keys()),
coords={k: np.arange(v) for k, v in data_shape.items()}
)
}
)
display(initial_dataset)
print(f"size = {initial_dataset.nbytes / 1e9} GB")
print(f"number of Dask tasks = {len(initial_dataset.__dask_graph__())}")
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71])
array([0, 1])
array([ 0, 1, 2, ..., 1997, 1998, 1999])
array([ 0, 1, 2, ..., 3997, 3998, 3999])
|
|
size = 18.432048592 GB number of Dask tasks = 1152
# time filtering doesn't work, because rolling over more
# than 2 chunks is not possible right now.
# So this will raise a ValueError:
try:
initial_dataset.rolling(time=12, center=True).mean()
except Exception as e:
print(f"{type(e).__name__}: {e}")
ValueError: For window size 12, every chunk should be larger than 6, but the smallest chunk size is 1. Rechunk your array with a larger chunk size or a chunk size that more evenly divides the shape of your array.
# try rechunking
rechunked_dataset = initial_dataset.chunk(data_chunk_shape_for_time_ops)
display(rechunked_dataset)
print(f"size = {rechunked_dataset.nbytes / 1e9} GB")
print(f"number of Dask tasks = {len(rechunked_dataset.__dask_graph__())}")
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71])
array([0, 1])
array([ 0, 1, 2, ..., 1997, 1998, 1999])
array([ 0, 1, 2, ..., 3997, 3998, 3999])
|
|
size = 18.432048592 GB number of Dask tasks = 2432
# try filtering again
filtered_data_6hmean = rechunked_dataset.rolling(time=6, center=True).mean()
display(filtered_data_6hmean)
print(f"size = {filtered_data_6hmean.nbytes / 1e9} GB")
print(f"number of Dask tasks = {len(filtered_data_6hmean.__dask_graph__())}")
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71])
array([0, 1])
array([ 0, 1, 2, ..., 1997, 1998, 1999])
array([ 0, 1, 2, ..., 3997, 3998, 3999])
|
|
size = 18.432048592 GB number of Dask tasks = 11136
import hvplot.xarray
%%time
derived_data = initial_dataset.mean("X").var("time")
derived_data = derived_data.compute() # run full comp and load to frontend
CPU times: user 8.43 s, sys: 514 ms, total: 8.94 s Wall time: 24.9 s
# some_calculation: histogram of time variance of zonal mean data at upper level vs lower level
(
derived_data.isel(depth=0).hvplot.area(stacked=True, label="upper level")
+ derived_data.isel(depth=1).hvplot.area(stacked=True, label="lower level")
).cols(1)
%%time
derived_data = rechunked_dataset.mean("X").var("time")
derived_data = derived_data.compute() # run full comp and load to frontend
CPU times: user 7.95 s, sys: 471 ms, total: 8.42 s Wall time: 26.4 s
# some_calculation: histogram of time variance of zonal mean data at upper level vs lower level
(
derived_data.isel(depth=0).hvplot.area(stacked=True, label="upper level")
+ derived_data.isel(depth=1).hvplot.area(stacked=True, label="lower level")
).cols(1)
%%time
derived_data = filtered_data_6hmean.mean("X").var("time")
derived_data = derived_data.compute() # run full comp and load to frontend
CPU times: user 32.1 s, sys: 2.14 s, total: 34.2 s Wall time: 1min 48s
# some_calculation: histogram of time variance of zonal mean data at upper level vs lower level
(
derived_data.isel(depth=0).hvplot.area(stacked=True, label="upper level")
+ derived_data.isel(depth=1).hvplot.area(stacked=True, label="lower level")
).cols(1)