This is the second post of the from scratch series of articles, this time on the k-means algorithm. K-means is a popular algorithm to group data into a pre-defined amount of clusters based on the mean distance to iteratively generated centroids.
The strong points of k-means are that it is rather fast and not very complex, the weak points are that it struggles with high-dimensional data (this is by far not exclusive to k-means and known as the curse of dimensionality) and that you have to specify the amount of clusters k you want a priori.
How do you know how many clusters you want? Ideally you have a strong hypothesis before you begin running your algorithm, otherwise you might want to work based on a metric like the elbow criterion. However, you should always be aware that for data with m rows you will in theory get great results from you clustering if you set k = m.
To characterise how close an observation is to another you need some sort of distance metric. A very popular metric here is the Euclidean Distance, a relatively simple, but nonetheless effective and useful characterisation of distance. Consider the following matrix:
mt <- matrix(c(1, 1, 10, 10), ncol = 2, byrow = TRUE)
mt
## [,1] [,2]
## [1,] 1 1
## [2,] 10 10
We want to calculate the distance between rows. Euclidean distance is the square root of the sum of squares of the differences between the rows, or to show it with the actual numbers:
sqrt(sum(c(((1 - 10) ^ 2), ((1 - 10) ^ 2))))
## [1] 12.72792
We can verify this by using R’s built-in distance function, dist
:
dist(mt)
## 1
## 2 12.72792
However, in the from scratch series we want to implement things ourselves. Let’s return to our example data first. Doing operations on vectors is easy in R and one of the strong points of the language.
sqrt(sum((mt[1, ] - mt[2, ]) ^ 2))
## [1] 12.72792
So we just need to do that calculation for all rows in our data, in case we deal with something bigger than a 2 by 2 matrix, right? Right.
euclidean_dist <- function(data) {
distances <- apply(data, 1, function(x) {
apply(data, 1, function(y) {
sqrt(sum((x - y) ^ 2))
})
})
distances
}
The nested apply
structure ensures that we work on all possible pairs of
rows and is more or less the equivalent of nested for
loops in other
languages.
Now, you might be thinking that this approach is not super-efficient and you
are correct.
For any given pair of rows the distance will be calculated twice and we
should subset the data before the inner apply and adapt our output if this
was more than a toy example for learning.
When we calculate Euclidean distance between all observations (rows) of a data set we can find which are the closest observations either to each other, or to some reference, as some distances might turn out to be relatively small, some might be larger.
As mentioned above, there is a phenomenon related to finding groups in data that is known as the curse of dimensionality. This phenomenon encompasses that with very high-dimensional data (think for example 100 columns) it becomes very hard to successfully find groups in data. Looking at how we calculate Euclidean distance it is easy to see why this is the case. Unless your data is as clean as data found in textbooks the amount of noise and variance in the columns will eventually lead to distances becoming less selective and distinct. As far as I can see there is no magic number where the curse of dimensionality begins, but if you want to cluster using many columns, chances are that you are affected by it.
Now that we have established how to characterise the distance between observations, we can move on to clustering. This is actually pretty easy in the case of k-means:
That does not sound too complex, doesn’t it? Despite the relative simplicity the approach does work reasonably well, although there are datasets that it fails on quite spectacularly.
Let’s look at the code.
First we need some example data and set the amount of clusters we want to find. The iris dataset might serve as a nice example here, as it is commonly used as an example dataset for clustering. To ensure that a single column has not too much influcence on the distance metric, the data is scaled as well. We also set a seed for reproducibility.
set.seed(42)
data(iris)
data <- scale(iris[, 1:4]) # omit the labels
k <- 3
First we draw a random sample of size k, and create two matrices: one with our centres we just selected, the other to later use for the updated values, which we can then compare to the previous one to see whether they are still changing.
center_idx <- sample(1:nrow(data), k)
centers <- data[center_idx, ]
center_mean <- as.data.frame(matrix(NA, ncol = ncol(centers),
nrow = nrow(centers)))
centers
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## [1,] -0.6561473 1.4744583 -1.27910398 -1.3110521482
## [2,] -0.2938574 -0.3609670 -0.08950329 0.1320672944
## [3,] 0.3099591 -0.5903951 0.53362088 0.0008746178
The next steps are straightforward. When putting things together later on
these will be enclosed in a while
loop.
But to have a look at the intermediate steps, we do one iteration manually.
The first step is calculating the distance between our initial selection and all observations.
distances <- euclidean_dist(rbind(centers, data))
If we take all but the first k rows and only the first k columns we have the distances of all observations to the cluster centres.
kdist <- distances[-c(1:k), 1:k]
head(kdist)
## [,1] [,2] [,3]
## [1,] 0.5216255 2.428017 3.041925
## [2,] 1.6780274 2.098326 2.743414
## [3,] 1.3615373 2.331341 3.021808
## [4,] 1.6154093 2.273031 2.960103
## [5,] 0.4325508 2.596145 3.217006
## [6,] 0.5539087 2.806177 3.327659
As each column in our sub-setted distance matrix contains the distance to each
of the cluster centres we can determine which cluster each observation is
the closest to.
Note that which.min
returns only the first extreme observation in case of
ties.
labels <- as.factor(apply(kdist, 1, which.min))
names(labels) <- NULL
head(labels)
## [1] 1 1 1 1 1 1
## Levels: 1 2 3
Now that each observation is associated with a cluster we can update the
means of the clusters.
The function by
is the equivalent in base R for the group_by() %>% summarise() %>% ungroup()
schema in dplyr
.
You supply a dataset, a grouping vector and a function you want to apply to
each group.
As the function returns a list, we bind it together using rbind
with
do.call
.
center_mean <- by(data, labels, colMeans)
center_mean <- do.call(rbind, center_mean)
center_mean
## Sepal.Length Sepal.Width Petal.Length Petal.Width
## 1 -0.9987207 0.9032290 -1.29875725 -1.25214931
## 2 -0.4593478 -0.8623100 0.09092997 0.09319539
## 3 0.8289150 -0.2834575 0.82681076 0.79512217
What’s left is checking whether our new centres are equivalent to our
previous centres.
If yes, we would leave our while loop on finishing the iteration, if not
we would continue.
No matter what we do, we update the centers
variable first.
Note, that while our first selection of centres corresponded to actual data
points in the data, after the first update it is likely that they are
in-between data-points.
if (all(center_mean == centers)) {
not_converged <- FALSE
}
centers <- center_mean
So, now that we went through all the steps once, we can put this together into a function:
k_means <- function(data, k, max_iterations = 100) {
center_idx <- sample(1:nrow(data), k)
centers <- data[center_idx, ]
center_mean <- as.data.frame(matrix(NA, ncol = ncol(centers),
nrow = nrow(centers)))
not_converged <- TRUE
iter_count <- 1
while (not_converged) {
distances <- euclidean_dist(rbind(centers, data))
kdist <- distances[-c(1:k), 1:k]
labels <- as.factor(apply(kdist, 1, which.min))
names(labels) <- NULL
center_mean <- by(data, labels, colMeans)
center_mean <- do.call(rbind, center_mean)
if (all(center_mean == centers)) {
not_converged <- FALSE
}
centers <- center_mean
iter_count <- iter_count + 1
if (iter_count > max_iterations) {
warning("Did not converge.")
break()
}
}
list(labels = labels, centers = centers, iterations = iter_count)
}
Note a few additions here:
while
loop that ends once our centres do not
update any more.Now let’s give this a try:
km <- k_means(data, 3)
And have a look at whether we were able to separate the iris species.
table(km$labels, iris$Species)
##
## setosa versicolor virginica
## 1 0 11 36
## 2 0 39 14
## 3 50 0 0
And plot it (we have four dimensions, so we choose only two for siplicity).
library("ggplot2")
ggplot(iris) +
aes(x = Sepal.Width, y = Petal.Length) +
geom_point(aes(colour = km$labels))
Apparently there is some confusion between versicolor and viriginica,
but the result are not too bad.
Compare with stats::k-means
.
km2 <- kmeans(data, 3)
table(km2$cluster, iris$Species)
##
## setosa versicolor virginica
## 1 0 46 50
## 2 33 0 0
## 3 17 4 0
ggplot(iris) +
aes(x = Sepal.Width, y = Petal.Length) +
geom_point(aes(colour = as.factor(km2$cluster)))
The results we got here are even worse.
While k-means is pretty elegant and works reasonably well, it is very dependant
on the initial selection of cluster centres.
Thus, it can end up in a local best solution, which is not necessarily the
best available solution.
The stats stats::k-means
has the argument nstart
which aims to mitigate
this situation by trying several initial selections and returning the best
result only.
km3 <- kmeans(data, 3, nstart = 10)
table(km3$cluster, iris$Species)
##
## setosa versicolor virginica
## 1 0 11 36
## 2 0 39 14
## 3 50 0 0
ggplot(iris) +
aes(x = Sepal.Width, y = Petal.Length) +
geom_point(aes(colour = as.factor(km3$cluster)))
Which is still not perfect (and very similar to what we got with our own implementation), but there is only so much you can do.
Doing a basic implementation of k-means was a nice learning experience. We have seen that at its core the method is quite simple, but nonetheless is able to provide useful results.
Additionally, the Gitlab repo for the
fromscratch
package also contains implementations of both k-means and
euclidean distance in C++.
Writing the C++ implementation was much harder than the R version, but at the
end the C++ version runs considerably faster.
And finally, note that there are several variations available on how to
implement the clustering, as reflected by the algorithm
option in
stats::kmeans
.