dask
Module¶dask
library to parallelize code¶Note: This Jupyter notebook uses parallelization and is not meant to be executed within a Google Colab environment.
Note: This Jupyter notebook requires the PyRosetta distributed layer which is obtained by building PyRosetta with the --serialization
flag or installing PyRosetta from the RosettaCommons conda channel
Please see Chapter 16.00 for setup instructions
import dask
import dask.array as da
import graphviz
import logging
logging.basicConfig(level=logging.INFO)
import numpy as np
import os
import pyrosetta
import pyrosetta.distributed
import pyrosetta.distributed.dask
import pyrosetta.distributed.io as io
import random
import sys
from dask.distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
from IPython.display import Image
if 'google.colab' in sys.modules:
print("This Jupyter notebook uses parallelization and is therefore not set up for the Google Colab environment.")
sys.exit(0)
Initialize PyRosetta within this Jupyter notebook using custom command line PyRosetta flags:
flags = """-out:level 100
-ignore_unrecognized_res 1
-ignore_waters 0
-detect_disulf 0 # Do not automatically detect disulfides
""" # These can be unformatted for user convenience, but no spaces in file paths!
pyrosetta.distributed.init(flags)
INFO:pyrosetta.distributed:maybe_init performing pyrosetta initialization: {'extra_options': '-out:level 100 -ignore_unrecognized_res 1 -ignore_waters 0 -detect_disulf 0', 'silent': True} INFO:pyrosetta.rosetta:Found rosetta database at: /home/klimaj/anaconda3/envs/PyRosetta.notebooks/lib/python3.7/site-packages/pyrosetta/database; using it.... INFO:pyrosetta.rosetta:PyRosetta-4 2020 [Rosetta PyRosetta4.conda.linux.CentOS.python37.Release 2020.02+release.22ef835b4a2647af94fcd6421a85720f07eddf12 2020-01-05T17:31:56] retrieved from: http://www.pyrosetta.org (C) Copyright Rosetta Commons Member Institutions. Created in JHU by Sergey Lyskov and PyRosetta Team.
If you are running this example on a high-performance computing (HPC) cluster with SLURM scheduling, use the SLURMCluster
class described below. For more information, visit https://jobqueue.dask.org/en/latest/generated/dask_jobqueue.SLURMCluster.html. Note: If you are running this example on a HPC cluster with a job scheduler other than SLURM, dask_jobqueue
also works with other job schedulers: http://jobqueue.dask.org/en/latest/api.html
The SLURMCluster
class in the dask_jobqueue
module is very useful! In this case, we are requesting four workers using cluster.scale(4)
, and specifying each worker to have:
cores=1
processes=1
job_cpu=1
memory="4GB"
queue="short"
walltime="03:00:00"
local_directory
job_extra
optionextra=pyrosetta.distributed.dask.worker_extra(init_flags=flags)
optionif not os.getenv("DEBUG"):
scratch_dir = os.path.join("/net/scratch", os.environ["USER"])
cluster = SLURMCluster(
cores=1,
processes=1,
job_cpu=1,
memory="4GB",
queue="short",
walltime="02:59:00",
local_directory=scratch_dir,
job_extra=["-o {}".format(os.path.join(scratch_dir, "slurm-%j.out"))],
extra=pyrosetta.distributed.dask.worker_extra(init_flags=flags)
)
cluster.scale(4)
client = Client(cluster)
else:
cluster = None
client = None
Note: The actual sbatch script submitted to the Slurm scheduler under the hood was:
if not os.getenv("DEBUG"):
print(cluster.job_script())
#!/usr/bin/env bash #SBATCH -J dask-worker #SBATCH -p short #SBATCH -n 1 #SBATCH --cpus-per-task=1 #SBATCH --mem=4G #SBATCH -t 02:59:00 #SBATCH -o /net/scratch/klimaj/slurm-%j.out JOB_ID=${SLURM_JOB_ID%;*} /home/klimaj/anaconda3/envs/PyRosetta.notebooks/bin/python -m distributed.cli.dask_worker tcp://172.16.131.107:19949 --nthreads 1 --memory-limit 4.00GB --name name --nanny --death-timeout 60 --local-directory /net/scratch/klimaj --preload pyrosetta.distributed.dask.worker ' -out:level 100 -ignore_unrecognized_res 1 -ignore_waters 0 -detect_disulf 0'
Otherwise, if you are running this example locally on your laptop, you can still spawn workers and take advantage of the dask
module:
# cluster = LocalCluster(n_workers=1, threads_per_worker=1)
# client = Client(cluster)
Open the dask
dashboard, which shows diagnostic information about the current state of your cluster and helps track progress, identify performance issues, and debug failures:
client
Client
|
Cluster
|
def inc(x):
return x + 1
def double(x):
return x + 2
def add(x, y):
return x + y
output = []
for x in range(10):
a = inc(x)
b = double(x)
c = add(a, b)
output.append(c)
total = sum(output)
print(total)
120
With a slight modification, we can parallelize it on the HPC cluster using the dask
module
output = []
for x in range(10):
a = dask.delayed(inc)(x)
b = dask.delayed(double)(x)
c = dask.delayed(add)(a, b)
output.append(c)
delayed = dask.delayed(sum)(output)
print(delayed)
Delayed('sum-829ef956-75b9-40dd-8d06-a720d323f1c4')
We used the dask.delayed
function to wrap the function calls that we want to turn into tasks. None of the inc
, double
, add
, or sum
calls have happened yet. Instead, the object total is a Delayed
object that contains a task graph of the entire computation to be executed.
Let's visualize the task graph to see clear opportunities for parallel execution.
if not os.getenv("DEBUG"):
delayed.visualize()
We can now compute this lazy result to execute the graph in parallel:
if not os.getenv("DEBUG"):
total = delayed.compute()
print(total)
120
We can also use dask.delayed
as a python function decorator for identical performance
@dask.delayed
def inc(x):
return x + 1
@dask.delayed
def double(x):
return x + 2
@dask.delayed
def add(x, y):
return x + y
output = []
for x in range(10):
a = inc(x)
b = double(x)
c = add(a, b)
output.append(c)
total = dask.delayed(sum)(output).compute()
print(total)
120
We can also use the dask.array
library, which implements a subset of the NumPy ndarray interface using blocked algorithms, cutting up the large array into many parallelizable small arrays.
See dask.array
documentation: http://docs.dask.org/en/latest/array.html, along with that of dask.bag
, dask.dataframe
, dask.delayed
, Futures
, etc.
if not os.getenv("DEBUG"):
x = da.random.random((10000, 10000, 10), chunks=(1000, 1000, 5))
y = da.random.random((10000, 10000, 10), chunks=(1000, 1000, 5))
z = (da.arcsin(x) + da.arccos(y)).sum(axis=(1, 2))
z.compute()
The dask dashboard allows visualizing parallel computation, including progress bars for tasks. Here is a snapshot of the dask dashboard while executing the previous cell:
Image(filename="inputs/dask_dashboard_example.png")
For more info on interpreting the dask dashboard, see: https://distributed.dask.org/en/latest/web.html
dask.delayed
with PyRosetta¶Let's look at a simple example of sending PyRosetta jobs to the dask-worker
, and the dask-worker
sending the results back to this Jupyter Notebook.
We will use the crystal structure of the de novo mini protein gEHEE_06 from PDB ID 5JG9
@dask.delayed
def mutate(ppose, target, new_res):
import pyrosetta
pose = io.to_pose(ppose)
mutate = pyrosetta.rosetta.protocols.simple_moves.MutateResidue(target=target, new_res=new_res)
mutate.apply(pose)
return io.to_packed(pose)
@dask.delayed
def refine(ppose):
import pyrosetta
pose = io.to_pose(ppose)
scorefxn = pyrosetta.create_score_function("ref2015_cart")
mm = pyrosetta.rosetta.core.kinematics.MoveMap()
mm.set_bb(True)
mm.set_chi(True)
min_mover = pyrosetta.rosetta.protocols.minimization_packing.MinMover()
min_mover.set_movemap(mm)
min_mover.score_function(scorefxn)
min_mover.min_type("lbfgs_armijo_nonmonotone")
min_mover.cartesian(True)
min_mover.tolerance(0.01)
min_mover.max_iter(200)
min_mover.apply(pose)
return io.to_packed(pose)
@dask.delayed
def score(ppose):
import pyrosetta
pose = io.to_pose(ppose)
scorefxn = pyrosetta.create_score_function("ref2015")
total_score = scorefxn(pose)
return pose, total_score
if not os.getenv("DEBUG"):
pose = pyrosetta.io.pose_from_file("inputs/5JG9.clean.pdb")
keep_chA = pyrosetta.rosetta.protocols.grafting.simple_movers.KeepRegionMover(
res_start=str(pose.chain_begin(1)), res_end=str(pose.chain_end(1))
)
keep_chA.apply(pose)
#kwargs = {"extra_options": pyrosetta.distributed._normflags(flags)}
output = []
for target in random.sample(range(1, pose.size() + 1), 10):
if pose.sequence()[target - 1] != "C":
for new_res in ["ALA", "TRP"]:
a = mutate(io.to_packed(pose), target, new_res)
b = refine(a)
c = score(b)
output.append((target, new_res, c[0], c[1]))
delayed_obj = dask.delayed(np.argmin)([x[-1] for x in output])
delayed_obj.visualize()
print(output)
[(24, 'ALA', Delayed('getitem-b11d1a339db967400c91571f44c08a76'), Delayed('getitem-7fa580c2d983f9fd10a538071ed44b61')), (24, 'TRP', Delayed('getitem-b55caad556de4e5c0b0929033852ca12'), Delayed('getitem-a965ab5ace297645356aaea7d2963e96')), (7, 'ALA', Delayed('getitem-8793683d48e3d7612c6ec9e931a53ccd'), Delayed('getitem-0ae1b4369cb92000bf6de6b910806dc5')), (7, 'TRP', Delayed('getitem-2975e00dd50bf05d06dc4393ba9d7fae'), Delayed('getitem-f5992d59a87529f2dec96ca28a8aac76')), (41, 'ALA', Delayed('getitem-5fcd893f41535432f13412653d21aa60'), Delayed('getitem-bbc17cd3c74e45d0a7781f50e73f90ec')), (41, 'TRP', Delayed('getitem-9d975eb7bdf9afce5911cab033207a98'), Delayed('getitem-7fa0047c7821e7059107c7035ea7990f')), (6, 'ALA', Delayed('getitem-d47790dba8d22a36e95b83b2195e89c5'), Delayed('getitem-3fa6b2a5a39372856d718a7e358c89f1')), (6, 'TRP', Delayed('getitem-fc621b75043ac0f709dbe19ea654452b'), Delayed('getitem-0458b9aae28d5854aab0182c570ed498')), (26, 'ALA', Delayed('getitem-13421915ecc7c3050e5b708aea9ad7e0'), Delayed('getitem-f4878511e2af67d4994018a8873e6e91')), (26, 'TRP', Delayed('getitem-aef267c93d8f46a4ef7cfcd0b3873c4b'), Delayed('getitem-b488f81f44089bfe60f34a77e48bf16a')), (17, 'ALA', Delayed('getitem-bd13a54d9a99b05f49a7cfdcb76cbfb5'), Delayed('getitem-8266f794c501eea737200511d076a85e')), (17, 'TRP', Delayed('getitem-97d4dfa8b04c79a932eb689435c83f53'), Delayed('getitem-7fb4d4dc8d0da3f6834bf1712df62143')), (35, 'ALA', Delayed('getitem-04a4482f64e23b102632b25e6dcc6ef1'), Delayed('getitem-2d2a75271eedad7faa8d6582358c9519')), (35, 'TRP', Delayed('getitem-e5c60b239e7cc8e8a57573120ba7eade'), Delayed('getitem-fc5efd391d999f6193593d7e255ff4ec')), (21, 'ALA', Delayed('getitem-77f183c4e33d54cf4602c0a268018412'), Delayed('getitem-27be0efefefaee8f79c6021be6d96f05')), (21, 'TRP', Delayed('getitem-21bdfadda4f376a0a0192e69a013fd7e'), Delayed('getitem-9bc54a391550ab42419e295c039b77fa')), (15, 'ALA', Delayed('getitem-9191b2432a22a6c689fc6839c3a7f201'), Delayed('getitem-4f3122e413b66e7c05dc28619aab5019')), (15, 'TRP', Delayed('getitem-a5596c6e66d50b619ad4fb9d8ed445d9'), Delayed('getitem-ed1e34f655b369fd42279a66ceae236c')), (44, 'ALA', Delayed('getitem-1a2f9dbc00089eea31a9c0446e22ec66'), Delayed('getitem-8a968c5bce68454010810950e463bf3b')), (44, 'TRP', Delayed('getitem-42dfe3b25b4cde932aa29bff995bcad4'), Delayed('getitem-4f908a18966804191b52f125c7add86d'))]
if not os.getenv("DEBUG"):
delayed_result = delayed_obj.persist()
progress(delayed_result)
VBox()
The dask progress bar allows visualizing parallelization directly within the Jupyter notebook. Here is a snapshot of the dask progress bar while executing the previous cell:
Image(filename="inputs/dask_progress_bar_example.png")
if not os.getenv("DEBUG"):
result = delayed_result.compute()
print("The mutation with the lowest energy is residue {0} at position {1}".format(output[result][1], output[result][0]))
The mutation with the lowest energy is residue ALA at position 26
Note: For best practices while using dask.delayed
, see: http://docs.dask.org/en/latest/delayed-best-practices.html