Multi-Class Multi-Label classification

We now turn to multi-label classification, whereby multiple labels can be assigned to each example. As a first example of the reach of LTNs, we shall see how the previous example can be extended naturally using LTN to account for multiple labels, not always a trivial extension for most ML algorithms.

The standard approach to the multi-label problem is to provide explicit negative examples for each class. By contrast, LTN can use background knowledge to relate classes directly to each other, thus becoming a powerful tool in the case of the multi-label problem when typically the labelled data is scarce.

We use the Leptograpsus crabs data set consisting of 200 examples of 5 morphological measurements of 50 crabs. The task is to classify the crabs according to their colour and sex. There are four labels: blue, orange, male and female.

The colour labels are mutually-exclusive, and so are the labels for sex. LTN will be used to specify such information logically.

In [1]:
import logging; logging.basicConfig(level=logging.INFO)
import tensorflow as tf
import logictensornetworks as ltn
import pandas as pd
Init Plugin
Init Graph Optimizer
Init Kernel


Crabs dataset from:

The crabs data frame has 200 rows and 8 columns, describing 5 morphological measurements on 50 crabs each of two colour forms and both sexes, of the species Leptograpsus variegatus collected at Fremantle, W. Australia.

  • Multi-class: Male, Female, Blue, Orange.
  • Multi-label: Only Male-Female and Blue-Orange are mutually exclusive.
In [2]:
df = pd.read_csv("crabs.dat",sep=" ", skipinitialspace=True)
df = df.sample(frac=1) #shuffle
    sp sex  index    FL    RW    CL    CW    BD
47   B   M     48  19.8  14.2  43.2  49.7  18.6
59   B   F     10  10.8   9.5  22.5  26.3   9.1
31   B   M     32  16.2  13.3  36.0  41.7  15.4
184  O   F     35  19.1  16.3  37.9  42.6  17.2
0    B   M      1   8.1   6.7  16.1  19.0   7.0

We use 160 samples for training and 40 samples for testing.

In [3]:
features = df[['FL','RW','CL','CW','BD']]
labels_sex = df['sex']
labels_color = df['sp']

ds_train =[:160],labels_sex[:160],labels_color[:160])).batch(batch_size)
ds_test =[160:],labels_sex[160:],labels_color[160:])).batch(batch_size)
Metal device set to: Apple M1

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

2021-08-30 14:45:10.156116: I tensorflow/core/common_runtime/pluggable_device/] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2021-08-30 14:45:10.156306: I tensorflow/core/common_runtime/pluggable_device/] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)



index class
0 Male
1 Female
2 Blue
3 Orange

Let's note that, since the classes are not mutually exclusive, the last layer of the model will be a sigmoid and not a softmax.

In [4]:
class MLP(tf.keras.Model):
    """Model that returns logits."""
    def __init__(self, n_classes, hidden_layer_sizes=(16,16,8)):
        super(MLP, self).__init__()
        self.denses = [tf.keras.layers.Dense(s, activation="elu") for s in hidden_layer_sizes]
        self.dense_class = tf.keras.layers.Dense(n_classes)
    def call(self, inputs):
        x = inputs
        for dense in self.denses:
            x = dense(x)
        return self.dense_class(x)

logits_model = MLP(4)
p = ltn.Predicate(ltn.utils.LogitsToPredicateModel(logits_model,single_label=False))

Constants to index the classes

In [5]:
class_male = ltn.Constant(0, trainable=False)
class_female = ltn.Constant(1, trainable=False)
class_blue = ltn.Constant(2, trainable=False)
class_orange = ltn.Constant(3, trainable=False)


forall x_blue: C(x_blue,blue)
forall x_orange: C(x_orange,orange)
forall x_male: C(x_male,male)
forall x_female: C(x_female,female)
forall x: ~(C(x,male) & C(x,female))
forall x: ~(C(x,blue) & C(x,orange))
In [6]:
Not = ltn.Wrapper_Connective(ltn.fuzzy_ops.Not_Std())
And = ltn.Wrapper_Connective(ltn.fuzzy_ops.And_Prod())
Or = ltn.Wrapper_Connective(ltn.fuzzy_ops.Or_ProbSum())
Implies = ltn.Wrapper_Connective(ltn.fuzzy_ops.Implies_Reichenbach())
Forall = ltn.Wrapper_Quantifier(ltn.fuzzy_ops.Aggreg_pMeanError(p=2),semantics="forall")
In [7]:
formula_aggregator = ltn.Wrapper_Formula_Aggregator(ltn.fuzzy_ops.Aggreg_pMeanError(p=2))

def axioms(features,labels_sex,labels_color):
    x = ltn.Variable("x",features)
    x_blue = ltn.Variable("x_blue",features[labels_color=="B"])
    x_orange = ltn.Variable("x_orange",features[labels_color=="O"])
    x_male = ltn.Variable("x_blue",features[labels_sex=="M"])
    x_female = ltn.Variable("x_blue",features[labels_sex=="F"])
    axioms = [
        Forall(x_blue, p([x_blue,class_blue])),
        Forall(x_orange, p([x_orange,class_orange])),
        Forall(x_male, p([x_male,class_male])),
        Forall(x_female, p([x_female,class_female])),
    sat_level = formula_aggregator(axioms).tensor
    return sat_level

Initialize all layers and the static graph.

In [8]:
for features, labels_sex, labels_color in ds_train:
    print("Initial sat level %.5f"%axioms(features,labels_sex,labels_color))
2021-08-30 14:46:45.306657: I tensorflow/compiler/mlir/] None of the MLIR Optimization Passes are enabled (registered 2)
2021-08-30 14:46:45.311761: W tensorflow/core/platform/profile_utils/] Failed to get CPU frequency: 0 Hz
2021-08-30 14:46:45.311945: I tensorflow/core/grappler/optimizers/] Plugin optimizer for device_type GPU is enabled.
Initial sat level 0.37398


Define the metrics. While training, we measure:

  1. The level of satisfiability of the Knowledge Base of the training data.
  2. The level of satisfiability of the Knowledge Base of the test data.
  3. The training accuracy.
  4. The test accuracy.
  5. The level of satisfiability of a formula $\phi_1$ we expect to be true. forall x (p(x,blue)->~p(x,orange)) (every blue crab cannot be orange and vice-versa)
  6. The level of satisfiability of a formula $\phi_2$ we expect to be false. forall x (p(x,blue)->p(x,orange)) (every blue crab is also orange)
  7. The level of satisfiability of a formula $\phi_3$ we expect to be false. forall x (p(x,blue)->p(x,male)) (every blue crab is male)

For the last 3 queries, we use $p=5$ when approximating the universal quantifier. A higher $p$ denotes a stricter universal quantification with a stronger focus on outliers (see turorial on operators for more details). Training should usually not focus on outliers, as optimizers would struggle to generalize and tend to get stuck in local minima. However, when querying $\phi_1$,$\phi_2$,$\phi_3$, we wish to be more careful about the interpretation of our statement.

In [12]:
metrics_dict = {
    'train_sat_kb': tf.keras.metrics.Mean(name='train_sat_kb'),
    'test_sat_kb': tf.keras.metrics.Mean(name='test_sat_kb'),
    'train_accuracy': tf.keras.metrics.Mean(name="train_accuracy"),
    'test_accuracy': tf.keras.metrics.Mean(name="test_accuracy"),
    'test_sat_phi1': tf.keras.metrics.Mean(name='test_sat_phi1'),
    'test_sat_phi2': tf.keras.metrics.Mean(name='test_sat_phi2'),
    'test_sat_phi3': tf.keras.metrics.Mean(name='test_sat_phi3')

def sat_phi1(features):
    x = ltn.Variable("x",features)
    phi1 = Forall(x, Implies(p([x,class_blue]),Not(p([x,class_orange]))),p=5)
    return phi1.tensor

def sat_phi2(features):
    x = ltn.Variable("x",features)
    phi2 = Forall(x, Implies(p([x,class_blue]),p([x,class_orange])),p=5)
    return phi2.tensor

def sat_phi3(features):
    x = ltn.Variable("x",features)
    phi3 = Forall(x, Implies(p([x,class_blue]),p([x,class_male])),p=5)
    return phi3.tensor

def multilabel_hamming_loss(y_true, y_pred, threshold=0.5,from_logits=False):
    if from_logits:
        y_pred = tf.math.sigmoid(y_pred)
    y_pred = y_pred > threshold
    y_true = tf.cast(y_true, tf.int32)
    y_pred = tf.cast(y_pred, tf.int32)
    nonzero = tf.cast(tf.math.count_nonzero(y_true-y_pred,axis=-1),tf.float32)
    return nonzero/y_true.get_shape()[-1]
In [13]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
def train_step(features, labels_sex, labels_color):
    # sat and update
    with tf.GradientTape() as tape:
        sat = axioms(features, labels_sex, labels_color)
        loss = 1.-sat
    gradients = tape.gradient(loss, p.trainable_variables)
    optimizer.apply_gradients(zip(gradients, p.trainable_variables))
    # accuracy
    predictions = logits_model(features)
    labels_male = (labels_sex == "M")
    labels_female = (labels_sex == "F")
    labels_blue = (labels_color == "B")
    labels_orange = (labels_color == "O")
    onehot = tf.stack([labels_male,labels_female,labels_blue,labels_orange],axis=-1)
def test_step(features, labels_sex, labels_color):
    # sat
    sat_kb = axioms(features, labels_sex, labels_color)
    # accuracy
    predictions = logits_model(features)
    labels_male = (labels_sex == "M")
    labels_female = (labels_sex == "F")
    labels_blue = (labels_color == "B")
    labels_orange = (labels_color == "O")
    onehot = tf.stack([labels_male,labels_female,labels_blue,labels_orange],axis=-1)
In [14]:
import commons

EPOCHS = 200

2021-08-30 14:49:36.979684: I tensorflow/core/grappler/optimizers/] Plugin optimizer for device_type GPU is enabled.
2021-08-30 14:49:37.551709: I tensorflow/core/grappler/optimizers/] Plugin optimizer for device_type GPU is enabled.
2021-08-30 14:49:38.343604: I tensorflow/core/grappler/optimizers/] Plugin optimizer for device_type GPU is enabled.
Epoch 0, train_sat_kb: 0.4152, test_sat_kb: 0.4329, train_accuracy: 0.5063, test_accuracy: 0.4750, test_sat_phi1: 0.9116, test_sat_phi2: 0.0802, test_sat_phi3: 0.0007
Epoch 20, train_sat_kb: 0.5053, test_sat_kb: 0.5054, train_accuracy: 0.5578, test_accuracy: 0.5312, test_sat_phi1: 0.5040, test_sat_phi2: 0.4654, test_sat_phi3: 0.0062
Epoch 40, train_sat_kb: 0.6822, test_sat_kb: 0.6806, train_accuracy: 0.7438, test_accuracy: 0.7312, test_sat_phi1: 0.6448, test_sat_phi2: 0.5052, test_sat_phi3: 0.7595
Epoch 60, train_sat_kb: 0.7399, test_sat_kb: 0.7318, train_accuracy: 0.8188, test_accuracy: 0.7563, test_sat_phi1: 0.5998, test_sat_phi2: 0.2873, test_sat_phi3: 0.7574
Epoch 80, train_sat_kb: 0.7876, test_sat_kb: 0.8001, train_accuracy: 0.9016, test_accuracy: 0.9438, test_sat_phi1: 0.6944, test_sat_phi2: 0.2684, test_sat_phi3: 0.5755
Epoch 100, train_sat_kb: 0.8447, test_sat_kb: 0.8754, train_accuracy: 0.9641, test_accuracy: 0.9938, test_sat_phi1: 0.8330, test_sat_phi2: 0.2429, test_sat_phi3: 0.4324
Epoch 120, train_sat_kb: 0.8735, test_sat_kb: 0.9043, train_accuracy: 0.9672, test_accuracy: 0.9938, test_sat_phi1: 0.9028, test_sat_phi2: 0.1991, test_sat_phi3: 0.3741
In [ ]: