EMD Demo

EnergyFlow website

In this tutorial, we demonstrate how to compute EMD values for particle physics events. The core of the computation is done using the Python Optimal Transport library with EnergyFlow providing a convenient interface to particle physics events. Batching functionality is also provided using the builtin multiprocessing library to distribute computations to worker processes.

Energy Mover's Distance

The Energy Mover's Distance was introduced in 1902.02346 as a metric between particle physics events. Closely related to the Earth Mover's Distance, the EMD solves an optimal transport problem between two distributions of energy (or transverse momentum), and the associated distance is the "work" required to transport supply to demand according to the resulting flow. Mathematically, we have $$\text{EMD}(\mathcal E, \mathcal E') = \min_{\{f_{ij}\ge0\}}\sum_{ij} f_{ij} \frac{\theta_{ij}}{R} + \left|\sum_i E_i - \sum_j E'_j\right|,$$ $$\sum_{j} f_{ij} \le E_i,\,\,\, \sum_i f_{ij} \le E'_j,\,\,\,\sum_{ij}f_{ij}= \min\Big(\sum_iE_i,\,\sum_jE'_j\Big).$$

Imports

In [1]:
import numpy as np
%load_ext wurlitzer
%matplotlib inline
import matplotlib.pyplot as plt

import energyflow as ef

Plot Style

In [2]:
plt.rcParams['figure.figsize'] = (4,4)
plt.rcParams['figure.dpi'] = 120
plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'serif'

Load EnergyFlow Quark/Gluon Jet Samples

In [3]:
# load quark and gluon jets
X, y = ef.qg_jets.load(2000, pad=False)
num = 750

# the jet radius for these jets
R = 0.4

# process jets
Gs, Qs = [], []
for arr,events in [(Gs, X[y==0]), (Qs, X[y==1])]:
    for i,x in enumerate(events):
        if i >= num:
            break

        # ignore padded particles and removed particle id information
        x = x[x[:,0] > 0,:3]

        # center jet according to pt-centroid
        yphi_avg = np.average(x[:,1:3], weights=x[:,0], axis=0)
        x[:,1:3] -= yphi_avg

        # mask out any particles farther than R=0.4 away from center (rare)
        x = x[np.linalg.norm(x[:,1:3], axis=1) <= R]

        # add to list
        arr.append(x)

Event Display with EMD Flow

In [4]:
# choose interesting events
ev0, ev1 = Gs[0], Gs[15]

# calculate the EMD and the optimal transport flow
R = 0.4
emdval, G = ef.emd.emd(ev0, ev1, R=R, return_flow=True)

# plot the two events
colors = ['red', 'blue']
labels = ['Gluon Jet 1', 'Gluon Jet 2']
for i,ev in enumerate([ev0, ev1]):
    pts, ys, phis = ev[:,0], ev[:,1], ev[:,2]
    plt.scatter(ys, phis, marker='o', s=2*pts, color=colors[i], lw=0, zorder=10, label=labels[i])
    
# plot the flow
mx = G.max()
xs, xt = ev0[:,1:3], ev1[:,1:3]
for i in range(xs.shape[0]):
    for j in range(xt.shape[0]):
        if G[i, j] > 0:
            plt.plot([xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]],
                     alpha=G[i, j]/mx, lw=1.25, color='black')

# plot settings
plt.xlim(-R, R); plt.ylim(-R, R)
plt.xlabel('Rapidity'); plt.ylabel('Azimuthal Angle')
plt.xticks(np.linspace(-R, R, 5)); plt.yticks(np.linspace(-R, R, 5))

plt.text(0.6, 0.03, 'EMD: {:.1f} GeV'.format(emdval), fontsize=10, transform=plt.gca().transAxes)
plt.legend(loc=(0.1, 1.0), frameon=False, ncol=2, handletextpad=0)

plt.show()

Intrinsic Dimension of Quark and Gluon Jets

The correlation dimension of a dataset is a type of fractal dimension which quantifies the dimensionality of the space of events at different energy scales $Q$.

It is motivated by the fact that the number of neighbors a point has in a ball of radius $Q$ grows as $Q^\mathrm{dim}$, giving rise to the definition:

$$ \dim (Q) = Q\frac{\partial}{\partial Q} \ln \sum_{i<j} \Theta(\mathrm{EMD}(\mathcal E_i, \mathcal E_j) < Q).$$
In [5]:
# compute pairwise EMDs between all jets (takes about 3 minutes, can change n_jobs if you have more cores)
g_emds = ef.emd.emds(Gs, R=R, norm=True, verbose=1, n_jobs=-1, print_every=25000)
q_emds = ef.emd.emds(Qs, R=R, norm=True, verbose=1, n_jobs=-1, print_every=25000)
PairwiseEMD
  ArrayEvent<8-byte float>
    norm - true

  EuclideanArrayDistance
    R - 0.4
    beta - 1

  NetworkSimplex
    n_iter_max - 100000
    epsilon_large - 2.22045e-12
    epsilon_small - 2.22045e-16

  num_threads - 4
  print_every - 25000
  store_sym_emds_flattened - true
  throw_on_error - true

  Pairwise EMD distance matrix stored internally

Finished preprocessing 750 events in 0.0076s
   25000 / 280875  EMDs computed  -   8.90% completed - 1.488s
   50000 / 280875  EMDs computed  -  17.80% completed - 3.000s
   75000 / 280875  EMDs computed  -  26.70% completed - 4.634s
  100000 / 280875  EMDs computed  -  35.60% completed - 6.266s
  125000 / 280875  EMDs computed  -  44.50% completed - 8.029s
  150000 / 280875  EMDs computed  -  53.40% completed - 9.649s
  175000 / 280875  EMDs computed  -  62.31% completed - 11.322s
  200000 / 280875  EMDs computed  -  71.21% completed - 12.904s
  225000 / 280875  EMDs computed  -  80.11% completed - 14.525s
  250000 / 280875  EMDs computed  -  89.01% completed - 16.075s
  275000 / 280875  EMDs computed  -  97.91% completed - 17.666s
  280875 / 280875  EMDs computed  - 100.00% completed - 19.079s
PairwiseEMD
  ArrayEvent<8-byte float>
    norm - true

  EuclideanArrayDistance
    R - 0.4
    beta - 1

  NetworkSimplex
    n_iter_max - 100000
    epsilon_large - 2.22045e-12
    epsilon_small - 2.22045e-16

  num_threads - 4
  print_every - 25000
  store_sym_emds_flattened - true
  throw_on_error - true

  Pairwise EMD distance matrix stored internally

Finished preprocessing 750 events in 0.0085s
   25000 / 280875  EMDs computed  -   8.90% completed - 0.714s
   50000 / 280875  EMDs computed  -  17.80% completed - 1.399s
   75000 / 280875  EMDs computed  -  26.70% completed - 2.119s
  100000 / 280875  EMDs computed  -  35.60% completed - 2.779s
  125000 / 280875  EMDs computed  -  44.50% completed - 3.479s
  150000 / 280875  EMDs computed  -  53.40% completed - 4.130s
  175000 / 280875  EMDs computed  -  62.31% completed - 4.838s
  200000 / 280875  EMDs computed  -  71.21% completed - 5.470s
  225000 / 280875  EMDs computed  -  80.11% completed - 6.147s
  250000 / 280875  EMDs computed  -  89.01% completed - 6.852s
  275000 / 280875  EMDs computed  -  97.91% completed - 7.559s
  280875 / 280875  EMDs computed  - 100.00% completed - 8.101s
In [6]:
# prepare for histograms
bins = 10**np.linspace(-2, 0, 60)
reg = 10**-30
midbins = (bins[:-1] + bins[1:])/2
dmidbins = np.log(midbins[1:]) - np.log(midbins[:-1]) + reg
midbins2 = (midbins[:-1] + midbins[1:])/2

# compute the correlation dimensions
dims = []
for emd_vals in [q_emds, g_emds]:
    uemds = np.triu(emd_vals)
    counts = np.cumsum(np.histogram(uemds[uemds > 0], bins=bins)[0])
    dims.append((np.log(counts[1:] + reg) - np.log(counts[:-1] + reg))/dmidbins)
In [7]:
# plot the correlation dimensions
plt.plot(midbins2, dims[0], '-', color='blue', label='Quarks')
plt.plot(midbins2, dims[1], '-', color='red', label='Gluons')

# labels
plt.legend(loc='center right', frameon=False)

# plot style
plt.xscale('log')
plt.xlabel('Energy Scale Q/pT'); plt.ylabel('Correlation Dimension')
plt.xlim(0.02, 1); plt.ylim(0, 5)

plt.show()

Using Built-In Correlation Dimension

In [9]:
# create external EMD handlers that will compute the correlation dimensions on the fly
gcorrdim = ef.emd.wasserstein.CorrelationDimension(0.01, 1, 60)
qcorrdim = ef.emd.wasserstein.CorrelationDimension(0.01, 1, 60)

# compute pairwise EMDs between all jets (takes about 3 minutes, can change n_jobs if you have more cores)
ef.emd.emds(Gs, R=R, norm=True, verbose=1, n_jobs=-1, print_every=-10, external_emd_handler=gcorrdim)
ef.emd.emds(Qs, R=R, norm=True, verbose=1, n_jobs=-1, print_every=-10, external_emd_handler=qcorrdim)
PairwiseEMD
  ArrayEvent<8-byte float>
    norm - true

  EuclideanArrayDistance
    R - 0.4
    beta - 1

  NetworkSimplex
    n_iter_max - 100000
    epsilon_large - 2.22045e-12
    epsilon_small - 2.22045e-16

  num_threads - 4
  print_every - auto, 10 total chunks
  store_sym_emds_flattened - true
  throw_on_error - true

  Pairwise EMD distance matrix stored internally

Finished preprocessing 750 events in 0.0084s
   28088 / 280875  EMDs computed  -  10.00% completed - 1.697s
   56176 / 280875  EMDs computed  -  20.00% completed - 3.400s
   84264 / 280875  EMDs computed  -  30.00% completed - 5.245s
  112352 / 280875  EMDs computed  -  40.00% completed - 7.010s
  140440 / 280875  EMDs computed  -  50.00% completed - 8.970s
  168528 / 280875  EMDs computed  -  60.00% completed - 10.727s
  196616 / 280875  EMDs computed  -  70.00% completed - 12.637s
  224704 / 280875  EMDs computed  -  80.00% completed - 14.448s
  252792 / 280875  EMDs computed  -  90.00% completed - 16.191s
PairwiseEMD
  ArrayEvent<8-byte float>
    norm - true

  EuclideanArrayDistance
    R - 0.4
    beta - 1

  NetworkSimplex
    n_iter_max - 100000
    epsilon_large - 2.22045e-12
    epsilon_small - 2.22045e-16

  num_threads - 4
  print_every - auto, 10 total chunks
  store_sym_emds_flattened - true
  throw_on_error - true

  Pairwise EMD distance matrix stored internally
  280875 / 280875  EMDs computed  - 100.00% completed - 17.979s

Finished preprocessing 750 events in 0.0075s
   28088 / 280875  EMDs computed  -  10.00% completed - 0.799s
   56176 / 280875  EMDs computed  -  20.00% completed - 1.569s
   84264 / 280875  EMDs computed  -  30.00% completed - 2.285s
  112352 / 280875  EMDs computed  -  40.00% completed - 3.050s
  140440 / 280875  EMDs computed  -  50.00% completed - 3.824s
  168528 / 280875  EMDs computed  -  60.00% completed - 4.564s
  196616 / 280875  EMDs computed  -  70.00% completed - 5.354s
  224704 / 280875  EMDs computed  -  80.00% completed - 6.103s
  252792 / 280875  EMDs computed  -  90.00% completed - 6.927s
  280875 / 280875  EMDs computed  - 100.00% completed - 7.729s
In [20]:
# plot the correlation dimensions
plt.plot(qcorrdim.corrdim_bins(), qcorrdim.corrdims()[0], '-', color='blue', label='Quarks')
plt.plot(gcorrdim.corrdim_bins(), gcorrdim.corrdims()[0], '-', color='red', label='Gluons')

# labels
plt.legend(loc='center right', frameon=False)

# plot style
plt.xscale('log')
plt.xlabel('Energy Scale Q/pT'); plt.ylabel('Correlation Dimension')
plt.xlim(0.02, 1); plt.ylim(0, 5)

plt.show()
In [ ]: