import xgboost as xgb
from sklearn.datasets import load_iris
import numpy as np
from matplotlib import pyplot as plt
from art.estimators.classification import XGBoostClassifier
from art.attacks.evasion import ZooAttack
from art.utils import load_mnist
import warnings
warnings.filterwarnings('ignore')
def get_adversarial_examples(x_train, y_train, num_classes):
# Create and fit XGBoost model
num_round = 10
param = {'objective': 'multi:softprob', 'metric': 'multi_logloss', 'num_class': num_classes}
train_data = xgb.DMatrix(x_train, label=y_train)
evallist = [(train_data, 'eval'), (train_data, 'train')]
model = xgb.train(param, train_data, num_round, evallist)
# Create ART classifier for XGBoost
art_classifier = XGBoostClassifier(model=model, nb_features=x_train.shape[1], nb_classes=10)
# Create ART Zeroth Order Optimization attack
zoo = ZooAttack(classifier=art_classifier, confidence=0.0, targeted=False, learning_rate=1e-1, max_iter=20,
binary_search_steps=10, initial_const=1e-3, abort_early=True, use_resize=False,
use_importance=False, nb_parallel=1, batch_size=1, variable_h=0.2)
# Generate adversarial samples with ART Zeroth Order Optimization attack
x_train_adv = zoo.generate(x_train)
return x_train_adv, model
def get_data(num_classes):
x_train, y_train = load_iris(return_X_y=True)
x_train = x_train[y_train < num_classes][:, [0, 1]]
y_train = y_train[y_train < num_classes]
x_train[:, 0][y_train == 0] *= 2
x_train[:, 1][y_train == 2] *= 2
x_train[:, 0][y_train == 0] -= 3
x_train[:, 1][y_train == 2] -= 2
x_train[:, 0] = (x_train[:, 0] - 4) / (9 - 4)
x_train[:, 1] = (x_train[:, 1] - 1) / (6 - 1)
return x_train, y_train
def plot_results(model, x_train, y_train, x_train_adv, num_classes):
fig, axs = plt.subplots(1, num_classes, figsize=(num_classes * 5, 5))
colors = ['orange', 'blue', 'green']
for i_class in range(num_classes):
# Plot difference vectors
for i in range(y_train[y_train == i_class].shape[0]):
x_1_0 = x_train[y_train == i_class][i, 0]
x_1_1 = x_train[y_train == i_class][i, 1]
x_2_0 = x_train_adv[y_train == i_class][i, 0]
x_2_1 = x_train_adv[y_train == i_class][i, 1]
if x_1_0 != x_2_0 or x_1_1 != x_2_1:
axs[i_class].plot([x_1_0, x_2_0], [x_1_1, x_2_1], c='black', zorder=1)
# Plot benign samples
for i_class_2 in range(num_classes):
axs[i_class].scatter(x_train[y_train == i_class_2][:, 0], x_train[y_train == i_class_2][:, 1], s=20,
zorder=2, c=colors[i_class_2])
axs[i_class].set_aspect('equal', adjustable='box')
# Show predicted probability as contour plot
h = .01
x_min, x_max = 0, 1
y_min, y_max = 0, 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
Z_proba = model.predict(xgb.DMatrix(np.c_[xx.ravel(), yy.ravel()]))
Z_proba = Z_proba[:, i_class].reshape(xx.shape)
im = axs[i_class].contourf(xx, yy, Z_proba, levels=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
vmin=0, vmax=1)
if i_class == num_classes - 1:
cax = fig.add_axes([0.95, 0.2, 0.025, 0.6])
plt.colorbar(im, ax=axs[i_class], cax=cax)
# Plot adversarial samples
for i in range(y_train[y_train == i_class].shape[0]):
x_1_0 = x_train[y_train == i_class][i, 0]
x_1_1 = x_train[y_train == i_class][i, 1]
x_2_0 = x_train_adv[y_train == i_class][i, 0]
x_2_1 = x_train_adv[y_train == i_class][i, 1]
if x_1_0 != x_2_0 or x_1_1 != x_2_1:
axs[i_class].scatter(x_2_0, x_2_1, zorder=2, c='red', marker='X')
axs[i_class].set_xlim((x_min, x_max))
axs[i_class].set_ylim((y_min, y_max))
axs[i_class].set_title('class ' + str(i_class))
axs[i_class].set_xlabel('feature 1')
axs[i_class].set_ylabel('feature 2')
num_classes = 2
x_train, y_train = get_data(num_classes=num_classes)
x_train_adv, model = get_adversarial_examples(x_train, y_train, num_classes)
plot_results(model, x_train, y_train, x_train_adv, num_classes)
[17:27:44] WARNING: /workspace/src/learner.cc:480: Parameters: { metric } might not be used. This may not be accurate due to some parameters are only used in language bindings but passed down to XGBoost core. Or some parameters are not used but slip through this verification. Please open an issue if you find above cases. [0] eval-merror:0.11000 train-merror:0.11000 [1] eval-merror:0.11000 train-merror:0.11000 [2] eval-merror:0.11000 train-merror:0.11000 [3] eval-merror:0.10000 train-merror:0.10000 [4] eval-merror:0.08000 train-merror:0.08000 [5] eval-merror:0.08000 train-merror:0.08000 [6] eval-merror:0.06000 train-merror:0.06000 [7] eval-merror:0.06000 train-merror:0.06000 [8] eval-merror:0.06000 train-merror:0.06000 [9] eval-merror:0.06000 train-merror:0.06000
ZOO: 100%|██████████| 100/100 [00:04<00:00, 20.23it/s]
num_classes = 3
x_train, y_train = get_data(num_classes=num_classes)
x_train_adv, model = get_adversarial_examples(x_train, y_train, num_classes)
plot_results(model, x_train, y_train, x_train_adv, num_classes)
[17:27:50] WARNING: /workspace/src/learner.cc:480: Parameters: { metric } might not be used. This may not be accurate due to some parameters are only used in language bindings but passed down to XGBoost core. Or some parameters are not used but slip through this verification. Please open an issue if you find above cases. [0] eval-merror:0.15333 train-merror:0.15333 [1] eval-merror:0.14000 train-merror:0.14000 [2] eval-merror:0.13333 train-merror:0.13333 [3] eval-merror:0.14000 train-merror:0.14000 [4] eval-merror:0.13333 train-merror:0.13333 [5] eval-merror:0.12667 train-merror:0.12667 [6] eval-merror:0.12000 train-merror:0.12000 [7] eval-merror:0.11333 train-merror:0.11333 [8] eval-merror:0.10000 train-merror:0.10000 [9] eval-merror:0.11333 train-merror:0.11333
ZOO: 100%|██████████| 150/150 [00:07<00:00, 19.87it/s]
(x_train, y_train), (x_test, y_test), min_, max_ = load_mnist()
n_samples_train = x_train.shape[0]
n_features_train = x_train.shape[1] * x_train.shape[2] * x_train.shape[3]
n_samples_test = x_test.shape[0]
n_features_test = x_test.shape[1] * x_test.shape[2] * x_test.shape[3]
x_train = x_train.reshape(n_samples_train, n_features_train)
x_test = x_test.reshape(n_samples_test, n_features_test)
y_train = np.argmax(y_train, axis=1)
y_test = np.argmax(y_test, axis=1)
n_samples_max = 200
x_train = x_train[0:n_samples_max]
y_train = y_train[0:n_samples_max]
x_test = x_test[0:n_samples_max]
y_test = y_test[0:n_samples_max]
num_round = 10
param = {'objective': 'multi:softprob', 'metric': 'multi_logloss', 'num_class': 10}
train_data = xgb.DMatrix(x_train, label=y_train)
validation_data = train_data
evallist=[(train_data, 'eval'), (train_data, 'train')]
model = xgb.train(param, train_data, num_round, evallist)
[17:27:59] WARNING: /workspace/src/learner.cc:480: Parameters: { metric } might not be used. This may not be accurate due to some parameters are only used in language bindings but passed down to XGBoost core. Or some parameters are not used but slip through this verification. Please open an issue if you find above cases. [0] eval-merror:0.05500 train-merror:0.05500 [1] eval-merror:0.04000 train-merror:0.04000 [2] eval-merror:0.01500 train-merror:0.01500 [3] eval-merror:0.01000 train-merror:0.01000 [4] eval-merror:0.00500 train-merror:0.00500 [5] eval-merror:0.00000 train-merror:0.00000 [6] eval-merror:0.00000 train-merror:0.00000 [7] eval-merror:0.00000 train-merror:0.00000 [8] eval-merror:0.00000 train-merror:0.00000 [9] eval-merror:0.00000 train-merror:0.00000
art_classifier = XGBoostClassifier(model=model, nb_features=x_train.shape[1], nb_classes=10)
zoo = ZooAttack(classifier=art_classifier, confidence=0.0, targeted=False, learning_rate=1e-1, max_iter=100,
binary_search_steps=20, initial_const=1e-3, abort_early=True, use_resize=False,
use_importance=False, nb_parallel=10, batch_size=1, variable_h=0.05)
x_train_adv = zoo.generate(x_train)
ZOO: 100%|██████████| 200/200 [05:51<00:00, 1.76s/it]
x_test_adv = zoo.generate(x_test)
ZOO: 100%|██████████| 200/200 [05:09<00:00, 1.55s/it]
y_pred = model.predict(xgb.DMatrix(x_train))
score = np.sum(y_train == np.argmax(y_pred, axis=1)) / y_train.shape[0]
print("Benign Training Score: %.4f" % score)
Benign Training Score: 1.0000
plt.matshow(x_train[0, :].reshape((28, 28)))
plt.clim(0, 1)
prediction = np.argmax(model.predict(xgb.DMatrix(x_train[0:1, :])), axis=1)
print("Benign Training Predicted Label: %i" % prediction)
Benign Training Predicted Label: 5
y_pred = model.predict(xgb.DMatrix(x_train_adv))
score = np.sum(y_train == np.argmax(y_pred, axis=1)) / y_train.shape[0]
print("Adversarial Training Score: %.4f" % score)
Adversarial Training Score: 0.5900
plt.matshow(x_train_adv[0, :].reshape((28, 28)))
plt.clim(0, 1)
prediction = np.argmax(model.predict(xgb.DMatrix(x_train_adv[0:1, :])), axis=1)
print("Adversarial Training Predicted Label: %i" % prediction)
Adversarial Training Predicted Label: 8
y_pred = model.predict(xgb.DMatrix(x_test))
score = np.sum(y_test == np.argmax(y_pred, axis=1)) / y_test.shape[0]
print("Benign Test Score: %.4f" % score)
Benign Test Score: 0.6450
plt.matshow(x_test[0, :].reshape((28, 28)))
plt.clim(0, 1)
prediction = np.argmax(model.predict(xgb.DMatrix(x_test[0:1, :])), axis=1)
print("Benign Test Predicted Label: %i" % prediction)
Benign Test Predicted Label: 7
y_pred = model.predict(xgb.DMatrix(x_test_adv))
score = np.sum(y_test == np.argmax(y_pred, axis=1)) / y_test.shape[0]
print("Adversarial Test Score: %.4f" % score)
Adversarial Test Score: 0.3750
plt.matshow(x_test_adv[0, :].reshape((28, 28)))
plt.clim(0, 1)
prediction = np.argmax(model.predict(xgb.DMatrix(x_test_adv[0:1, :])), axis=1)
print("Adversarial Test Predicted Label: %i" % prediction)
Adversarial Test Predicted Label: 3