FM(Factorization Machine)系列主要参考美团的《深入FFM原理与实践》这篇博客内容做介绍并实现,首先介绍一下原理

一.FM原理

大家可能用过sklearn中的这个多项式特征处理函数:sklearn.preprocessing.PolynomialFeatures,它作用是就是将原始特征扩展为多项式特征,比如原始特征为$a,b$,那么它会做如下扩展:

$$ [a,b]\rightarrow [1,a,b,a^2,ab,b^2] $$

而FM的初衷便是对这组新特征做线性建模,一般地,它可以表示如下:

$$ y(x)=w_0+\sum_{i=1}^nw_ix_i+\sum_{i=1}^{n-1}\sum_{j=i+1}^nw_{ij}x_ix_j $$

FM通常不会对平方项建模(比如上面的$a^2,b^2$),这里$n$代表样本的特征数量,$x_i$是第$i$个特征值,$w_0,w_i,w_{ij}$是模型参数,到这里大家可能会有疑惑,我们干嘛不先通过多项式特征处理函数做转换,然后再接着做一个线性回归或者logistic回归之类的不就行了吗?那这个...FM拿它何用?如果按照刚刚的操作其实是有很大问题的,主要有两点问题:

(1)参数爆炸

(2)高维稀疏

第一点比较容易理解,对于$n$个特征,$w_{ij}$将会有$\frac{n(n-1)}{2}$项,形象点说就是平常随意用到的100个特征,扩展后将会有5000个,而参数越多,如果没有足够的数据做训练,模型很容易陷入过拟合,而对于第二点,经常处理离散特征的同学会很容易理解,比如下图 avatar

包含有三个特征(最左侧的是标签),且都是离散特征,而对于离散特征我们经常做的一种操作便是one-hot转换,转换后的结果如下图:

avatar

如果我们在对这些特征做多项式转换,可以发现转后的20多个特征中,仅仅只有3个非零特征,这就意味着绝大部分的$x_ix_j$将会是0,而损失函数关于$w_{ij}$的导数必然会包含有$x_ix_j$这一项,这就意味$w_{ij}$大部分情况下就是个摆设,很难被更新到,而FM便可以解决这两个问题,它假设$w_{ij}$由两个向量的内积生成:

$$ w_{ij}:=<v_i,v_j> $$

这里,$v_i$表示第$i$个特征的隐向量,其向量长度为$k(k<<n)$,通常$k=4$即可,这时FM模型方程如下:

$$ y(x)=w_0+\sum_{i=1}^nw_ix_i+\sum_{i=1}^{n-1}\sum_{j=i+1}^n<v_i,v_j>x_ix_j $$


进一步的,我们可以将其化简为如下形式:

$$ y(x)=w_0+\sum_{i=1}^nw_ix_i+\frac{1}{2}\sum_{f=1}^k((\sum_{i=1}^nv_{i,f}x_i)^2-\sum_{i=1}^nv_{i,f}^2x_i^2) $$

这里,$v_{i,f}$表示向量$v_i$的第$f$个元素,上述化简用到了这样的关系:$ab+ac+bc=\frac{1}{2}((a+b+c)^2-(a^2+b^2+c^2))$,接下来我们可以进一步看看梯度:

$$ \frac{\partial}{\partial\theta}y(x)=\left\{\begin{matrix} 1 &\theta=w_0 \\ x_i &\theta=w_i \\ x_i\sum_{j=1}^nv_{j,f}x_j-v_{i,f}x_i^2 & \theta=v_{i,f} \end{matrix}\right. $$

可以发现前面的两个问题可以被FM解决了,第一个问题,参数量从$n(n-1)/2$降低到了$kn$,第二个高维稀疏导致参数无法被训练的问题,对于$v_{i,f}$只要$x_i$不为0,且所有$x_j,j=1,2,...,n$中有一个不为0,那么梯度$\frac{\partial}{\partial v_{i,f}}y(x)$就不为0,这比$x_ix_j$不为0的条件松了很多

二.代码实现

这里就对FM应用到回归任务做简单实现,更多的功能扩展放到FFM中,下面推导一下损失函数对参数的梯度,假设样本$x$对应的标签为$t$,那么损失函数可以表示如下:

$$ L(\theta)=\frac{1}{2}(y(x)-t)^2 $$

那么:

$$ \frac{\partial L(\theta)}{\partial y(x)}=y(x)-t $$

再根据链式求导有:

$$ \frac{\partial L(\theta)}{\partial \theta}=\frac{\partial L(\theta)}{\partial y(x)}\frac{\partial y(x)}{\partial\theta}\\ =(y(x)-t)\cdot \left\{\begin{matrix} 1 &\theta=w_0 \\ x_i &\theta=w_i \\ x_i\sum_{j=1}^nv_{j,f}x_j-v_{i,f}x_i^2 & \theta=v_{i,f} \end{matrix}\right. $$
In [1]:
"""
FM因子分解机的简单实现,只实现了损失函数为平方损失的回归任务,更多功能扩展请使用后续的FFM
代码封装到ml_models.fm中
"""
import numpy as np


class FM(object):
    def __init__(self, epochs=1, lr=1e-3, adjust_lr=True, batch_size=1, hidden_dim=4, lamb=1e-3, alpha=1e-3,
                 normal=True, solver='adam', rho_1=0.9, rho_2=0.999, early_stopping_rounds=100):
        """

        :param epochs: 迭代轮数
        :param lr: 学习率
        :param adjust_lr:是否根据特征数量再次调整学习率 max(lr,1/n_feature)
        :param batch_size:
        :param hidden_dim:隐变量维度
        :param lamb:l2正则项系数
        :param alpha:l1正则项系数
        :param normal:是否归一化,默认用min-max归一化
        :param solver:优化方式,包括sgd,adam,默认adam
        :param rho_1:adam的rho_1的权重衰减,solver=adam时生效
        :param rho_2:adam的rho_2的权重衰减,solver=adam时生效
        :param early_stopping_rounds:对early_stopping进行支持,使用rmse作为评估指标,默认20
        """
        self.epochs = epochs
        self.lr = lr
        self.adjust_lr = adjust_lr
        self.batch_size = batch_size
        self.hidden_dim = hidden_dim
        self.lamb = lamb
        self.alpha = alpha
        self.solver = solver
        self.rho_1 = rho_1
        self.rho_2 = rho_2
        self.early_stopping_rounds = early_stopping_rounds
        # 初始化参数
        self.w = None  # w_0,w_i
        self.V = None  # v_{i,f}
        # 归一化
        self.normal = normal
        if normal:
            self.xmin = None
            self.xmax = None

    def _y(self, X):
        """
        实现y(x)的功能
        :param X:
        :return:
        """
        # 去掉第一列bias
        X_ = X[:, 1:]
        X_V = X_ @ self.V
        X_V_2 = X_V * X_V
        X_2_V_2 = (X_ * X_) @ (self.V * self.V)
        pol = 0.5 * np.sum(X_V_2 - X_2_V_2, axis=1)
        return X @ self.w.reshape(-1) + pol

    def fit(self, X, y, eval_set=None, show_log=True):
        X_o = X.copy()
        if self.normal:
            self.xmin = X.min(axis=0)
            self.xmax = X.max(axis=0)
            X = (X - self.xmin) / self.xmax
        n_sample, n_feature = X.shape
        x_y = np.c_[np.ones(n_sample), X, y]
        # 记录loss
        train_losses = []
        eval_losses = []
        # 调整一下学习率
        if self.adjust_lr:
            self.lr = max(self.lr, 1 / n_feature)
        # 初始化参数
        self.w = np.random.random((n_feature + 1, 1)) * 1e-3
        self.V = np.random.random((n_feature, self.hidden_dim)) * 1e-3
        if self.solver == 'adam':
            # 缓存梯度一阶,二阶估计
            w_1 = np.zeros_like(self.w)
            V_1 = np.zeros_like(self.V)
            w_2 = np.zeros_like(self.w)
            V_2 = np.zeros_like(self.V)
        # 更新参数
        count = 0
        best_eval_value = np.power(2., 1023)
        eval_count = 0
        for epoch in range(self.epochs):
            np.random.shuffle(x_y)
            for index in range(x_y.shape[0] // self.batch_size):
                count += 1
                batch_x_y = x_y[self.batch_size * index:self.batch_size * (index + 1)]
                batch_x = batch_x_y[:, :-1]
                batch_y = batch_x_y[:, -1:]
                # 计算y(x)-t
                y_x_t = self._y(batch_x).reshape((-1, 1)) - batch_y
                # 更新w
                if self.solver == 'sgd':
                    self.w = self.w - (self.lr * (np.sum(y_x_t * batch_x, axis=0) / self.batch_size).reshape(
                        (-1, 1)) + self.lamb * self.w + self.alpha * np.where(self.w > 0, 1, 0))
                elif self.solver == 'adam':
                    w_reg = self.lamb * self.w + self.alpha * np.where(self.w > 0, 1, 0)
                    w_grad = (np.sum(y_x_t * batch_x, axis=0) / self.batch_size).reshape(
                        (-1, 1)) + w_reg
                    w_1 = self.rho_1 * w_1 + (1 - self.rho_1) * w_grad
                    w_2 = self.rho_2 * w_2 + (1 - self.rho_2) * w_grad * w_grad
                    w_1_ = w_1 / (1 - np.power(self.rho_1, count))
                    w_2_ = w_2 / (1 - np.power(self.rho_2, count))
                    self.w = self.w - (self.lr * w_1_) / (np.sqrt(w_2_) + 1e-8)

                # 更新 V
                batch_x_ = batch_x[:, 1:]
                V_X = batch_x_ @ self.V
                X_2 = batch_x_ * batch_x_
                # 从i,f单个元素逐步更新有点慢
                # for i in range(self.V.shape[0]):
                #     for f in range(self.V.shape[1]):
                #         if self.solver == "sgd":
                #             self.V[i, f] -= self.lr * (
                #                 np.sum(y_x_t.reshape(-1) * (batch_x_[:, i] * V_X[:, f] - self.V[i, f] * X_2[:, i]))
                #                 / self.batch_size + self.lamb * self.V[i, f] + self.alpha * (self.V[i, f] > 0))
                #         elif self.solver == "adam":
                #             v_reg = self.lamb * self.V[i, f] + self.alpha * (self.V[i, f] > 0)
                #             v_grad = np.sum(y_x_t.reshape(-1) * (
                #                 batch_x_[:, i] * V_X[:, f] - self.V[i, f] * X_2[:, i])) / self.batch_size + v_reg
                #             V_1[i, f] = self.rho_1 * V_1[i, f] + (1 - self.rho_1) * v_grad
                #             V_2[i, f] = self.rho_2 * V_2[i, f] + (1 - self.rho_2) * v_grad * v_grad
                #             v_1_ = V_1[i, f] / (1 - np.power(self.rho_1, count))
                #             v_2_ = V_2[i, f] / (1 - np.power(self.rho_2, count))
                #             self.V[i, f] = self.V[i, f] - (self.lr * v_1_) / (np.sqrt(v_2_) + 1e-8)

                # 从隐变量的维度进行更新
                for f in range(self.V.shape[1]):
                    if self.solver == 'sgd':
                        V_grad = np.sum(
                            y_x_t.reshape((-1, 1)) * (batch_x_ * V_X[:, f].reshape((-1, 1)) - X_2 * self.V[:, f]),
                            axis=0)
                        self.V[:, f] = self.V[:, f] - self.lr * V_grad - self.lamb * self.V[:, f] - self.alpha * (
                            self.V[:, f] > 0)
                    elif self.solver == 'adam':
                        V_reg = self.lamb * self.V[:, f] + self.alpha * (self.V[:, f] > 0)
                        V_grad = np.sum(
                            y_x_t.reshape((-1, 1)) * (batch_x_ * V_X[:, f].reshape((-1, 1)) - X_2 * self.V[:, f]),
                            axis=0) + V_reg
                        V_1[:, f] = self.rho_1 * V_1[:, f] + (1 - self.rho_1) * V_grad
                        V_2[:, f] = self.rho_2 * V_2[:, f] + (1 - self.rho_2) * V_grad * V_grad
                        V_1_ = V_1[:, f] / (1 - np.power(self.rho_1, count))
                        V_2_ = V_2[:, f] / (1 - np.power(self.rho_2, count))
                        self.V[:, f] = self.V[:, f] - (self.lr * V_1_) / (np.sqrt(V_2_) + 1e-8)

                # 计算eval loss
                eval_loss = None
                if eval_set is not None:
                    eval_x, eval_y = eval_set
                    eval_loss = np.std(eval_y - self.predict(eval_x))
                    eval_losses.append(eval_loss)
                # 是否显示
                if show_log:
                    train_loss = np.std(y - self.predict(X_o))
                    print("epoch:", epoch + 1, "/", self.epochs, ",samples:", (index + 1) * self.batch_size, "/",
                          n_sample,
                          ",train loss:",
                          train_loss, ",eval loss:", eval_loss)
                    train_losses.append(train_loss)
                # 是否早停
                if eval_loss is not None and self.early_stopping_rounds is not None:
                    if eval_loss < best_eval_value:
                        eval_count = 0
                        best_eval_value = eval_loss
                    else:
                        eval_count += 1
                    if eval_count >= self.early_stopping_rounds:
                        print("---------------early_stopping-----------------------------")
                        break

        return train_losses, eval_losses

    def predict(self, X):
        """
        :param X:
        :return:
        """
        if self.normal:
            X = (X - self.xmin) / self.xmax
        n_sample, n_feature = X.shape
        X_V = X @ self.V
        X_V_2 = X_V * X_V
        X_2_V_2 = (X * X) @ (self.V * self.V)
        pol = 0.5 * np.sum(X_V_2 - X_2_V_2, axis=1)
        return np.c_[np.ones(n_sample), X] @ self.w.reshape(-1) + pol

三.测试

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline
In [3]:
#造伪数据
data1 = np.linspace(1, 10, num=200)
data2 = np.linspace(1, 10, num=200) + np.random.random(size=200)
target = data1 * 2 + data2 * 1 + 10 * data1 * data2 + np.random.random(size=200)
data = np.c_[data1, data2]
In [4]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.4, random_state=0)
In [5]:
#训练模型
model = FM()
train_losses,eval_losses = model.fit(X_train, y_train, eval_set=(X_test,y_test))
epoch: 1 / 1 ,samples: 1 / 120 ,train loss: 327.43304536928594 ,eval loss: 291.9721877542085
epoch: 1 / 1 ,samples: 2 / 120 ,train loss: 326.7160901918784 ,eval loss: 291.3271189485855
epoch: 1 / 1 ,samples: 3 / 120 ,train loss: 325.7151828483661 ,eval loss: 290.43528663690176
epoch: 1 / 1 ,samples: 4 / 120 ,train loss: 324.4152130238985 ,eval loss: 289.2810645127806
epoch: 1 / 1 ,samples: 5 / 120 ,train loss: 323.01897436328363 ,eval loss: 288.0444933403834
epoch: 1 / 1 ,samples: 6 / 120 ,train loss: 321.3838207470574 ,eval loss: 286.5995936954016
epoch: 1 / 1 ,samples: 7 / 120 ,train loss: 319.94927792261774 ,eval loss: 285.3310103840077
epoch: 1 / 1 ,samples: 8 / 120 ,train loss: 318.5719337115685 ,eval loss: 284.11400223694636
epoch: 1 / 1 ,samples: 9 / 120 ,train loss: 316.9476360205337 ,eval loss: 282.6800903373831
epoch: 1 / 1 ,samples: 10 / 120 ,train loss: 314.7864859352729 ,eval loss: 280.7747654767163
epoch: 1 / 1 ,samples: 11 / 120 ,train loss: 312.13341354273456 ,eval loss: 278.43789101462926
epoch: 1 / 1 ,samples: 12 / 120 ,train loss: 309.16581289509793 ,eval loss: 275.8247460891347
epoch: 1 / 1 ,samples: 13 / 120 ,train loss: 306.27255717177525 ,eval loss: 273.2781949183259
epoch: 1 / 1 ,samples: 14 / 120 ,train loss: 302.69416496771913 ,eval loss: 270.1306586526588
epoch: 1 / 1 ,samples: 15 / 120 ,train loss: 298.94444644455086 ,eval loss: 266.8335058027848
epoch: 1 / 1 ,samples: 16 / 120 ,train loss: 295.3143220143411 ,eval loss: 263.6423561128022
epoch: 1 / 1 ,samples: 17 / 120 ,train loss: 291.3237509064271 ,eval loss: 260.1354974551523
epoch: 1 / 1 ,samples: 18 / 120 ,train loss: 287.49921467259674 ,eval loss: 256.77523476633075
epoch: 1 / 1 ,samples: 19 / 120 ,train loss: 283.7042670379911 ,eval loss: 253.4414837569022
epoch: 1 / 1 ,samples: 20 / 120 ,train loss: 278.9662785039226 ,eval loss: 249.28103217515456
epoch: 1 / 1 ,samples: 21 / 120 ,train loss: 274.36280682169763 ,eval loss: 245.23921862961456
epoch: 1 / 1 ,samples: 22 / 120 ,train loss: 269.9916328968292 ,eval loss: 241.4018533233914
epoch: 1 / 1 ,samples: 23 / 120 ,train loss: 265.6497926902622 ,eval loss: 237.59063822710482
epoch: 1 / 1 ,samples: 24 / 120 ,train loss: 261.5304618788339 ,eval loss: 233.97503851467806
epoch: 1 / 1 ,samples: 25 / 120 ,train loss: 255.9712297539595 ,eval loss: 229.0974061155879
epoch: 1 / 1 ,samples: 26 / 120 ,train loss: 249.63664207715678 ,eval loss: 223.540478679283
epoch: 1 / 1 ,samples: 27 / 120 ,train loss: 243.5177699482245 ,eval loss: 218.1731976241854
epoch: 1 / 1 ,samples: 28 / 120 ,train loss: 235.93035644162356 ,eval loss: 211.51827263729555
epoch: 1 / 1 ,samples: 29 / 120 ,train loss: 228.29334367815406 ,eval loss: 204.8203973571134
epoch: 1 / 1 ,samples: 30 / 120 ,train loss: 221.0354471284559 ,eval loss: 198.45549752245287
epoch: 1 / 1 ,samples: 31 / 120 ,train loss: 213.95307377288907 ,eval loss: 192.2449043683773
epoch: 1 / 1 ,samples: 32 / 120 ,train loss: 207.1184199693237 ,eval loss: 186.25187133472122
epoch: 1 / 1 ,samples: 33 / 120 ,train loss: 200.57104162351598 ,eval loss: 180.5110271166571
epoch: 1 / 1 ,samples: 34 / 120 ,train loss: 194.30667243270833 ,eval loss: 175.01858282042767
epoch: 1 / 1 ,samples: 35 / 120 ,train loss: 185.93707682966206 ,eval loss: 167.68207164353436
epoch: 1 / 1 ,samples: 36 / 120 ,train loss: 175.7761789726375 ,eval loss: 158.77665170450192
epoch: 1 / 1 ,samples: 37 / 120 ,train loss: 164.91722726673692 ,eval loss: 149.26036101318223
epoch: 1 / 1 ,samples: 38 / 120 ,train loss: 154.08437475131961 ,eval loss: 139.76767558665097
epoch: 1 / 1 ,samples: 39 / 120 ,train loss: 143.8890967396074 ,eval loss: 130.8343368901387
epoch: 1 / 1 ,samples: 40 / 120 ,train loss: 132.60651489715613 ,eval loss: 120.94947293213396
epoch: 1 / 1 ,samples: 41 / 120 ,train loss: 122.05404130856687 ,eval loss: 111.70494507865556
epoch: 1 / 1 ,samples: 42 / 120 ,train loss: 112.12358484286929 ,eval loss: 103.00579214468047
epoch: 1 / 1 ,samples: 43 / 120 ,train loss: 102.89744660707402 ,eval loss: 94.92394460267941
epoch: 1 / 1 ,samples: 44 / 120 ,train loss: 91.42133543792461 ,eval loss: 84.87225024380368
epoch: 1 / 1 ,samples: 45 / 120 ,train loss: 78.13817406765556 ,eval loss: 73.23715667480447
epoch: 1 / 1 ,samples: 46 / 120 ,train loss: 65.70062355560859 ,eval loss: 62.33800496293836
epoch: 1 / 1 ,samples: 47 / 120 ,train loss: 54.02535849002556 ,eval loss: 52.094392295236844
epoch: 1 / 1 ,samples: 48 / 120 ,train loss: 41.683807766260934 ,eval loss: 41.22614231474926
epoch: 1 / 1 ,samples: 49 / 120 ,train loss: 30.433511697626553 ,eval loss: 31.185882623013303
epoch: 1 / 1 ,samples: 50 / 120 ,train loss: 21.514229448428743 ,eval loss: 22.771017028293713
epoch: 1 / 1 ,samples: 51 / 120 ,train loss: 17.16571530025454 ,eval loss: 17.13932905161945
epoch: 1 / 1 ,samples: 52 / 120 ,train loss: 19.682964325018517 ,eval loss: 16.627893626619425
epoch: 1 / 1 ,samples: 53 / 120 ,train loss: 26.17980787137188 ,eval loss: 20.579384922979518
epoch: 1 / 1 ,samples: 54 / 120 ,train loss: 33.98060706754525 ,eval loss: 26.580680020029057
epoch: 1 / 1 ,samples: 55 / 120 ,train loss: 40.718522653083056 ,eval loss: 32.11709555436113
epoch: 1 / 1 ,samples: 56 / 120 ,train loss: 47.1280969103238 ,eval loss: 37.51589197307846
epoch: 1 / 1 ,samples: 57 / 120 ,train loss: 53.23655823745947 ,eval loss: 42.72628701847929
epoch: 1 / 1 ,samples: 58 / 120 ,train loss: 57.91439837653235 ,eval loss: 46.7435156480082
epoch: 1 / 1 ,samples: 59 / 120 ,train loss: 62.2916257861008 ,eval loss: 50.517296176954176
epoch: 1 / 1 ,samples: 60 / 120 ,train loss: 63.37465187759853 ,eval loss: 51.453547885856835
epoch: 1 / 1 ,samples: 61 / 120 ,train loss: 62.79759163667555 ,eval loss: 50.95608888525005
epoch: 1 / 1 ,samples: 62 / 120 ,train loss: 62.14927336845161 ,eval loss: 50.397184853260356
epoch: 1 / 1 ,samples: 63 / 120 ,train loss: 58.68121179556597 ,eval loss: 47.407939684556275
epoch: 1 / 1 ,samples: 64 / 120 ,train loss: 54.521834742756766 ,eval loss: 43.83360698214427
epoch: 1 / 1 ,samples: 65 / 120 ,train loss: 50.152024574195885 ,eval loss: 40.09547496724038
epoch: 1 / 1 ,samples: 66 / 120 ,train loss: 44.39992720800713 ,eval loss: 35.213398197916526
epoch: 1 / 1 ,samples: 67 / 120 ,train loss: 39.333586356144004 ,eval loss: 30.969734556859894
epoch: 1 / 1 ,samples: 68 / 120 ,train loss: 34.594628478215704 ,eval loss: 27.079303861985274
epoch: 1 / 1 ,samples: 69 / 120 ,train loss: 30.59771443493049 ,eval loss: 23.897471633271614
epoch: 1 / 1 ,samples: 70 / 120 ,train loss: 27.012269759755636 ,eval loss: 21.175346759436962
epoch: 1 / 1 ,samples: 71 / 120 ,train loss: 24.206812085666503 ,eval loss: 19.193485550720503
epoch: 1 / 1 ,samples: 72 / 120 ,train loss: 21.92448958652032 ,eval loss: 17.744576448472227
epoch: 1 / 1 ,samples: 73 / 120 ,train loss: 20.28332003228646 ,eval loss: 16.855437337583528
epoch: 1 / 1 ,samples: 74 / 120 ,train loss: 19.048206816991755 ,eval loss: 16.332043336716456
epoch: 1 / 1 ,samples: 75 / 120 ,train loss: 18.153218019182912 ,eval loss: 16.097736138464715
epoch: 1 / 1 ,samples: 76 / 120 ,train loss: 17.62019020726425 ,eval loss: 16.078084905499313
epoch: 1 / 1 ,samples: 77 / 120 ,train loss: 17.276138511016793 ,eval loss: 16.184496145420404
epoch: 1 / 1 ,samples: 78 / 120 ,train loss: 17.095845317846123 ,eval loss: 16.35939284722499
epoch: 1 / 1 ,samples: 79 / 120 ,train loss: 17.01710369315616 ,eval loss: 16.582774084657874
epoch: 1 / 1 ,samples: 80 / 120 ,train loss: 17.010142091853233 ,eval loss: 16.78141646069649
epoch: 1 / 1 ,samples: 81 / 120 ,train loss: 17.039469465368413 ,eval loss: 16.983741883163493
epoch: 1 / 1 ,samples: 82 / 120 ,train loss: 17.09361648072788 ,eval loss: 17.183288214670156
epoch: 1 / 1 ,samples: 83 / 120 ,train loss: 17.103607965717508 ,eval loss: 17.215825153711435
epoch: 1 / 1 ,samples: 84 / 120 ,train loss: 17.054551444977335 ,eval loss: 17.058868161589743
epoch: 1 / 1 ,samples: 85 / 120 ,train loss: 17.002073634940558 ,eval loss: 16.763590578211996
epoch: 1 / 1 ,samples: 86 / 120 ,train loss: 17.0432834771113 ,eval loss: 16.429314526411407
epoch: 1 / 1 ,samples: 87 / 120 ,train loss: 17.20566174022564 ,eval loss: 16.206362628856045
epoch: 1 / 1 ,samples: 88 / 120 ,train loss: 17.454192113839586 ,eval loss: 16.086808426814954
epoch: 1 / 1 ,samples: 89 / 120 ,train loss: 17.810801283058517 ,eval loss: 16.047768615374864
epoch: 1 / 1 ,samples: 90 / 120 ,train loss: 18.352418415549227 ,eval loss: 16.11499606954419
epoch: 1 / 1 ,samples: 91 / 120 ,train loss: 18.949218549079465 ,eval loss: 16.281230192264882
epoch: 1 / 1 ,samples: 92 / 120 ,train loss: 19.745142441168305 ,eval loss: 16.591784659561966
epoch: 1 / 1 ,samples: 93 / 120 ,train loss: 20.56392873099257 ,eval loss: 16.981097554999526
epoch: 1 / 1 ,samples: 94 / 120 ,train loss: 21.492873758070026 ,eval loss: 17.483056184086117
epoch: 1 / 1 ,samples: 95 / 120 ,train loss: 22.638913137139614 ,eval loss: 18.166733085093927
epoch: 1 / 1 ,samples: 96 / 120 ,train loss: 23.784450245546704 ,eval loss: 18.903240572621787
epoch: 1 / 1 ,samples: 97 / 120 ,train loss: 25.096632758570667 ,eval loss: 19.79634156759791
epoch: 1 / 1 ,samples: 98 / 120 ,train loss: 26.379321474881408 ,eval loss: 20.708880523495196
epoch: 1 / 1 ,samples: 99 / 120 ,train loss: 27.730190859825026 ,eval loss: 21.703041579812098
epoch: 1 / 1 ,samples: 100 / 120 ,train loss: 29.024156936889867 ,eval loss: 22.680837567070597
epoch: 1 / 1 ,samples: 101 / 120 ,train loss: 30.446320541768557 ,eval loss: 23.778754692244608
epoch: 1 / 1 ,samples: 102 / 120 ,train loss: 31.90272032223926 ,eval loss: 24.923654288406865
epoch: 1 / 1 ,samples: 103 / 120 ,train loss: 33.25186049429429 ,eval loss: 25.999494771611396
epoch: 1 / 1 ,samples: 104 / 120 ,train loss: 34.67854715054938 ,eval loss: 27.150439790282746
epoch: 1 / 1 ,samples: 105 / 120 ,train loss: 34.645435697943064 ,eval loss: 27.123685246246136
epoch: 1 / 1 ,samples: 106 / 120 ,train loss: 34.68994743526686 ,eval loss: 27.159874803596903
epoch: 1 / 1 ,samples: 107 / 120 ,train loss: 34.81225420024342 ,eval loss: 27.259213281522232
epoch: 1 / 1 ,samples: 108 / 120 ,train loss: 34.93032281889466 ,eval loss: 27.355191146766472
epoch: 1 / 1 ,samples: 109 / 120 ,train loss: 35.214471007739654 ,eval loss: 27.58635316409818
epoch: 1 / 1 ,samples: 110 / 120 ,train loss: 35.65887123004236 ,eval loss: 27.94868223998747
epoch: 1 / 1 ,samples: 111 / 120 ,train loss: 35.95282974874901 ,eval loss: 28.188932498818
epoch: 1 / 1 ,samples: 112 / 120 ,train loss: 36.40533677831546 ,eval loss: 28.559526013999065
epoch: 1 / 1 ,samples: 113 / 120 ,train loss: 36.878308753734025 ,eval loss: 28.947889054255956
epoch: 1 / 1 ,samples: 114 / 120 ,train loss: 37.38011945871737 ,eval loss: 29.360998690623084
epoch: 1 / 1 ,samples: 115 / 120 ,train loss: 37.86842339837046 ,eval loss: 29.763993967202698
epoch: 1 / 1 ,samples: 116 / 120 ,train loss: 38.40383736296162 ,eval loss: 30.20691691837426
epoch: 1 / 1 ,samples: 117 / 120 ,train loss: 38.93430977980122 ,eval loss: 30.64680085481013
epoch: 1 / 1 ,samples: 118 / 120 ,train loss: 39.50290701669599 ,eval loss: 31.119384447400858
epoch: 1 / 1 ,samples: 119 / 120 ,train loss: 39.83255578849347 ,eval loss: 31.39391286683079
epoch: 1 / 1 ,samples: 120 / 120 ,train loss: 40.15889993779814 ,eval loss: 31.66603426567604
In [6]:
#查看拟合效果
plt.scatter(data[:, 0], target)
plt.plot(data[:, 0], model.predict(data), color='r')
Out[6]:
[<matplotlib.lines.Line2D at 0x2dac366ed30>]
In [7]:
#查看loss
plt.plot(range(0,len(train_losses)),train_losses,label='train loss')
plt.plot(range(0,len(eval_losses)),eval_losses,label='eval loss')
plt.legend()
Out[7]:
<matplotlib.legend.Legend at 0x2dadc90a8d0>