This lesson is in the early stages of development (Alpha version)

K-means

Overview

Teaching: 60 min
Exercises: 20 min
Questions
  • How do we detect real clusters in high-dimensional data?

  • How does K-means work and when should it be used?

  • How can we perform K-means in R?

  • How can we appraise a clustering and test cluster robustness?

Objectives
  • Understand what clusters look like in in high-dimensional data.

  • Understand and perform K-means clustering in R.

  • Assess clustering performance using silhouette scores.

  • Assess cluster robustness using bootstrapping.

Introduction

As we saw in previous episodes, visualising high-dimensional data with a large amount of features is difficult and can limit our understanding of the data and associated processes. In some cases, a known grouping causes this heterogeneity (sex, treatment groups, etc). In other cases, heterogeneity may arise from the presence of unknown subgroups in the data. While PCA can be used to reduce the dimension of the dataset into a smaller set of uncorrelated variables and factor analysis can be used to identify underlying factors, clustering is a set of techniques that allow us to discover unknown groupings.

Cluster analysis involves finding groups of observations that are more similar to each other (according to some feature) than they are to observations in other groups and are thus likely to represent the same source of heterogeneity. Once groups (or clusters) of observations have been identified using cluster analysis, further analyses or interpretation can be carried out on the groups, for example, using metadata to further explore groups.

Cluster analysis is commonly used to discover unknown groupings in fields such as bioinformatics, genomics, and image processing, in which large datasets that include many features are often produced.

There are various ways to look for clusters of observations in a dataset using different clustering algorithms. One way of clustering data is to minimise distance between observations within a cluster and maximise distance between proposed clusters. Using this process, we can also iteratively update clusters so that we become more confident about the shape and size of the clusters.

What is K-means clustering?

K-means clustering groups data points into a user-defined number of distinct, non-overlapping clusters. To create clusters of ‘similar’ data points, K-means clustering creates clusters that minimise the within-cluster variation and thus the amount that data points within a cluster differ from each other. The distance between data points within a cluster is used as a measure of within-cluster variation.

To carry out K-means clustering, we first pick $k$ initial points as centres or “centroids” of our clusters. There are a few ways to choose these initial “centroids” and this is discussed below. Once we have picked intitial points, we then follow these two steps until appropriate clusters have been formed:

  1. Assign each data point to the cluster with the closest centroid
  2. Update centroid positions as the average of the points in that cluster.

We can see this process in action in this animation:

An animated scatter plot of data y versus x. The animation starts by identifying 3 initial points, delineated by different colours. The animations then colour codes all points by an associated cluster colour, delineating three distinct and non-overlapping clusters in the space of the scatter plot.

Animation showing the iterative process of K-means clustering on data y versus x.

While K-means has some advantages over other clustering methods (easy to implement and to understand), it does have some disadvantages, particularly difficulties in identifying initial clusters which observations belong to and the need for the user to specify the number of clusters that the data should be partitioned into.

Initialisation

The algorithm used in K-means clustering finds a local rather than a global optimum, so that results of clustering are dependent on the initial cluster that each observation is randomly assigned to. This initial configuration can have a significant effect on the final configuration of the clusters, so dealing with this limitation is an important part of K-means clustering. Some strategies to deal with this problem are:

  • Choose $K$ points at random from the data as the cluster centroids.
  • Randomly split the data into $K$ groups, and then average these groups.
  • Use the K-means++ algorithm to choose initial values.

These each have advantages and disadvantages. In general, it’s good to be aware of this limitation of K-means clustering and that this limitation can be addressed by choosing a good initialisation method, initialising clusters manually, or running the algorithm from multiple different starting points.

Believing in clusters

When using clustering, it’s important to realise that data may seem to group together even when these groups are created randomly. It’s especially important to remember this when making plots that add extra visual aids to distinguish clusters. For example, if we cluster data from a single 2D normal distribution and draw ellipses around the points, these clusters suddenly become almost visually convincing. This is a somewhat extreme example, since there is genuinely no heterogeneity in the data, but it does reflect what can happen if you allow yourself to read too much into faint signals.

Let’s explore this further using an example. We create two columns of data (‘x’ and ‘y’) and partition these data into three groups (‘a’, ‘b’, ‘c’) according to data values. We then plot these data and their allocated clusters and put ellipses around the clusters using the stat_ellipse() function in ggplot.

A scatter plot of random data y versus x. The points are horizontally partitioned at 2 random groups, forming three colour coded clusters. Circles are drawn around each cluster. The data shown appears to have no clusters but the colours and circles give the appearance of clusters artificially.

Example of artificial clusters fitted to data points.

The randomly created data used here appear to form three clusters when we plot the data. Putting ellipses around the clusters can further convince us that the clusters are ‘real’. But how do we tell if clusters identified visually are ‘real’?

Initialisation

The algorithm used in K-means clustering finds a local rather than a global optimum, so that results of clustering are dependent on the initial cluster that each observation is randomly assigned to. This initial configuration can have a significant effect on the final configuration of the clusters, so dealing with this limitation is an important part of K-means clustering. Some strategies to deal with this problem are:

  • Choose $k$ points at random from the data as the cluster centroids.
  • Randomly split the data into $k$ groups, and then average these groups.
  • Use the K-means++ algorithm to choose initial values.

These each have advantages and disadvantages. In general, it’s good to be aware of this limitation of K-means clustering and that this limitation can be addressed by choosing a good initialisation method, initialising clusters manually, or running the algorithm from multiple different starting points.

K-means clustering applied to the single-cell RNA sequencing data

Let’s carry out K-means clustering in R using some real high-dimensional data. We’re going to work with single-cell RNA sequencing data, scRNAseq, in these clustering challenges, which is often very high-dimensional. Commonly, experiments profile the expression level of 10,000+ genes in thousands of cells. Even after filtering the data to remove low quality observations, the dataset we’re using in this episode contains measurements for over 9,000 genes in over 3,000 cells.

One way to get a handle on a dataset of this size is to use something we covered earlier in the course - dimensionality reduction. Dimensionality reduction allows us to visualise this incredibly complex data in a small number of dimensions. In this case, we’ll be using principal component analysis (PCA) to compress the data by identifying the major axes of variation in the data, before running our clustering algorithms on this lower-dimensional data to explore the similarity of features.

The scater package has some easy-to-use tools to calculate a PCA for SummarizedExperiment objects. Let’s load the scRNAseq data and calculate some principal components.

library("SingleCellExperiment")
library("scater")

scrnaseq <- readRDS(here::here("data/scrnaseq.rds"))
scrnaseq <- runPCA(scrnaseq, ncomponents = 15)
pcs <- reducedDim(scrnaseq)[, 1:2]

The first two principal components capture almost 50% of the variation within the data. For now, we’ll work with just these two principal components, since we can visualise those easily, and they’re a quantitative representation of the underlying data, representing the two largest axes of variation.

We can now run K-means clustering on the first and second principal components of the scRNAseq data using the kmeans() function.

set.seed(42)
cluster <- kmeans(pcs, centers = 4)
scrnaseq$kmeans <- as.character(cluster$cluster)
plotReducedDim(scrnaseq, "PCA", colour_by = "kmeans")
A scatter plot of principal component 2 versus principal component 1 of the `scrnaseq` data. Each point is one of four colours, representing cluster membership. Points of the same colour appear in the same areas of the plot, showing four distinct clusters in the data.

Scatter plot of principal component 2 versus principal component 1 with points colour coded according to the cluster to which they belong.

We can see that this produces a sensible-looking partition of the data. However, is it totally clear whether there might be more or fewer clusters here?

Challenge 1

Perform clustering to group the data into $k=5$ clusters, and plot it using plotReducedDim(). Save this with a variable name that’s different to what we just used, because we’ll use this again later.

Solution

set.seed(42)
cluster5 <- kmeans(pcs, centers = 5)
scrnaseq$kmeans5 <- as.character(cluster5$cluster)
plotReducedDim(scrnaseq, "PCA", colour_by = "kmeans5")
A scatter plot of principal component 1 versus principal component 2 of the scrnaseq data. Each point is one of five colours, representing cluster membership. Points of the same colour appear in the same areas of the plot, showing five distinct clusters in the data.

Scatter plot of principal component 2 against principal component 1 in the scRNAseq data, coloured by clusters produced by k-means clustering.

K-medoids (PAM)

One problem with K-means is that using the mean to define cluster centroids means that clusters can be very sensitive to outlying observations. K-medoids, also known as “partitioning around medoids (PAM)” is similar to K-means, but uses the median rather than the mean as the method for defining cluster centroids. Using the median rather than the mean reduces sensitivity of clusters to outliers in the data. K-medioids has had popular application in genomics, for example the well-known PAM50 gene set in breast cancer, which has seen some prognostic applications. The following example shows how cluster centroids differ when created using medians rather than means.

x <- rnorm(20)
y <- rnorm(20)
x[10] <- x[10] + 10
plot(x, y, pch = 16)
points(mean(x), mean(y), pch = 16, col = "firebrick")
points(median(x), median(y), pch = 16, col = "dodgerblue")
Scatter plot of random data y versus x. There are many black points on the plot representing the data. Two additional points are shown: the (mean(x), mean(y)) co-ordinate point in red and the (median(x), median(y)) co-ordinate point in blue. The median co-ordinate point in blue has a lower x value and is shown to the left of the red mean co-ordinate point.

Scatter plot of random data y versus x with the (mean(x), mean(y)) co-ordinate point shown in red and the (median(x), median(y)) co-ordinate point shown in blue.

PAM can be carried out using pam() from the cluster package.

Cluster separation

When performing clustering, it is important for us to be able to measure how well our clusters are separated. One measure to test this is silhouette width, which measures how similar a data point is to points in the same cluster compared to other clusters. The silhouette width is computed for every observation and can range from -1 to 1. A high silhouette width means an observation is closer to other observations within the same cluster. For each cluster, the silhouette widths can then be averaged or an overall average can be taken.

More detail on silhouette widths

In more detail, each observation’s silhouette width is computed as follows:

  1. Compute the average distance between the focal observation and all other observations in the same cluster.
  2. For each of the other clusters, compute the average distance between focal observation and all observations in the other cluster. Keep the smallest of these average distances.
  3. Subtract (1.)-(2.) then divivde by whichever is smaller (1.) or (2).

Ideally, we would have only large positive silhouette widths, indicating that each data point is much more similar to points within its cluster than it is to the points in any other cluster. However, this is rarely the case. Often, clusters are very fuzzy, and even if we are relatively sure about the existence of discrete groupings in the data, observations on the boundaries can be difficult to confidently place in either cluster.

Here we use the silhouette() function from the cluster package to calculate the silhouette width of our K-means clustering using a distance matrix of distances between points in the clusters.

library("cluster")
dist_mat <- dist(pcs)
sil <- silhouette(cluster$cluster, dist = dist_mat)
plot(sil, border = NA)
Plot with horizontal axis silhoutte width. The plot shows the silhouette width for each point in the data set according to cluster. Cluster 4 contains over half of the points in the data set and largely consists of points with a large silhouette list, leading to a bar that extends to the right side of the graph. The other clusters contain many fewer points and have slightly lower silhouette widths. The bars therefore reach further to the left than cluster 4.

Silhouette plot for each point according to cluster.

Let’s plot the silhouette score on the original dimensions used to cluster the data. Here, we’re mapping cluster membership to point shape, and silhouette width to colour.

pc <- as.data.frame(pcs)
colnames(pc) <- c("x", "y")
pc$sil <- sil[, "sil_width"]
pc$clust <- factor(cluster$cluster)
mean(sil[, "sil_width"])
[1] 0.7065662
ggplot(pc) +
    aes(x, y, shape = clust, colour = sil) +
    geom_point() +
    scale_colour_gradient2(
        low = "dodgerblue", high = "firebrick"
    ) +
    scale_shape_manual(
        values = setNames(1:4, 1:4)
    )
A scatter plot of the random y versus x data. Cluster membership is delineated using different point characters. Data points in the same cluster have the same point character. Each point is coloured by its silhouette width: solid red delineating a silhouette width of 1 and white delineating a silhouette width of 0. Colours in between delineate the intermediate colours. Many points are red and fade to white at the boundaries of each cluster.

Scatter plot of y versus x coloured according to silhouette width and point characters grouped according to cluster membership.

This plot shows that silhouette values for individual observations tends to be very high in the centre of clusters, but becomes quite low towards the edges. This makes sense, as points that are “between” two clusters may be more similar to points in another cluster than they are to the points in the cluster one they belong to.

Challenge 2

Calculate the silhouette width for the k of 5 clustering we did earlier. Are 5 clusters appropriate? Why/why not?

Can you identify where the differences lie?

Solution

sil5 <- silhouette(cluster5$cluster, dist = dist_mat)
scrnaseq$kmeans5 <- as.character(cluster5$cluster)
plotReducedDim(scrnaseq, "PCA", colour_by = "kmeans5")
A scatter plot of principal component 1 versus principal component 2 of the scrnaseq data. Each point is one of five colours, representing cluster membership. Points of the same colour appear in the same areas of the plot, showing five distinct clusters in the data.

Scatter plot of principal component 2 against principal component 1 in the scRNAseq data, coloured by clusters produced by k-means clustering.

mean(sil5[, "sil_width"])
[1] 0.5849979

The average silhouette width is lower when k=5.

plot(sil5, border = NA)
Plot with horizontal axis silhoutte width. The plot shows the silhouette width for each point in the data set according to cluster. Cluster 4 contains almost half of the points in the data set and largely consists of points with a large silhouette list, leading to a bar that extends to the right side of the graph. The other clusters contain many fewer points and have similar silhouette widths. The bars for cluster 5 are much smaller, with a small number extending to the left of the origin, indicating negative silhouette widths.

Silhouette plot for each point according to cluster.

This seems to be because some observations in clusters 3 and 5 seem to be more similar to other clusters than the one they have been assigned to. This may indicate that k is too high.

Gap statistic

Another measure of how good our clustering is is the “gap statistic”. This compares the observed squared distance between observations in a cluster and the centre of the cluster to an “expected” squared distances. The expected distances are calculated by randomly distributing cells within the range of the original data. Larger values represent lower squared distances within clusters, and thus better clustering. We can see how this is calculated in the following example.

library("cluster")
gaps <- clusGap(pcs, kmeans, K.max = 20, iter.max = 20)
best_k <- maxSE(gaps$Tab[, "gap"], gaps$Tab[, "SE.sim"])
best_k
plot(gaps$Tab[,"gap"], xlab = "Number of clusters", ylab = "Gap statistic")
abline(v = best_k, col = "red")

Cluster robustness: bootstrapping

When we cluster data, we want to be sure that the clusters we identify are not a result of the exact properties of the input data. That is, we want to ensure that the clusters identified do not change substantially if the observed data change slightly. This makes it more likely that the clusters can be reproduced.

To assess this, we can bootstrap. We first resample the data with replacement to reproduce a ‘new’ data set. We can then calculate new clusters for this data set and compare these to those on the original data set, thus helping us to see how the clusters may change for small changes in the data. This is maybe easier to see with an example. First, we define some data:

data <- 1:5

Then, we can take a sample from this data without replacement:

sample(data, 5)
[1] 4 1 3 5 2

This sample is a subset of the original data, and points are only present once. This is the case every time even if we do it many times:

## Each column is a sample
replicate(10, sample(data, 5))
     [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
[1,]    5    2    5    2    3    1    3    5    5     3
[2,]    4    5    4    5    4    4    1    3    1     2
[3,]    2    1    1    3    2    5    2    2    3     4
[4,]    1    4    2    1    5    3    5    1    2     5
[5,]    3    3    3    4    1    2    4    4    4     1

However, if we sample with replacement, then sometimes individual data points are present more than once.

replicate(10, sample(data, 5, replace = TRUE))
     [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
[1,]    3    1    2    2    1    3    3    2    4     2
[2,]    1    3    2    4    2    5    2    1    2     5
[3,]    5    5    4    4    2    2    1    1    1     3
[4,]    1    1    4    2    1    4    4    5    5     4
[5,]    3    1    2    1    4    2    5    3    3     2

Bootstrapping

Bootstrapping is a powerful and common statistical technique.

We would like to know about the sampling distribution of a statistic, but we don’t have any knowledge of its behaviour under the null hypothesis.

For example, we might want to understand the uncertainty around an estimate of the mean of our data. To do this, we could resample the data with replacement and calculate the mean of each average.

boots <- replicate(1000, mean(sample(data, 5, replace = TRUE)))
hist(boots,
    breaks = "FD",
    main = "1,000 bootstrap samples",
    xlab = "Mean of sample"
 )
A histogram of the mean of each bootstrapped sample. The histogram appears roughly symmetric around 2.8 on the x axis.

Histogram of mean of bootstapped samples.

In this case, the example is simple, but it’s possible to devise more complex statistical tests using this kind of approach.

The bootstrap, along with permutation testing, can be a very flexible and general solution to many statistical problems.

In applying the bootstrap to clustering, we want to see two things:

  1. Will observations within a cluster consistently cluster together in different bootstrap replicates?
  2. Will observations frequently swap between clusters?

In the plot below, the diagonal of the plot shows how often the clusters are reproduced in boostrap replicates. High scores on the diagonal mean that the clusters are consistently reproduced in each boostrap replicate. Similarly, the off-diagonal elements represent how often observations swap between clusters in bootstrap replicates. High scores indicate that observations rarely swap between clusters.

library("pheatmap")
library("bluster")
library("viridis")

km_fun <- function(x) {
    kmeans(x, centers = 4)$cluster
}
ratios <- bootstrapStability(pcs, FUN = km_fun, clusters = cluster$cluster)
pheatmap(ratios,
    cluster_rows = FALSE, cluster_cols = FALSE,
    col = viridis(10),
    breaks = seq(0, 1, length.out = 10)
)
Grid of 16 squares labelled 1-4 on each of the x and y axes. The diagonal and off-diagonal squares of the grid are coloured in green, indicating the highest scoring value of 1 according to the legend. The lower triangular squares are coloured in grey, indicating NA values since these would be the same as the upper triangular squares.

Grid of empirical cluster swapping behaviour estimated by the bootstrap samples.

Yellow boxes indicate values slightly greater than 1, which may be observed. These are “good” (despite missing in the colour bar).

Challenge 3

Repeat the bootstrapping process with k=5. Do the results appear better or worse? Can you identify where the differences occur on the plotReducedDim()?

Solution

km_fun5 <- function(x) {
    kmeans(x, centers = 5)$cluster
}
set.seed(42)
ratios5 <- bootstrapStability(pcs, FUN = km_fun5, clusters = cluster5$cluster)
pheatmap(ratios5,
    cluster_rows = FALSE, cluster_cols = FALSE,
    col = viridis(10),
    breaks = seq(0, 1, length.out = 10)
)
Grid of 25 squares labelled 1-5 on each of the x and y axes. The diagonal and off-diagonal squares of the grid are coloured in green, indicating the highest scoring value of 1 according to the legend, with the exception of the square corresponding to (4, 5), which is slightly darker green indicating a lower value. The lower triangular squares are coloured in grey, indicating NA values since these would be the same as the upper triangular squares.

Grid of empirical cluster swapping behaviour estimated by the bootstrap samples.

When k=5, we can see that the values on the diagonal of the matrix are smaller, indicating that the clusters aren’t exactly reproducible in the bootstrap samples.

Similarly, the off-diagonal elements are considerably lower for some elements. This indicates that observations are “swapping” between these clusters in bootstrap replicates.

Consensus clustering

One useful and generic method of clustering is consensus clustering. This method can use k-means or other clustering methods.

The idea behind this is to bootstrap the data repeatedly, and cluster it each time, perhaps using different numbers of clusters. If a pair of data points always end up in the same cluster, it’s likely that they really belong to the same underlying cluster.

This is really computationally demanding but has been shown to perform very well in some situations. It also allows you to visualise how cluster membership changes over different values of k.

Speed

It’s worth noting that a lot of the methods we’ve discussed here are very computationally demanding. When clustering data, we may have to compare points to each other many times. This becomes more and more difficult when we have many observations and many features. This is especially problematic when we want to do things like bootstrapping that requires us to cluster the data over and over.

As a result, there are a lot of approximate methods for finding clusters in the data. For example, the mbkmeans package includes an algorithm for clustering extremely large data. The idea behind this algorithm is that if the clusters we find are robust, we don’t need to look at all of the data every time. This is very helpful because it reduces the amount of data that needs to be held in memory at once, but also because it minimises the computational cost.

Similarly, approximate nearest neighbour methods like Annoy can be used to identify what the $k$ closest points are in the data, and this can be used in some clustering methods (for example, graph-based clustering).

Generally, these methods sacrifice a bit of accuracy for a big gain in speed.

Further reading

Key Points

  • K-means is an intuitive algorithm for clustering data.

  • K-means has various advantages but can be computationally intensive.

  • Apparent clusters in high-dimensional data should always be treated with some scepticism.

  • Silhouette width and bootstrapping can be used to assess how well our clustering algorithm has worked.