#!/usr/bin/env python # coding: utf-8 # # USGS Water Balance Model: Create cloud-optimized output # Data from a 2.5 arc minute CONUS model from 1895 to 2020 # The provided files were fixed-width ASCII, with year and date in the first two columns, and the data in the rest of the columns. The raster data from each time step is written to a single row, with only the non-missing values written. There is one file for each variable. There is also a separate CSV file that contains the lon,lat locations for each column of data. # # To parallelize the workflow, we split the original files (`tmean.monthly.all.gz`, `prcp.monthly.all.gz`) into many smaller text files using `split`, choosing the number of lines to match the desired number of time steps in the chunked output. # ``` # #!/bin/bash # for var in prcp tmean # do # mkdir $var # zcat $var.monthly.all.gz | split -l 120 --numeric-suffixes - $var/$var # done # # ``` # In[1]: import fsspec import xarray as xr import pandas as pd import numpy as np import datetime as dt import hvplot.xarray from dask.distributed import Client import dask # In[2]: fs = fsspec.filesystem('') # In[3]: inpath = '/scratch/mike/' outpath = '/scratch/mike/wbm.zarr' # In[4]: fs.ls(inpath) # #### Read the station locations # In[5]: df = pd.read_csv(f'{inpath}/LatLongs.csv') # Determine the i,j locations on the grid corresponding to the given lon,lat point: # In[6]: ii = np.round((df['X']-df['X'].min())/(2.5/60)).astype('int') jj = np.round((df['Y']-df['Y'].min())/(2.5/60)).astype('int') # In[7]: nx = max(ii)+1 ny = max(jj)+1 print(nx,ny) # In[8]: lon = np.linspace(df['X'].min(), df['X'].max(),nx) lat = np.linspace(df['Y'].min(), df['Y'].max(),ny) # #### Create the empty Zarr dataset to fill with chunks # In[9]: dates = pd.date_range(start='1895-01-01 00:00',end='2021-01-01 00:00', freq='M') # In[10]: nt = len(dates) print(nt) # In[11]: chunk_lon = 700 chunk_lat = 300 chunk_time = 120 # In[12]: fs.ls(f'{inpath}/gzfiles/') # In[13]: d = dask.array.zeros((nt,ny,nx), chunks=(chunk_time, chunk_lat, chunk_lon), dtype='float32') # In[14]: ds0 = xr.Dataset( { "prcp": (['time', 'lat', 'lon'], d), "tmean": (['time', 'lat', 'lon'], d), "aet": (['time', 'lat', 'lon'], d), "pet": (['time', 'lat', 'lon'], d), "rain": (['time', 'lat', 'lon'], d), "runoff": (['time', 'lat', 'lon'], d), "snow": (['time', 'lat', 'lon'], d), "soilstorage": (['time', 'lat', 'lon'], d), "swe": (['time', 'lat', 'lon'], d) }, coords={ "lon": (["lon"], lon), "lat": (["lat"], lat), "time": dates }, ) # In[15]: ds0.to_zarr(outpath, mode='w', compute=False, consolidated=True) # In[16]: def write_chunk(var, f, istart): a = np.loadtxt(f, dtype='float32') year = a[:,0].astype('int') mon = a[:,1].astype('int') t = [np.datetime64(dt.datetime(year[k],mon[k],1)) for k in range(len(mon))] data = a[:,2:] [nt, nr] = data.shape b = np.nan * np.zeros((nt,ny,nx), dtype='float32') for k in range(nr): b[:, jj[k], ii[k]] = data[:,k] da = xr.DataArray(data=b, dims=['time','lat','lon'], coords=dict( lon=('lon',lon), lat=('lat',lat), time=('time',t) )) ds = da.to_dataset(name=var) ds = ds.chunk(chunks={'time':chunk_time, 'lat':chunk_lat, 'lon':chunk_lon}) ds.drop(['lon','lat']).to_zarr(outpath, region={'time':slice(istart,istart+nt)}) # In[17]: client = Client() # In[18]: get_ipython().run_cell_magic('time', '', "tasks=[]\nfor var in ['tmean','prcp','aet','pet','rain','runoff','snow','soilstorage','swe']:\n flist = fs.glob(f'/scratch/mike/gzfiles/{var}/{var}??')\n i = 0\n for f in flist:\n print(f)\n istart=i*chunk_time\n tasks.append(dask.delayed(write_chunk)(var, f, istart))\n i = i + 1\n") # In[19]: get_ipython().run_cell_magic('time', '', "dask.compute(tasks, scheduler='processes', num_workers=4)\n") # #### Let's see what we produced! # In[28]: ds2 = xr.open_dataset(outpath, engine='zarr', chunks={}) # In[29]: ds2 # In[30]: ds2.tmean.sel(time='1925-01-01').hvplot.quadmesh(x='lon',y='lat', geo=True, tiles='OSM', cmap='turbo', rasterize=True, alpha=0.7) # In[23]: ds2.prcp.hvplot.quadmesh(x='lon',y='lat', geo=True, tiles='OSM', cmap='turbo', rasterize=True, alpha=0.7) # In[24]: ds2.tmean # In[25]: ds2.tmean.sel(lon=-90.,lat=35.,method='nearest').plot()