In [1]:
import decimal

import matplotlib.pyplot as plt
import numpy
import pandas as pd
import scipy.sparse
import scipy.special
import scipy.stats
import tqdm
In [2]:
import hetmech.hetmat
import hetmech.degree_group
import hetmech.degree_weight
import hetmech.pipeline
In [3]:
hetmat = hetmech.hetmat.HetMat('../data/hetionet-v1.0.hetmat/')
In [4]:
metapaths = ['DaGbC', 'SpDpS', 'SEcCrCtD',]
In [5]:
# Used in the next cell
def matrix_to_dgp(matrix, dwpc_scaler, source_degree_to_ind, target_degree_to_ind):
    rows = dict()
    if scipy.sparse.issparse(matrix):
        matrix = matrix.toarray()
    
    for source_degree, row_inds in source_degree_to_ind.items():
        row_matrix = matrix[row_inds, :]
        for target_degree, col_inds in target_degree_to_ind.items():
            if source_degree == 0 or target_degree == 0:
                continue
            slice_matrix = row_matrix[:, col_inds]
            values = numpy.arcsinh(slice_matrix / dwpc_scaler)
            rows[(source_degree, target_degree)] = values.flatten().tolist()
    return rows
In [6]:
def metapath_to_full_dgp(hetmat, metapath):
    _, _, pc_matrix = hetmech.degree_weight.dwpc(hetmat, metapath, dense_threshold=0.7, damping=0.0)
    _, _, dwpc_matrix = hetmech.degree_weight.dwpc(hetmat, metapath, dense_threshold=0.7, damping=0.5)
#     nnz = dwpc_matrix.nnz if scipy.sparse.issparse(dwpc_matrix) else numpy.count_nonzero(dwpc_matrix)
#     scaler = dwpc_matrix.sum() / nnz
    scaler = dwpc_matrix.mean()

    source_degree_to_ind, target_degree_to_ind = hetmech.degree_group.metapath_to_degree_dicts(hetmat, metapath)

    perm_dgp = None
    for name, permat in tqdm.tqdm(hetmat.permutations.items()):
        _, _, matrix = hetmech.degree_weight.dwpc(permat, metapath, damping=0.5, dense_threshold=0.7)
        degree_to_dgp = matrix_to_dgp(matrix, scaler, source_degree_to_ind, target_degree_to_ind)
        if perm_dgp is None:
            perm_dgp = degree_to_dgp.copy()
        else:
            for degree_combo, dgp_list in perm_dgp.items():
                dgp_list.extend(degree_to_dgp[degree_combo])
    return perm_dgp
In [17]:
class GammaHurdle:
    def __init__(self):
        self.alpha = None
        self.beta = None
        self._gamma_coef = None
        self._p_nnz = None
    
    def fit(self, values):
        values = numpy.array(values)
        dwpc_sum = numpy.sum(values)
        dwpc_sum_sq = numpy.sum([i**2 for i in values])
        nnz = len(values[values > 0])
        mean_nz = dwpc_sum / nnz
        sd_nz = ((dwpc_sum_sq - (dwpc_sum**2) / nnz) / (nnz - 1)) ** 0.5
        self.beta = mean_nz / sd_nz ** 2
        self.alpha = mean_nz * self.beta
        self._gamma_coef = (self.beta ** self.alpha) / scipy.special.gamma(self.alpha)
        self._p_nnz = nnz / len(values)
        
    def pdf(self, x):
        if type(x) == numpy.ndarray:
            return [self.pdf(i) for i in x]
        else:
            return self._p_nnz * self._gamma_coef * (x ** (self.alpha - 1)) * numpy.exp(- self.beta * x)
    
    def cdf(self, x):
        if type(x) == numpy.ndarray:
            return [self.cdf(i) for i in x]
        else:
            return (1 - self._p_nnz) + self._p_nnz * scipy.special.gammainc(self.alpha, self.beta * x)
In [8]:
def check_fit(values):
    g = GammaHurdle()
    g.fit(values)

    x = numpy.linspace(0, 10, 100)
    y = [g.pdf(i) for i in x]

    values = numpy.array(values)
    values_nz = values[values > 0]
    plt.figure()
    plt.plot(x, y, linewidth=2, label='gamma fit')
    bars, _, _ = plt.hist(values, density=True, label='true dist')
    plt.legend()
    plt.ylim((0, 1.1 * max(bars[1:])));
    
    return scipy.stats.kstest(values_nz, 'gamma', args=(g.alpha, 0, 1/g.beta))
In [9]:
for metapath in metapaths:
    perm_dgp = metapath_to_full_dgp(hetmat, metapath)
    for degree_combo in [(1,1), (1,3), (3,3), (3, 10),]:
        dgp_values = perm_dgp[degree_combo]
        values_nz = [i for i in dgp_values if i > 0]
        ks_result = check_fit(dgp_values)

        p_value = decimal.Decimal(ks_result[1])
        plt.title(f'{metapath} - {degree_combo} - p={p_value :.2E} - {len(values_nz)} nonzero values')
100%|██████████| 200/200 [00:30<00:00,  6.53it/s]
100%|██████████| 200/200 [00:12<00:00, 16.05it/s]
100%|██████████| 200/200 [01:01<00:00,  3.26it/s]
In [10]:
values_nz = [i for i in dgp_values if i > 0]

distribution_performance = dict()

for distribution in tqdm.tqdm([
    scipy.stats.alpha,
    scipy.stats.anglit,
    scipy.stats.arcsine,
    scipy.stats.beta,
    scipy.stats.betaprime,
    scipy.stats.bradford,
    scipy.stats.burr,
    scipy.stats.burr12,
    scipy.stats.cauchy,
    scipy.stats.chi,
    scipy.stats.chi2,
    scipy.stats.cosine,
    scipy.stats.dgamma,
    scipy.stats.dweibull,
#     scipy.stats.erlang,
    scipy.stats.expon,
    scipy.stats.exponnorm,
    scipy.stats.exponweib,
    scipy.stats.exponpow,
    scipy.stats.f,
    scipy.stats.fatiguelife,
    scipy.stats.fisk,
    scipy.stats.foldcauchy,
    scipy.stats.foldnorm,
#     scipy.stats.frechet_r,
#     scipy.stats.frechet_l,
    scipy.stats.genlogistic,
    scipy.stats.gennorm,
    scipy.stats.genpareto,
    scipy.stats.genexpon,
    scipy.stats.genextreme,
    scipy.stats.gausshyper,
    scipy.stats.gamma,
    scipy.stats.gengamma,
    scipy.stats.genhalflogistic,
    scipy.stats.gilbrat,
    scipy.stats.gompertz,
    scipy.stats.gumbel_r,
    scipy.stats.gumbel_l,
    scipy.stats.halfcauchy,
    scipy.stats.halflogistic,
    scipy.stats.halfnorm,
    scipy.stats.halfgennorm,
    scipy.stats.hypsecant,
    scipy.stats.invgamma,
    scipy.stats.invgauss,
    scipy.stats.invweibull,
    scipy.stats.johnsonsb,
    scipy.stats.johnsonsu,
    scipy.stats.kappa4,
    scipy.stats.kappa3,
    scipy.stats.ksone,
    scipy.stats.kstwobign,
    scipy.stats.laplace,
    scipy.stats.levy,
    scipy.stats.levy_l,
    scipy.stats.levy_stable,
    scipy.stats.logistic,
    scipy.stats.loggamma,
    scipy.stats.loglaplace,
    scipy.stats.lognorm,
    scipy.stats.lomax,
    scipy.stats.maxwell,
    scipy.stats.mielke,
    scipy.stats.nakagami,
    scipy.stats.ncx2,
    scipy.stats.ncf,
    scipy.stats.nct,
    scipy.stats.norm,
    scipy.stats.pareto,
    scipy.stats.pearson3,
    scipy.stats.powerlaw,
    scipy.stats.powerlognorm,
    scipy.stats.powernorm,
    scipy.stats.rdist,
    scipy.stats.reciprocal,
    scipy.stats.rayleigh,
    scipy.stats.rice,
    scipy.stats.recipinvgauss,
    scipy.stats.semicircular,
    scipy.stats.skewnorm,
    scipy.stats.t,
    scipy.stats.trapz,
    scipy.stats.triang,
    scipy.stats.truncexpon,
    scipy.stats.truncnorm,
    scipy.stats.tukeylambda,
    scipy.stats.uniform,
    scipy.stats.vonmises,
    scipy.stats.vonmises_line,
    scipy.stats.wald,
    scipy.stats.weibull_min,
    scipy.stats.weibull_max,
    scipy.stats.wrapcauchy
]):
    try:
        params = distribution.fit(values_nz)
        ks, p = scipy.stats.kstest(values_nz, distribution.cdf, args=params)
        distribution_performance[distribution] = (ks, p)
    except:
        continue
  0%|          | 0/91 [00:00<?, ?it/s]/home/michael/.conda/envs/hetmech/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:2305: RuntimeWarning: invalid value encountered in double_scalars
  Lhat = muhat - Shat*mu
  2%|▏         | 2/91 [00:00<00:41,  2.16it/s]/home/michael/.conda/envs/hetmech/lib/python3.6/site-packages/scipy/stats/_continuous_distns.py:312: RuntimeWarning: divide by zero encountered in true_divide
  return 1.0/np.pi/np.sqrt(x*(1-x))
  4%|▍         | 4/91 [00:02<00:50,  1.71it/s]/home/michael/.conda/envs/hetmech/lib/python3.6/site-packages/scipy/stats/_continuous_distns.py:589: RuntimeWarning: divide by zero encountered in true_divide
  a/(b-1.0),
/home/michael/.conda/envs/hetmech/lib/python3.6/site-packages/scipy/stats/_continuous_distns.py:593: RuntimeWarning: divide by zero encountered in true_divide
  a*(a+1.0)/((b-2.0)*(b-1.0)),
/home/michael/.conda/envs/hetmech/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:1037: RuntimeWarning: invalid value encountered in subtract
  mu2 = mu2p - mu * mu
 27%|██▋       | 25/91 [00:46<02:02,  1.85s/it]/home/michael/.conda/envs/hetmech/lib/python3.6/site-packages/scipy/stats/_continuous_distns.py:2159: RuntimeWarning: divide by zero encountered in true_divide
  val = val + cnk * (-1) ** ki / (1.0 - c * ki)
 53%|█████▎    | 48/91 [03:12<02:52,  4.01s/it]/home/michael/.conda/envs/hetmech/lib/python3.6/site-packages/scipy/integrate/quadpack.py:385: IntegrationWarning: The maximum number of subdivisions (50) has been achieved.
  If increasing the limit yields no improvement it is advised to analyze 
  the integrand in order to determine the difficulties.  If the position of a 
  local difficulty can be determined (singularity, discontinuity) one will 
  probably gain from splitting up the interval and calling the integrator 
  on the subranges.  Perhaps a special-purpose integrator should be used.
  warnings.warn(msg, IntegrationWarning)
/home/michael/.conda/envs/hetmech/lib/python3.6/site-packages/scipy/stats/_continuous_distns.py:44: RuntimeWarning: floating point number truncated to an integer
  return 1.0 - sc.smirnov(n, x)
 57%|█████▋    | 52/91 [04:08<03:06,  4.77s/it]/home/michael/.conda/envs/hetmech/lib/python3.6/site-packages/scipy/stats/_distn_infrastructure.py:1615: RuntimeWarning: divide by zero encountered in log
  return log(self._pdf(x, *args))
 66%|██████▌   | 60/91 [04:18<02:13,  4.30s/it]/home/michael/.conda/envs/hetmech/lib/python3.6/site-packages/scipy/integrate/quadpack.py:385: IntegrationWarning: Extremely bad integrand behavior occurs at some points of the
  integration interval.
  warnings.warn(msg, IntegrationWarning)
/home/michael/.conda/envs/hetmech/lib/python3.6/site-packages/scipy/integrate/quadpack.py:385: IntegrationWarning: The integral is probably divergent, or slowly convergent.
  warnings.warn(msg, IntegrationWarning)
100%|██████████| 91/91 [14:48<00:00,  9.76s/it]
In [11]:
model_df = pd.DataFrame.from_dict(distribution_performance, orient='index')

model_df.sort_values(by=1, ascending=False).head(5)
Out[11]:
0 1
<scipy.stats._continuous_distns.exponweib_gen object at 0x7fef1d0af5f8> 0.019945 3.357472e-13
<scipy.stats._continuous_distns.gengamma_gen object at 0x7fef1d0def98> 0.020486 6.654501e-14
<scipy.stats._continuous_distns.exponpow_gen object at 0x7fef1d0afc50> 0.021329 4.914416e-15
<scipy.stats._continuous_distns.gompertz_gen object at 0x7fef1d0e44e0> 0.026101 2.640552e-22
<scipy.stats._continuous_distns.gausshyper_gen object at 0x7fef1d06d978> 0.026233 1.586766e-22

The 'best-fitting distribution' according to the above KS-test

In [12]:
metapath = metapaths[2]
degree_combo = (3,10)
dist = scipy.stats.exponweib

perm_dgp = metapath_to_full_dgp(hetmat, metapath)

dgp_values = perm_dgp[degree_combo]
values_nz = [i for i in dgp_values if i > 0]

params = dist.fit(values_nz)

ks_result = scipy.stats.kstest(values_nz, dist.cdf, args=params)
print(ks_result)

x = numpy.linspace(0, 10, 100)
y = [dist.pdf(i, *params) for i in x]

plt.figure()
plt.plot(x, y, linewidth=2, label='exponweib fit')
bars, _, _ = plt.hist(values_nz, density=True, label='true dist')
plt.legend()
plt.ylim((0, 1.1 * max(bars[1:])));

p_value = decimal.Decimal(ks_result[1])
plt.title(f'{metapath} - {degree_combo} - p={p_value :.2E} - {len(values_nz)} nonzero values');
100%|██████████| 200/200 [01:13<00:00,  2.71it/s]
KstestResult(statistic=0.019944602384085824, pvalue=3.3574715198648447e-13)

Does not perform well on other distributions

In [13]:
metapath = metapaths[1]
degree_combo = (3,3)
dist = scipy.stats.exponweib

perm_dgp = metapath_to_full_dgp(hetmat, metapath)

dgp_values = perm_dgp[degree_combo]
values_nz = [i for i in dgp_values if i > 0]

params = dist.fit(values_nz)

ks_result = scipy.stats.kstest(values_nz, dist.cdf, args=params)
print(ks_result)

x = numpy.linspace(0, 10, 100)
y = [dist.pdf(i, *params) for i in x]

plt.figure()
plt.plot(x, y, linewidth=2, label='exponweib fit')
bars, _, _ = plt.hist(values_nz, density=True, label='true dist', bins=30)
plt.legend()
plt.ylim((0, 1.1 * max(bars[1:])));

p_value = decimal.Decimal(ks_result[1])
plt.title(f'{metapath} - {degree_combo} - p={p_value :.2E} - {len(values_nz)} nonzero values');
100%|██████████| 200/200 [00:12<00:00, 15.78it/s]
KstestResult(statistic=0.5494956784587973, pvalue=0.0)

Second best distribution

In [14]:
metapath = metapaths[1]
degree_combo = (3,3)
dist = scipy.stats.gengamma

perm_dgp = metapath_to_full_dgp(hetmat, metapath)

dgp_values = perm_dgp[degree_combo]
values_nz = [i for i in dgp_values if i > 0]

params = dist.fit(values_nz)

ks_result = scipy.stats.kstest(values_nz, dist.cdf, args=params)
print(ks_result)

x = numpy.linspace(0, 10, 100)
y = [dist.pdf(i, *params) for i in x]

plt.figure()
plt.plot(x, y, linewidth=2, label='exponweib fit')
bars, _, _ = plt.hist(values_nz, density=True, label='true dist', bins=30)
plt.legend()
plt.ylim((0, 1.1 * max(bars[1:])));

p_value = decimal.Decimal(ks_result[1])
plt.title(f'{metapath} - {degree_combo} - p={p_value :.2E} - {len(values_nz)} nonzero values');
100%|██████████| 200/200 [00:12<00:00, 15.85it/s]
KstestResult(statistic=0.5599760844619381, pvalue=0.0)
In [15]:
g = GammaHurdle()
g.fit(values_nz)
ks_result = scipy.stats.kstest(values_nz, g.cdf)
ks_result

x = numpy.linspace(0, 10, 100)
y = g.pdf(x)

plt.figure()
plt.plot(x, y, linewidth=2, label='gamma fit')
bars, _, _ = plt.hist(values_nz, density=True, label='true dist', bins=30)
plt.legend()
plt.ylim((0, 1.1 * max(bars[1:])));

p_value = decimal.Decimal(ks_result[1])
plt.title(f'{metapath} - {degree_combo} - p={p_value :.2E} - {len(values_nz)} nonzero values');
In [28]:
g = GammaHurdle()
g.fit(values_nz)
ks_result = scipy.stats.kstest(values_nz, g.cdf)
ks_result

x = numpy.linspace(0, 5, 100)
y = g.cdf(x)

plt.figure()
plt.plot(x, y, linewidth=2, label='gamma fit')
bars, _, _ = plt.hist(values_nz, density=True, label='true dist', bins=30, cumulative=True)
plt.legend()
plt.ylim((0, 1.1 * max(bars[1:])));

p_value = decimal.Decimal(ks_result[1])
plt.title(f'{metapath} - {degree_combo} - p={p_value :.2E} - {len(values_nz)} nonzero values');
In [29]:
g = GammaHurdle()
g.fit(values_nz)
ks_result = scipy.stats.kstest(values_nz, g.cdf)
ks_result

x = numpy.linspace(0, 5, 100)
y = g.cdf(x)

plt.figure()
plt.plot(x, y, linewidth=2, label='gamma fit')
bars, _, _ = plt.hist(values_nz, density=True, label='true dist', bins=30, cumulative=True)
plt.legend()
plt.ylim((0.7, 1.05));
plt.xlim((2, 5))

p_value = decimal.Decimal(ks_result[1])
plt.title(f'{metapath} - {degree_combo} - p={p_value :.2E} - {len(values_nz)} nonzero values');
In [ ]: