%matplotlib inline
import itertools
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml
from prml import bayesnet as bn
np.random.seed(1234)
b = bn.discrete([0.1, 0.9])
f = bn.discrete([0.1, 0.9])
g = bn.discrete([[[0.9, 0.8], [0.8, 0.2]], [[0.1, 0.2], [0.2, 0.8]]], b, f)
print("b:", b)
print("f:", f)
print("g:", g)
b: DiscreteVariable(proba=[0.1 0.9]) f: DiscreteVariable(proba=[0.1 0.9]) g: DiscreteVariable(proba=[0.315 0.685])
g.observe(0)
print("b:", b)
print("f:", f)
print("g:", g)
b: DiscreteVariable(proba=[0.25714286 0.74285714]) f: DiscreteVariable(proba=[0.25714286 0.74285714]) g: DiscreteVariable(observed=[1. 0.])
b.observe(0)
print("b:", b)
print("f:", f)
print("g:", g)
b: DiscreteVariable(observed=[1. 0.]) f: DiscreteVariable(proba=[0.11111111 0.88888889]) g: DiscreteVariable(observed=[1. 0.])
x, _ = fetch_openml("mnist_784", return_X_y=True, as_frame=False)
x = x[0]
binarized_img = (x > 127).astype(np.int).reshape(28, 28)
plt.imshow(binarized_img, cmap="gray")
/var/folders/9s/lky4p_js2czgsr4_5962ffbw0000gn/T/ipykernel_11247/693879585.py:3: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information. Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations binarized_img = (x > 127).astype(np.int).reshape(28, 28)
<matplotlib.image.AxesImage at 0x14909fc40>
indices = np.random.choice(binarized_img.size, size=int(binarized_img.size * 0.1), replace=False)
noisy_img = np.copy(binarized_img)
noisy_img.ravel()[indices] = 1 - noisy_img.ravel()[indices]
plt.imshow(noisy_img, cmap="gray")
<matplotlib.image.AxesImage at 0x14918b370>
markov_random_field = np.array([
[[bn.discrete([0.5, 0.5], name=f"p(z_({i},{j}))") for j in range(28)] for i in range(28)],
[[bn.DiscreteVariable(2) for _ in range(28)] for _ in range(28)]])
a = 0.9
b = 0.9
pa = [[a, 1 - a], [1 - a, a]]
pb = [[b, 1 - b], [1 - b, b]]
for i, j in itertools.product(range(28), range(28)):
bn.discrete(pb, markov_random_field[0, i, j], out=markov_random_field[1, i, j], name=f"p(x_({i},{j})|z_({i},{j}))")
if i != 27:
bn.discrete(pa, out=[markov_random_field[0, i, j], markov_random_field[0, i + 1, j]], name=f"p(z_({i},{j}), z_({i+1},{j}))")
if j != 27:
bn.discrete(pa, out=[markov_random_field[0, i, j], markov_random_field[0, i, j + 1]], name=f"p(z_({i},{j}), z_({i},{j+1}))")
markov_random_field[1, i, j].observe(noisy_img[i, j], proprange=0)
for _ in range(10000):
i, j = np.random.choice(28, 2)
markov_random_field[1, i, j].send_message(proprange=3)
restored_img = np.zeros_like(noisy_img)
for i, j in itertools.product(range(28), range(28)):
restored_img[i, j] = np.argmax(markov_random_field[0, i, j].proba)
plt.imshow(restored_img, cmap="gray")
<matplotlib.image.AxesImage at 0x149605a30>