We have a 3D (time,lat,lon) cube of water level data and we want to run a tidal analysis function ("solve" from Utide) at each lon,lat grid point (or every other lat,lon grid point). This is embarrassingly parallel, and we could use other python parallel approaches, but can we use Dask Delayed?
%matplotlib inline
import xarray as xr
import matplotlib.pyplot as plt
import gcsfs
from dask.diagnostics import Profiler, ResourceProfiler, CacheProfiler
from dask.diagnostics import visualize
from bokeh.io import output_notebook
output_notebook()
from utide import solve
import numpy as np
import warnings
from dask.distributed import Client, progress, LocalCluster
from dask_kubernetes import KubeCluster
cluster = KubeCluster.from_yaml('/home/jovyan/myworker.yml')
cluster.scale(20)
cluster
Failed to display Jupyter Widget of type VBox
.
If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean that the widgets JavaScript is still loading. If this message persists, it likely means that the widgets JavaScript library is either not installed or not enabled. See the Jupyter Widgets Documentation for setup instructions.
If you're reading this message in another frontend (for example, a static rendering on GitHub or NBViewer), it may mean that your frontend doesn't currently support widgets.
client = Client(cluster)
client
Client
|
Cluster
|
#fs = gcsfs.GCSFileSystem(project='pangeo-181919', token='browser', access='read_only')
fs = gcsfs.GCSFileSystem(project='pangeo-181919', access='read_only')
gcsmap = gcsfs.mapping.GCSMap('pangeo-data/rsignell/ocean_his_tide_zeta',
gcs=fs, check=False, create=False)
ds = xr.open_zarr(gcsmap, decode_times=False)
ds.info()
xarray.Dataset { dimensions: eta_rho = 324 ; ocean_time = 1441 ; xi_rho = 1542 ; variables: float64 lat_rho(eta_rho, xi_rho) ; lat_rho:field = lat_rho, scalar ; lat_rho:long_name = latitude of RHO-points ; lat_rho:standard_name = latitude ; lat_rho:units = degree_north ; float64 lon_rho(eta_rho, xi_rho) ; lon_rho:field = lon_rho, scalar ; lon_rho:long_name = longitude of RHO-points ; lon_rho:standard_name = longitude ; lon_rho:units = degree_east ; float64 ocean_time(ocean_time) ; ocean_time:calendar = julian ; ocean_time:field = time, scalar, series ; ocean_time:long_name = time since initialization ; ocean_time:units = seconds since 0001-01-01 00:00:00 ; float32 zeta(ocean_time, eta_rho, xi_rho) ; zeta:field = free-surface, scalar, series ; zeta:grid = grid ; zeta:location = face ; zeta:long_name = free-surface ; zeta:time = ocean_time ; zeta:units = meter ; // global attributes: :CPP_options = GSB, ADD_FSOBC, ADD_M2OBC, ANA_BSFLUX, ANA_BTFLUX, ANA_FSOBC, ANA_INITIAL, ANA_M2OBC, ANA_SMFLUX, ANA_SSFLUX, ANA_STFLUX, ASSUMED_SHAPE, DJ_GRADPS, DOUBLE_PRECISION, GLS_MIXING, KANTHA_CLAYSON, MASKING, MIX_S_UV, MPI, NONLINEAR, !NONLIN_EOS, N2S2_HORAVG, POWER_LAW, PROFILE, K_GSCHEME, RAMP_TIDES, !RST_SINGLE, SOLVE3D, SSH_TIDES, TS_C4HADVECTION, TS_C4VADVECTION, TS_FIXED, UV_ADV, UV_COR, UV_U3HADVECTION, UV_C4VADVECTION, UV_LOGDRAG, UV_TIDES, UV_VIS2, VAR_RHO_2D, WET_DRY ; :Conventions = CF-1.4 ; :NCO = 4.7.3 ; :NLM_LBC = EDGE: WEST SOUTH EAST NORTH zeta: Cha Cha Cha Cha ubar: Fla Fla Fla Fla vbar: Fla Fla Fla Fla u: Gra Gra Gra Gra v: Gra Gra Gra Gra temp: Gra Gra Gra Gra salt: Gra Gra Gra Gra tke: Gra Gra Gra Gra ; :ana_file = ROMS/Functionals/ana_btflux.h, ROMS/Functionals/ana_fsobc.h, ROMS/Functionals/ana_initial.h, ROMS/Functionals/ana_m2obc.h, ROMS/Functionals/ana_smflux.h, ROMS/Functionals/ana_stflux.h ; :code_dir = /cxfs/projects/usgs/hazards/cmgp/woodshole/aaretxabaleta/models/COAWST ; :compiler_command = /opt/intel/impi/5.0.1.035/intel64/bin/mpif90 ; :compiler_flags = -heap-arrays -fp-model precise -ip -O3 -xW -free ; :compiler_system = ifort ; :cpu = x86_64 ; :file = ocean_his_gsb_tides_55nb.nc ; :format = netCDF-3 64bit offset file ; :frc_file_01 = ../forcings/tide_forc_GSB_55.nc ; :grd_file = ../grids/GSB_55nb.nc ; :header_dir = /cxfs/projects/usgs/hazards/cmgp/woodshole/aaretxabaleta/projects/GSB_tides_55nb ; :header_file = gsb.h ; :his_file = ocean_his_gsb_tides_55nb.nc ; :history = Mon Mar 12 09:36:34 2018: ncks -O -v ocean_time,zeta,lon_rho,lat_rho ocean_his_gsb_tides_55nb.nc /home/rsignell/ocean_his_zeta.nc ROMS/TOMS, Version 3.7, Monday - February 26, 2018 - 10:23:23 AM ; :os = Linux ; :rst_file = ocean_rst.nc ; :script_file = ; :svn_rev = ; :svn_url = https:://myroms.org/svn/src ; :tiling = 036x010 ; :title = Great south Bay ; :type = ROMS/TOMS history file ; :var_info = varinfo.dat ; }
#client.get_versions(check=True)
dt, n, m = ds['zeta'].shape
print(dt,n,m)
1441 324 1542
ds['zeta']
<xarray.DataArray 'zeta' (ocean_time: 1441, eta_rho: 324, xi_rho: 1542)> dask.array<shape=(1441, 324, 1542), dtype=float32, chunksize=(91, 41, 193)> Coordinates: lat_rho (eta_rho, xi_rho) float64 dask.array<shape=(324, 1542), chunksize=(81, 771)> lon_rho (eta_rho, xi_rho) float64 dask.array<shape=(324, 1542), chunksize=(81, 771)> * ocean_time (ocean_time) float64 0.0 1.8e+03 3.6e+03 5.4e+03 7.2e+03 ... Dimensions without coordinates: eta_rho, xi_rho Attributes: field: free-surface, scalar, series grid: grid location: face long_name: free-surface time: ocean_time units: meter
nsub=4
%time z = ds['zeta'][:,::nsub,::nsub].load().values
CPU times: user 5.82 s, sys: 877 ms, total: 6.7 s Wall time: 10.2 s
# convert time to days
t = ds['ocean_time'].values/(3600.*24)
# nominal latitude for tide calcs
lat = 40.7
%%time
# analyze tides at single cell as a test
with warnings.catch_warnings():
warnings.simplefilter("ignore")
acoef = solve(t=t, u=z[:,20,20], v=None, lat=lat,
trend=False, nodal=False, Rayleigh_min=0.95, method='ols',
conf_int='linear', verbose=False)
CPU times: user 111 ms, sys: 136 ms, total: 247 ms Wall time: 113 ms
# M2 amplitude
acoef['A'][0]
0.6128407747092631
import dask.array as da
from dask import delayed
usolve = delayed(solve)
kk,jj,ii = z.shape
%%time
# set up the Dask delayed task list
coefs = [usolve(t=t, u=z[:,j,i], v=None, lat=lat,
trend=False, nodal=False, verbose=False, Rayleigh_min=0.95, method='ols',
conf_int='linear') for j in range(jj) for i in range(ii)]
CPU times: user 8.96 s, sys: 515 ms, total: 9.47 s Wall time: 8.85 s
# compute tidal analysis (parallel)
%time total = delayed(coefs).compute()
CPU times: user 1min 24s, sys: 5.48 s, total: 1min 29s Wall time: 1min 52s
%%time
# compute tidal analysis in a regular loop (serial)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
bcoef = [solve(t=t, u=z[:,j,i], v=None, lat=lat,
trend=False, nodal=False, Rayleigh_min=0.95, method='ols',
conf_int='linear', verbose=False) for j in range(jj) for i in range(ii)]
CPU times: user 1h 10min 7s, sys: 1h 20min 31s, total: 2h 30min 38s Wall time: 1h 15min 56s
m2amp = [f['A'][0] for f in total]
len(m2amp)
31266
m2amp = np.array(m2amp).reshape((jj,ii))
%matplotlib inline
import matplotlib.pyplot as plt
plt.figure(figsize=(12,8))
plt.pcolormesh(m2amp)
plt.colorbar()
plt.title('M2 Elevation Amplitude');