%matplotlib inline
import matplotlib.pyplot as plt
import xarray as xr
from dask.distributed import Client, progress
pangeo='USGS-HPC-YETI'
if pangeo=='USGS-HPC-YETI':
fdir = '/lustre/projects/hazards/cmgp/woodshole/rsignell/crs_maps'
from dask_jobqueue import SLURMCluster
import os
cluster = SLURMCluster(processes=4, threads=1, memory='8GB',
project='woodshole', walltime='00:30:00', queue='normal',
interface='ib0')
workers = cluster.start_workers(5)
print(cluster.job_script())
#!/bin/bash #SBATCH -J dask-worker #SBATCH -e dask-worker.err #SBATCH -o dask-worker.out #SBATCH -p normal #SBATCH -A woodshole #SBATCH -n 1 #SBATCH --cpus-per-task=4 #SBATCH --mem=30G #SBATCH -t 00:30:00 /home/rsignell/miniconda3/envs/pangeo/bin/dask-worker tcp://10.12.0.20:35955 --nthreads 1 --nprocs 4 --memory-limit 8GB --name dask-worker-7 --death-timeout 60 --interface ib0
client = Client(cluster)
if False:
client = Client(scheduler_file='scheduler.json')
client
Read data with NetCDF4 for speed (low overhead)
ds = xr.open_dataset('/cxfs/projects/usgs/hazards/cmgp/woodshole/aaretxabaleta/projects/GSB_tides_55nb/ocean_his_gsb_tides_55nb.nc',
decode_times=None)
ds
dt, n, m = ds['zeta'].shape
print(dt,n,m)
t = ds['ocean_time'][istart:]/(3600*24)
from utide import solve
import numpy as np
import warnings
lat = 40.7
rayleigh = 0.95
istart = 50 # start at step 50 to avoid 1st day of spinup
t = t[istart:]
import dask.array as da
from dask import delayed, compute
%%time
nsub = 2
z = ds['zeta'][istart:,::nsub,::nsub].compute()
print(z.nbytes/1e9)
z.shape
%%time
with warnings.catch_warnings():
warnings.simplefilter("ignore")
acoef = solve(t, z[:,10,10], 0*t, lat, trend=False,
nodal=False, Rayleigh_min=0.95, method='ols', conf_int='linear')
val = acoef['Lsmaj']
acoef['name']
import dask.array as da
from dask import delayed
dt, n, m = z.shape
%%time
with warnings.catch_warnings():
warnings.simplefilter("ignore")
coefs = [delayed(solve)(t, z[:,j,i], t*0.0, lat,
trend=False, nodal=False, Rayleigh_min=0.95, method='ols',
conf_int='linear') for j in range(n) for i in range(m)]
arrays = [da.from_delayed(coef['Lsmaj'], dtype=val.dtype, shape=val.shape) for coef in coefs]
stack = da.stack(arrays, axis=0) # Stack all small Dask arrays into one
stack
%%time
amps = np.array(stack)
m2amp = amps[:,0].reshape((n,m))
plt.figure(figsize=(12,8))
plt.pcolormesh(m2amp)
plt.colorbar()
plt.title('M2 Elevation Amplitude');
cluster.stop_workers(workers)