ML.NET Binary Classification

Creates a binary classification model to predict the quality of wine using 11 physicochemical features. Uses the DataFrame API to read the raw data and prepare it.

NuGet package installation

In [1]:
#r "nuget:Microsoft.ML, 1.4.0"
#r "nuget:XPlot.Plotly, 3.0.1"
Installing package Microsoft.ML, version 1.4.0................done!
Successfully added reference to package Microsoft.ML, version 1.4.0
Installing package XPlot.Plotly, version 3.0.1........done!
Successfully added reference to package XPlot.Plotly, version 3.0.1

Namespaces

In [2]:
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using XPlot.Plotly;

Input class definition

In [3]:
public class BinaryClassificationData
{
    [LoadColumn(0)]
    public float FixedAcidity;

    [LoadColumn(1)]
    public float VolatileAcidity;

    [LoadColumn(2)]
    public float CitricAcid;

    [LoadColumn(3)]
    public float ResidualSugar;

    [LoadColumn(4)]
    public float Chlorides;

    [LoadColumn(5)]
    public float FreeSulfurDioxide;

    [LoadColumn(6)]
    public float TotalSulfurDioxide;

    [LoadColumn(7)]
    public float Density;

    [LoadColumn(8)]
    public float Ph;

    [LoadColumn(9)]
    public float Sulphates;

    [LoadColumn(10)]
    public float Alcohol;

    [LoadColumn(11)]
    public float Quality;
}

public class RichBinaryClassificationData: BinaryClassificationData
{
    public bool Label => Quality > 5;
}

Output class definition

In [4]:
public class BinaryClassificationPrediction
{
    public bool Label;

    [ColumnName("PredictedLabel")]
    public bool PredictedLabel;

    public int LabelAsNumber => PredictedLabel ? 1 : 0;
}

Bring in the DataFrame

In [5]:
#r "nuget:Microsoft.Data.Analysis,0.2.0"
using Microsoft.Data.Analysis;
using Microsoft.AspNetCore.Html;

// Convenient custom formatter.
Formatter<DataFrame>.Register((df, writer) =>
{
    var headers = new List<IHtmlContent>();
    headers.Add(th(i("index")));
    headers.AddRange(df.Columns.Select(c => (IHtmlContent) th(c.Name)));
    var rows = new List<List<IHtmlContent>>();
    var take = 5;
    for (var i = 0; i < Math.Min(take, df.Rows.Count); i++)
    {
        var cells = new List<IHtmlContent>();
        cells.Add(td(i));
        foreach (var obj in df.Rows[i])
        {
            cells.Add(td(obj));
        }
        rows.Add(cells);
    }

    var t = table(
        thead(
            headers),
        tbody(
            rows.Select(
                r => tr(r))));

    writer.Write(t);
}, "text/html");
Installing package Microsoft.Data.Analysis, version 0.2.0......done!
Successfully added reference to package Microsoft.Data.Analysis, version 0.2.0

Read the raw data

In [6]:
var trainingData = DataFrame.LoadCsv(
    "./WineQuality_White_Train.csv",
    separator: ';',
    columnNames: new[]
                {
                    "FixedAcidity",
                    "VolatileAcidity",
                    "CitricAcid",
                    "ResidualSugar",
                    "Chlorides",
                    "FreeSulfurDioxide",
                    "TotalSulfurDioxide",
                    "Density",
                    "Ph",
                    "Sulphates",
                    "Alcohol",
                    "Quality"
                });

display(trainingData);
indexFixedAcidityVolatileAcidityCitricAcidResidualSugarChloridesFreeSulfurDioxideTotalSulfurDioxideDensityPhSulphatesAlcoholQuality
070.270.3620.70.045451701.00130.458.86
16.30.30.341.60.049141320.9943.30.499.56
28.10.280.46.90.0530970.99513.260.4410.16
37.20.230.328.50.058471860.99563.190.49.96
47.20.230.328.50.058471860.99563.190.49.96

Prepare the data

In [7]:
// Create the Label column and add it to the data.
var labelCol = trainingData["Quality"].ElementwiseGreaterThanOrEqual(6);
labelCol.SetName("Label");
trainingData.Columns.Add(labelCol);

// This works, but we need the Quality column in later cells ...
// trainingData.Columns.Remove(trainingData["Quality"]);

display(trainingData);
indexFixedAcidityVolatileAcidityCitricAcidResidualSugarChloridesFreeSulfurDioxideTotalSulfurDioxideDensityPhSulphatesAlcoholQualityLabel
070.270.3620.70.045451701.00130.458.86True
16.30.30.341.60.049141320.9943.30.499.56True
28.10.280.46.90.0530970.99513.260.4410.16True
37.20.230.328.50.058471860.99563.190.49.96True
47.20.230.328.50.058471860.99563.190.49.96True
In [8]:
var mlContext = new MLContext(seed: null);

// Define the pipeline.
var pipeline =
        mlContext.Transforms.ReplaceMissingValues(
            outputColumnName: "FixedAcidity",
            replacementMode: MissingValueReplacingEstimator.ReplacementMode.Mean)
        .Append(mlContext.Transforms.Concatenate("Features",
            new[]
            {
                "FixedAcidity",
                "VolatileAcidity",
                "CitricAcid",
                "ResidualSugar",
                "Chlorides",
                "FreeSulfurDioxide",
                "TotalSulfurDioxide",
                "Density",
                "Ph",
                "Sulphates",
                "Alcohol"
            }))
        .Append(mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression());

Train the model

In [9]:
var model = pipeline.Fit(trainingData);

Evaluate the model

In [10]:
// Load the raw test data.
var testData = mlContext.Data.LoadFromTextFile<BinaryClassificationData>(
    "./WineQuality_White_Test.csv", 
    separatorChar: ';',
    hasHeader: true);
    
// Calculate the Label (IDataView to IEnumerable to IDataView).    
var stronglyTypedTestData = mlContext.Data.CreateEnumerable<RichBinaryClassificationData>(trainingData, false);
testData = mlContext.Data.LoadFromEnumerable(stronglyTypedTestData);

// Score the test data and calculate the metrics.
var scoredData = model.Transform(testData);
var qualityMetrics =  mlContext.BinaryClassification.Evaluate(scoredData);
display(qualityMetrics);
LogLossLogLossReductionEntropyAreaUnderRocCurveAccuracyPositivePrecisionPositiveRecallNegativePrecisionNegativeRecallF1ScoreAreaUnderPrecisionRecallCurveConfusionMatrix
0.74520452598977840.19740040333434760.9284885378533480.79096305668454760.73908603523104420.7645249487354750.87042801556420240.66397578203834510.48849294729027470.81404657933042210.8749940309174482{ Microsoft.ML.Data.ConfusionMatrix: PerClassPrecision: [ 0.764524948735475, 0.6639757820383451 ], PerClassRecall: [ 0.8704280155642024, 0.4884929472902747 ], Counts: [ [ 2237, 333 ], [ 689, 658 ] ], NumberOfClasses: 2 }

Visualize the quality metrics

In [11]:
string[] metricNames = 
    { 
        "Log Loss", 
        "Log Loss Reduction", 
        "Entropy", 
        "Area Under Curve", 
        "Accuracy",
        "Positive Recall", 
        "Negative Recall",
        "F1 Score"
    };

double[] metricValues = 
    { 
        qualityMetrics.LogLoss, 
        qualityMetrics.LogLossReduction, 
        qualityMetrics.Entropy, 
        qualityMetrics.AreaUnderRocCurve, 
        qualityMetrics.Accuracy,
        qualityMetrics.PositiveRecall, 
        qualityMetrics.NegativeRecall,
        qualityMetrics.F1Score
    };

var graph = new Graph.Bar()
{
    x = metricValues,
    y = metricNames,
    orientation = "h",
    marker = new Graph.Marker { color = "darkred" }
};

var chart = Chart.Plot(graph);

var layout = new Layout.Layout(){ title="Quality Metrics" };
chart.WithLayout(layout);

display(chart);

Drawing the Confusion Matrix

Default

In [12]:
display(qualityMetrics.ConfusionMatrix);
PerClassPrecisionPerClassRecallCountsNumberOfClasses
[ 0.764524948735475, 0.6639757820383451 ][ 0.8704280155642024, 0.4884929472902747 ][ [ 2237, 333 ], [ 689, 658 ] ]2

Custom Formatter for Binary Confusion Matrix

In [13]:
Formatter<ConfusionMatrix>.Register((df, writer) =>
{
    var rows = new List<IHtmlContent>();

    var cells = new List<IHtmlContent>();
    var n = df.Counts[0][0] + df.Counts[0][1] + df.Counts[1][0] + df.Counts[1][1];
    cells.Add(td[rowspan: 2, colspan: 2, style: "text-align: center; background-color: transparent"]("n = " + n));
    cells.Add(td[colspan: 2, style: "border: 1px solid black; text-align: center; padding: 24px; background-color: lightsteelblue"](b("Predicted")));
    rows.Add(tr[style: "background-color: transparent"](cells));
    
    cells = new List<IHtmlContent>();
    cells.Add(td[style:"border: 1px solid black; padding: 24px; background-color: #E3EAF3"](b("True")));
    cells.Add(td[style:"border: 1px solid black; padding: 24px; background-color: #E3EAF3"](b("False")));
    rows.Add(tr[style: "background-color: transparent"](cells));
    
    cells = new List<IHtmlContent>();
    cells.Add(td[rowspan: 2, style:"border: 1px solid black; text-align: center; padding: 24px;  background-color: lightsteelblue"](b("Actual")));
    cells.Add(td[style:"border: 1px solid black; text-align: center; padding: 24px; background-color: #E3EAF3"](b("True")));    
    cells.Add(td[style:"border: 1px solid black; padding: 24px"](df.Counts[0][0]));
    cells.Add(td[style:"border: 1px solid black; padding: 24px"](df.Counts[0][1]));
    rows.Add(tr[style: "background-color: transparent"](cells));
    
    cells = new List<IHtmlContent>();
    cells.Add(td[style:"border: 1px solid black; text-align: center; padding: 24px; background-color: #E3EAF3"](b("False")));
    cells.Add(td[style:"border: 1px solid black; padding: 24px"](df.Counts[1][0]));
    cells.Add(td[style:"border: 1px solid black; padding: 24px"](df.Counts[1][1]));
    rows.Add(tr(cells));

    var t = table(
        tbody(
            rows));

    writer.Write(t);
}, "text/html");

Tadaa

In [14]:
display(qualityMetrics.ConfusionMatrix);
n = 3917Predicted
TrueFalse
ActualTrue2237333
False689658

Create a prediction engine and use it on a random sample

In [15]:
// Create prediction engine
var predictionEngine = mlContext.Model.CreatePredictionEngine<RichBinaryClassificationData, BinaryClassificationPrediction>(model);

// Get a random data sample
var shuffledData = mlContext.Data.ShuffleRows(trainingData);
var rawSample = mlContext.Data.TakeRows(shuffledData, 1);
var sample = mlContext.Data.CreateEnumerable<RichBinaryClassificationData>(rawSample, false).First();
display(sample);

// Predict quality of sample
var prediction = predictionEngine.Predict(sample);
display(prediction);
LabelFixedAcidityVolatileAcidityCitricAcidResidualSugarChloridesFreeSulfurDioxideTotalSulfurDioxideDensityPhSulphatesAlcoholQuality
False7.10.370.6710.50.045491550.99753.160.448.75
LabelAsNumberLabelPredictedLabel
0FalseFalse