Creates a binary classification model to predict the quality of wine using 11 physicochemical features. Uses the DataFrame API to read the raw data and prepare it.
#r "nuget:Microsoft.ML, 1.4.0"
#r "nuget:XPlot.Plotly, 3.0.1"
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using XPlot.Plotly;
public class BinaryClassificationData
{
[LoadColumn(0)]
public float FixedAcidity;
[LoadColumn(1)]
public float VolatileAcidity;
[LoadColumn(2)]
public float CitricAcid;
[LoadColumn(3)]
public float ResidualSugar;
[LoadColumn(4)]
public float Chlorides;
[LoadColumn(5)]
public float FreeSulfurDioxide;
[LoadColumn(6)]
public float TotalSulfurDioxide;
[LoadColumn(7)]
public float Density;
[LoadColumn(8)]
public float Ph;
[LoadColumn(9)]
public float Sulphates;
[LoadColumn(10)]
public float Alcohol;
[LoadColumn(11)]
public float Quality;
}
public class RichBinaryClassificationData: BinaryClassificationData
{
public bool Label => Quality > 5;
}
public class BinaryClassificationPrediction
{
public bool Label;
[ColumnName("PredictedLabel")]
public bool PredictedLabel;
public int LabelAsNumber => PredictedLabel ? 1 : 0;
}
#r "nuget:Microsoft.Data.Analysis,0.2.0"
using Microsoft.Data.Analysis;
using Microsoft.AspNetCore.Html;
// Convenient custom formatter.
Formatter<DataFrame>.Register((df, writer) =>
{
var headers = new List<IHtmlContent>();
headers.Add(th(i("index")));
headers.AddRange(df.Columns.Select(c => (IHtmlContent) th(c.Name)));
var rows = new List<List<IHtmlContent>>();
var take = 5;
for (var i = 0; i < Math.Min(take, df.Rows.Count); i++)
{
var cells = new List<IHtmlContent>();
cells.Add(td(i));
foreach (var obj in df.Rows[i])
{
cells.Add(td(obj));
}
rows.Add(cells);
}
var t = table(
thead(
headers),
tbody(
rows.Select(
r => tr(r))));
writer.Write(t);
}, "text/html");
var trainingData = DataFrame.LoadCsv(
"./WineQuality_White_Train.csv",
separator: ';',
columnNames: new[]
{
"FixedAcidity",
"VolatileAcidity",
"CitricAcid",
"ResidualSugar",
"Chlorides",
"FreeSulfurDioxide",
"TotalSulfurDioxide",
"Density",
"Ph",
"Sulphates",
"Alcohol",
"Quality"
});
display(trainingData);
index | FixedAcidity | VolatileAcidity | CitricAcid | ResidualSugar | Chlorides | FreeSulfurDioxide | TotalSulfurDioxide | Density | Ph | Sulphates | Alcohol | Quality |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 7 | 0.27 | 0.36 | 20.7 | 0.045 | 45 | 170 | 1.001 | 3 | 0.45 | 8.8 | 6 |
1 | 6.3 | 0.3 | 0.34 | 1.6 | 0.049 | 14 | 132 | 0.994 | 3.3 | 0.49 | 9.5 | 6 |
2 | 8.1 | 0.28 | 0.4 | 6.9 | 0.05 | 30 | 97 | 0.9951 | 3.26 | 0.44 | 10.1 | 6 |
3 | 7.2 | 0.23 | 0.32 | 8.5 | 0.058 | 47 | 186 | 0.9956 | 3.19 | 0.4 | 9.9 | 6 |
4 | 7.2 | 0.23 | 0.32 | 8.5 | 0.058 | 47 | 186 | 0.9956 | 3.19 | 0.4 | 9.9 | 6 |
// Create the Label column and add it to the data.
var labelCol = trainingData["Quality"].ElementwiseGreaterThanOrEqual(6);
labelCol.SetName("Label");
trainingData.Columns.Add(labelCol);
// This works, but we need the Quality column in later cells ...
// trainingData.Columns.Remove(trainingData["Quality"]);
display(trainingData);
index | FixedAcidity | VolatileAcidity | CitricAcid | ResidualSugar | Chlorides | FreeSulfurDioxide | TotalSulfurDioxide | Density | Ph | Sulphates | Alcohol | Quality | Label |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 7 | 0.27 | 0.36 | 20.7 | 0.045 | 45 | 170 | 1.001 | 3 | 0.45 | 8.8 | 6 | True |
1 | 6.3 | 0.3 | 0.34 | 1.6 | 0.049 | 14 | 132 | 0.994 | 3.3 | 0.49 | 9.5 | 6 | True |
2 | 8.1 | 0.28 | 0.4 | 6.9 | 0.05 | 30 | 97 | 0.9951 | 3.26 | 0.44 | 10.1 | 6 | True |
3 | 7.2 | 0.23 | 0.32 | 8.5 | 0.058 | 47 | 186 | 0.9956 | 3.19 | 0.4 | 9.9 | 6 | True |
4 | 7.2 | 0.23 | 0.32 | 8.5 | 0.058 | 47 | 186 | 0.9956 | 3.19 | 0.4 | 9.9 | 6 | True |
var mlContext = new MLContext(seed: null);
// Define the pipeline.
var pipeline =
mlContext.Transforms.ReplaceMissingValues(
outputColumnName: "FixedAcidity",
replacementMode: MissingValueReplacingEstimator.ReplacementMode.Mean)
.Append(mlContext.Transforms.Concatenate("Features",
new[]
{
"FixedAcidity",
"VolatileAcidity",
"CitricAcid",
"ResidualSugar",
"Chlorides",
"FreeSulfurDioxide",
"TotalSulfurDioxide",
"Density",
"Ph",
"Sulphates",
"Alcohol"
}))
.Append(mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression());
var model = pipeline.Fit(trainingData);
// Load the raw test data.
var testData = mlContext.Data.LoadFromTextFile<BinaryClassificationData>(
"./WineQuality_White_Test.csv",
separatorChar: ';',
hasHeader: true);
// Calculate the Label (IDataView to IEnumerable to IDataView).
var stronglyTypedTestData = mlContext.Data.CreateEnumerable<RichBinaryClassificationData>(trainingData, false);
testData = mlContext.Data.LoadFromEnumerable(stronglyTypedTestData);
// Score the test data and calculate the metrics.
var scoredData = model.Transform(testData);
var qualityMetrics = mlContext.BinaryClassification.Evaluate(scoredData);
display(qualityMetrics);
LogLoss | LogLossReduction | Entropy | AreaUnderRocCurve | Accuracy | PositivePrecision | PositiveRecall | NegativePrecision | NegativeRecall | F1Score | AreaUnderPrecisionRecallCurve | ConfusionMatrix |
---|---|---|---|---|---|---|---|---|---|---|---|
0.7452045259897784 | 0.1974004033343476 | 0.928488537853348 | 0.7909630566845476 | 0.7390860352310442 | 0.764524948735475 | 0.8704280155642024 | 0.6639757820383451 | 0.4884929472902747 | 0.8140465793304221 | 0.8749940309174482 | { Microsoft.ML.Data.ConfusionMatrix: PerClassPrecision: [ 0.764524948735475, 0.6639757820383451 ], PerClassRecall: [ 0.8704280155642024, 0.4884929472902747 ], Counts: [ [ 2237, 333 ], [ 689, 658 ] ], NumberOfClasses: 2 } |
string[] metricNames =
{
"Log Loss",
"Log Loss Reduction",
"Entropy",
"Area Under Curve",
"Accuracy",
"Positive Recall",
"Negative Recall",
"F1 Score"
};
double[] metricValues =
{
qualityMetrics.LogLoss,
qualityMetrics.LogLossReduction,
qualityMetrics.Entropy,
qualityMetrics.AreaUnderRocCurve,
qualityMetrics.Accuracy,
qualityMetrics.PositiveRecall,
qualityMetrics.NegativeRecall,
qualityMetrics.F1Score
};
var graph = new Graph.Bar()
{
x = metricValues,
y = metricNames,
orientation = "h",
marker = new Graph.Marker { color = "darkred" }
};
var chart = Chart.Plot(graph);
var layout = new Layout.Layout(){ title="Quality Metrics" };
chart.WithLayout(layout);
display(chart);
display(qualityMetrics.ConfusionMatrix);
PerClassPrecision | PerClassRecall | Counts | NumberOfClasses |
---|---|---|---|
[ 0.764524948735475, 0.6639757820383451 ] | [ 0.8704280155642024, 0.4884929472902747 ] | [ [ 2237, 333 ], [ 689, 658 ] ] | 2 |
Formatter<ConfusionMatrix>.Register((df, writer) =>
{
var rows = new List<IHtmlContent>();
var cells = new List<IHtmlContent>();
var n = df.Counts[0][0] + df.Counts[0][1] + df.Counts[1][0] + df.Counts[1][1];
cells.Add(td[rowspan: 2, colspan: 2, style: "text-align: center; background-color: transparent"]("n = " + n));
cells.Add(td[colspan: 2, style: "border: 1px solid black; text-align: center; padding: 24px; background-color: lightsteelblue"](b("Predicted")));
rows.Add(tr[style: "background-color: transparent"](cells));
cells = new List<IHtmlContent>();
cells.Add(td[style:"border: 1px solid black; padding: 24px; background-color: #E3EAF3"](b("True")));
cells.Add(td[style:"border: 1px solid black; padding: 24px; background-color: #E3EAF3"](b("False")));
rows.Add(tr[style: "background-color: transparent"](cells));
cells = new List<IHtmlContent>();
cells.Add(td[rowspan: 2, style:"border: 1px solid black; text-align: center; padding: 24px; background-color: lightsteelblue"](b("Actual")));
cells.Add(td[style:"border: 1px solid black; text-align: center; padding: 24px; background-color: #E3EAF3"](b("True")));
cells.Add(td[style:"border: 1px solid black; padding: 24px"](df.Counts[0][0]));
cells.Add(td[style:"border: 1px solid black; padding: 24px"](df.Counts[0][1]));
rows.Add(tr[style: "background-color: transparent"](cells));
cells = new List<IHtmlContent>();
cells.Add(td[style:"border: 1px solid black; text-align: center; padding: 24px; background-color: #E3EAF3"](b("False")));
cells.Add(td[style:"border: 1px solid black; padding: 24px"](df.Counts[1][0]));
cells.Add(td[style:"border: 1px solid black; padding: 24px"](df.Counts[1][1]));
rows.Add(tr(cells));
var t = table(
tbody(
rows));
writer.Write(t);
}, "text/html");
display(qualityMetrics.ConfusionMatrix);
n = 3917 | Predicted | ||
True | False | ||
Actual | True | 2237 | 333 |
False | 689 | 658 |
// Create prediction engine
var predictionEngine = mlContext.Model.CreatePredictionEngine<RichBinaryClassificationData, BinaryClassificationPrediction>(model);
// Get a random data sample
var shuffledData = mlContext.Data.ShuffleRows(trainingData);
var rawSample = mlContext.Data.TakeRows(shuffledData, 1);
var sample = mlContext.Data.CreateEnumerable<RichBinaryClassificationData>(rawSample, false).First();
display(sample);
// Predict quality of sample
var prediction = predictionEngine.Predict(sample);
display(prediction);
Label | FixedAcidity | VolatileAcidity | CitricAcid | ResidualSugar | Chlorides | FreeSulfurDioxide | TotalSulfurDioxide | Density | Ph | Sulphates | Alcohol | Quality |
---|---|---|---|---|---|---|---|---|---|---|---|---|
False | 7.1 | 0.37 | 0.67 | 10.5 | 0.045 | 49 | 155 | 0.9975 | 3.16 | 0.44 | 8.7 | 5 |
LabelAsNumber | Label | PredictedLabel |
---|---|---|
0 | False | False |