After recently having brushed up my knowledge on Bayesian statistics, I did also learn about naïve Bayes classification. The naïve Bayes classification is a simplified version of the Bayes Theorem assuming conditional independence between all observations if we know the category they belong to. A good summary can be found in the scikit-learn documentation.
The main appeal of the method is that it is relatively easy to implement and works surprisingly well given that it is seemingly simple.
It is unlikely that I am the first person to do this, but for this project I used the UCI Machine Learning SMS Spam classification dataset, to build and test a naïve Bayes spam filter. To challenge myself a bit more I also added the option to update the filter and to classify messages from the command line. Below I’ll walk you through the core parts of the project (i.e. building and testing the filter) while the full project can be found on my Gitlab.
First we load the packages we will be using extensively:
library("dplyr")
library("stringr")
Also set a seed for reproducibility, you do not want this in a productive scenario:
set.seed(2417)
And then read the raw data:
spam <- readr::read_csv("spam.csv")
Now examine it:
spam
## # A tibble: 5,572 x 5
## v1 v2 X3 X4 X5
## <chr> <chr> <chr> <chr> <chr>
## 1 ham Go until jurong point, crazy.. Available only i… <NA> <NA> <NA>
## 2 ham Ok lar... Joking wif u oni... <NA> <NA> <NA>
## 3 spam Free entry in 2 a wkly comp to win FA Cup final… <NA> <NA> <NA>
## 4 ham U dun say so early hor... U c already then say.… <NA> <NA> <NA>
## 5 ham Nah I don't think he goes to usf, he lives arou… <NA> <NA> <NA>
## 6 spam "FreeMsg Hey there darling it's been 3 week's n… <NA> <NA> <NA>
## 7 ham Even my brother is not like to speak with me. T… <NA> <NA> <NA>
## 8 ham As per your request 'Melle Melle (Oru Minnaminu… <NA> <NA> <NA>
## 9 spam "WINNER!! As a valued network customer you have… <NA> <NA> <NA>
## 10 spam Had your mobile 11 months or more? U R entitled… <NA> <NA> <NA>
## # … with 5,562 more rows
The columns are messy. There are no names and it seems line some messages are spread across multiple columns. So we merge the last columns into a single one and rename the first.
spam <- spam %>%
tidyr::unite(col = "msg", 2:5, sep = " ", na.rm = TRUE) %>%
rename("label" = v1)
Before doing anything else the data should be split into a test and a train
set, so we have some fresh data later on to validate our classifier against.
The rsample
package makes this process really straightforward.
split <- rsample::initial_split(spam, strata = label)
train_spam <- rsample::training(split)
test_spam <- rsample::testing(split)
We use stratified sampling to not over sample or under sample the amount of spam in each dataset. We can verify by having a look at the distributions of the categories inside each data set:
prop.table(table(train_spam$label))
##
## ham spam
## 0.8693467 0.1306533
prop.table(table(test_spam$label))
##
## ham spam
## 0.8557071 0.1442929
To build a naïve Bayes classification we will need to clean up all the messages and split them into individual words, so we can then calculate the probabilities of the words to appear in each message category. The assumption here is that wording between spam and non-spam messages differs and certain words are more likely to appear in a specific category. This appears straightforward if you have ever gotten spam messages, but this does not automatically mean that this holds true when looking at a large amount of messages trying to classify them automatically.
To process the messages, we can build a function that takes a string as input and processes it so we get a tidy result. A function is useful here, as the test data needs to be processed in the same way when applying the classification to the test dataset.
string_cleaner <- function(text_vector) {
tx <- text_vector %>%
str_replace_all("[^[:alnum:] ]+", "") %>%
str_to_lower() %>%
str_replace_all("\\b(http|www.+)\\b", "_url_") %>%
str_replace_all("\\b(\\d{7,})\\b", "_longnum_") %>%
str_split(" ")
tx <- lapply(tx, function(x) x[nchar(x) > 1])
tx
}
What we do here is as follows:
And then this can be applied to the messages in the train set.
train_spam <- train_spam %>%
mutate(msg_list = string_cleaner(.$msg))
Leading to the following result:
train_spam$msg_list[1:3]
## [[1]]
## [1] "ok" "lar" "joking" "wif" "oni"
##
## [[2]]
## [1] "free" "entry" "in"
## [4] "wkly" "comp" "to"
## [7] "win" "fa" "cup"
## [10] "final" "tkts" "21st"
## [13] "may" "2005" "text"
## [16] "fa" "to" "87121"
## [19] "to" "receive" "entry"
## [22] "questionstd" "txt" "ratetcs"
## [25] "apply" "08452810075over18s"
##
## [[3]]
## [1] "dun" "say" "so" "early" "hor" "already" "then"
## [8] "say"
With the messages split into tidy words the next step is to calculate the probabilities of any given word appearing in each type of message, either spam or ham.
For generalisability you might want to put this step into a function (which I did in the project on my Gitlab).
The first step is to extract of all unique words in the dataset. This will come in handy really soon, as to calculate the probability of a word appearing in a category we need the amount of unique words.
vocab <- train_spam %>%
select(msg_list) %>%
unlist() %>%
unique() %>%
tibble::enframe(name = NULL, value = "word")
vocab
## # A tibble: 7,803 x 1
## word
## <chr>
## 1 ok
## 2 lar
## 3 joking
## 4 wif
## 5 oni
## 6 free
## 7 entry
## 8 in
## 9 wkly
## 10 comp
## # … with 7,793 more rows
And now we do something similar for words in ham and spam, with the main
difference to the previous step being that we do not use unique
, meaning
the result contains all words in the respective category regardless of the
amount of times they appear in the data. Also, we explicitly deframe
and
unlist
the data to get a clean vector of words to work with:
ham_vocab <- train_spam %>%
filter(label == "ham") %>%
select(msg_list) %>%
tibble::deframe() %>%
unlist()
spam_vocab <- train_spam %>%
filter(label == "spam") %>%
select(msg_list) %>%
tibble::deframe() %>%
unlist()
head(ham_vocab)
## [1] "ok" "lar" "joking" "wif" "oni" "dun"
Now we can count how often the words appear in each category and join that information to the previously extracted unique vocabulary:
vocab <- table(ham_vocab) %>%
tibble::as_tibble() %>%
rename(ham_n = n) %>%
left_join(vocab, ., by = c("word" = "ham_vocab"))
vocab <- table(spam_vocab) %>%
tibble::as_tibble() %>%
rename(spam_n = n) %>%
left_join(vocab, ., by = c("word" = "spam_vocab"))
vocab
## # A tibble: 7,803 x 3
## word ham_n spam_n
## <chr> <int> <int>
## 1 ok 199 3
## 2 lar 33 NA
## 3 joking 6 NA
## 4 wif 23 NA
## 5 oni 3 NA
## 6 free 43 170
## 7 entry NA 20
## 8 in 615 59
## 9 wkly NA 11
## 10 comp NA 7
## # … with 7,793 more rows
The next step is to turn these counts into some sort of probability as that is what we ultimately need for the classification.
First get some baseline data:
word_n <- c("unique" = nrow(vocab),
"ham" = length(ham_vocab),
"spam" = length(spam_vocab))
class_probs <- prop.table(table(train_spam$label))
And then calculate the probability of a word appearing in a given category. To do so we first set up a little helper function. The function relates the amount of words in the category to the amount of total words in the category and also adds Laplacian smoothing, ensuring that the result is never 0. As we work with products here having a 0 somewhere would of course lead to the entire product turning 0.
word_probabilities <- function(word_n, category_n, vocab_n, smooth = 1) {
prob <- (word_n + smooth) / (category_n + smooth * vocab_n)
prob
}
The function takes the frequency of a word, the amount of words in the category, the amount of total unique words and a smoothing value and returns the probability value of the word belonging to a category, given the data we have.
And now replace NA counts with 0 (because that’s what they are) and apply our helper function to each row of the data.
vocab <- vocab %>%
tidyr::replace_na(list(ham_n = 0, spam_n = 0)) %>%
rowwise() %>%
mutate(ham_prob = word_probabilities(
ham_n, word_n["ham"], word_n["unique"])) %>%
mutate(spam_prob = word_probabilities(
spam_n, word_n["spam"], word_n["unique"])) %>%
ungroup()
vocab
## # A tibble: 7,803 x 5
## word ham_n spam_n ham_prob spam_prob
## <chr> <dbl> <dbl> <dbl> <dbl>
## 1 ok 199 3 0.00368 0.000203
## 2 lar 33 0 0.000626 0.0000506
## 3 joking 6 0 0.000129 0.0000506
## 4 wif 23 0 0.000442 0.0000506
## 5 oni 3 0 0.0000736 0.0000506
## 6 free 43 170 0.000810 0.00866
## 7 entry 0 20 0.0000184 0.00106
## 8 in 615 59 0.0113 0.00304
## 9 wkly 0 11 0.0000184 0.000608
## 10 comp 0 7 0.0000184 0.000405
## # … with 7,793 more rows
The entire computation above computation takes less than two seconds on a somewhat recent machine.
As we now have the probabilities we can now classify messages. By multiplying the probabilities for all words in the message given each category and also adding the baseline probability of the categories into the product.
To classify new data we define another function which takes a raw message as input and ultimately returns a classification:
classifier <- function(msg, prob_df, ham_p = 0.5, spam_p = 0.5) {
clean_message <- string_cleaner(msg) %>% unlist()
probs <- sapply(clean_message, function(x) {
filter(prob_df, word == x) %>%
select(ham_prob, spam_prob)
})
if (!is.null(dim(probs))) {
ham_prob <- prod(unlist(as.numeric(probs[1, ])), na.rm = TRUE)
spam_prob <- prod(unlist(as.numeric(probs[2, ])), na.rm = TRUE)
ham_prob <- ham_p * ham_prob
spam_prob <- spam_p * spam_prob
if (ham_prob > spam_prob) {
classification <- "ham"
} else if (ham_prob < spam_prob) {
classification <- "spam"
} else {
classification <- "unknown"
}
} else {
classification <- "unknown"
}
classification
}
The function takes four inputs: The message, a data frame of probabilities
(vocab
in our case) and baseline probabilities for ham and spam messages.
It first tidies and splits the messages, retrieves the probabilities for the words, creates the product of the probabilities and multiplies it with the baseline probabilities for each category and depending on which probability is larger returns a classification or unknown when both values are equal.
This can then be applied to the test data set.
spam_classification <- sapply(test_spam$msg,
function(x) classifier(x, vocab, class_probs["ham"],
class_probs["spam"]), USE.NAMES = FALSE)
And now we can see how we are performing.
To use the metrics
function of the yardstick
package we need to
convert both prediction and label to factors with the same levels.
fct_levels <- c("ham", "spam", "unknown")
test_spam <- test_spam %>%
mutate(label = factor(.$label, levels = fct_levels),
.pred = factor(spam_classification, levels = fct_levels))
performance <- yardstick::metrics(test_spam, label, .pred)
And the moment of truth:
performance
## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy multiclass 0.986
## 2 kap multiclass 0.944
This is pretty good for an approach which is rather simple.
Also have a look at the confusion matrix:
table(paste("actual", test_spam$label), paste("pred", test_spam$.pred))
##
## pred ham pred spam pred unknown
## actual ham 1186 5 1
## actual spam 13 188 0
Finally, let’s also compare the performance of a simple approach labelling everything as ham, against the naïve Bayes classification:
test_spam %>%
mutate(all_ham = "ham") %>%
mutate(all_ham = factor(all_ham, levels = fct_levels)) %>%
yardstick::metrics(label, all_ham)
## # A tibble: 2 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 accuracy multiclass 0.856
## 2 kap multiclass 0
Obviously, given that the fraction of spam messages in this data set is at about 13%, a classification is reasonably accurate if we label everything as ham, but the kappa value is poor.