(Better displayed in nbviewer as red warnings in font tag may not be displayed on github)
This notebook runs a set of illustrative examples of causal inference taken from Ferenc Huszar blogpost. For the explanation and interepretation of the results shown below, please refer to the excellent blogpost.
Importing libraries.
import pymc3 as pm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
Setting parameters for running inference algorithms.
n_samples=1000
Defining a couple of support functions for visualization.
def jointplot(x,y,color,title):
g = sns.jointplot(x, y, color=color)
g.annotate(stats.pearsonr)
g.set_axis_labels(xlabel='x', ylabel='y')
g.fig.suptitle(title)
def kdeplot(y,color,title):
g = sns.kdeplot(y, color=color, label=title)
def model1():
with pm.Model():
x = pm.Normal('x', mu=0, sd=1)
y = pm.Deterministic('y',x + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1))
trace = pm.sample(n_samples)
return trace['x'],trace['y']
def model2():
with pm.Model():
y = pm.Deterministic('y', 1 + 2*pm.Normal('n0', mu=0, sd=1))
x = pm.Deterministic('x', (y-1)/4 + np.sqrt(3)*pm.Normal('n1', mu=0, sd=1)/2)
trace = pm.sample(n_samples)
return trace['x'],trace['y']
def model3():
with pm.Model():
z = pm.Normal('z', mu=0, sd=1)
y = pm.Deterministic('y',z + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1))
x = pm.Deterministic('x',z)
trace = pm.sample(n_samples)
return trace['x'],trace['y']
We use PyMC3 to sample and plot the joint distribution $P(X,Y)$ for the three models.
x,y = model1()
jointplot(x, y, color='blue', title='Observational P(X,Y) for model1')
x,y = model2()
jointplot(x, y, color='green', title='Observational P(X,Y) for model2')
x,y = model3()
jointplot(x, y, color='red', title='Observational P(X,Y) for model3')
Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [n0, x] Sampling 2 chains: 100%|██████████| 3000/3000 [00:00<00:00, 3599.48draws/s] /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/mkl_fft/_numpy_fft.py:1044: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result. output = mkl_fft.rfftn_numpy(a, s, axes) /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/scipy/stats/stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result. return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/seaborn/axisgrid.py:1847: UserWarning: JointGrid annotation is deprecated and will be removed in a future release. warnings.warn(UserWarning(msg)) Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [n1, n0] Sampling 2 chains: 100%|██████████| 3000/3000 [00:00<00:00, 3649.44draws/s] /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/mkl_fft/_numpy_fft.py:1044: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result. output = mkl_fft.rfftn_numpy(a, s, axes) /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/scipy/stats/stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result. return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/seaborn/axisgrid.py:1847: UserWarning: JointGrid annotation is deprecated and will be removed in a future release. warnings.warn(UserWarning(msg)) Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [n0, x] Sampling 2 chains: 100%|██████████| 3000/3000 [00:00<00:00, 3443.41draws/s] /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/seaborn/axisgrid.py:1847: UserWarning: JointGrid annotation is deprecated and will be removed in a future release. warnings.warn(UserWarning(msg))
As expected the three models show the same observational joint distribution.
We now analyze the behaviour of these three models under the observation $X=3$. We use again PyMC3 to redefine the models with the conditioning $X=3$. (Notice that the models has been slightly reformulated in order to get rid of the Deterministic object which can not be conditoned in PyMC3).
def model1_observe_X_3():
with pm.Model():
x = pm.Normal('x', mu=0, sd=1,observed=3)
y = pm.Deterministic('y',x + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1))
trace = pm.sample(n_samples)
return trace['y']
def model2_observe_X_3():
with pm.Model():
y = pm.Deterministic('y', 1 + 2*pm.Normal('n0', mu=0, sd=1))
x = pm.Normal('x', mu=(y-1)/4.0, sd=3/4.0, observed=3)
trace = pm.sample(n_samples)
return trace['y']
def model3_observe_X_3():
with pm.Model():
z = pm.Normal('z', mu=0, sd=1,observed=3)
y = pm.Deterministic('y',z + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1))
x = pm.Deterministic('x',z)
trace = pm.sample(n_samples)
return trace['y']
We use again PyMC3 to sample from these models and estimate $P(Y \vert X=3)$ for the three models.
y1 = model1_observe_X_3()
#jointplot(3*np.ones(y1.shape[0]), y1, color='blue', title='Observational P(X,Y | X=3) for model1')
y2 = model2_observe_X_3()
#jointplot(3*np.ones(y2.shape[0]), y2, color='green', title='Observational P(X,Y | X=3) for model2')
y3 = model3_observe_X_3()
#jointplot(3*np.ones(y3.shape[0]), y3, color='red', title='Observational P(X,Y | X=3) for model3')
plt.figure()
plt.title('Observational P(Y | X=3)')
kdeplot(y1, color='blue', title='model1')
kdeplot(y2, color='green', title='model2')
kdeplot(y3, color='red', title='model3')
plt.legend()
Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [n0] Sampling 2 chains: 100%|██████████| 3000/3000 [00:00<00:00, 4272.94draws/s] /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/mkl_fft/_numpy_fft.py:1044: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result. output = mkl_fft.rfftn_numpy(a, s, axes) The acceptance probability does not match the target. It is 0.8849474658323155, but should be close to 0.8. Try to increase the number of tuning steps. Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [n0] Sampling 2 chains: 100%|██████████| 3000/3000 [00:00<00:00, 4360.57draws/s] /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/mkl_fft/_numpy_fft.py:1044: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result. output = mkl_fft.rfftn_numpy(a, s, axes) Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [n0] Sampling 2 chains: 100%|██████████| 3000/3000 [00:00<00:00, 4555.60draws/s] /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/scipy/stats/stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result. return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval
<matplotlib.legend.Legend at 0x7f88a5a4e3c8>
The observational conditional distributions $P(Y \vert X=3)$ of the three models are (approximately) the same. (Need to check model2 which is slighlty shifted...)
We now analyze the behaviour of these three models under the intervention $X=3$. We redefine the models and we force $X=3$ (formally, this amount to redefine the model under mutilation with the structural equation for $X$ fixed to 3).
def model1_do_X_3():
with pm.Model():
x = 3
y = pm.Deterministic('y',x + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1))
trace = pm.sample(n_samples)
return trace['y']
def model2_do_X_3():
with pm.Model():
y = pm.Deterministic('y', 1 + 2*pm.Normal('n0', mu=0, sd=1))
x = 3
trace = pm.sample(n_samples)
return trace['y']
def model3_do_X_3():
with pm.Model():
z = pm.Normal('z', mu=0, sd=1)
y = pm.Deterministic('y',z + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1))
x = 3
trace = pm.sample(n_samples)
return trace['y']
We now sample from the interventional models and estimate $P(Y \vert do(X=3))$.
y1 = model1_do_X_3()
#jointplot(3*np.ones(y1.shape[0]), y1, color='blue', title='Interventional P(X,Y | do(X=3)) for model1')
y2 = model2_do_X_3()
#jointplot(3*np.ones(y2.shape[0]), y2, color='green', title='Interventional P(X,Y | do(X=3)) for model2')
y3 = model3_do_X_3()
#jointplot(3*np.ones(y3.shape[0]), y3, color='red', title='Interventional P(X,Y | do(X=3)) for model3')
plt.figure()
plt.title('Interventional P(Y | do(X=3))')
kdeplot(y1, color='blue', title='model1')
kdeplot(y2, color='green', title='model2')
kdeplot(y3, color='red', title='model3')
plt.legend()
Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [n0] Sampling 2 chains: 100%|██████████| 3000/3000 [00:00<00:00, 4394.38draws/s] /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/mkl_fft/_numpy_fft.py:1044: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result. output = mkl_fft.rfftn_numpy(a, s, axes) Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [n0] Sampling 2 chains: 100%|██████████| 3000/3000 [00:00<00:00, 4364.67draws/s] Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [n0, z] Sampling 2 chains: 100%|██████████| 3000/3000 [00:00<00:00, 3477.02draws/s] /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/scipy/stats/stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result. return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval
<matplotlib.legend.Legend at 0x7f88ad809ef0>
The interventional distributions $P(Y \vert do(X=3))$ are not the same anymore.
We now analyze evaluate the three models under the intervention $X=3$ relying on the observational model transformed via do-calculus. We redefine the models as the models derived by do-calculus and then we observe their behaviour under conditioning $X=3$.
def model1_docalculus_X_3():
with pm.Model():
x = pm.Normal('x', mu=0, sd=1,observed=3)
y = pm.Deterministic('y',x + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1))
trace = pm.sample(n_samples)
return trace['y']
def model2_docalculus_X_3():
with pm.Model():
y = pm.Deterministic('y', 1 + 2*pm.Normal('n0', mu=0, sd=1))
trace = pm.sample(n_samples)
return trace['y']
def model3_docalculus_X_3():
with pm.Model():
z = pm.Normal('z', mu=0, sd=1)
y = pm.Deterministic('y',z + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1))
trace = pm.sample(n_samples)
return trace['y']
We now sample from the models computed via do-calculus and estimate $P(Y \vert X=3)$.
y1 = model1_docalculus_X_3()
#jointplot(3*np.ones(y1.shape[0]), y1, color='blue', title='Interventional P(X,Y | do(X=3)) for model1')
y2 = model2_docalculus_X_3()
#jointplot(3*np.ones(y2.shape[0]), y2, color='green', title='Interventional P(X,Y | do(X=3)) for model2')
y3 = model3_docalculus_X_3()
#jointplot(3*np.ones(y3.shape[0]), y3, color='red', title='Interventional P(X,Y | do(X=3)) for model3')
plt.figure()
plt.title('Interventional P(Y | do(X=3)) evaluated from observational data via do-calculus')
kdeplot(y1, color='blue', title='model1')
kdeplot(y2, color='green', title='model2')
kdeplot(y3, color='red', title='model3')
plt.legend()
Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [n0] Sampling 2 chains: 100%|██████████| 3000/3000 [00:00<00:00, 4655.26draws/s] /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/mkl_fft/_numpy_fft.py:1044: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result. output = mkl_fft.rfftn_numpy(a, s, axes) Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [n0] Sampling 2 chains: 100%|██████████| 3000/3000 [00:00<00:00, 3572.25draws/s] The acceptance probability does not match the target. It is 0.8818434387499936, but should be close to 0.8. Try to increase the number of tuning steps. The acceptance probability does not match the target. It is 0.8814147116657487, but should be close to 0.8. Try to increase the number of tuning steps. Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Multiprocess sampling (2 chains in 2 jobs) NUTS: [n0, z] Sampling 2 chains: 100%|██████████| 3000/3000 [00:00<00:00, 3585.96draws/s] /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/scipy/stats/stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result. return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval
<matplotlib.legend.Legend at 0x7f88a5b31278>
The estimated distributions match the estimation of $P(Y \vert do(X=3))$. In other words, we can estimate the interventional distribution $P(Y \vert do(X=3))$ as the conditional observational distribution $P(Y \vert X=3)$ on the model derived via do-calculus.
We now examine individual counterfactuals. To do this, we sample from the a model, and then we perform the intervention $X=3$ while keeping everything else unmodified (formally, this amount to sampling from the model, keeping the value of the exogenous nodes of the SEM model, performing the intervention, and then computing the counterfactual of interest).
def model1_counterfactual_X_3():
with pm.Model():
x = pm.Normal('x', mu=0, sd=1)
y = pm.Deterministic('y',x + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1))
trace = pm.sample(1,chains=1)
factual_x = trace['x']
factual_y = trace['y']
factual_n0 = trace['n0']
counterfactual_x = 3
counterfactual_y = counterfactual_x + 1 + factual_n0
return factual_x, factual_y, counterfactual_x, counterfactual_y
def model2_counterfactual_X_3():
with pm.Model():
y = pm.Deterministic('y', 1 + 2*pm.Normal('n0', mu=0, sd=1))
x = pm.Deterministic('x', (y-1)/4 + np.sqrt(3)*pm.Normal('n1', mu=0, sd=1)/2)
trace = pm.sample(1,chains=1)
factual_x = trace['x']
factual_y = trace['y']
factual_n0 = trace['n0']
factual_n1 = trace['n1']
counterfactual_y = factual_y
counterfactual_x = 3
return factual_x, factual_y, counterfactual_x, counterfactual_y
def model3_counterfactual_X_3():
with pm.Model():
z = pm.Normal('z', mu=0, sd=1)
y = pm.Deterministic('y',z + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1))
x = pm.Deterministic('x',z)
trace = pm.sample(1,chains=1)
factual_z = trace['x']
factual_x = trace['x']
factual_y = trace['y']
factual_n0 = trace['n0']
counterfactual_y = factual_z + 1 + np.sqrt(3)*factual_n0
counterfactual_x = 3
return factual_x, factual_y, counterfactual_x, counterfactual_y
We now run some simple simulations to evaluate how the value of the variables change.
factual_x, factual_y, counterfactual_x, counterfactual_y = model1_counterfactual_X_3()
print('MODEL1 - Factual (x,y): {0}, {1}'.format(factual_x, factual_y))
print('MODEL1 - Counterfactual (x,y): {0}, {1}'.format(counterfactual_x, counterfactual_y))
factual_x, factual_y, counterfactual_x, counterfactual_y = model1_counterfactual_X_3()
print('MODEL1 - Factual (x,y): {0}, {1}'.format(factual_x, factual_y))
print('MODEL1 - Counterfactual (x,y): {0}, {1}'.format(counterfactual_x, counterfactual_y))
factual_x, factual_y, counterfactual_x, counterfactual_y = model2_counterfactual_X_3()
print('MODEL2 - Factual (x,y): {0}, {1}'.format(factual_x, factual_y))
print('MODEL2 - Counterfactual (x,y): {0}, {1}'.format(counterfactual_x, counterfactual_y))
factual_x, factual_y, counterfactual_x, counterfactual_y = model2_counterfactual_X_3()
print('MODEL2 - Factual (x,y): {0}, {1}'.format(factual_x, factual_y))
print('MODEL2 - Counterfactual (x,y): {0}, {1}'.format(counterfactual_x, counterfactual_y))
factual_x, factual_y, counterfactual_x, counterfactual_y = model3_counterfactual_X_3()
print('MODEL3 - Factual (x,y): {0}, {1}'.format(factual_x, factual_y))
print('MODEL3 - Counterfactual (x,y): {0}, {1}'.format(counterfactual_x, counterfactual_y))
factual_x, factual_y, counterfactual_x, counterfactual_y = model3_counterfactual_X_3()
print('MODEL3 - Factual (x,y): {0}, {1}'.format(factual_x, factual_y))
print('MODEL3 - Counterfactual (x,y): {0}, {1}'.format(counterfactual_x, counterfactual_y))
Only 1 samples in chain. Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Sequential sampling (1 chains in 1 job) NUTS: [n0, x] 100%|██████████| 501/501 [00:00<00:00, 1876.13it/s] /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/pymc3/sampling.py:478: UserWarning: The number of samples is too small to check convergence reliably. warnings.warn("The number of samples is too small to check convergence reliably.") Only 1 samples in chain. Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag...
MODEL1 - Factual (x,y): [0.61062177], [3.01207834] MODEL1 - Counterfactual (x,y): 3, [4.80913133]
Sequential sampling (1 chains in 1 job) NUTS: [n0, x] 100%|██████████| 501/501 [00:00<00:00, 1922.27it/s] /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/pymc3/sampling.py:478: UserWarning: The number of samples is too small to check convergence reliably. warnings.warn("The number of samples is too small to check convergence reliably.") Only 1 samples in chain. Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag...
MODEL1 - Factual (x,y): [1.11526751], [1.73660493] MODEL1 - Counterfactual (x,y): 3, [3.78137905]
Sequential sampling (1 chains in 1 job) NUTS: [n1, n0] 100%|██████████| 501/501 [00:00<00:00, 1948.68it/s] /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/pymc3/sampling.py:478: UserWarning: The number of samples is too small to check convergence reliably. warnings.warn("The number of samples is too small to check convergence reliably.") Only 1 samples in chain. Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag...
MODEL2 - Factual (x,y): [0.35093419], [0.76245645] MODEL2 - Counterfactual (x,y): 3, [0.76245645]
Sequential sampling (1 chains in 1 job) NUTS: [n1, n0] 100%|██████████| 501/501 [00:00<00:00, 1912.74it/s] /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/pymc3/sampling.py:478: UserWarning: The number of samples is too small to check convergence reliably. warnings.warn("The number of samples is too small to check convergence reliably.") Only 1 samples in chain. Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag...
MODEL2 - Factual (x,y): [0.12764888], [-3.18352881] MODEL2 - Counterfactual (x,y): 3, [-3.18352881]
Sequential sampling (1 chains in 1 job) NUTS: [n0, x] 100%|██████████| 501/501 [00:00<00:00, 1712.92it/s] /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/pymc3/sampling.py:478: UserWarning: The number of samples is too small to check convergence reliably. warnings.warn("The number of samples is too small to check convergence reliably.") Only 1 samples in chain. Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag...
MODEL3 - Factual (x,y): [1.12547449], [1.47806851] MODEL3 - Counterfactual (x,y): 3, [1.47806851]
Sequential sampling (1 chains in 1 job) NUTS: [n0, x] 100%|██████████| 501/501 [00:00<00:00, 1681.32it/s]
MODEL3 - Factual (x,y): [-1.16979628], [1.3040578] MODEL3 - Counterfactual (x,y): 3, [1.3040578]
/home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/pymc3/sampling.py:478: UserWarning: The number of samples is too small to check convergence reliably. warnings.warn("The number of samples is too small to check convergence reliably.")
Notice how the counterfactual value of $Y$ changes in model1 wrt to its factual value; this means that, in model1, the value of $Y$ would change if we were to perform the intervention $X=3$ while keeping everything else the same.
Differently the counterfactual value of $Y$ does NOT change in model2 and model3 wrt to its factual value; this means that, in model2 and model3, the value of $Y$ would NOT change if we were to perform the intervention $X=3$ while keeping everything else the same.
These results make sense when we consider that under the intervention $do(X=3)$, the variables $X$ and $Y$ become independent in model2 and model3 (formally, this can be seen by performing the mutilation of the SEM graph). It is not surprising then, that if we intervene on $X$, the value of $Y$ simply remains the same (because of its independence from $X$).
def model1_counterfactual_X_3():
with pm.Model():
x = pm.Normal('x', mu=0, sd=1)
y = pm.Deterministic('y',x + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1))
trace = pm.sample(n_samples*2,chains=1)
factual_x = trace['x']
factual_y = trace['y']
factual_n0 = trace['n0']
counterfactual_x = 3*np.ones(n_samples*2)
counterfactual_y = counterfactual_x + 1 + factual_n0
return factual_x, factual_y, counterfactual_x, counterfactual_y
def model2_counterfactual_X_3():
with pm.Model():
y = pm.Deterministic('y', 1 + 2*pm.Normal('n0', mu=0, sd=1))
x = pm.Deterministic('x', (y-1)/4 + np.sqrt(3)*pm.Normal('n1', mu=0, sd=1)/2)
trace = pm.sample(n_samples*2,chains=1)
factual_x = trace['x']
factual_y = trace['y']
factual_n0 = trace['n0']
factual_n1 = trace['n1']
counterfactual_y = factual_y
counterfactual_x = 3*np.ones(n_samples*2)
return factual_x, factual_y, counterfactual_x, counterfactual_y
def model3_counterfactual_X_3():
with pm.Model():
z = pm.Normal('z', mu=0, sd=1)
y = pm.Deterministic('y',z + 1 + np.sqrt(3)*pm.Normal('n0', mu=0, sd=1))
x = pm.Deterministic('x',z)
trace = pm.sample(n_samples*2,chains=1,verbose=-1)
factual_z = trace['x']
factual_x = trace['x']
factual_y = trace['y']
factual_n0 = trace['n0']
counterfactual_y = factual_z + 1 + np.sqrt(3)*factual_n0
counterfactual_x = 3*np.ones(n_samples*2)
return factual_x, factual_y, counterfactual_x, counterfactual_y
We sample and evaluate the observational distribution $P(Y\vert X)$ and the related counterfactual distribution $P(Y^*\vert X^*=3)$ under the intervention $do(X=3)$.
factual_x1, factual_y1, counterfactual_x1, counterfactual_y1 = model1_counterfactual_X_3()
factual_x2, factual_y2, counterfactual_x2, counterfactual_y2 = model2_counterfactual_X_3()
factual_x3, factual_y3, counterfactual_x3, counterfactual_y3 = model3_counterfactual_X_3()
plt.figure()
plt.title('Observational P(Y|X)')
kdeplot(factual_y1, color='blue', title='model1')
kdeplot(factual_y2, color='green', title='model2')
kdeplot(factual_y3, color='red', title='model3')
plt.legend()
plt.figure()
plt.title('Counterfactual P(Y* | X*=3)')
kdeplot(counterfactual_y1, color='blue', title='model1')
kdeplot(counterfactual_y2, color='green', title='model2')
kdeplot(counterfactual_y3, color='red', title='model3')
plt.legend()
Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Sequential sampling (1 chains in 1 job) NUTS: [n0, x] 100%|██████████| 2500/2500 [00:00<00:00, 2620.05it/s] Only one chain was sampled, this makes it impossible to run some convergence checks Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Sequential sampling (1 chains in 1 job) NUTS: [n1, n0] 100%|██████████| 2500/2500 [00:00<00:00, 2631.60it/s] Only one chain was sampled, this makes it impossible to run some convergence checks Auto-assigning NUTS sampler... Initializing NUTS using jitter+adapt_diag... Sequential sampling (1 chains in 1 job) NUTS: [n0, x] 100%|██████████| 2500/2500 [00:00<00:00, 2561.84it/s] Only one chain was sampled, this makes it impossible to run some convergence checks /home/fmzennaro/miniconda2_1/envs/pymc3/lib/python3.7/site-packages/scipy/stats/stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result. return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval
<matplotlib.legend.Legend at 0x7f88a571e898>
Consistently with the result before, the marginal distribution of $Y$ does not change in the observational and in the counterfactual model for model2 and model3, due to the fact that under intervention $X$ and $Y$ are independent.
Instead, for model1, we actually register a different between the observational and in the counterfactual model, due to the fact that intervening on $X$ does affect the outcome of $Y$.