import re
import glob
from pathlib import Path
from random import shuffle
from math import exp, log
from collections import defaultdict, Counter
from typing import NamedTuple, List, Set, Tuple
# Ensure that we have a `data` directory we use to store downloaded data
!mkdir -p data
data_dir: Path = Path('data')
# We're using the "Enron Spam" data set
!wget -nc -P data http://nlp.cs.aueb.gr/software_and_datasets/Enron-Spam/preprocessed/enron1.tar.gz
--2020-02-09 12:03:06-- http://nlp.cs.aueb.gr/software_and_datasets/Enron-Spam/preprocessed/enron1.tar.gz Resolving nlp.cs.aueb.gr (nlp.cs.aueb.gr)... 195.251.248.252 Connecting to nlp.cs.aueb.gr (nlp.cs.aueb.gr)|195.251.248.252|:80... connected. HTTP request sent, awaiting response... 200 OK Length: 1802573 (1.7M) [application/x-gzip] Saving to: ‘data/enron1.tar.gz’ enron1.tar.gz 100%[===================>] 1.72M 920KB/s in 1.9s 2020-02-09 12:03:08 (920 KB/s) - ‘data/enron1.tar.gz’ saved [1802573/1802573]
!tar -xzf data/enron1.tar.gz -C data
# The data set has 2 directories: One for `spam` messages, one for `ham` messages
spam_data_path: Path = data_dir / 'enron1' / 'spam'
ham_data_path: Path = data_dir / 'enron1' / 'ham'
# Our data container for `spam` and `ham` messages
class Message(NamedTuple):
text: str
is_spam: bool
# Globbing for all the `.txt` files in both (`spam` and `ham`) directories
spam_message_paths: List[str] = glob.glob(str(spam_data_path / '*.txt'))
ham_message_paths: List[str] = glob.glob(str(ham_data_path / '*.txt'))
message_paths: List[str] = spam_message_paths + ham_message_paths
message_paths[:5]
['data/enron1/spam/4743.2005-06-25.GP.spam.txt', 'data/enron1/spam/1309.2004-06-08.GP.spam.txt', 'data/enron1/spam/0726.2004-03-26.GP.spam.txt', 'data/enron1/spam/0202.2004-01-13.GP.spam.txt', 'data/enron1/spam/3988.2005-03-06.GP.spam.txt']
# The list which eventually contains all the parsed Enron `spam` and `ham` messages
messages: List[Message] = []
# Open every file individually, turn it into a `Message` and append it to our `messages` list
for path in message_paths:
with open(path, errors='ignore') as file:
is_spam: bool = True if 'spam' in path else False
# We're only interested in the subject for the time being
text: str = file.readline().replace('Subject:', '').strip()
messages.append(Message(text, is_spam))
shuffle(messages)
messages[:5]
[Message(text='january production estimate', is_spam=False), Message(text='re : your code # 5 g 6878', is_spam=True), Message(text='account # 20367 s tue , 28 jun 2005 11 : 41 : 41 - 0800', is_spam=True), Message(text='congratulations', is_spam=True), Message(text='fw : hpl imbalance payback', is_spam=False)]
len(messages)
5172
# Given a string, normalize and extract all words with length greater than 2
def tokenize(text: str) -> Set[str]:
words: List[str] = []
for word in re.findall(r'[A-Za-z0-9\']+', text):
if len(word) >= 2:
words.append(word.lower())
return set(words)
assert tokenize('Is this a text? If so, Tokenize this text!...') == {'is', 'this', 'text', 'if', 'so', 'tokenize'}
tokenize(messages[0].text)
{'estimate', 'january', 'production'}
# Split the list of messages into a `train` and `test` set (defaults to 80/20 train/test split)
def train_test_split(messages: List[Message], pct=0.8) -> Tuple[List[Message], List[Message]]:
shuffle(messages)
num_train = int(round(len(messages) * pct, 0))
return messages[:num_train], messages[num_train:]
assert len(train_test_split(messages)[0]) + len(train_test_split(messages)[1]) == len(messages)
# The Naive Bayes classifier
class NaiveBayes:
def __init__(self, k=1) -> None:
# `k` is the smoothening factor
self._k: int = k
self._num_spam_messages: int = 0
self._num_ham_messages: int = 0
self._num_word_in_spam: Dict[int] = defaultdict(int)
self._num_word_in_ham: Dict[int] = defaultdict(int)
self._spam_words: Set[str] = set()
self._ham_words: Set[str] = set()
self._words: Set[str] = set()
# Iterate through the given messages and gather the necessary statistics
def train(self, messages: List[Message]) -> None:
msg: Message
token: str
for msg in messages:
tokens: Set[str] = tokenize(msg.text)
self._words.update(tokens)
if msg.is_spam:
self._num_spam_messages += 1
self._spam_words.update(tokens)
for token in tokens:
self._num_word_in_spam[token] += 1
else:
self._num_ham_messages += 1
self._ham_words.update(tokens)
for token in tokens:
self._num_word_in_ham[token] += 1
# Probability of `word` being spam
def _p_word_spam(self, word: str) -> float:
return (self._k + self._num_word_in_spam[word]) / ((2 * self._k) + self._num_spam_messages)
# Probability of `word` being ham
def _p_word_ham(self, word: str) -> float:
return (self._k + self._num_word_in_ham[word]) / ((2 * self._k) + self._num_ham_messages)
# Given a `text`, how likely is it spam?
def predict(self, text: str) -> float:
text_words: Set[str] = tokenize(text)
log_p_spam: float = 0.0
log_p_ham: float = 0.0
for word in self._words:
p_spam: float = self._p_word_spam(word)
p_ham: float = self._p_word_ham(word)
if word in text_words:
log_p_spam += log(p_spam)
log_p_ham += log(p_ham)
else:
log_p_spam += log(1 - p_spam)
log_p_ham += log(1 - p_ham)
p_if_spam: float = exp(log_p_spam)
p_if_ham: float = exp(log_p_ham)
return p_if_spam / (p_if_spam + p_if_ham)
# Tests
def test_naive_bayes():
messages: List[Message] = [
Message('Spam message', is_spam=True),
Message('Ham message', is_spam=False),
Message('Ham message about Spam', is_spam=False)]
nb: NaiveBayes = NaiveBayes()
nb.train(messages)
assert nb._num_spam_messages == 1
assert nb._num_ham_messages == 2
assert nb._spam_words == {'spam', 'message'}
assert nb._ham_words == {'ham', 'message', 'about', 'spam'}
assert nb._num_word_in_spam == {'spam': 1, 'message': 1}
assert nb._num_word_in_ham == {'ham': 2, 'message': 2, 'about': 1, 'spam': 1}
assert nb._words == {'spam', 'message', 'ham', 'about'}
# Our test message
text: str = 'A spam message'
# Reminder: The `_words` we iterater over are: {'spam', 'message', 'ham', 'about'}
# Calculate how spammy the `text` might be
p_if_spam: float = exp(sum([
log( (1 + 1) / ((2 * 1) + 1)), # `spam` (also in `text`)
log( (1 + 1) / ((2 * 1) + 1)), # `message` (also in `text`)
log(1 - ((1 + 0) / ((2 * 1) + 1))), # `ham` (NOT in `text`)
log(1 - ((1 + 0) / ((2 * 1) + 1))), # `about` (NOT in `text`)
]))
# Calculate how hammy the `text` might be
p_if_ham: float = exp(sum([
log( (1 + 1) / ((2 * 1) + 2)), # `spam` (also in `text`)
log( (1 + 2) / ((2 * 1) + 2)), # `message` (also in `text`)
log(1 - ((1 + 2) / ((2 * 1) + 2))), # `ham` (NOT in `text`)
log(1 - ((1 + 1) / ((2 * 1) + 2))), # `about` (NOT in `text`)
]))
p_spam: float = p_if_spam / (p_if_spam + p_if_ham)
assert p_spam == nb.predict(text)
test_naive_bayes()
train: List[Message]
test: List[Message]
# Splitting our Enron messages into a `train` and `test` set
train, test = train_test_split(messages)
# Train our Naive Bayes classifier with the `train` set
nb: NaiveBayes = NaiveBayes()
nb.train(train)
print(f'Spam messages in training data: {nb._num_spam_messages}')
print(f'Ham messages in training data: {nb._num_ham_messages}')
print(f'Most spammy words: {Counter(nb._num_word_in_spam).most_common(20)}')
Spam messages in training data: 1227 Ham messages in training data: 2911 Most spammy words: [('you', 115), ('the', 104), ('your', 104), ('for', 86), ('to', 83), ('re', 81), ('on', 56), ('and', 51), ('get', 48), ('is', 48), ('in', 43), ('with', 40), ('of', 38), ('it', 35), ('at', 35), ('online', 34), ('all', 33), ('from', 33), ('this', 32), ('new', 31)]
# Grabbing all the spam messages from our `test` set
spam_messages: List[Message] = [item for item in test if item.is_spam]
spam_messages[:5]
[Message(text="a witch . i don ' t", is_spam=True), Message(text='active and strong', is_spam=True), Message(text='get great prices on medications', is_spam=True), Message(text='', is_spam=True), Message(text='popular software at low low prices . misunderstand developments', is_spam=True)]
# Using our trained Naive Bayes classifier to classify a spam message
message: str = spam_messages[10].text
print(f'Predicting likelihood of "{message}" being spam.')
nb.predict(message)
Predicting likelihood of "get your hand clock repliacs todday carson" being spam.
0.9884313222593173
# Grabbing all the ham messages from our `test` set
ham_messages: List[Message] = [item for item in test if not item.is_spam]
ham_messages[:5]
[Message(text='new update for buybacks', is_spam=False), Message(text='enron and blockbuster to launch entertainment on - demand service', is_spam=False), Message(text='re : astros web site comments', is_spam=False), Message(text='re : formosa meter # : 1000', is_spam=False), Message(text='re : deal extension for 11 / 21 / 2000 for 98 - 439', is_spam=False)]
# Using our trained Naive Bayes classifier to classify a ham message
message: str = ham_messages[10].text
print(f'Predicting likelihood of "{text}" being spam.')
nb.predict(message)
Predicting likelihood of "associate & analyst mid - year 2001 prc process" being spam.
5.3089147140900964e-05