The classification and prediction of groups at high risks of suicide is important for preventing suicides. Thus, this analysis aims to investigate whether several features such as age, sex, population, etc. have classification power for identifying high risk groups.
I will be using the "Suicides Rates Overview 1985 to 2016" dataset from Kaggle to perform exploratory data analysis and classification analysis in order to find the best classifier.
import numpy as np
import pandas as pd
import seaborn as sns
import geopandas as gpd
from sklearn.utils import resample
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn import (
linear_model, metrics, neural_network, pipeline, preprocessing, model_selection
)
from sklearn.neural_network import MLPClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
import eli5
from eli5.sklearn import PermutationImportance
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline
url = "https://raw.githubusercontent.com/uplotnik/Data607/master/Suicide.csv"
data = pd.read_csv(url)
# Changing the names of the columns
data.columns = ["country", "year", "sex", "age", "suicides_no", "population", "suicides_per_100k_pop", "country-year", "HDI_for_year", "gdp_for_year", "gdp_per_capita", "generation"]
data.head()
data.tail()
# Creating a binary class
# If suicides_per_100k_pop > mean(suicides_per_100k_pop), then high risk
# Else, low risk
mean = data["suicides_per_100k_pop"].mean()
data['risk'] = (data["suicides_per_100k_pop"] >= mean) *1
data.head()
# Checking for missing values
data.isnull().sum()
# cleaning the data
# Removing commas from the "gdp_for_year" column
data['gdp_for_year'] = data['gdp_for_year'].str.replace(',', '')
# Dropping HDI_for_year because of a lot of missing values
data = data.drop(["HDI_for_year"], axis = 1)
# Checking if there is any missing values again
data.isnull().any()
# Information about the dataset
data.info()
# Changing the data type of the "gdp_for_year" column
data["gdp_for_year"] = data["gdp_for_year"].astype(float)
# Summary of statistics for numerical variables
# 2 decimal places
pd.options.display.float_format = "{:.2f}".format
data.describe()
# Number of unique values
data.nunique()
# Unique values for relevant columns
cols = data[["country", "year", "sex", "age", "generation", "risk"]]
for column in list(cols):
print(f"\033[1m {column} unique values: \033[0m")
print("")
print(data[column].unique())
print("")
# Correlation matrix for numerical variables
continuous = data.drop(["risk"], axis = 1)
corr = continuous.corr()
corr
# Correlation heatmap for numerical variables
# Lower triangle
corr_lower = corr.where(np.tril(np.ones(corr.shape)).astype(np.bool))
fig, ax = plt.subplots(figsize = (12,10))
ax = sns.heatmap(corr_lower, cmap = "Blues",annot = True,
linewidths = 0.5, cbar_kws={"shrink": .7})
ax.tick_params(axis = "both", labelsize = 13)
# Fixing a bug
bottom, top = ax.get_ylim()
ax.set_ylim(bottom + 0.5, top - 0.5)
fig.text(0.06,-0.16,"Fig. 1. Correlation heatmap", fontsize = 17, fontweight = "bold")
yearly_data = data.groupby("year").sum()
# Convert index into a column
yearly_data.reset_index(level = 0, inplace = True)
# Plotting
plt.rcParams["figure.figsize"] = [8, 6]
yearly_suicide_graph = yearly_data.plot(kind = "line", x = "year", y = "suicides_no")
# Setting the axis and title labels
yearly_suicide_graph.set_xlabel("Year", size = 14)
yearly_suicide_graph.set_ylabel("Total Suicides", size = 14)
yearly_suicide_graph.set_title("Fig. 2. Total number of suicides in the world from 1985 to 2016"
, x = 0.45, y = -0.22,
size = 14, fontweight = "bold")
# Removing the top and right lines of the graph
yearly_suicide_graph.spines["right"].set_visible(False)
yearly_suicide_graph.spines["top"].set_visible(False)
# Removing the legend
yearly_suicide_graph.get_legend().remove()
# horizontal gridlines
yearly_suicide_graph.set_axisbelow(True)
yearly_suicide_graph.yaxis.grid(color = "gray", linestyle = "dashed")
# Checking to see if the number of females and males are balanced
print("Number of females and males:")
count1 = data['sex'].value_counts()
count1
Therefore, the number of men and women is equal.
gender_data = data.groupby(["year", "sex"]).sum()
gender_table = gender_data.pivot_table(index = "year", columns = "sex",
values = "suicides_no", aggfunc = "sum")
# Plotting
plt.rcParams["figure.figsize"] = [8, 6]
gender_graph = gender_table.plot()
gender_graph.spines["right"].set_visible(False)
gender_graph.spines["top"].set_visible(False)
plt.xlabel("Year", size = 13)
plt.ylabel("Total Suicides", size = 13)
plt.title("Fig. 3. Total number of suicides by gender from 1985 to 2016",
x = 0.45, y = -0.22,
size = 14, fontweight = "bold")
gender_graph.set_axisbelow(True)
gender_graph.yaxis.grid(color = "gray", linestyle = "dashed")
# Checking to see if the age groups are balanced in terms of number of observations
print("Number of observations for each age band:")
count = data['age'].value_counts()
count
Therefore, the number of observations is roughly equal for all age bands.
age_table = data.pivot_table('suicides_no', index='age', aggfunc='sum')
age_table.reset_index(level = 0, inplace = True)
# Plotting
# Colors of the bars
colors = ["b", "r", "g", "c", "m", "y"]
age_graph = age_table.plot(x = "age", y = "suicides_no", kind = "bar",
width = 0.7, color = colors)
plt.xlabel("Age Group", size = 13)
plt.ylabel("Total Suicides", size = 13)
plt.title("Fig. 4. Total number of suicides by age",
x = 0.33, y = -0.3,
size = 14, fontweight = "bold")
plt.xticks(rotation = 30)
age_graph.spines["right"].set_visible(False)
age_graph.spines["top"].set_visible(False)
# Removing the legend
age_graph.get_legend().remove()
age_graph.set_axisbelow(True)
age_graph.yaxis.grid(color = "gray", linestyle = "dashed")
gender_age_data = data.groupby(["year", "sex", "age"]).suicides_no.sum().reset_index()
# Plotting
def plot(df, sex, age, ax, color):
_df = df.query("(sex == @sex) & (age == @age)")
_df.plot(
kind = "line", x = "year", y = "suicides_no", ax = ax, color = color
)
return ax
fig, ax = plt.subplots(1, 2, figsize=(11, 6), sharey=True)
# Naming the legends
labels = ["15-24 years", "35-54 years", "75+ years", "25-34 years", "55-74 years", "5-14 years"]
for (i, sex) in enumerate(gender_age_data.sex.unique()):
plot(gender_age_data, sex,"15-24 years", ax[i], "b")
plot(gender_age_data, sex,"35-54 years", ax[i], "r")
plot(gender_age_data, sex,"75+ years", ax[i], "k")
plot(gender_age_data, sex,"25-34 years", ax[i], "y")
plot(gender_age_data, sex,"55-74 years", ax[i], "g")
plot(gender_age_data, sex,"5-14 years", ax[i], "m")
ax[i].set_title(str(sex))
ax[i].legend(labels = labels)
for (i, _ax) in enumerate(ax):
# setting axis labels
_ax.set_xlabel("Year", size = 13)
_ax.set_ylabel("Total Suicides", size = 13)
_ax.spines['right'].set_visible(False)
_ax.spines['top'].set_visible(False)
_ax.yaxis.grid(True)
plt.suptitle("Fig. 5. Total number of suicides by age and sex from 1985 to 2016", x = 0.43,
y = -0.05, fontsize=15, fontweight = "bold")
plt.show()
country_year_data = data.groupby(['country', 'year'], sort = True).sum()
country_year_data_suicides = country_year_data[["suicides_per_100k_pop"]].reset_index()
country_year_data_suicides.columns = ["country", "year", "Suicides per 100k Population"]
# Plotting
country_year_graph = sns.FacetGrid(country_year_data_suicides, col = "country", col_wrap = 4, sharey = True)
country_year_graph.map(plt.plot, "year", "Suicides per 100k Population", marker = ".")
plt.show()
country_data = data.groupby("country").sum()
country_data_suicides = country_data[["suicides_per_100k_pop"]].reset_index()
country_data_suicides = country_data_suicides.sort_values("suicides_per_100k_pop", ascending = False)
country_data_suicides.head()
# Finding the list of countries that are not in the world dataset
world = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))
world = world.set_index("iso_a3")
data_country = list(data.country.unique())
world_country = list(world.name.unique())
list(set(data_country)-set(world_country))
# Comparing the countries in the world dataset and the countries in the data,
# we see that some countries have different names in the world dataset.
# So, I will change the names of some countries in the dataset.
# The map below does not include all the countries in the dataset.
country_data_suicides.loc[country_data_suicides.country == "Bosnia and Herzegovina", "country"] = "Bosnia and Herz."
country_data_suicides.loc[country_data_suicides.country == "Russian Federation", "country"] = "Russia"
country_data_suicides.loc[country_data_suicides.country == "United States", "country"] = "United States of America"
country_data_suicides.loc[country_data_suicides.country == "Dominica", "country"] = "Dominican Rep."
country_data_suicides.loc[country_data_suicides.country == "Czech Republic", "country"] = "Czechia"
country_data_suicides.loc[country_data_suicides.country == "Republic of Korea", "country"] = "South Korea"
world_data = world.merge(country_data_suicides, left_on = "name", right_on = "country", how = "inner")
fig, ax = plt.subplots(figsize=(30,15))
# Plot the world
world.plot(ax=ax, edgecolor='black', color = "white")
world_data.plot(ax = ax, edgecolor = "black", column = "suicides_per_100k_pop",
legend = True, cmap = "YlOrRd", vmin = 0, vmax = 12000)
# turning off the axes
plt.axis("off")
# Adding text below the color bar
ax.annotate('Suicides per 100k',xy=(0.83, 0.055), xycoords='figure fraction', size = 14)
plt.suptitle("Fig. 6. Total suicides per 100k population from 1985 to 2016", x = 0.34, y = 0.1,
fontsize=27, fontweight = "bold")
plt.show()