Classification Analysis of Suicide Rates

The classification and prediction of groups at high risks of suicide is important for preventing suicides. Thus, this analysis aims to investigate whether several features such as age, sex, population, etc. have classification power for identifying high risk groups.

I will be using the "Suicides Rates Overview 1985 to 2016" dataset from Kaggle to perform exploratory data analysis and classification analysis in order to find the best classifier.

In [186]:
import numpy as np
import pandas as pd
import seaborn as sns
import geopandas as gpd
from sklearn.utils import resample
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler
from sklearn import (
    linear_model, metrics, neural_network, pipeline, preprocessing, model_selection
)
from sklearn.neural_network import MLPClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
import eli5
from eli5.sklearn import PermutationImportance
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline

Loading the Dataset

In [136]:
url = "https://raw.githubusercontent.com/uplotnik/Data607/master/Suicide.csv"
data = pd.read_csv(url)

# Changing the names of the columns
data.columns = ["country", "year", "sex", "age", "suicides_no", "population", "suicides_per_100k_pop", "country-year", "HDI_for_year", "gdp_for_year", "gdp_per_capita", "generation"]

data.head()
Out[136]:
country year sex age suicides_no population suicides_per_100k_pop country-year HDI_for_year gdp_for_year gdp_per_capita generation
0 Albania 1987 male 15-24 years 21 312900 6.71 Albania1987 nan 2,156,624,900 796 Generation X
1 Albania 1987 male 35-54 years 16 308000 5.19 Albania1987 nan 2,156,624,900 796 Silent
2 Albania 1987 female 15-24 years 14 289700 4.83 Albania1987 nan 2,156,624,900 796 Generation X
3 Albania 1987 male 75+ years 1 21800 4.59 Albania1987 nan 2,156,624,900 796 G.I. Generation
4 Albania 1987 male 25-34 years 9 274300 3.28 Albania1987 nan 2,156,624,900 796 Boomers
In [137]:
data.tail()
Out[137]:
country year sex age suicides_no population suicides_per_100k_pop country-year HDI_for_year gdp_for_year gdp_per_capita generation
27815 Uzbekistan 2014 female 35-54 years 107 3620833 2.96 Uzbekistan2014 0.68 63,067,077,179 2309 Generation X
27816 Uzbekistan 2014 female 75+ years 9 348465 2.58 Uzbekistan2014 0.68 63,067,077,179 2309 Silent
27817 Uzbekistan 2014 male 5-14 years 60 2762158 2.17 Uzbekistan2014 0.68 63,067,077,179 2309 Generation Z
27818 Uzbekistan 2014 female 5-14 years 44 2631600 1.67 Uzbekistan2014 0.68 63,067,077,179 2309 Generation Z
27819 Uzbekistan 2014 female 55-74 years 21 1438935 1.46 Uzbekistan2014 0.68 63,067,077,179 2309 Boomers

Data Wrangling and Exploratory Data Analysis

In [138]:
# Creating a binary class 

# If suicides_per_100k_pop > mean(suicides_per_100k_pop), then high risk
# Else, low risk

mean = data["suicides_per_100k_pop"].mean()
data['risk'] = (data["suicides_per_100k_pop"] >= mean) *1
data.head()
Out[138]:
country year sex age suicides_no population suicides_per_100k_pop country-year HDI_for_year gdp_for_year gdp_per_capita generation risk
0 Albania 1987 male 15-24 years 21 312900 6.71 Albania1987 nan 2,156,624,900 796 Generation X 0
1 Albania 1987 male 35-54 years 16 308000 5.19 Albania1987 nan 2,156,624,900 796 Silent 0
2 Albania 1987 female 15-24 years 14 289700 4.83 Albania1987 nan 2,156,624,900 796 Generation X 0
3 Albania 1987 male 75+ years 1 21800 4.59 Albania1987 nan 2,156,624,900 796 G.I. Generation 0
4 Albania 1987 male 25-34 years 9 274300 3.28 Albania1987 nan 2,156,624,900 796 Boomers 0
In [139]:
# Checking for missing values
data.isnull().sum()
Out[139]:
country                      0
year                         0
sex                          0
age                          0
suicides_no                  0
population                   0
suicides_per_100k_pop        0
country-year                 0
HDI_for_year             19456
gdp_for_year                 0
gdp_per_capita               0
generation                   0
risk                         0
dtype: int64
In [140]:
# cleaning the data

# Removing commas from the "gdp_for_year" column
data['gdp_for_year'] = data['gdp_for_year'].str.replace(',', '')


# Dropping HDI_for_year because of a lot of missing values
data = data.drop(["HDI_for_year"], axis = 1)
In [141]:
# Checking if there is any missing values again
data.isnull().any()
Out[141]:
country                  False
year                     False
sex                      False
age                      False
suicides_no              False
population               False
suicides_per_100k_pop    False
country-year             False
gdp_for_year             False
gdp_per_capita           False
generation               False
risk                     False
dtype: bool
In [142]:
# Information about the dataset
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 27820 entries, 0 to 27819
Data columns (total 12 columns):
 #   Column                 Non-Null Count  Dtype  
---  ------                 --------------  -----  
 0   country                27820 non-null  object 
 1   year                   27820 non-null  int64  
 2   sex                    27820 non-null  object 
 3   age                    27820 non-null  object 
 4   suicides_no            27820 non-null  int64  
 5   population             27820 non-null  int64  
 6   suicides_per_100k_pop  27820 non-null  float64
 7   country-year           27820 non-null  object 
 8   gdp_for_year           27820 non-null  object 
 9   gdp_per_capita         27820 non-null  int64  
 10  generation             27820 non-null  object 
 11  risk                   27820 non-null  int64  
dtypes: float64(1), int64(5), object(6)
memory usage: 2.5+ MB
In [143]:
# Changing the data type of the "gdp_for_year" column

data["gdp_for_year"] = data["gdp_for_year"].astype(float)
In [144]:
# Summary of statistics for numerical variables

# 2 decimal places
pd.options.display.float_format = "{:.2f}".format
data.describe()
Out[144]:
year suicides_no population suicides_per_100k_pop gdp_for_year gdp_per_capita risk
count 27820.00 27820.00 27820.00 27820.00 27820.00 27820.00 27820.00
mean 2001.26 242.57 1844793.62 12.82 445580969025.73 16866.46 0.31
std 8.47 902.05 3911779.44 18.96 1453609985940.92 18887.58 0.46
min 1985.00 0.00 278.00 0.00 46919625.00 251.00 0.00
25% 1995.00 3.00 97498.50 0.92 8985352832.00 3447.00 0.00
50% 2002.00 25.00 430150.00 5.99 48114688201.00 9372.00 0.00
75% 2008.00 131.00 1486143.25 16.62 260202429150.00 24874.00 1.00
max 2016.00 22338.00 43805214.00 224.97 18120714000000.00 126352.00 1.00
In [145]:
# Number of unique values 
data.nunique()
Out[145]:
country                    101
year                        32
sex                          2
age                          6
suicides_no               2084
population               25564
suicides_per_100k_pop     5298
country-year              2321
gdp_for_year              2321
gdp_per_capita            2233
generation                   6
risk                         2
dtype: int64
In [146]:
# Unique values for relevant columns
cols = data[["country", "year", "sex", "age", "generation", "risk"]]

for column in list(cols):
    print(f"\033[1m  {column} unique values: \033[0m")
    print("")
    print(data[column].unique())
    print("")
  country unique values: 

['Albania' 'Antigua and Barbuda' 'Argentina' 'Armenia' 'Aruba' 'Australia'
 'Austria' 'Azerbaijan' 'Bahamas' 'Bahrain' 'Barbados' 'Belarus' 'Belgium'
 'Belize' 'Bosnia and Herzegovina' 'Brazil' 'Bulgaria' 'Cabo Verde'
 'Canada' 'Chile' 'Colombia' 'Costa Rica' 'Croatia' 'Cuba' 'Cyprus'
 'Czech Republic' 'Denmark' 'Dominica' 'Ecuador' 'El Salvador' 'Estonia'
 'Fiji' 'Finland' 'France' 'Georgia' 'Germany' 'Greece' 'Grenada'
 'Guatemala' 'Guyana' 'Hungary' 'Iceland' 'Ireland' 'Israel' 'Italy'
 'Jamaica' 'Japan' 'Kazakhstan' 'Kiribati' 'Kuwait' 'Kyrgyzstan' 'Latvia'
 'Lithuania' 'Luxembourg' 'Macau' 'Maldives' 'Malta' 'Mauritius' 'Mexico'
 'Mongolia' 'Montenegro' 'Netherlands' 'New Zealand' 'Nicaragua' 'Norway'
 'Oman' 'Panama' 'Paraguay' 'Philippines' 'Poland' 'Portugal'
 'Puerto Rico' 'Qatar' 'Republic of Korea' 'Romania' 'Russian Federation'
 'Saint Kitts and Nevis' 'Saint Lucia' 'Saint Vincent and Grenadines'
 'San Marino' 'Serbia' 'Seychelles' 'Singapore' 'Slovakia' 'Slovenia'
 'South Africa' 'Spain' 'Sri Lanka' 'Suriname' 'Sweden' 'Switzerland'
 'Thailand' 'Trinidad and Tobago' 'Turkey' 'Turkmenistan' 'Ukraine'
 'United Arab Emirates' 'United Kingdom' 'United States' 'Uruguay'
 'Uzbekistan']

  year unique values: 

[1987 1988 1989 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002
 2003 2004 2005 2006 2007 2008 2009 2010 1985 1986 1990 1991 2012 2013
 2014 2015 2011 2016]

  sex unique values: 

['male' 'female']

  age unique values: 

['15-24 years' '35-54 years' '75+ years' '25-34 years' '55-74 years'
 '5-14 years']

  generation unique values: 

['Generation X' 'Silent' 'G.I. Generation' 'Boomers' 'Millenials'
 'Generation Z']

  risk unique values: 

[0 1]

Correlation Matrix and Heatmap

In [147]:
# Correlation matrix for numerical variables
continuous = data.drop(["risk"], axis = 1)
corr = continuous.corr()
corr
Out[147]:
year suicides_no population suicides_per_100k_pop gdp_for_year gdp_per_capita
year 1.00 -0.00 0.01 -0.04 0.09 0.34
suicides_no -0.00 1.00 0.62 0.31 0.43 0.06
population 0.01 0.62 1.00 0.01 0.71 0.08
suicides_per_100k_pop -0.04 0.31 0.01 1.00 0.03 0.00
gdp_for_year 0.09 0.43 0.71 0.03 1.00 0.30
gdp_per_capita 0.34 0.06 0.08 0.00 0.30 1.00
In [151]:
# Correlation heatmap for numerical variables

# Lower triangle
corr_lower = corr.where(np.tril(np.ones(corr.shape)).astype(np.bool))

fig, ax = plt.subplots(figsize = (12,10))

ax = sns.heatmap(corr_lower, cmap = "Blues",annot = True, 
                 linewidths = 0.5, cbar_kws={"shrink": .7})

ax.tick_params(axis = "both", labelsize = 13)

# Fixing a bug
bottom, top = ax.get_ylim()
ax.set_ylim(bottom + 0.5, top - 0.5)

fig.text(0.06,-0.16,"Fig. 1. Correlation heatmap", fontsize = 17, fontweight = "bold")
Out[151]:
Text(0.06, -0.16, 'Fig. 1. Correlation heatmap')

Visualization of the Features

Total Suicides and Year

In [153]:
yearly_data = data.groupby("year").sum()

# Convert index into a column
yearly_data.reset_index(level = 0, inplace = True)

# Plotting
plt.rcParams["figure.figsize"] = [8, 6]
yearly_suicide_graph = yearly_data.plot(kind = "line", x = "year", y = "suicides_no")

# Setting the axis and title labels
yearly_suicide_graph.set_xlabel("Year", size = 14)
yearly_suicide_graph.set_ylabel("Total Suicides", size = 14)
yearly_suicide_graph.set_title("Fig. 2. Total number of suicides in the world from 1985 to 2016"
                               , x = 0.45, y = -0.22, 
                               size = 14, fontweight = "bold")

# Removing the top and right lines of the graph
yearly_suicide_graph.spines["right"].set_visible(False)
yearly_suicide_graph.spines["top"].set_visible(False)

# Removing the legend
yearly_suicide_graph.get_legend().remove()

# horizontal gridlines
yearly_suicide_graph.set_axisbelow(True)
yearly_suicide_graph.yaxis.grid(color = "gray", linestyle = "dashed")

Total Suicides and Year by Gender

In [260]:
# Checking to see if the number of females and males are balanced

print("Number of females and males:")
count1 = data['sex'].value_counts()
count1
Number of females and males:
Out[260]:
male      13910
female    13910
Name: sex, dtype: int64

Therefore, the number of men and women is equal.

In [155]:
gender_data = data.groupby(["year", "sex"]).sum()
gender_table = gender_data.pivot_table(index = "year", columns = "sex", 
                        values = "suicides_no", aggfunc = "sum")
# Plotting
plt.rcParams["figure.figsize"] = [8, 6]
gender_graph = gender_table.plot()

gender_graph.spines["right"].set_visible(False)
gender_graph.spines["top"].set_visible(False)

plt.xlabel("Year", size = 13)
plt.ylabel("Total Suicides", size = 13)
plt.title("Fig. 3. Total number of suicides by gender from 1985 to 2016", 
          x = 0.45, y = -0.22,
          size = 14, fontweight = "bold")

gender_graph.set_axisbelow(True)
gender_graph.yaxis.grid(color = "gray", linestyle = "dashed")

Total Suicides and Year by Age

In [267]:
# Checking to see if the age groups are balanced in terms of number of observations

print("Number of observations for each age band:")
count = data['age'].value_counts()
count
Number of observations for each age band:
Out[267]:
35-54 years    4642
55-74 years    4642
25-34 years    4642
75+ years      4642
15-24 years    4642
5-14 years     4610
Name: age, dtype: int64

Therefore, the number of observations is roughly equal for all age bands.

In [156]:
age_table = data.pivot_table('suicides_no', index='age', aggfunc='sum')
age_table.reset_index(level = 0, inplace = True)

# Plotting
# Colors of the bars
colors = ["b", "r", "g", "c", "m", "y"]
age_graph = age_table.plot(x = "age", y = "suicides_no", kind = "bar", 
                           width = 0.7, color = colors)

plt.xlabel("Age Group", size = 13)
plt.ylabel("Total Suicides", size = 13)
plt.title("Fig. 4. Total number of suicides by age", 
          x = 0.33, y = -0.3,
          size = 14, fontweight = "bold")
plt.xticks(rotation = 30)

age_graph.spines["right"].set_visible(False)
age_graph.spines["top"].set_visible(False)

# Removing the legend
age_graph.get_legend().remove()

age_graph.set_axisbelow(True)
age_graph.yaxis.grid(color = "gray", linestyle = "dashed")

Total Suicides and Year by Gender and Age

In [157]:
gender_age_data = data.groupby(["year", "sex", "age"]).suicides_no.sum().reset_index()

# Plotting
def plot(df, sex, age, ax, color):
    _df = df.query("(sex == @sex) & (age == @age)")
    _df.plot(
        kind = "line", x = "year", y = "suicides_no", ax = ax, color = color
)
    return ax
fig, ax = plt.subplots(1, 2, figsize=(11, 6), sharey=True)

# Naming the legends
labels = ["15-24 years", "35-54 years", "75+ years", "25-34 years", "55-74 years", "5-14 years"]

for (i, sex) in enumerate(gender_age_data.sex.unique()):
    plot(gender_age_data, sex,"15-24 years", ax[i], "b")
    plot(gender_age_data, sex,"35-54 years", ax[i], "r")
    plot(gender_age_data, sex,"75+ years", ax[i], "k")
    plot(gender_age_data, sex,"25-34 years", ax[i], "y")
    plot(gender_age_data, sex,"55-74 years", ax[i], "g")
    plot(gender_age_data, sex,"5-14 years", ax[i], "m")
    ax[i].set_title(str(sex))
    ax[i].legend(labels = labels)
    

for (i, _ax) in enumerate(ax):
    
    # setting axis labels
    _ax.set_xlabel("Year", size = 13)
    _ax.set_ylabel("Total Suicides", size = 13)
    
    _ax.spines['right'].set_visible(False)
    _ax.spines['top'].set_visible(False)
    
    _ax.yaxis.grid(True)
    

plt.suptitle("Fig. 5. Total number of suicides by age and sex from 1985 to 2016", x = 0.43, 
             y = -0.05, fontsize=15, fontweight = "bold")
plt.show()

Suicide Rates and Year for Different Countries

In [158]:
country_year_data = data.groupby(['country', 'year'], sort = True).sum()

country_year_data_suicides = country_year_data[["suicides_per_100k_pop"]].reset_index()
country_year_data_suicides.columns = ["country", "year", "Suicides per 100k Population"]

# Plotting
country_year_graph = sns.FacetGrid(country_year_data_suicides, col = "country", col_wrap = 4, sharey = True)
country_year_graph.map(plt.plot, "year", "Suicides per 100k Population", marker = ".")
plt.show()

World Map of Total Suicide Rates

In [160]:
country_data = data.groupby("country").sum()
country_data_suicides = country_data[["suicides_per_100k_pop"]].reset_index()
country_data_suicides = country_data_suicides.sort_values("suicides_per_100k_pop", ascending = False)
country_data_suicides.head()
Out[160]:
country suicides_per_100k_pop
75 Russian Federation 11305.13
52 Lithuania 10588.88
40 Hungary 10156.07
47 Kazakhstan 9519.52
73 Republic of Korea 9350.45
In [161]:
# Finding the list of countries that are not in the world dataset

world = gpd.read_file(gpd.datasets.get_path("naturalearth_lowres"))
world = world.set_index("iso_a3")

data_country = list(data.country.unique())
world_country = list(world.name.unique())

list(set(data_country)-set(world_country)) 
Out[161]:
['Macau',
 'Antigua and Barbuda',
 'Republic of Korea',
 'Maldives',
 'Bosnia and Herzegovina',
 'Seychelles',
 'Malta',
 'Saint Vincent and Grenadines',
 'Czech Republic',
 'Mauritius',
 'Singapore',
 'Saint Lucia',
 'Aruba',
 'Cabo Verde',
 'United States',
 'San Marino',
 'Barbados',
 'Bahrain',
 'Grenada',
 'Kiribati',
 'Dominica',
 'Russian Federation',
 'Saint Kitts and Nevis']
In [162]:
# Comparing the countries in the world dataset and the countries in the data, 
# we see that some countries have different names in the world dataset. 
# So, I will change the names of some countries in the dataset.
# The map below does not include all the countries in the dataset.

country_data_suicides.loc[country_data_suicides.country == "Bosnia and Herzegovina", "country"] = "Bosnia and Herz."
country_data_suicides.loc[country_data_suicides.country == "Russian Federation", "country"] = "Russia"
country_data_suicides.loc[country_data_suicides.country == "United States", "country"] = "United States of America"
country_data_suicides.loc[country_data_suicides.country == "Dominica", "country"] = "Dominican Rep."
country_data_suicides.loc[country_data_suicides.country == "Czech Republic", "country"] = "Czechia"
country_data_suicides.loc[country_data_suicides.country == "Republic of Korea", "country"] = "South Korea"
In [163]:
world_data = world.merge(country_data_suicides, left_on = "name", right_on = "country", how = "inner")

fig, ax = plt.subplots(figsize=(30,15))

# Plot the world
world.plot(ax=ax, edgecolor='black', color = "white")

world_data.plot(ax = ax, edgecolor = "black", column = "suicides_per_100k_pop", 
           legend = True, cmap = "YlOrRd", vmin = 0, vmax = 12000)

# turning off the axes
plt.axis("off")

# Adding text below the color bar
ax.annotate('Suicides per 100k',xy=(0.83, 0.055),  xycoords='figure fraction', size = 14)

plt.suptitle("Fig. 6. Total suicides per 100k population from 1985 to 2016", x = 0.34, y = 0.1, 
             fontsize=27, fontweight = "bold")
plt.show()