Parallel GST using MPI

The purpose of this tutorial is to demonstrate how to compute GST estimates in parallel (using multiple CPUs or "processors"). The core PyGSTi computational routines are written to take advantage of multiple processors via the MPI communication framework, and so one must have a version of MPI and the mpi4py python package installed in order use run pyGSTi calculations in parallel.

Since mpi4py doesn't play nicely with Jupyter notebooks, this tutorial is a bit more clunky than the others. In it, we will create a standalone Python script that imports mpi4py and execute it.

We will use as an example the same "standard" single-qubit gate set of the first tutorial. We'll first create a dataset, and then a script to be run in parallel which loads the data. The creation of a simulated data is performed in the same way as the first tutorial. Since random numbers are generated and used as simulated counts within the call to generate_fake_data, it is important that this is not done in a parallel environment, or different CPUs may get different data sets. (This isn't an issue in the typical situation when the data is obtained experimentally.)

In [1]:
#Import pyGSTi and the "stardard 1-qubit quantities for a gateset with X(pi/2), Y(pi/2), and idle gates"
import pygsti
from pygsti.construction import std1Q_XYI

#Create a data set
gs_target = std1Q_XYI.gs_target
fiducials = std1Q_XYI.fiducials
germs = std1Q_XYI.germs
maxLengths = [1,2,4,8,16,32]

gs_datagen = gs_target.depolarize(gate_noise=0.1, spam_noise=0.001)
listOfExperiments = pygsti.construction.make_lsgst_experiment_list(gs_target.gates.keys(), fiducials, fiducials, germs, maxLengths)
ds = pygsti.construction.generate_fake_data(gs_datagen, listOfExperiments, nSamples=1000,
                                            sampleError="binomial", seed=1234)
pygsti.io.write_dataset("example_files/mpi_example_dataset.txt", ds)

Next, we'll write a Python script that will load in the just-created DataSet, run GST on it, and write the output to a file. The only major difference between the contents of this script and previous examples is that the script imports mpi4py and passes a MPI comm object (comm) to the do_long_sequence_gst function. Since parallel computing is best used for computationaly intensive GST calculations, we also demonstrate how to set a per-processor memory limit to tell pyGSTi to partition its computations so as to not exceed this memory usage. Lastly, note the use of the gaugeOptParams argument of do_long_sequence_gst, which can be used to weight different gate set members differently during gauge optimization.

In [2]:
mpiScript = """
import time
import pygsti
from pygsti.construction import std1Q_XYI

#get MPI comm
from mpi4py import MPI
comm = MPI.COMM_WORLD

print("Rank %d started" % comm.Get_rank())

#define target gateset, fiducials, and germs as before
gs_target = std1Q_XYI.gs_target
fiducials = std1Q_XYI.fiducials
germs = std1Q_XYI.germs
maxLengths = [1,2,4,8,16,32]

#tell gauge optimization to weight the gate matrix
# elements 100x more heavily than the SPAM vector elements, and
# to specifically weight the Gx gate twice as heavily as the other
# gates.
goParams = {'itemWeights':{'spam': 0.01, 'gates': 1.0, 'Gx': 2.0} }

#Specify a per-core memory limit (useful for larger GST calculations)
memLim = 2.1*(1024)**3  # 2.1 GB

#Perform TP-constrained GST
gs_target.set_all_parameterizations("TP")
    
#load the dataset
ds = pygsti.io.load_dataset("example_files/mpi_example_dataset.txt")

start = time.time()
results = pygsti.do_long_sequence_gst(ds, gs_target, fiducials, fiducials,
                                      germs, maxLengths,memLimit=memLim,
                                      gaugeOptParams=goParams, comm=comm,
                                      verbosity=2)
end = time.time()
print("Rank %d finished in %.1fs" % (comm.Get_rank(), end-start))
if comm.Get_rank() == 0:
    import pickle
    pickle.dump(results, open("example_files/mpi_example_results.pkl","wb"))
"""
with open("example_files/mpi_example_script.py","w") as f:
    f.write(mpiScript)

Next, we run the script with 3 processors using mpiexec. The mpiexec executable should have been installed with your MPI distribution -- if it doesn't exist, try replacing mpiexec with mpirun.

In [3]:
! mpiexec -n 3 python3 "example_files/mpi_example_script.py"
Rank 1 started
Rank 2 started
Rank 0 started
--- Gate Sequence Creation ---
   1702 sequences created
   Dataset has 1702 entries: 1702 utilized, 0 requested sequences were missing
--- LGST ---
  Singular values of I_tilde (truncating to first 4 of 6) = 
  4.245058951635965
  1.1585845663892917
  0.9677945789630839
  0.9223871580426515
  0.07048465733093594
  0.0213024962973271
  
  Singular values of target I_tilde (truncating to first 4 of 6) = 
  4.242640687119285
  1.4142135623730954
  1.4142135623730947
  1.4142135623730945
  3.1723744950054595e-16
  1.0852733691121267e-16
  
--- Iterative MLGST: Iter 1 of 6  92 gate strings ---: 
  --- Minimum Chi^2 GST ---
  Memory limit = 2.10GB
  Cur, Persist, Gather = 0.11, 0.00, 0.21 GB
  Finding num_nongauge_params is too expensive: using total params.
  Sum of Chi^2 = 86.4516 (92 data params - 43 model params = expected mean of 49; p-value = 0.000769292)
  Completed in 0.3s
  2*Delta(log(L)) = 86.7954
  Iteration 1 took 0.4s
  
--- Iterative MLGST: Iter 2 of 6  168 gate strings ---: 
  --- Minimum Chi^2 GST ---
  Memory limit = 2.10GB
  Cur, Persist, Gather = 0.11, 0.00, 0.21 GB
  Finding num_nongauge_params is too expensive: using total params.
  Sum of Chi^2 = 167.265 (168 data params - 43 model params = expected mean of 125; p-value = 0.00692365)
  Completed in 0.3s
  2*Delta(log(L)) = 167.759
  Iteration 2 took 0.5s
  
--- Iterative MLGST: Iter 3 of 6  450 gate strings ---: 
  --- Minimum Chi^2 GST ---
  Memory limit = 2.10GB
  Cur, Persist, Gather = 0.11, 0.00, 0.21 GB
  Finding num_nongauge_params is too expensive: using total params.
  Sum of Chi^2 = 477.858 (450 data params - 43 model params = expected mean of 407; p-value = 0.00876085)
  Completed in 0.6s
  2*Delta(log(L)) = 478.93
  Iteration 3 took 1.2s
  
--- Iterative MLGST: Iter 4 of 6  862 gate strings ---: 
  --- Minimum Chi^2 GST ---
  Memory limit = 2.10GB
  Cur, Persist, Gather = 0.11, 0.00, 0.21 GB
  Finding num_nongauge_params is too expensive: using total params.
  Sum of Chi^2 = 890.263 (862 data params - 43 model params = expected mean of 819; p-value = 0.0419469)
  Completed in 1.1s
  2*Delta(log(L)) = 891.425
  Iteration 4 took 2.0s
  
--- Iterative MLGST: Iter 5 of 6  1282 gate strings ---: 
  --- Minimum Chi^2 GST ---
  Memory limit = 2.10GB
  Cur, Persist, Gather = 0.12, 0.00, 0.21 GB
  Finding num_nongauge_params is too expensive: using total params.
  Sum of Chi^2 = 1343.35 (1282 data params - 43 model params = expected mean of 1239; p-value = 0.0200193)
  Completed in 1.5s
  2*Delta(log(L)) = 1344.71
  Iteration 5 took 2.9s
  
--- Iterative MLGST: Iter 6 of 6  1702 gate strings ---: 
  --- Minimum Chi^2 GST ---
  Memory limit = 2.10GB
  Cur, Persist, Gather = 0.13, 0.00, 0.21 GB
  Finding num_nongauge_params is too expensive: using total params.
  Sum of Chi^2 = 1791.49 (1702 data params - 43 model params = expected mean of 1659; p-value = 0.0121291)
  Completed in 1.9s
  2*Delta(log(L)) = 1793.12
  Iteration 6 took 3.8s
  
  Switching to ML objective (last iteration)
  --- MLGST ---
  Memory: limit = 2.10GB(cur, persist, gthr = 0.13, 0.00, 0.21 GB)
  Finding num_nongauge_params is too expensive: using total params.
    Maximum log(L) = 896.511 below upper bound of -2.84686e+06
      2*Delta(log(L)) = 1793.02 (1702 data params - 43 model params = expected mean of 1659; p-value = 0.011354)
    Completed in 4.4s
  2*Delta(log(L)) = 1793.02
  Final MLGST took 4.4s
  
Iterative MLGST Total Time: 15.2s
  -- Adding Gauge Optimized (go0) --
--- Re-optimizing logl after robust data scaling ---
  --- MLGST ---
  Memory: limit = 2.10GB(cur, persist, gthr = 0.13, 0.00, 0.21 GB)
  Finding num_nongauge_params is too expensive: using total params.
    Maximum log(L) = 847.34 below upper bound of -2.84686e+06
      2*Delta(log(L)) = 1694.68 (1702 data params - 43 model params = expected mean of 1659; p-value = 0.265464)
    Completed in 4.5s
Rank 1 finished in 41.4s
Rank 2 finished in 41.5s
  -- Adding Gauge Optimized (go0) --
Rank 0 finished in 41.5s

Notice in the above that output within do_long_sequence_gst is not duplicated (only the first processor outputs to stdout) so that the output looks identical to running on a single processor. Finally, we just need to read the pickled Results object from file and proceed with any post-processing analysis. In this case, we'll just create a report.

In [4]:
import pickle
results = pickle.load(open("example_files/mpi_example_results.pkl","rb"))
pygsti.report.create_standard_report(results, "example_files/mpi_example_brief",
                                    title="MPI Example Report", verbosity=2, auto_open=True)
*** Creating workspace ***
*** Generating switchboard ***
Found standard clifford compilation from std1Q_XYI
*** Generating tables ***
  targetSpamBriefTable                          took 0.036339 seconds
  targetGatesBoxTable                           took 0.042545 seconds
  datasetOverviewTable                          took 0.074418 seconds
  bestGatesetSpamParametersTable                took 0.000744 seconds
  bestGatesetSpamBriefTable                     took 0.044855 seconds
  bestGatesetSpamVsTargetTable                  took 0.10052 seconds
  bestGatesetGaugeOptParamsTable                took 0.000615 seconds
  bestGatesetGatesBoxTable                      took 0.043628 seconds
  bestGatesetChoiEvalTable                      took 0.064883 seconds
  bestGatesetDecompTable                        took 0.053313 seconds
  bestGatesetEvalTable                          took 0.004118 seconds
  bestGermsEvalTable                            took 0.02304 seconds
  bestGatesetVsTargetTable                      took 1.045753 seconds
  bestGatesVsTargetTable_gv                     took 0.19339 seconds
  bestGatesVsTargetTable_gvgerms                took 0.069318 seconds
  bestGatesVsTargetTable_gi                     took 0.012074 seconds
  bestGatesVsTargetTable_gigerms                took 0.027005 seconds
  bestGatesVsTargetTable_sum                    took 0.18262 seconds
  bestGatesetErrGenBoxTable                     took 0.2096 seconds
  metadataTable                                 took 0.098872 seconds
  stdoutBlock                                   took 0.001928 seconds
  profilerTable                                 took 0.00188 seconds
  softwareEnvTable                              took 0.068032 seconds
  exampleTable                                  took 0.017096 seconds
  singleMetricTable_gv                          took 0.174502 seconds
  singleMetricTable_gi                          took 0.015208 seconds
  fiducialListTable                             took 0.001194 seconds
  prepStrListTable                              took 0.000576 seconds
  effectStrListTable                            took 0.000275 seconds
  colorBoxPlotKeyPlot                           took 0.019118 seconds
  germList2ColTable                             took 0.000623 seconds
  progressTable                                 took 5.78273 seconds
*** Generating plots ***
  gramBarPlot                                   took 0.052711 seconds
  progressBarPlot                               took 0.381747 seconds
  progressBarPlot_sum                           took 0.000671 seconds
  finalFitComparePlot                           took 0.134434 seconds
  bestEstimateColorBoxPlot                      took 21.550004 seconds
  bestEstimateTVDColorBoxPlot                   took 18.409911 seconds
  bestEstimateColorScatterPlot                  took 21.664449 seconds
  bestEstimateColorHistogram                    took 19.050734 seconds
  progressTable_scl                             took 4.790609 seconds
  progressBarPlot_scl                           took 0.295817 seconds
  bestEstimateColorBoxPlot_scl                  took 18.492381 seconds
  bestEstimateColorScatterPlot_scl              took 21.518567 seconds
  bestEstimateColorHistogram_scl                took 18.695508 seconds
  dataScalingColorBoxPlot                       took 0.143529 seconds
*** Merging into template file ***
  Rendering bestGatesetChoiEvalTable            took 0.033745 seconds
  Rendering stdoutBlock                         took 0.001235 seconds
  Rendering dataScalingColorBoxPlot             took 0.019759 seconds
  Rendering bestGatesetGatesBoxTable            took 0.04592 seconds
  Rendering targetSpamBriefTable                took 0.022889 seconds
  Rendering bestGatesetVsTargetTable            took 0.001444 seconds
  Rendering singleMetricTable_gi                took 0.007747 seconds
  Rendering profilerTable                       took 0.002771 seconds
  Rendering softwareEnvTable                    took 0.00474 seconds
  Rendering targetGatesBoxTable                 took 0.02136 seconds
  Rendering bestEstimateColorHistogram          took 0.048994 seconds
  Rendering bestEstimateColorBoxPlot_scl        took 0.073053 seconds
  Rendering metricSwitchboard_gv                took 0.000124 seconds
  Rendering bestGatesetDecompTable              took 0.02705 seconds
  Rendering germList2ColTable                   took 0.003683 seconds
  Rendering bestGatesetSpamVsTargetTable        took 0.003372 seconds
  Rendering fiducialListTable                   took 0.003846 seconds
  Rendering exampleTable                        took 0.009154 seconds
  Rendering bestGatesetErrGenBoxTable           took 0.088269 seconds
  Rendering metricSwitchboard_gi                took 9.4e-05 seconds
  Rendering metadataTable                       took 0.005963 seconds
  Rendering bestEstimateColorHistogram_scl      took 0.045217 seconds
  Rendering bestGatesVsTargetTable_gigerms      took 0.005504 seconds
  Rendering singleMetricTable_gv                took 0.008137 seconds
  Rendering progressBarPlot_sum                 took 0.003947 seconds
  Rendering bestGatesVsTargetTable_gvgerms      took 0.007567 seconds
  Rendering effectStrListTable                  took 0.002663 seconds
  Rendering maxLSwitchboard1                    took 0.000205 seconds
  Rendering bestEstimateColorBoxPlot            took 0.0746 seconds
  Rendering progressTable_scl                   took 0.008232 seconds
  Rendering gramBarPlot                         took 0.00424 seconds
  Rendering bestGatesetSpamParametersTable      took 0.002144 seconds
  Rendering colorBoxPlotKeyPlot                 took 0.012292 seconds
  Rendering bestEstimateColorScatterPlot        took 0.058579 seconds
  Rendering bestGermsEvalTable                  took 0.083469 seconds
  Rendering finalFitComparePlot                 took 0.004606 seconds
  Rendering bestGatesVsTargetTable_sum          took 0.004716 seconds
  Rendering datasetOverviewTable                took 0.001223 seconds
  Rendering bestGatesVsTargetTable_gv           took 0.004931 seconds
  Rendering progressTable                       took 0.008913 seconds
  Rendering progressBarPlot_scl                 took 0.003431 seconds
  Rendering progressBarPlot                     took 0.00421 seconds
  Rendering bestGatesVsTargetTable_gi           took 0.00529 seconds
  Rendering prepStrListTable                    took 0.003234 seconds
  Rendering topSwitchboard                      took 0.000173 seconds
  Rendering bestGatesetSpamBriefTable           took 0.054296 seconds
  Rendering bestEstimateTVDColorBoxPlot         took 0.076042 seconds
  Rendering bestGatesetGaugeOptParamsTable      took 0.001321 seconds
  Rendering bestEstimateColorScatterPlot_scl    took 0.069485 seconds
  Rendering bestGatesetEvalTable                took 0.027865 seconds
Output written to example_files/mpi_example_brief directory
Opening example_files/mpi_example_brief/main.html...
*** Report Generation Complete!  Total time 155.457s ***
Out[4]:
<pygsti.report.workspace.Workspace at 0x10bec69b0>
In [ ]: