!pip3 install plotly
Requirement already satisfied: plotly in /usr/local/lib/python3.6/dist-packages (4.5.0) Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from plotly) (1.14.0) Requirement already satisfied: retrying>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from plotly) (1.3.3)
from random import random
from typing import Union, List
from plotly import express as px
from plotly import graph_objects as go
# The data set we're using (source: https://miabellaai.net/)
data = [
[1.15, 0.59, 4.18],
[-1.23, 1.65, -1],
[-4.77, 4.02, -1.77],
[0.57, 0.13, 6.06],
[3.29, 3.14, 2.43],
[2.22, 3.58, 4.49],
[-3.98, 1.15, -2.64],
[-0.98, 1.01, 1.87],
[-2.26, 1.09, 3.43],
[-0.96, 0.24, 1.03],
[-2.92, 1.62, 1.8],
[3.88, -1.13, 5.59],
[0.01, 2.66, 4.42],
[3.3, 1.04, 3.7],
[0.44, 0.14, 1.2],
[4.7, -0.73, 6.95],
[-0.05, 1.3, 0.93],
[3.74, -1.46, 3.97],
[-3.69, 2.85, -2.07],
[-4.39, 7.78, -2.8],
[2.95, -1.02, 2.7],
[1.19, -0.35, 4.24],
[3.83, -1.72, 3.25],
[-4.57, 2.72, -0.6],
[-2.07, 5.79, 0.4],
[-1.56, 1.34, -0.61],
[0.85, 0.07, 1.06],
[3.13, -0.98, 2.88],
[-2.22, 0.6, 1.53],
[-2.98, 2.43, 2.04],
[2.59, 4.8, 1.8],
[1.43, -0.91, 2.92],
[-3.48, 2.24, 2.44],
[2.69, 2.38, 7.48],
[0.42, 4.33, 4.32],
[1.75, -0.23, 3.57],
[-4.17, 2.25, -0.3],
[1.35, 0.13, 3.63],
[-3.68, 1.77, -1.43],
[-3.34, 4.32, 3.05],
[-0.79, 0.62, 1.33],
[4.56, -1.85, 3.36],
[-4.25, 6.17, 0.95],
[-2.96, 1.8, 4.44],
[3.36, -1.06, 2.76],
[1.13, 1.79, 4.03],
[0.07, 0.72, 3.46],
[3.94, 4.01, 7.62],
[-0.81, 6.04, 0.31],
[2.21, 4.37, 5.33],
[-3.11, 6.65, -0.5],
[3.88, -1.07, 7.86],
[0.82, -0.46, -0.07],
[4.27, -1.21, 3.77],
[-3.98, 8.22, -2.81],
[-0.54, 0.34, 2.92],
[-1.34, 2.23, 3.63],
[-4.96, 2.03, -2.55],
[3.2, -1.22, 3.18],
[-2.17, 5.18, 1.87],
[-4.13, 7.58, -1.77],
[2.82, 3.2, 7.1],
[-1.16, 1.14, 0.71],
[-4.22, 1.29, 1.58],
[-1.21, 0.9, 0.16],
[-2.53, 1.82, -1.66],
[-3.56, 5.63, -2.12],
[3.39, -0.33, 7.96],
[4.2, -0.8, 3.76],
[0.52, 2.22, 0.51],
[3.86, -0.22, 3.88],
[2.05, 5.4, 1.56],
[1.27, 3.06, 1.48],
[4.81, 0.65, 3.43],
[4.58, -0.91, 7.02],
[3.16, -0.23, 4.17],
[2.51, 0.19, 2.9],
[-4.09, 5.52, -2.09],
[2.61, -0.66, 1.98],
[4.86, 1.16, 5.41],
[4.24, 2.87, 5.67],
[-3.27, 3.01, 1.81],
[-2.43, 3.56, 4.22],
[1.34, 0.17, 3.5],
[-0.74, 1.17, 1.41],
[4.38, -2.08, 4.16],
[4.42, -0.21, 4.72],
[4.87, 2.71, 7.01],
[-1.69, 4.08, -0.38],
[0.34, 0.65, 1.18],
[1.4, 4.44, 0.79],
[4.28, 0.77, 7.04],
[1.36, 3.11, 0.87],
[0.42, 5.54, 2.76],
[0.61, 1.6, 2.93],
[-1.12, 2.63, 1.65],
[0.49, 2.54, -0.23],
[-3.19, 6.53, 2.05],
[-2.45, 4.7, 1.29],
[4.07, -1.54, 2.2]
]
# Turn the data into a list of x vectors (one for every pair of x items) and a vector containing all the y items
xs: List[List[float]] = []
ys: List[float] = []
for item in data:
x1: float = item[0]
x2: float = item[1]
y: float = item[2]
xs.append([x1, x2])
ys.append(y)
# A convenience function which creates a scatter plot with an optional hyperplane
def plot(xs: List[List[float]], ys: List[float], ys_pred: Union[List[float], None] = None) -> None:
# Translate our `xs` and `ys` into data Plotly understands
x: List[float] = [item[0] for item in xs] # x1
y: List[float] = [item[1] for item in xs] # x2
z: List[float] = ys
fig = px.scatter_3d(x=x, y=y, z=z, labels={'x': 'x1', 'y': 'x2', 'z': 'y'})
# If present, add the hyperplane
if ys_pred:
fig.add_trace(
go.Scatter3d(
x=x, y=y, z=ys_pred, name='Guess', surfaceaxis=1
)
)
fig.show()
plot(xs, ys)
# The function which predicts a `y` value based on `x` and the `alpha` and `beta` parameters
def predict(alpha: float, beta: List[float], x: List[float]) -> float:
assert len(beta) == len(x)
# Prepare data so that we can easily do a dot product calculation
# Prepend `alpha` to the `beta` vector
beta: List[float] = beta.copy()
beta.insert(0, alpha)
# Prepend a constant (1) to the `x` vector
x: List[float] = x.copy()
x.insert(0, 1)
# Calculate the y value via the dot product (https://en.wikipedia.org/wiki/Dot_product)
return sum([a * b for a, b in zip(x, beta)])
# (5 * 1) + (1 * 3) + (2 * 4) = 16 <-- the 5 and 1 are the prepended `alpha` and constant values
assert predict(5, [1, 2], [3, 4]) == 16
# SSE (sum of squared estimate of errors), the function we use to calculate how "wrong" we are
# "How much do the actual y values (`ys`) differ from our predicted y values (`ys_pred`)?"
def sum_squared_error(ys: List[float], ys_pred: List[float]) -> float:
assert len(ys) == len(ys_pred)
return sum([(y - y_p) ** 2 for y, y_p in zip(ys, ys_pred)])
assert sum_squared_error([1, 2, 3], [4, 5, 6]) == 27
# Find the best fitting hyperplane through the data points via Gradient Descent
alpha: float = random()
beta: List[float] = [random(), random()]
print(f'Starting with "alpha": {alpha}')
print(f'Starting with "beta": {beta}')
epochs: int = 1000
learning_rate: float = 0.00001
for epoch in range(epochs):
# Calculate predictions for `y` values given the current `alpha` and `beta`
ys_pred: List[float] = [predict(alpha, beta, x) for x in xs]
# Calculate and print the error
if epoch % 100 == True:
loss = sum_squared_error(ys, ys_pred)
print(f'Epoch {epoch} --> loss: {loss}')
# Calculate the gradient
x: List[float]
y: List[float]
# Taking the (partial) derivative of SSE with respect to `alpha` results in `2 (y_pred - y)`
grad_alpha: float = sum([2 * (predict(alpha, beta, x) - y) for x, y in zip(xs, ys)])
# Taking the (partial) derivative of SSE with respect to `beta` results in `2 * x (y_pred - y)`
grad_beta: List[float] = list(range(len(beta)))
for x, y in zip(xs, ys):
error: float = (predict(alpha, beta, x) - y)
for i, x in enumerate(x):
grad_beta[i] = 2 * error * x
# Take a small step in the direction of greatest decrease
alpha = alpha + (grad_alpha * -learning_rate)
beta = [b + (gb * -learning_rate) for b, gb in zip(beta, grad_beta)]
print(f'Best estimate for "alpha": {alpha}')
print(f'Best estimate for "beta": {beta}')
Starting with "alpha": 0.11130581731209588 Starting with "beta": [0.9391701681748561, 0.18051073908894166] Epoch 1 --> loss: 634.0034428435199 Epoch 101 --> loss: 540.813687006875 Epoch 201 --> loss: 476.80912166653883 Epoch 301 --> loss: 432.667702845224 Epoch 401 --> loss: 402.139828655792 Epoch 501 --> loss: 381.03068824639536 Epoch 601 --> loss: 366.5184679981156 Epoch 701 --> loss: 356.69807632541597 Epoch 801 --> loss: 350.2761634751978 Epoch 901 --> loss: 346.36753875260837 Best estimate for "alpha": 1.4856456435223253 Best estimate for "beta": [0.7759898691708914, 0.24225463600936092]
plot(xs, ys, ys_pred)