In the previous lecture, we learned about linear regression, which explores a linear relationship between the independent and dependent variables. In a sense, logistic regression is analogous to linear regression in that it is a generalized linear model. However there are several key differences in logistic regression that makes it very different. Let us explore those differences, and understand how logistic regression works
One key difference between logistic regression and linear regression is that the final output of the logistic regression model is binary, that is, 0 or 1, whereas linear regression has no such property. Thus a logistic regression model will always map from the real number space to a binary space of 0 and 1. Let us examine how logistic regression does this.
The key underlying equation that underlies the model is the logistic equation and is formulated as below:
In this case, the t in the equation is some linear combination of n variables, or a linear function in an n-dimensional feature space. The formulation of t is therefore identical to the linear regression formula.
To summarise the logistic equation:
A visualization of the outputs of the logistic equation is as below (note that this is but one possible output of a logit regression model):
It's important to realize that the logistic regression should output a binary set of numbers, namely 0 and 1. While the logistic equation does have an output between 0 and 1, the output is continuous. So how do we convert it to 0 and 1?
We use something called a threshold value, such that if the output of the F(x) > threshold, then 1 otherwise, 0. As a general formula:
The threshold value is the epsilon value in the equation, and is a key parameter in logistic regression, because it determines two key characteristics of a logistic regression classifier:
The Confusion Matrix
The confusion matrix is a good representation of the predictive power of a logistic regression model.
Sensitivity, otherwise known as a True Positive Rate, is the proportion of true positives out of the entire pool of "actual positives."
The formula is True Positive / ( True Positive + False Negatives )
Specificity, otherwise known as a True Negative Rate, is the proportion of true negatives out of the entire pool of "actual negatives."
The formula is True Negative / ( True Negative + False Positives )
It is important to understand that there will always be a trade-off between the two characteristics. This trade-off is best understood in terms of how we set our threshold values
Let us consider the trivial cases:classify everything as the same value.
If we classify all points as positive, sensitivity = 1 and specificity = 0. All positive data points have been classified as positive, along with all the negative data points.
On the contrary, if everything was classified as negative, sensitivity = 0 and specificity = 1 All negative points have been classified correctly.
Sensitivity decreases as threshold grows, since the predictor will classify more and more positive points incorrectly. Specificity increases as threhold grows, since the predictor will classify more and more negative points correctly.
This trade-off is represented by the ROC curve, which tells us how good a model performs in terms of specificity and sensitivity. Sample ROC curve
# Preprocess iris to create a binary case
iris_mod <- iris
iris_mod <- dplyr::mutate(iris_mod, is_setosa = as.numeric(Species=='setosa'))[-5]
print(iris_mod)
# Split train, test
ind <- sample(nrow(iris_mod),0.8*nrow(iris_mod))
train <- iris_mod[ind,]
test <- iris_mod[-ind,]
Sepal.Length Sepal.Width Petal.Length Petal.Width is_setosa 1 5.1 3.5 1.4 0.2 1 2 4.9 3.0 1.4 0.2 1 3 4.7 3.2 1.3 0.2 1 4 4.6 3.1 1.5 0.2 1 5 5.0 3.6 1.4 0.2 1 6 5.4 3.9 1.7 0.4 1 7 4.6 3.4 1.4 0.3 1 8 5.0 3.4 1.5 0.2 1 9 4.4 2.9 1.4 0.2 1 10 4.9 3.1 1.5 0.1 1 11 5.4 3.7 1.5 0.2 1 12 4.8 3.4 1.6 0.2 1 13 4.8 3.0 1.4 0.1 1 14 4.3 3.0 1.1 0.1 1 15 5.8 4.0 1.2 0.2 1 16 5.7 4.4 1.5 0.4 1 17 5.4 3.9 1.3 0.4 1 18 5.1 3.5 1.4 0.3 1 19 5.7 3.8 1.7 0.3 1 20 5.1 3.8 1.5 0.3 1 21 5.4 3.4 1.7 0.2 1 22 5.1 3.7 1.5 0.4 1 23 4.6 3.6 1.0 0.2 1 24 5.1 3.3 1.7 0.5 1 25 4.8 3.4 1.9 0.2 1 26 5.0 3.0 1.6 0.2 1 27 5.0 3.4 1.6 0.4 1 28 5.2 3.5 1.5 0.2 1 29 5.2 3.4 1.4 0.2 1 30 4.7 3.2 1.6 0.2 1 31 4.8 3.1 1.6 0.2 1 32 5.4 3.4 1.5 0.4 1 33 5.2 4.1 1.5 0.1 1 34 5.5 4.2 1.4 0.2 1 35 4.9 3.1 1.5 0.2 1 36 5.0 3.2 1.2 0.2 1 37 5.5 3.5 1.3 0.2 1 38 4.9 3.6 1.4 0.1 1 39 4.4 3.0 1.3 0.2 1 40 5.1 3.4 1.5 0.2 1 41 5.0 3.5 1.3 0.3 1 42 4.5 2.3 1.3 0.3 1 43 4.4 3.2 1.3 0.2 1 44 5.0 3.5 1.6 0.6 1 45 5.1 3.8 1.9 0.4 1 46 4.8 3.0 1.4 0.3 1 47 5.1 3.8 1.6 0.2 1 48 4.6 3.2 1.4 0.2 1 49 5.3 3.7 1.5 0.2 1 50 5.0 3.3 1.4 0.2 1 51 7.0 3.2 4.7 1.4 0 52 6.4 3.2 4.5 1.5 0 53 6.9 3.1 4.9 1.5 0 54 5.5 2.3 4.0 1.3 0 55 6.5 2.8 4.6 1.5 0 56 5.7 2.8 4.5 1.3 0 57 6.3 3.3 4.7 1.6 0 58 4.9 2.4 3.3 1.0 0 59 6.6 2.9 4.6 1.3 0 60 5.2 2.7 3.9 1.4 0 61 5.0 2.0 3.5 1.0 0 62 5.9 3.0 4.2 1.5 0 63 6.0 2.2 4.0 1.0 0 64 6.1 2.9 4.7 1.4 0 65 5.6 2.9 3.6 1.3 0 66 6.7 3.1 4.4 1.4 0 67 5.6 3.0 4.5 1.5 0 68 5.8 2.7 4.1 1.0 0 69 6.2 2.2 4.5 1.5 0 70 5.6 2.5 3.9 1.1 0 71 5.9 3.2 4.8 1.8 0 72 6.1 2.8 4.0 1.3 0 73 6.3 2.5 4.9 1.5 0 74 6.1 2.8 4.7 1.2 0 75 6.4 2.9 4.3 1.3 0 76 6.6 3.0 4.4 1.4 0 77 6.8 2.8 4.8 1.4 0 78 6.7 3.0 5.0 1.7 0 79 6.0 2.9 4.5 1.5 0 80 5.7 2.6 3.5 1.0 0 81 5.5 2.4 3.8 1.1 0 82 5.5 2.4 3.7 1.0 0 83 5.8 2.7 3.9 1.2 0 84 6.0 2.7 5.1 1.6 0 85 5.4 3.0 4.5 1.5 0 86 6.0 3.4 4.5 1.6 0 87 6.7 3.1 4.7 1.5 0 88 6.3 2.3 4.4 1.3 0 89 5.6 3.0 4.1 1.3 0 90 5.5 2.5 4.0 1.3 0 91 5.5 2.6 4.4 1.2 0 92 6.1 3.0 4.6 1.4 0 93 5.8 2.6 4.0 1.2 0 94 5.0 2.3 3.3 1.0 0 95 5.6 2.7 4.2 1.3 0 96 5.7 3.0 4.2 1.2 0 97 5.7 2.9 4.2 1.3 0 98 6.2 2.9 4.3 1.3 0 99 5.1 2.5 3.0 1.1 0 100 5.7 2.8 4.1 1.3 0 101 6.3 3.3 6.0 2.5 0 102 5.8 2.7 5.1 1.9 0 103 7.1 3.0 5.9 2.1 0 104 6.3 2.9 5.6 1.8 0 105 6.5 3.0 5.8 2.2 0 106 7.6 3.0 6.6 2.1 0 107 4.9 2.5 4.5 1.7 0 108 7.3 2.9 6.3 1.8 0 109 6.7 2.5 5.8 1.8 0 110 7.2 3.6 6.1 2.5 0 111 6.5 3.2 5.1 2.0 0 112 6.4 2.7 5.3 1.9 0 113 6.8 3.0 5.5 2.1 0 114 5.7 2.5 5.0 2.0 0 115 5.8 2.8 5.1 2.4 0 116 6.4 3.2 5.3 2.3 0 117 6.5 3.0 5.5 1.8 0 118 7.7 3.8 6.7 2.2 0 119 7.7 2.6 6.9 2.3 0 120 6.0 2.2 5.0 1.5 0 121 6.9 3.2 5.7 2.3 0 122 5.6 2.8 4.9 2.0 0 123 7.7 2.8 6.7 2.0 0 124 6.3 2.7 4.9 1.8 0 125 6.7 3.3 5.7 2.1 0 126 7.2 3.2 6.0 1.8 0 127 6.2 2.8 4.8 1.8 0 128 6.1 3.0 4.9 1.8 0 129 6.4 2.8 5.6 2.1 0 130 7.2 3.0 5.8 1.6 0 131 7.4 2.8 6.1 1.9 0 132 7.9 3.8 6.4 2.0 0 133 6.4 2.8 5.6 2.2 0 134 6.3 2.8 5.1 1.5 0 135 6.1 2.6 5.6 1.4 0 136 7.7 3.0 6.1 2.3 0 137 6.3 3.4 5.6 2.4 0 138 6.4 3.1 5.5 1.8 0 139 6.0 3.0 4.8 1.8 0 140 6.9 3.1 5.4 2.1 0 141 6.7 3.1 5.6 2.4 0 142 6.9 3.1 5.1 2.3 0 143 5.8 2.7 5.1 1.9 0 144 6.8 3.2 5.9 2.3 0 145 6.7 3.3 5.7 2.5 0 146 6.7 3.0 5.2 2.3 0 147 6.3 2.5 5.0 1.9 0 148 6.5 3.0 5.2 2.0 0 149 6.2 3.4 5.4 2.3 0 150 5.9 3.0 5.1 1.8 0
# Use glm function to predict species just with Sepal.Width
fit_logit <- glm(data=train,family=binomial,formula = is_setosa ~ Sepal.Width)
print(summary(fit_logit))
Call: glm(formula = is_setosa ~ Sepal.Width, family = binomial, data = train) Deviance Residuals: Min 1Q Median 3Q Max -2.6015 -0.5468 -0.2140 0.2627 1.9871 Coefficients: Estimate Std. Error z value Pr(>|z|) (Intercept) -21.228 3.967 -5.351 8.76e-08 *** Sepal.Width 6.468 1.237 5.227 1.72e-07 *** --- Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 (Dispersion parameter for binomial family taken to be 1) Null deviance: 146.607 on 119 degrees of freedom Residual deviance: 78.476 on 118 degrees of freedom AIC: 82.476 Number of Fisher Scoring iterations: 6
# predict using predict
pred_logit <- predict(object = fit_logit, newdata = test)
# use roc function in pROC library
library(pROC)
roc_curve <- roc(predictor = pred_logit, response=test$is_setosa )
plot(roc_curve)
auc(curve)
Call: roc.default(response = test$is_setosa, predictor = pred_logit) Data: pred_logit in 16 controls (test$is_setosa 0) < 14 cases (test$is_setosa 1). Area under the curve: 0.6987
The area under the ROC curve should always be greater and equal to the total proportion of the majority class, since the worse case is classifying everything as one. The closer to 1 the area is, the stronger the model strength
auc(roc_curve)
Linear and logistic regression aren't the only ways to make predications, we can also use a method called cllassifcation and regression trees, or CART.
CART builds what is called a tree by splitting on the values of the independent variables. To predict the outcome for a new observation or case, you can follow the splits in the tree and at the end, you predict the most frequent outcome in the training set that followed the same path. Some advantages of CART are that it does not assume a linear model, like logistic regression or linear regression, and it's a very to interpret how the model works.
Let's make a simple CART model. We'll be attempting to predict supreme court decisions, as mentioned in lecture.
install.packages("rpart")
library(rpart)
install.packages("rpart.plot")
library(rpart.plot)
TrainCourt = read.csv("resources/TrainCourt.csv")
TestCourt = read.csv("resources/TestCourt.csv")
SupremeCourtTree = rpart(Reverse ~ Circuit + Issue + Petitioner + Respondent + LowerCourt +
Unconst, data = Train, method="class", minbucket=25)
File "<ipython-input-2-cff4912f83d6>", line 8 SupremeCourtTree = rpart(Reverse ~ Circuit + Issue + Petitioner + Respondent + LowerCourt + ^ SyntaxError: invalid syntax
In the above code, we're trying to predict the value of Reverse (whether or not the supreme court will reverse a lower court's decision) using the Circuit, Issue, Petitioner, Respondent, LowerCourt, and Uncost features, all found in the Train dataset.
The minbucket parameter controls how many splits are made in our tree by setting the minimum number of observed data points in each branch of the tree. If it’s too small, overfitting will occur (variance).If it’s too large, model will be too simple and inaccurate (bias). You'll learn more on bias vs. variance in future lectures.
It's very easy to visualize our CART model, and can be done with the prp function. This is one of the reasons CART is more interpretable than Logisitc regression. We can see exactly how it works.
prp(SupremeCourtTree)
The predict() function allows us to apply our model and predict future cases.
PredictCART = predict(SupremeCourtTree, newdata = TestCourt, type = "class")