Dask is a big data processing library used for:
Most notably, it provides the support for working with larger-than-memory datasets. In this case, dask partitions the dataset into smaller chunks, then loads only a few chunks from the disk, and once the necessary processing is completed, it throws away the intermediate values. This way, the computations are performed without exceeding the memory limit.
Check out these links if you're unsure whether your workflow can benefit from using Dask or not:
dask:why
dask:array-best-practices
Excerpt from "Dask Array Best Practices" doc.
If your data fits comfortably in RAM and you are not performance bound, then using
NumPy
might be the right choice. Dask adds another layer of complexity which may get in the way.If you are just looking for speedups rather than scalability then you may want to consider a project like
Numba
.
:::{caution} Dask is an optional dependency inside ArviZ, which is still being actively developed. Currently, few functions belonging to diagnostics and stats module can utilize Dask's capabilities. :::
import arviz as az
import numpy as np
import timeit
import dask
from arviz.utils import conditional_jit, Dask
# optional imports
from dask.distributed import Client
from dask.diagnostics import ResourceProfiler
from bokeh.resources import INLINE
import bokeh.io
bokeh.io.output_notebook(INLINE)
%reload_ext memory_profiler
:::{note}
{func}~dask.diagnostics.ResourceProfiler
and {class}~distributed.Client
are optional. They are only used for the visualizing and profiling the dask enabled methods. ArviZ-Dask integration can be used without using these objects.
:::
client = Client(threads_per_worker=4, n_workers=1, memory_limit="1.2GB")
client
Client-3a9af271-4541-11ec-b824-5820b17a12fa
Connection method: Cluster object | Cluster type: distributed.LocalCluster |
Dashboard: http://127.0.0.1:8787/status |
58ea9317
Dashboard: http://127.0.0.1:8787/status | Workers: 1 |
Total threads: 4 | Total memory: 1.12 GiB |
Status: running | Using processes: True |
Scheduler-b3987f5c-3da8-4564-9fb5-76fb51a47a9e
Comm: tcp://127.0.0.1:35215 | Workers: 1 |
Dashboard: http://127.0.0.1:8787/status | Total threads: 4 |
Started: Just now | Total memory: 1.12 GiB |
Comm: tcp://127.0.0.1:36217 | Total threads: 4 |
Dashboard: http://127.0.0.1:37745/status | Memory: 1.12 GiB |
Nanny: tcp://127.0.0.1:41547 | |
Local directory: /home/oriol/Public/arviz/doc/source/user_guide/dask-worker-space/worker-zyz5nqyr |
array_size = 250_000_000
Calculating variance using Numpy
%%memit
data = np.random.randn(array_size)
np.var(data, ddof=1)
del data
peak memory: 4072.28 MiB, increment: 3815.28 MiB
Calculating variance using Dask arrays:
%memit data = dask.array.random.normal(size=array_size, chunks="auto")
data
peak memory: 258.30 MiB, increment: 0.28 MiB
|
var = dask.array.var(data, ddof=1)
var.visualize()
with ResourceProfiler(dt=0.25) as rprof:
var.compute()
rprof.visualize();
del data
Here, the NumPy
version consumed around ~5GB memory but the Dask version was able to compute variance in under 1.2Gb memory (the limit set in the Client
configuration above) which shows how beneficial Dask can be when dealing with large datasets.
InferenceData
is the central data format for ArviZ and there are several ways to generate this object (which you can look {ref}here<creating_InferenceData>
.
However, as the ArviZ-Dask integraton is still a work in progress, to use InferenceData
object with Dask-compatible methods we'll have generate it in a different way. {func}arviz.from_netcdf
has an experimental group_kwargs
argument that can be used to read netCDF files directly with Dask.
We will progressively add more ways to generate Dask backed InferenceData
and document them here. If you are interested in helping out, reach out on Gitter
dask.array
¶We start creating a dask array with random samples, that we can then convert to InferenceData
using {func}arviz.from_dict
. ArviZ passes values and coord values as is to xarray, so by passing a dask array we'll get a dask backed InferenceData automatically.
%memit daskdata = dask.array.random.random((10, 1000, 10000), chunks=(10, 1000, 625))
daskdata
peak memory: 260.43 MiB, increment: 0.07 MiB
|
daskdata.visualize() # Each chunk will follow lazy evaluation
:::{note} Setting up the right value of the chunks parameter is very important. Computation on Dask arrays with small chunks are slow because each operation on a chunk has some overhead. On the other side, if your chunks are too big, then it might not fit in the memory. :::
datadict = {"x": daskdata}
%memit idata_dask = az.from_dict(posterior=datadict, dims={"x": ["dim_1"]})
idata_dask
peak memory: 260.55 MiB, increment: 0.07 MiB
<xarray.Dataset> Dimensions: (chain: 10, draw: 1000, dim_1: 10000) Coordinates: * chain (chain) int64 0 1 2 3 4 5 6 7 8 9 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * dim_1 (dim_1) int64 0 1 2 3 4 5 6 ... 9993 9994 9995 9996 9997 9998 9999 Data variables: x (chain, draw, dim_1) float64 dask.array<chunksize=(10, 1000, 625), meta=np.ndarray> Attributes: created_at: 2021-11-14T11:51:52.050185 arviz_version: 0.11.4
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
array([ 0, 1, 2, ..., 997, 998, 999])
array([ 0, 1, 2, ..., 9997, 9998, 9999])
|
{class}arviz.Dask
provides the functionality of disabling/re-enabling Dask. This is an ArviZ specific class that therefore works only with ArviZ functions that support computation via Dask.
We can also use it to set default arguments which are then taken by the Dask supporting functions and passed to {func}xarray.apply_ufunc
.
For comparison lets first create an InferenceData
object using numpy
array
%memit npdata = np.random.rand(10, 1000, 10000)
datadict = {"x": npdata}
idata_numpy = az.from_dict(posterior=datadict, dims={"x": ["dim_1"]})
idata_numpy
peak memory: 1023.65 MiB, increment: 762.97 MiB
<xarray.Dataset> Dimensions: (chain: 10, draw: 1000, dim_1: 10000) Coordinates: * chain (chain) int64 0 1 2 3 4 5 6 7 8 9 * draw (draw) int64 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999 * dim_1 (dim_1) int64 0 1 2 3 4 5 6 ... 9993 9994 9995 9996 9997 9998 9999 Data variables: x (chain, draw, dim_1) float64 0.3969 0.9378 0.349 ... 0.8059 0.6392 Attributes: created_at: 2021-11-14T11:51:53.110470 arviz_version: 0.11.4
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
array([ 0, 1, 2, ..., 997, 998, 999])
array([ 0, 1, 2, ..., 9997, 9998, 9999])
array([[[3.96893013e-01, 9.37832731e-01, 3.49005231e-01, ..., 1.36533029e-01, 8.52493391e-01, 3.11825216e-01], [2.24438848e-01, 5.46843569e-01, 2.23949351e-01, ..., 8.19941477e-01, 8.64443797e-02, 1.92593403e-01], [3.27913875e-01, 7.92223340e-01, 8.45441266e-02, ..., 5.21305713e-01, 4.10688001e-01, 2.04838186e-01], ..., [5.83650829e-01, 7.22516288e-01, 1.99610883e-01, ..., 1.31725612e-01, 5.43205032e-01, 4.61174897e-01], [4.13712821e-01, 5.41706904e-01, 2.99458008e-01, ..., 8.80293621e-01, 9.73619510e-01, 5.86407345e-01], [6.12662482e-01, 8.02328683e-01, 9.69486169e-01, ..., 2.98126426e-01, 2.79703190e-02, 9.26846418e-04]], [[2.05484112e-01, 6.13739789e-01, 6.21671265e-01, ..., 7.92393616e-01, 6.80348978e-01, 7.38175852e-01], [6.37058780e-01, 7.88523683e-01, 2.75056811e-02, ..., 7.70084758e-01, 6.15943940e-02, 1.28364699e-01], [5.54628005e-01, 5.81644903e-01, 4.01931224e-01, ..., 5.30039904e-01, 2.35943604e-02, 6.49195282e-01], ... [3.34881965e-01, 8.01624356e-01, 7.55748346e-02, ..., 9.86300065e-01, 9.75312881e-01, 2.96055886e-01], [1.90045905e-01, 9.58935861e-01, 3.34628238e-01, ..., 2.40706967e-01, 8.63792822e-01, 9.38813997e-01], [7.83035223e-01, 6.98403375e-01, 5.63195463e-01, ..., 2.79195156e-02, 3.26374185e-01, 4.92441648e-01]], [[2.75316625e-01, 6.89981496e-01, 1.83490693e-01, ..., 7.75977562e-01, 7.87100253e-01, 1.56747858e-01], [3.68682670e-01, 2.75850782e-01, 2.69705447e-01, ..., 5.60678692e-01, 4.47949484e-01, 5.89556820e-02], [2.26021523e-01, 7.01343504e-02, 4.07186700e-02, ..., 3.29680917e-01, 2.40345734e-01, 1.94683189e-02], ..., [4.90731079e-01, 6.65312479e-01, 7.90649553e-01, ..., 8.49850498e-01, 8.47466322e-01, 3.01141839e-01], [1.32789591e-01, 3.01567641e-01, 9.82754771e-01, ..., 3.83435223e-02, 7.55896741e-01, 7.44475982e-01], [7.13453261e-01, 5.70347081e-02, 1.07477961e-01, ..., 4.77604550e-01, 8.05875931e-01, 6.39158677e-01]]])
arviz.ess
¶%%time
%%memit
az.ess(idata_numpy)
peak memory: 1034.65 MiB, increment: 10.89 MiB CPU times: user 21 s, sys: 192 ms, total: 21.2 s Wall time: 21 s
:::{tip}
Set the most common default dask_kwargs when enabling Dask in order to simplify
future function calls. If needed, those default kwargs can always be overrriden
with the function specific dask_kwargs
argument.
:::
Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]})
%%time
%%memit
ess = az.ess(idata_dask)
with ResourceProfiler(dt=0.25) as rprof:
ess.compute()
peak memory: 1035.45 MiB, increment: 0.79 MiB CPU times: user 643 ms, sys: 104 ms, total: 747 ms Wall time: 15.8 s
Each chunk also contains the evaluation expression which will be calculated in parallel and on-the-fly
ess.data_vars["x"].data.visualize()
rprof.visualize()
Dask.disable_dask()
Here, dask enabled method consumed around ~400MB memory which is around ~360MB lesser than the vanilla method (also considering the memory consumption of the Numpy
Array ). Dask enabled method is also a bit faster.
arviz.rhat
¶%%time
%%memit
az.rhat(idata_numpy)
peak memory: 1035.81 MiB, increment: 0.30 MiB CPU times: user 32.7 s, sys: 167 ms, total: 32.9 s Wall time: 32.5 s
Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [int]})
We have now enabled Dask with incorrect default kwargs, which we have to override in the function call:
%%time
%%memit
rhat = az.rhat(idata_dask, dask_kwargs={"output_dtypes": [float]})
with ResourceProfiler(dt=0.25) as rprof:
rhat.compute()
peak memory: 1036.70 MiB, increment: 0.88 MiB CPU times: user 709 ms, sys: 156 ms, total: 865 ms Wall time: 20.6 s
rprof.visualize()
Dask.disable_dask()
arviz.hdi
¶%%time
%%memit
az.hdi(idata_numpy, hdi_prob=.68)
peak memory: 1037.05 MiB, increment: 0.20 MiB CPU times: user 5.9 s, sys: 63.8 ms, total: 5.96 s Wall time: 5.95 s
Dask.enable_dask(dask_kwargs={"dask": "parallelized", "output_dtypes": [float]})
With {func}arviz.hdi
we are introducing a new dimension to the output, the one containing the lower and higher HDI limits, so we need to use dask_gufunc_kwargs
from {func}xarray.apply_ufunc
which is passed as **kwargs
first to {func}arviz.wrap_xarray_ufunc
, then to {func}xarray.apply_ufunc
.
%%time
%%memit
hdi = az.hdi(idata_dask, hdi_prob=0.68, dask_gufunc_kwargs={"output_sizes": {"hdi": 2}})
with ResourceProfiler(dt=0.25) as rprof:
hdi.compute()
peak memory: 1037.55 MiB, increment: 0.50 MiB CPU times: user 266 ms, sys: 78.1 ms, total: 344 ms Wall time: 2.78 s
rprof.visualize()
Dask.disable_dask()
client.close()
In all the examples, it's noticeable that: