Creates a regression model to predict the quality of wine using 11 physicochemical features
#r "nuget:Microsoft.ML, 1.4.0"
#r "nuget:XPlot.Plotly, 3.0.1"
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
using Microsoft.ML.Transforms;
using XPlot.Plotly;
public class RegressionData
{
[LoadColumn(0)]
public float FixedAcidity;
[LoadColumn(1)]
public float VolatileAcidity;
[LoadColumn(2)]
public float CitricAcid;
[LoadColumn(3)]
public float ResidualSugar;
[LoadColumn(4)]
public float Chlorides;
[LoadColumn(5)]
public float FreeSulfurDioxide;
[LoadColumn(6)]
public float TotalSulfurDioxide;
[LoadColumn(7)]
public float Density;
[LoadColumn(8)]
public float Ph;
[LoadColumn(9)]
public float Sulphates;
[LoadColumn(10)]
public float Alcohol;
[LoadColumn(11)]
public float Label;
}
public class RegressionPrediction
{
[ColumnName("Label")]
public float Label;
[ColumnName("Score")]
public float PredictedLabel;
}
string[] featureNames =
{
"FixedAcidity",
"VolatileAcidity",
"CitricAcid",
"ResidualSugar",
"Chlorides",
"FreeSulfurDioxide",
"TotalSulfurDioxide",
"Density",
"Ph",
"Sulphates",
"Alcohol"
};
string[] metricNames =
{
"Mean Absolute Error",
"Mean Squared Error",
"Root Mean Squared Error",
"Loss Function",
"R Squared"
};
var mlContext = new MLContext(seed: null);
var trainingData = mlContext.Data.LoadFromTextFile<RegressionData>(
"./WineQuality_White_Train.csv",
separatorChar: ';',
hasHeader: true);
var pipeline =
mlContext.Transforms.ReplaceMissingValues(
outputColumnName: "PreparedFixedAcidity",
inputColumnName: "FixedAcidity",
replacementMode: MissingValueReplacingEstimator.ReplacementMode.Mean)
.Append(mlContext.Transforms.DropColumns("FixedAcidity"))
.Append(mlContext.Transforms.Concatenate("Features",
new[]
{
"PreparedFixedAcidity",
"VolatileAcidity",
"CitricAcid",
"ResidualSugar",
"Chlorides",
"FreeSulfurDioxide",
"TotalSulfurDioxide",
"Density",
"Ph",
"Sulphates",
"Alcohol"
}))
.Append(mlContext.Transforms.NormalizeMeanVariance("Features"))
.Append(mlContext.Regression.Trainers.Sdca());
var model = pipeline.Fit(trainingData);
// Load the raw test data.
var testData = mlContext.Data.LoadFromTextFile<RegressionData>(
"./WineQuality_White_Test.csv",
separatorChar: ';',
hasHeader: true);
// Score the test data and calculate the metrics.
var scoredData = model.Transform(testData);
var qualityMetrics = mlContext.Regression.Evaluate(scoredData);
display(qualityMetrics);
MeanAbsoluteError | MeanSquaredError | RootMeanSquaredError | LossFunction | RSquared |
---|---|---|---|---|
0.5416236609129366 | 0.5025304962757207 | 0.7088938540259188 | 0.502530496086152 | 0.16448854936302226 |
double[] metricValues =
{
qualityMetrics.MeanAbsoluteError,
qualityMetrics.MeanSquaredError,
qualityMetrics.RootMeanSquaredError,
qualityMetrics.LossFunction,
qualityMetrics.RSquared
};
var graph = new Graph.Bar()
{
x = metricValues, y = metricNames,
orientation = "h", marker = new Graph.Marker { color = "darkred" }
};
var chart = Chart.Plot(graph);
var layout = new Layout.Layout(){ title="Quality Metrics" };
chart.WithLayout(layout);
display(chart);
var regressionModel = model.Last() as RegressionPredictionTransformer<LinearRegressionModelParameters>;
var contributions = regressionModel.Model.Weights;
var graph2 = new Graph.Bar()
{
x = contributions,
y = featureNames,
orientation = "h",
marker = new Graph.Marker { color = "darkblue" }
};
var chart2 = Chart.Plot(graph2);
var layout2 = new Layout.Layout(){ title="Feature Contributions" };
chart2.WithLayout(layout2);
display(chart2);
// Create prediction engine
var predictionEngine = mlContext.Model.CreatePredictionEngine<RegressionData, RegressionPrediction>(model);
// Get a random data sample
var shuffledData = mlContext.Data.ShuffleRows(trainingData);
var rawSample = mlContext.Data.TakeRows(shuffledData, 1);
var sample = mlContext.Data.CreateEnumerable<RegressionData>(rawSample, false).First();
display(sample);
// Predict quality of sample
var prediction = predictionEngine.Predict(sample);
display(prediction);
FixedAcidity | VolatileAcidity | CitricAcid | ResidualSugar | Chlorides | FreeSulfurDioxide | TotalSulfurDioxide | Density | Ph | Sulphates | Alcohol | Label |
---|---|---|---|---|---|---|---|---|---|---|---|
6.4 | 0.33 | 0.28 | 1.1 | 0.038 | 30 | 110 | 0.9917 | 3.12 | 0.42 | 10.5 | 6 |
Label | PredictedLabel |
---|---|
6 | 5.665756 |