This notebook contains material from PyRosetta; content is available on Github.

Examples Using the dask Module

We can make use of the dask library to parallelize code

Note: This Jupyter notebook uses parallelization and is not meant to be executed within a Google Colab environment.

Note: This Jupyter notebook requires the PyRosetta distributed layer which is obtained by building PyRosetta with the --serialization flag or installing PyRosetta from the RosettaCommons conda channel

Please see Chapter 16.00 for setup instructions

In [2]:
import dask
import dask.array as da
import graphviz
import logging
logging.basicConfig(level=logging.INFO)
import numpy as np
import os
import pyrosetta
import pyrosetta.distributed
import pyrosetta.distributed.dask
import pyrosetta.distributed.io as io
import random
import sys

from dask.distributed import Client, LocalCluster, progress
from dask_jobqueue import SLURMCluster
from IPython.display import Image

if 'google.colab' in sys.modules:
    print("This Jupyter notebook uses parallelization and is therefore not set up for the Google Colab environment.")
    sys.exit(0)

Initialize PyRosetta within this Jupyter notebook using custom command line PyRosetta flags:

In [3]:
flags = """-out:level 100
-ignore_unrecognized_res   1
     -ignore_waters 0 
  -detect_disulf  0 # Do not automatically detect disulfides
""" # These can be unformatted for user convenience, but no spaces in file paths!
pyrosetta.distributed.init(flags)
INFO:pyrosetta.distributed:maybe_init performing pyrosetta initialization: {'extra_options': '-out:level 100 -ignore_unrecognized_res 1 -ignore_waters 0 -detect_disulf 0', 'silent': True}
INFO:pyrosetta.rosetta:Found rosetta database at: /home/klimaj/anaconda3/envs/PyRosetta.notebooks/lib/python3.7/site-packages/pyrosetta/database; using it....
INFO:pyrosetta.rosetta:PyRosetta-4 2020 [Rosetta PyRosetta4.conda.linux.CentOS.python37.Release 2020.02+release.22ef835b4a2647af94fcd6421a85720f07eddf12 2020-01-05T17:31:56] retrieved from: http://www.pyrosetta.org
(C) Copyright Rosetta Commons Member Institutions. Created in JHU by Sergey Lyskov and PyRosetta Team.

If you are running this example on a high-performance computing (HPC) cluster with SLURM scheduling, use the SLURMCluster class described below. For more information, visit https://jobqueue.dask.org/en/latest/generated/dask_jobqueue.SLURMCluster.html. Note: If you are running this example on a HPC cluster with a job scheduler other than SLURM, dask_jobqueue also works with other job schedulers: http://jobqueue.dask.org/en/latest/api.html

The SLURMCluster class in the dask_jobqueue module is very useful! In this case, we are requesting four workers using cluster.scale(4), and specifying each worker to have:

  • one thread per worker with cores=1
  • one process per worker with processes=1
  • one CPU per task per worker with job_cpu=1
  • a total of 4GB memory per worker with memory="4GB"
  • itself run on the "short" queue/partition on the SLURM scheduler with queue="short"
  • a maximum job walltime of 3 hours using walltime="03:00:00"
  • output dask files directed to local_directory
  • output SLURM log files directed to file path and file name (and any other SLURM commands) with the job_extra option
  • pre-initialization with the same custom command line PyRosetta flags used in this Jupyter notebook, using the extra=pyrosetta.distributed.dask.worker_extra(init_flags=flags) option
In [4]:
if not os.getenv("DEBUG"):
    scratch_dir = os.path.join("/net/scratch", os.environ["USER"])
    cluster = SLURMCluster(
        cores=1,
        processes=1,
        job_cpu=1,
        memory="4GB",
        queue="short",
        walltime="02:59:00",
        local_directory=scratch_dir,
        job_extra=["-o {}".format(os.path.join(scratch_dir, "slurm-%j.out"))],
        extra=pyrosetta.distributed.dask.worker_extra(init_flags=flags)
    )
    cluster.scale(4)
    client = Client(cluster)
else:
    cluster = None
    client = None

Note: The actual sbatch script submitted to the Slurm scheduler under the hood was:

In [5]:
if not os.getenv("DEBUG"):
    print(cluster.job_script())
#!/usr/bin/env bash

#SBATCH -J dask-worker
#SBATCH -p short
#SBATCH -n 1
#SBATCH --cpus-per-task=1
#SBATCH --mem=4G
#SBATCH -t 02:59:00
#SBATCH -o /net/scratch/klimaj/slurm-%j.out

JOB_ID=${SLURM_JOB_ID%;*}

/home/klimaj/anaconda3/envs/PyRosetta.notebooks/bin/python -m distributed.cli.dask_worker tcp://172.16.131.107:19949 --nthreads 1 --memory-limit 4.00GB --name name --nanny --death-timeout 60 --local-directory /net/scratch/klimaj --preload pyrosetta.distributed.dask.worker ' -out:level 100 -ignore_unrecognized_res 1 -ignore_waters 0 -detect_disulf 0'

Otherwise, if you are running this example locally on your laptop, you can still spawn workers and take advantage of the dask module:

In [1]:
# cluster = LocalCluster(n_workers=1, threads_per_worker=1)
# client = Client(cluster)

Open the dask dashboard, which shows diagnostic information about the current state of your cluster and helps track progress, identify performance issues, and debug failures:

In [7]:
client
Out[7]:

Client

Cluster

  • Workers: 4
  • Cores: 4
  • Memory: 16.00 GB

Consider the following example that runs within this Jupyter notebook kernel just fine but could be parallelized:

In [8]:
def inc(x):
    return x + 1

def double(x):
    return x + 2

def add(x, y):
    return x + y
In [9]:
output = []
for x in range(10):
    a = inc(x)
    b = double(x)
    c = add(a, b)
    output.append(c)

total = sum(output)
print(total)
120

With a slight modification, we can parallelize it on the HPC cluster using the dask module

In [10]:
output = []
for x in range(10):
    a = dask.delayed(inc)(x)
    b = dask.delayed(double)(x)
    c = dask.delayed(add)(a, b)
    output.append(c)

delayed = dask.delayed(sum)(output)
print(delayed)
Delayed('sum-829ef956-75b9-40dd-8d06-a720d323f1c4')

We used the dask.delayed function to wrap the function calls that we want to turn into tasks. None of the inc, double, add, or sum calls have happened yet. Instead, the object total is a Delayed object that contains a task graph of the entire computation to be executed.

Let's visualize the task graph to see clear opportunities for parallel execution.

In [11]:
if not os.getenv("DEBUG"):
    delayed.visualize()
Out[11]:

We can now compute this lazy result to execute the graph in parallel:

In [12]:
if not os.getenv("DEBUG"):
    total = delayed.compute()
    print(total)
120

We can also use dask.delayed as a python function decorator for identical performance

In [13]:
@dask.delayed
def inc(x):
    return x + 1

@dask.delayed
def double(x):
    return x + 2

@dask.delayed
def add(x, y):
    return x + y

output = []
for x in range(10):
    a = inc(x)
    b = double(x)
    c = add(a, b)
    output.append(c)

total = dask.delayed(sum)(output).compute()
print(total)
120

We can also use the dask.array library, which implements a subset of the NumPy ndarray interface using blocked algorithms, cutting up the large array into many parallelizable small arrays.

See dask.array documentation: http://docs.dask.org/en/latest/array.html, along with that of dask.bag, dask.dataframe, dask.delayed, Futures, etc.

In [10]:
if not os.getenv("DEBUG"):
    x = da.random.random((10000, 10000, 10), chunks=(1000, 1000, 5))
    y = da.random.random((10000, 10000, 10), chunks=(1000, 1000, 5))
    z = (da.arcsin(x) + da.arccos(y)).sum(axis=(1, 2))
    z.compute()

The dask dashboard allows visualizing parallel computation, including progress bars for tasks. Here is a snapshot of the dask dashboard while executing the previous cell:

In [12]:
Image(filename="inputs/dask_dashboard_example.png") 
Out[12]:

For more info on interpreting the dask dashboard, see: https://distributed.dask.org/en/latest/web.html

Example Using dask.delayed with PyRosetta

Let's look at a simple example of sending PyRosetta jobs to the dask-worker, and the dask-worker sending the results back to this Jupyter Notebook.

We will use the crystal structure of the de novo mini protein gEHEE_06 from PDB ID 5JG9

In [14]:
@dask.delayed
def mutate(ppose, target, new_res):
    import pyrosetta
    pose = io.to_pose(ppose)
    mutate = pyrosetta.rosetta.protocols.simple_moves.MutateResidue(target=target, new_res=new_res)
    mutate.apply(pose)
    return io.to_packed(pose)

@dask.delayed
def refine(ppose):
    import pyrosetta
    pose = io.to_pose(ppose)
    scorefxn = pyrosetta.create_score_function("ref2015_cart")
    mm = pyrosetta.rosetta.core.kinematics.MoveMap()
    mm.set_bb(True)
    mm.set_chi(True)
    min_mover = pyrosetta.rosetta.protocols.minimization_packing.MinMover()
    min_mover.set_movemap(mm)
    min_mover.score_function(scorefxn)
    min_mover.min_type("lbfgs_armijo_nonmonotone")
    min_mover.cartesian(True)
    min_mover.tolerance(0.01)
    min_mover.max_iter(200)
    min_mover.apply(pose)
    return io.to_packed(pose)

@dask.delayed
def score(ppose):
    import pyrosetta
    pose = io.to_pose(ppose)
    scorefxn = pyrosetta.create_score_function("ref2015")
    total_score = scorefxn(pose)
    return pose, total_score


if not os.getenv("DEBUG"):
    pose = pyrosetta.io.pose_from_file("inputs/5JG9.clean.pdb")
    keep_chA = pyrosetta.rosetta.protocols.grafting.simple_movers.KeepRegionMover(
        res_start=str(pose.chain_begin(1)), res_end=str(pose.chain_end(1))
    )
    keep_chA.apply(pose)

    #kwargs = {"extra_options": pyrosetta.distributed._normflags(flags)}

    output = []
    for target in random.sample(range(1, pose.size() + 1), 10):
        if pose.sequence()[target - 1] != "C":
            for new_res in ["ALA", "TRP"]:
                a = mutate(io.to_packed(pose), target, new_res)
                b = refine(a)
                c = score(b)
                output.append((target, new_res, c[0], c[1]))

    delayed_obj = dask.delayed(np.argmin)([x[-1] for x in output])
    delayed_obj.visualize()
Out[14]:
In [15]:
print(output)
[(24, 'ALA', Delayed('getitem-b11d1a339db967400c91571f44c08a76'), Delayed('getitem-7fa580c2d983f9fd10a538071ed44b61')), (24, 'TRP', Delayed('getitem-b55caad556de4e5c0b0929033852ca12'), Delayed('getitem-a965ab5ace297645356aaea7d2963e96')), (7, 'ALA', Delayed('getitem-8793683d48e3d7612c6ec9e931a53ccd'), Delayed('getitem-0ae1b4369cb92000bf6de6b910806dc5')), (7, 'TRP', Delayed('getitem-2975e00dd50bf05d06dc4393ba9d7fae'), Delayed('getitem-f5992d59a87529f2dec96ca28a8aac76')), (41, 'ALA', Delayed('getitem-5fcd893f41535432f13412653d21aa60'), Delayed('getitem-bbc17cd3c74e45d0a7781f50e73f90ec')), (41, 'TRP', Delayed('getitem-9d975eb7bdf9afce5911cab033207a98'), Delayed('getitem-7fa0047c7821e7059107c7035ea7990f')), (6, 'ALA', Delayed('getitem-d47790dba8d22a36e95b83b2195e89c5'), Delayed('getitem-3fa6b2a5a39372856d718a7e358c89f1')), (6, 'TRP', Delayed('getitem-fc621b75043ac0f709dbe19ea654452b'), Delayed('getitem-0458b9aae28d5854aab0182c570ed498')), (26, 'ALA', Delayed('getitem-13421915ecc7c3050e5b708aea9ad7e0'), Delayed('getitem-f4878511e2af67d4994018a8873e6e91')), (26, 'TRP', Delayed('getitem-aef267c93d8f46a4ef7cfcd0b3873c4b'), Delayed('getitem-b488f81f44089bfe60f34a77e48bf16a')), (17, 'ALA', Delayed('getitem-bd13a54d9a99b05f49a7cfdcb76cbfb5'), Delayed('getitem-8266f794c501eea737200511d076a85e')), (17, 'TRP', Delayed('getitem-97d4dfa8b04c79a932eb689435c83f53'), Delayed('getitem-7fb4d4dc8d0da3f6834bf1712df62143')), (35, 'ALA', Delayed('getitem-04a4482f64e23b102632b25e6dcc6ef1'), Delayed('getitem-2d2a75271eedad7faa8d6582358c9519')), (35, 'TRP', Delayed('getitem-e5c60b239e7cc8e8a57573120ba7eade'), Delayed('getitem-fc5efd391d999f6193593d7e255ff4ec')), (21, 'ALA', Delayed('getitem-77f183c4e33d54cf4602c0a268018412'), Delayed('getitem-27be0efefefaee8f79c6021be6d96f05')), (21, 'TRP', Delayed('getitem-21bdfadda4f376a0a0192e69a013fd7e'), Delayed('getitem-9bc54a391550ab42419e295c039b77fa')), (15, 'ALA', Delayed('getitem-9191b2432a22a6c689fc6839c3a7f201'), Delayed('getitem-4f3122e413b66e7c05dc28619aab5019')), (15, 'TRP', Delayed('getitem-a5596c6e66d50b619ad4fb9d8ed445d9'), Delayed('getitem-ed1e34f655b369fd42279a66ceae236c')), (44, 'ALA', Delayed('getitem-1a2f9dbc00089eea31a9c0446e22ec66'), Delayed('getitem-8a968c5bce68454010810950e463bf3b')), (44, 'TRP', Delayed('getitem-42dfe3b25b4cde932aa29bff995bcad4'), Delayed('getitem-4f908a18966804191b52f125c7add86d'))]
In [16]:
if not os.getenv("DEBUG"):
    delayed_result = delayed_obj.persist()
    progress(delayed_result)

The dask progress bar allows visualizing parallelization directly within the Jupyter notebook. Here is a snapshot of the dask progress bar while executing the previous cell:

In [20]:
Image(filename="inputs/dask_progress_bar_example.png") 
Out[20]:
In [17]:
if not os.getenv("DEBUG"):
    result = delayed_result.compute()
    print("The mutation with the lowest energy is residue {0} at position {1}".format(output[result][1], output[result][0]))
The mutation with the lowest energy is residue ALA at position 26

Note: For best practices while using dask.delayed, see: http://docs.dask.org/en/latest/delayed-best-practices.html