#!/usr/bin/env python # coding: utf-8 # In[1]: from collections import OrderedDict # In[2]: # parameters # Some advice # # - on a single big node, favor in-memory comms (i.e., many # threads per few workers) # # - memory_limit is per worker and hence must be able to cope with # threads x chunksize x "a handful" (<-- depends on complexity of calculation) # # - also keep the Dask graph in mind: Rechunking adds lots of info to the # dependency graph and scheduling can quickly become the bottleneck. # (Play optimized chunk size and watch length of the __dask_graph__.) # # - Rechunking is hard. As Dask tries very hard to not repeat itself, # rechunking will only realease an initial chunk from memroy after all # final chunks that depend on it are assembled. Hence, chunking from fully # row-oriented to fully column-oriented will result in all data in memory at # the same time. # # - To scale out horizontally, consider Dask jobqueue, but keep the # communication in mind. At the same time, if IO bandwidth is an issue, # # # make cluster setup easy to see and change here dask_local_cluster_args = { "n_workers": 1, "threads_per_worker": 8, "memory_limit": 60e9 } # make it easy to limit to fewer data time_steps = int(3 * 24) # 3 days for now # total data shape data_shape = OrderedDict(time=time_steps, depth=2, Y=2000, X=4000) # with x.open_mfdataset, chunks cannot cross file boundaries: data_chunk_shape_initial = OrderedDict(time=1, depth=1, Y=1000, X=2000) # chunking that is optimized for time-filtering # too aggressive: data_chunk_shape_for_time_ops = OrderedDict(time=None, depth=1, Y=1, X=1) data_chunk_shape_for_time_ops = OrderedDict(time=time_steps, depth=1, Y=500, X=500) # In[3]: from operator import mul from functools import reduce # In[4]: print(f"data size per var = {reduce(mul, data_shape.values()) * 8 / 1e9} GB") print(f"initial chunk size = {reduce(mul, data_chunk_shape_initial.values()) * 8 / 1e9} GB") print(f"optimized chunk size = {reduce(mul, data_chunk_shape_for_time_ops.values()) * 8 / 1e9} GB") # In[5]: from dask import array as darr from dask.distributed import LocalCluster, Client, wait import numpy as np import xarray as xr # In[6]: # set up Dask cluster cluster = LocalCluster(**dask_local_cluster_args, ip="0.0.0.0") client = Client(cluster) client # In[7]: tuple(data_chunk_shape_initial.values()) # In[8]: # create dataset with two vars # no coordinate variables for now initial_dataset = xr.Dataset( { "u": xr.DataArray( darr.random.normal( size=tuple(data_shape.values()), chunks=tuple(data_chunk_shape_initial.values()) ), dims=tuple(data_shape.keys()), coords={k: np.arange(v) for k, v in data_shape.items()} ), "v": xr.DataArray( darr.random.normal( size=tuple(data_shape.values()), chunks=tuple(data_chunk_shape_initial.values()) ), dims=tuple(data_shape.keys()), coords={k: np.arange(v) for k, v in data_shape.items()} ) } ) # In[9]: display(initial_dataset) print(f"size = {initial_dataset.nbytes / 1e9} GB") print(f"number of Dask tasks = {len(initial_dataset.__dask_graph__())}") # In[10]: # time filtering doesn't work, because rolling over more # than 2 chunks is not possible right now. # So this will raise a ValueError: try: initial_dataset.rolling(time=12, center=True).mean() except Exception as e: print(f"{type(e).__name__}: {e}") # In[11]: # try rechunking rechunked_dataset = initial_dataset.chunk(data_chunk_shape_for_time_ops) display(rechunked_dataset) print(f"size = {rechunked_dataset.nbytes / 1e9} GB") print(f"number of Dask tasks = {len(rechunked_dataset.__dask_graph__())}") # In[12]: # try filtering again filtered_data_6hmean = rechunked_dataset.rolling(time=6, center=True).mean() display(filtered_data_6hmean) print(f"size = {filtered_data_6hmean.nbytes / 1e9} GB") print(f"number of Dask tasks = {len(filtered_data_6hmean.__dask_graph__())}") # In[13]: import hvplot.xarray # In[14]: get_ipython().run_cell_magic('time', '', '\nderived_data = initial_dataset.mean("X").var("time")\nderived_data = derived_data.compute() # run full comp and load to frontend\n') # In[15]: # some_calculation: histogram of time variance of zonal mean data at upper level vs lower level ( derived_data.isel(depth=0).hvplot.area(stacked=True, label="upper level") + derived_data.isel(depth=1).hvplot.area(stacked=True, label="lower level") ).cols(1) # In[16]: get_ipython().run_cell_magic('time', '', '\nderived_data = rechunked_dataset.mean("X").var("time")\nderived_data = derived_data.compute() # run full comp and load to frontend\n') # In[17]: # some_calculation: histogram of time variance of zonal mean data at upper level vs lower level ( derived_data.isel(depth=0).hvplot.area(stacked=True, label="upper level") + derived_data.isel(depth=1).hvplot.area(stacked=True, label="lower level") ).cols(1) # In[18]: get_ipython().run_cell_magic('time', '', '\nderived_data = filtered_data_6hmean.mean("X").var("time")\nderived_data = derived_data.compute() # run full comp and load to frontend\n') # In[19]: # some_calculation: histogram of time variance of zonal mean data at upper level vs lower level ( derived_data.isel(depth=0).hvplot.area(stacked=True, label="upper level") + derived_data.isel(depth=1).hvplot.area(stacked=True, label="lower level") ).cols(1)