The natural extension of binary classification is a multi-class classification task. We first approach multi-class single-label classification, which makes the assumption that each example is assigned to one and only one label.
We use the Iris flower data set, which consists of a classification into three mutually-exclusive classes; call these $A$, $B$ and $C$.
While one could train three unary predicates $A(x)$, $B(x)$ and $C(x)$, it turns out to be more effective if this problem is modelled by a single binary predicate $P(x,l)$, where $l$ is a variable denoting a multi-class label, in this case classes $A$, $B$ or $C$.
import logging; logging.basicConfig(level=logging.INFO)
import tensorflow as tf
import pandas as pd
import ltn
Init Plugin Init Graph Optimizer Init Kernel
Load the iris dataset: 50 samples from each of three species of iris flowers (setosa, virginica, versicolor), measured with four features.
df_train = pd.read_csv("iris_training.csv")
df_test = pd.read_csv("iris_test.csv")
print(df_train.head(5))
sepal_length sepal_width petal_length petal_width species 0 6.4 2.8 5.6 2.2 2 1 5.0 2.3 3.3 1.0 1 2 4.9 2.5 4.5 1.7 2 3 4.9 3.1 1.5 0.1 0 4 5.7 3.8 1.7 0.3 0
labels_train = df_train.pop("species")
labels_test = df_test.pop("species")
batch_size = 64
ds_train = tf.data.Dataset.from_tensor_slices((df_train,labels_train)).batch(batch_size)
ds_test = tf.data.Dataset.from_tensor_slices((df_test,labels_test)).batch(batch_size)
Metal device set to: Apple M1 systemMemory: 16.00 GB maxCacheSize: 5.33 GB
2021-08-30 14:38:15.642262: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support. 2021-08-30 14:38:15.642359: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
Predicate with softmax P(x,class)
class MLP(tf.keras.Model):
"""Model that returns logits."""
def __init__(self, n_classes, hidden_layer_sizes=(16,16,8)):
super(MLP, self).__init__()
self.denses = [tf.keras.layers.Dense(s, activation="elu") for s in hidden_layer_sizes]
self.dense_class = tf.keras.layers.Dense(n_classes)
self.dropout = tf.keras.layers.Dropout(0.2)
def call(self, inputs, training=False):
x = inputs[0]
for dense in self.denses:
x = dense(x)
x = self.dropout(x, training=training)
return self.dense_class(x)
logits_model = MLP(3)
p = ltn.Predicate.FromLogits(logits_model, activation_function="softmax", with_class_indexing=True)
Constants to index/iterate on the classes
class_A = ltn.Constant(0, trainable=False)
class_B = ltn.Constant(1, trainable=False)
class_C = ltn.Constant(2, trainable=False)
Operators and axioms
Not = ltn.Wrapper_Connective(ltn.fuzzy_ops.Not_Std())
And = ltn.Wrapper_Connective(ltn.fuzzy_ops.And_Prod())
Or = ltn.Wrapper_Connective(ltn.fuzzy_ops.Or_ProbSum())
Implies = ltn.Wrapper_Connective(ltn.fuzzy_ops.Implies_Reichenbach())
Forall = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMeanError(p=2),semantics="forall")
formula_aggregator = ltn.Wrapper_Formula_Aggregator(ltn.fuzzy_ops.Aggreg_pMeanError(p=2))
@tf.function
def axioms(features, labels, training=False):
x_A = ltn.Variable("x_A",features[labels==0])
x_B = ltn.Variable("x_B",features[labels==1])
x_C = ltn.Variable("x_C",features[labels==2])
axioms = [
Forall(x_A,p([x_A,class_A],training=training)),
Forall(x_B,p([x_B,class_B],training=training)),
Forall(x_C,p([x_C,class_C],training=training))
]
sat_level = formula_aggregator(axioms).tensor
return sat_level
Initialize all layers and the static graph
for features, labels in ds_test:
print("Initial sat level %.5f"%axioms(features,labels))
break
Initial sat level 0.25581
2021-08-30 14:38:20.990753: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2) 2021-08-30 14:38:20.992807: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz 2021-08-30 14:38:20.992905: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
Define the metrics. While training, we measure:
metrics_dict = {
'train_sat_kb': tf.keras.metrics.Mean(name='train_sat_kb'),
'test_sat_kb': tf.keras.metrics.Mean(name='test_sat_kb'),
'train_accuracy': tf.keras.metrics.CategoricalAccuracy(name="train_accuracy"),
'test_accuracy': tf.keras.metrics.CategoricalAccuracy(name="test_accuracy")
}
Define the training and test step
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
@tf.function
def train_step(features, labels):
# sat and update
with tf.GradientTape() as tape:
sat = axioms(features, labels, training=True)
loss = 1.-sat
gradients = tape.gradient(loss, p.trainable_variables)
optimizer.apply_gradients(zip(gradients, p.trainable_variables))
sat = axioms(features, labels) # compute sat without dropout
metrics_dict['train_sat_kb'](sat)
# accuracy
predictions = logits_model([features])
metrics_dict['train_accuracy'](tf.one_hot(labels,3),predictions)
@tf.function
def test_step(features, labels):
# sat
sat = axioms(features, labels)
metrics_dict['test_sat_kb'](sat)
# accuracy
predictions = logits_model([features])
metrics_dict['test_accuracy'](tf.one_hot(labels,3),predictions)
Train
import commons
EPOCHS = 500
commons.train(
EPOCHS,
metrics_dict,
ds_train,
ds_test,
train_step,
test_step,
csv_path="iris_results.csv",
track_metrics=20
)
2021-08-30 14:39:58.964336: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled. 2021-08-30 14:39:59.951405: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled. 2021-08-30 14:40:00.487437: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
Epoch 0, train_sat_kb: 0.2620, test_sat_kb: 0.2640, train_accuracy: 0.3000, test_accuracy: 0.4667 Epoch 20, train_sat_kb: 0.4088, test_sat_kb: 0.4085, train_accuracy: 0.7333, test_accuracy: 0.5667 Epoch 40, train_sat_kb: 0.5422, test_sat_kb: 0.5404, train_accuracy: 0.9417, test_accuracy: 0.9000 Epoch 60, train_sat_kb: 0.6432, test_sat_kb: 0.6381, train_accuracy: 0.9417, test_accuracy: 0.9000 Epoch 80, train_sat_kb: 0.7105, test_sat_kb: 0.7041, train_accuracy: 0.9583, test_accuracy: 0.9000 Epoch 100, train_sat_kb: 0.7486, test_sat_kb: 0.7443, train_accuracy: 0.9667, test_accuracy: 0.9333 Epoch 120, train_sat_kb: 0.7888, test_sat_kb: 0.7884, train_accuracy: 0.9667, test_accuracy: 0.9667 Epoch 140, train_sat_kb: 0.8182, test_sat_kb: 0.8197, train_accuracy: 0.9750, test_accuracy: 0.9667 Epoch 160, train_sat_kb: 0.8356, test_sat_kb: 0.8374, train_accuracy: 0.9750, test_accuracy: 1.0000 Epoch 180, train_sat_kb: 0.8525, test_sat_kb: 0.8457, train_accuracy: 0.9750, test_accuracy: 0.9667 Epoch 200, train_sat_kb: 0.8561, test_sat_kb: 0.8563, train_accuracy: 0.9833, test_accuracy: 0.9667 Epoch 220, train_sat_kb: 0.8706, test_sat_kb: 0.8541, train_accuracy: 0.9833, test_accuracy: 0.9667 Epoch 240, train_sat_kb: 0.8739, test_sat_kb: 0.8587, train_accuracy: 0.9833, test_accuracy: 0.9667 Epoch 260, train_sat_kb: 0.8694, test_sat_kb: 0.8635, train_accuracy: 0.9750, test_accuracy: 0.9667 Epoch 280, train_sat_kb: 0.8709, test_sat_kb: 0.8625, train_accuracy: 0.9750, test_accuracy: 0.9667 Epoch 300, train_sat_kb: 0.8782, test_sat_kb: 0.8429, train_accuracy: 0.9833, test_accuracy: 0.9667 Epoch 320, train_sat_kb: 0.8780, test_sat_kb: 0.8387, train_accuracy: 0.9833, test_accuracy: 0.9667 Epoch 340, train_sat_kb: 0.8791, test_sat_kb: 0.8614, train_accuracy: 0.9750, test_accuracy: 0.9667 Epoch 360, train_sat_kb: 0.8880, test_sat_kb: 0.8497, train_accuracy: 0.9833, test_accuracy: 0.9333 Epoch 380, train_sat_kb: 0.8894, test_sat_kb: 0.8541, train_accuracy: 0.9750, test_accuracy: 0.9333 Epoch 400, train_sat_kb: 0.8870, test_sat_kb: 0.8401, train_accuracy: 0.9917, test_accuracy: 0.9667 Epoch 420, train_sat_kb: 0.8894, test_sat_kb: 0.8402, train_accuracy: 0.9917, test_accuracy: 0.9667 Epoch 440, train_sat_kb: 0.8912, test_sat_kb: 0.8557, train_accuracy: 0.9750, test_accuracy: 0.9667 Epoch 460, train_sat_kb: 0.8953, test_sat_kb: 0.8519, train_accuracy: 0.9750, test_accuracy: 0.9333 Epoch 480, train_sat_kb: 0.8810, test_sat_kb: 0.8593, train_accuracy: 0.9750, test_accuracy: 0.9667