Using a K-Means Clustering Algorithm for Customer Segmentation

Machine learning (ML) is an umbrella topic encompassing different techniques, learning methods, and applications. My first coding project is one that has significance to many companies today, customer segmentation. Customer segmentation can be performed with a K-Means clustering algorithm which I built from scratch based on the code from a tutorial by Nerd for Tech, and this is what I’ll be explaining in today’s piece!

Introduction to Clustering for Segmentation

Unsupervised Learning

ML is a subset of AI that learns from data and makes predictions in order to solve tasks. It can learn without being programmed with explicit instructions, and there are three main types of algorithms: supervised learning, unsupervised learning, and reinforcement learning. In supervised learning, an ML algorithm is trained on labelled data, in reinforcement learning, an algorithm is based on rewards and punishment, and in unsupervised learning, the data provided to the algorithm is not classified or labelled.

Because unsupervised learning algorithms are not provided with any hints, suggestions, or training data, they must identify patterns and information in the training data set on their own — basically, they’re thrown into the deep end and have to figure it out. Unsupervised learning is optimal for certain scenarios where minimal human intervention is required, and one of its main uses is clustering. Clustering is the process of splitting a dataset into groups (or clusters) based on detected similarities/patterns. No target to predict = unsupervised learning problem!


Why Customer Segmentation?

Customer segmentation is an essential strategy for the optimization and targeting of marketing tactics, as well as knowing your customer to maximize their value and improve their experience with the provided products. Segmentation is grouping customers with similar attributes so that you can target your communications and incorporate personalization into your business without having to do individual reach out (which is pretty much impossible with a large corporation like Google, Facebook, or Twitter). For example, if you segment your customers into three clusters based on income, you can recommend each group of customers products that make sense for them. This is often done using K-means clustering, a very common clustering algorithm!

Getting Started

Getting started with this project, we can import the necessary libraries:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score
from sklearn.cluster import KMeans
from matplotlib import pyplot

The data set used was from Kaggle called “mall customer segmentation data”. There are 5 variables, customer ID, age, annual income, spending score, and gender. Customer ID isn’t useful as it is the unique identifier of each customer, so the column can be deleted from the Pandas DataFrame using the del df[name] function. We can additionally print the head of our dataset. This translates to the following code:

# Load dataset
df = pd.read_csv("/Users/")
del df['CustomerID']

And here is our output, so we can take a peek at our data!

Gender Age Annual Income (k$) Spending Score (1–100)
Male 19 15 39
Male 21 15 81
Female 20 16 6
Female 23 16 77
Female 31 17 40

In total, there are 200 rows and 4 columns. Age, annual income, and spending score are all numerical data types, but gender is categorical meaning that it has to be pre-processed and converted to numeric form, which we’ll do next after mapping out a graph showing a comparison in gender.

Exploring Data

Now that we have all our data input, we can do our EDA — exploratory data analysis — by creating visuals. Seeing things in graph form helps us find relationships, patterns, and similarities between variables. I find that sometimes code can look complicated but when you visualize it, things become a lot more clear! There are three main types of EDA, univariate, bivariate, and multivariate analysis, all of which we will do.

Univariate Analysis

Univariate analysis is the simplest form of analyzing data where the data only has one variable. Common methods for displaying univariate analysis are tables, charts, and histograms. Let’s take a look at a comparison in gender to see if we have more of one gender than the other.

We can usecountplot(), which is like a histogram or bar graph. It shows the counts of observations in each categorical “bin” displayed by bars.

We can input,


To get this graph:

We can see that there are more females than males by a decent amount. There are over 20 more females than males.

Now we have no use for the categorical form of ‘male’ and ‘female’ so we can turn them to numerical form, AKA 0s and 1s, using the following code,

gender= {'Male':0, 'Female':1}
df['Gender']= df['Gender'].map(gender)

And we get this output, where 0s represent males and 1s represent females:

Gender Age Annual Income (k$) Spending Score (1–100)
0 19 15 39
0 21 15 81
1 20 16 6
1 23 16 77
1 31 17 40

Next we’ll look at the variance in age across our dataset. This is important to know for any company because it is crucial for understanding targeted demographics. We can use distplot(), otherwise called a distribution plot, which represents the distribution of continuous data variables.

By inputting the following code,

sns.displot(df['Age'], bins=20)

We get this graph as output:

The age values are distributed across a large range, but we can see the age group best represented by our data is the mid-30s.

Bivariate Analysis

Bivariate analysis involves the analysis of the relationship between two variables. It is used for observing the correlation between features. We can do this because we have everything in numerical form now. I did six scatter plots.

The first was between age and annual income,

plt.scatter(df['Age'],df['Annual Income (k$)'], marker='o');
plt.ylabel('Annual Income (k$)')
plt.title('Scatter plot between Age and Annual Income')

Which yielded this graph:

The highest annual income is for those between ages 30 and 50. Both ends of the scale look alike; 20 year olds get a similar annual income to 60 and 70 year olds. We can make the assumption that this is because at early stages people are just starting off and tend to make less money, and at later stages, people go into retirement.

The second plot was between age and spending score,

plt.scatter(df['Age'],df['Spending Score (1-100)'], marker='o');
plt.ylabel('Spending Score (1-100)')
plt.title('Scatter plot between Age and Spending Score')

Which yielded this graph:

The lower the age, the higher the spending score. 20 and 30 year olds have the highest spending score, and then the spending scores flatten and look similar for ages 40 through 70.

The third plot was between annual income and spending score,

plt.scatter(df['Annual Income (k$)'],df['Spending Score (1-100)'], marker='o');
plt.xlabel('Annual Income (k$)')
plt.ylabel('Spending Score (1-100)')
plt.title('Scatter plot between Annual Income and Spending Score')

Which yielded this graph:

We can see that the points spread out to look like five different sub-groups, and the one in the middle occurs between $40–70k at the 40–60 spending score mark. Other than that, though, it is quite scattered; there is not a linear correlation with annual income and spending score. You might expect higher income = higher score but the highest spending scores occur at both the 20k mark and the 80k mark. This makes sense because we saw in previous graphs that although 20 year olds get relatively lower income than the other age groups, they also have some of the highest spending scores.

The fourth plot was between gender and annual income,

plt.scatter(df['Gender'],df['Annual Income (k$)'], marker='o');
plt.ylabel('Annual Income (k$)')
plt.title('Scatter plot between Gender and Annual Income')

Which yielded this graph:

There isn’t a huge difference, but males (mapped to 0) have a higher income than females (mapped to 1).

The fifth plot was between gender and spending score,

plt.scatter(df['Gender'],df['Spending Score (1-100)'], marker='o');
plt.ylabel('Spending Score (1-100)')
plt.title('Scatter plot between Gender and Spending Score')

Which yielded this graph:

This difference is even less notable. The spending score of females is marginally higher than that of males, meaning they are likely to spend more.

Lastly, the final and sixth plot was between gender and age,

plt.scatter(df['Gender'],df['Age'], marker='o');
plt.title('Scatter plot between Gender and Age')

Which yielded this graph:

This difference is even more subtle! The males in the group are slightly older. It’s also important to take into account that there are quite a few more females than males in the dataset.


Our last EDA method is multivariate analysis, which analyzes data involving at least three variables at the same time to understand their relationships with one another. This is usually done with a heatmap, a graphical representation of data where individual values are contained in a matrix and visualized with colours. We can compare all four of our variables: age, annual income, spending score, and gender.

We can use this code:

fig_dims = (10, 10)
fig, ax = plt.subplots(figsize=fig_dims)
sns.heatmap(df.corr(), annot=True, cmap='inferno')

The last line may look a bit like gibberish but breaking it down, df.corr() means data correlation and is meant for finding the correlation between columns in the dataframe, annot=Trueis an attribute that puts text over, or annotates, each cell, and cmap=‘inferno’is just the colour scheme.

We get this graph as output:

We can see that age is very negatively correlated with all spending score, annual income, and gender, and annual income and spending score are minimally correlated as well.


Before we build our model we have to standardize our data, which is a key part of pre-processing. This is crucial for placing data in a uniform format that can be shared with others, seeing as the data is organized with logical descriptions and labels. It works by putting various variables on the same scale and enables the data to be internally consistent.

The industry’s standard for standardization (pun intended) is the function StandardScaler() that removes the mean and scales each variable to unit variance. Unit variance is when all values are divided by the standard deviation (a measure of the data dispersion with regards to the mean). This function changes the data so the mean = 0, and the standard deviation = 1.

We will do this with age, annual income, and spending score, but do not have to include gender since it is simply portrayed by 0s and 1s. With that being said, we can input the following code,

scaler = StandardScaler()
scaled_data = scaler.fit_transform(data[[‘Age’, ‘Annual Income’, ‘Spending Score’]])

To yield this output:

[[-1.42456879 -1.73899919 -0.43480148]
[-1.28103541 -1.73899919 1.19570407]
[-1.3528021 -1.70082976 -1.71591298]
[-1.13750203 -1.70082976 1.04041783]
[-0.56336851 -1.66266033 -0.39597992]
[-1.20926872 -1.66266033 1.00159627]
[-0.27630176 -1.62449091 -1.71591298]
[-1.13750203 -1.62449091 1.70038436]
[ 1.80493225 -1.58632148 -1.83237767]
[-0.6351352 -1.58632148 0.84631002]
[ 2.02023231 -1.58632148 -1.4053405 ]
[-0.27630176 -1.58632148 1.89449216]
[ 1.37433211 -1.54815205 -1.36651894]

We can see that the variables have been transformed and are now centred around 0.

Building The Model

Now that we’re done with pre-processing, onto building the model!

How the K-Means Algorithm Works

K-means is a centroid-based algorithm where we calculate distances in order to assign a point to a cluster. In K-Means, each cluster is associated with a centroid, AKA the location that represents the cluster’s centre. The number of clusters = the value k.

Before we get into building the model, let’s outline how the K-Means Algorithm really works. Breaking it into steps, the process is like this:

  1. Pick the number of clusters for the dataset (K)
  2. Randomly select a point as the centroid of each cluster
  3. Assign each data point to the nearest centroid (can use a measurement like the Euclidean distance)
  4. Compute the centroid of the clusters again by finding a point in the cluster equidistant from all the data points
  5. Once again, find the points nearest to the new centroids for each cluster
  6. Repeat steps 3–5 until the position of the centroids doesn’t change

Let’s start by specifying the number of clusters we want the data to be grouped into. Initially we randomly assign a value to the model and then we can use a more specific technique to find the optimal number of clusters.

We can create a copy of the data variable using x = df.copy(), which lets us make a backup copy of a data set while performing an operation on another copy and then we can choose a random number of clusters. We can do this using kmeans = KMeans() and put 3 in the brackets. Then we can fit the data, where the parameters of a known function (or model) are transformed to best match the input data. We can make a copy of the input data, and then take note of the predicted clusters (to define cluster_pred). This comes together in the following code:

x = df.copy()
kmeans = KMeans(3)
clusters = x.copy()

We can use the following code to visualize this in a graph,

plt.scatter(clusters['Annual Income (k$)'],clusters['Spending Score (1-100)'],c=clusters['cluster_pred'],cmap='rainbow')
plt.title("Clustering customers based on Annual Income and Spending score", fontsize=15,fontweight="bold")
plt.xlabel("Annual Income")
plt.ylabel("Spending Score")

Which outputs this:

It’s important to note that 3 was chosen for the value K randomly. We can see that the green group is definitely the largest and it looks like it contains three clusters in itself. To find the most appropriate K, we can use the elbow method.

The Elbow Method

Just before we implement the elbow method, let’s understand what the K-means algorithm is really trying to do on a more technical level.

The K-means clustering algorithm’s goal is to cluster similar points, hence reducing the distance of the points in a cluster with their centroid as much as possible. Inertia is a measure of intracluster distances, which is the distance between two objects belonging to the same cluster, and is calculated by measuring the distance between each data point and its centroid using the Euclidean distance (the length of a line segment between two points), squaring the distance, and summing these squares across one cluster. The lesser the inertia value, the better the clusters because the points are closer together.

To find the optimal number of clusters, one method we can use is the elbow method. The steps of recomputing the centroid of the clusters and finding the points nearest to the new centroids for each cluster repeat until the inertia value cannot further be reduced. We can plot a graph, or “elbow curve” to visually see this, where the x-axis represents the number of clusters and the y-axis is the evaluation metric.

Now that we understand, let’s input the following code:

SSE = []
for cluster in range(1,11):
kmeans = KMeans(n_clusters = cluster, init='k-means++')
# converting the results into a dataframe and plotting them
frame = pd.DataFrame({'Cluster':range(1,11), 'SSE':SSE})
plt.figure(figsize=(10, 10))
plt.plot(frame['Cluster'], frame['SSE'], marker='o')
plt.xlabel('Number of clusters')

PS: SSE in our code = sum of squared differences.

Note the K-Means++ in our code. Sometimes if the initialization of clusters is not appropriate, K-Means can result in badly grouped clusters. When we were finding the initial centroids in our previous, we were using randomization. The initial k-centroids were randomly chosen from the data points. However, randomization is not very accurate — because, as the name implies, it’s random.

Therefore we use K-Means++, which specifies a procedure to initialize the centroids before proceeding with the standard k-means algorithm. The steps here are that 1.) the first cluster is chosen at random (instead of all the centroids being picked here, just one is), 2.) the distance of each data point from the centroid that has already been selected is computed, 3.) a new centroid is chosen that has the maximum probability of being proportional to this distance, and 4.) steps 2 and 3 repeat until the clusters have been chosen.

We get this plot as output:

The cluster value is where the inertia value stops largely decreasing and becomes constant, which is 5. Therefore, we have 5 clusters.

Silhouette Coefficient

Something else to take into account with this algorithm is that the clusters should be distinct from one another, so each cluster will have unique features that others do not. The best algorithm will make it so that the clusters are as different from each other as possible. The distance between the centroids of two different clusters is called the inter-cluster distance, and we want this to be maximized.

One clustering evaluation metric that measures how dense and well separated the clusters are is the silhouette score. The metric ranges from -1 to 1, -1 being the worst (the distance between the clusters is not significant and they overlap), and 1 being the best (the clusters are distinguished and significantly apart from one another). It is essentially a way to measure the accuracy or goodness of our clustering technique.

We can input the following code,

print(silhouette_score(clusters, kmeans.labels_, metric=’euclidean’))

And in return we get the following value as output:


This means that the silhouette score for the model is 0.38 which is sufficient.

We can assign the optimal number of clusters as 5 in our code and make a new data frame with the predicted clusters.

kmeans_new = KMeans(5)
#Fit the data
#Create a new data frame with the predicted clusters
clusters_new = x.copy()
clusters_new['cluster_pred'] = kmeans_new.fit_predict(x)

We can also finally revert our categorical feature ‘Gender’ from numerical form with the following code:

gender= {0:'Male',1:'Female'}
clusters_new['Gender']= clusters_new['Gender'].map(gender)

Scatter Plot

The last part before diving into the analysis of our clusters is actually visualizing them with a scatter plot! It’s also nice to see how they’ve turned out.

We can input the following code, which looks similar to the code for the clusters where we used random initialization. Everything is the same except now instead of using clusters() we can use clusters_new()because we have used the elbow method to get an accurate depiction of our clusters.

plt.scatter(clusters_new['Annual Income(k$)'],clusters_new['Spending Score (1-100)'],c=clusters_new['cluster_pred'],cmap='rainbow')
plt.title("Clustering customers based on Annual Income and Spending score", fontsize=15,fontweight="bold")
plt.xlabel("Annual Income")
plt.ylabel("Spending Score")

We are given the following plot:

We can now see that the clusters have been divided quite nicely, and there are no overlaps. One green point is a bit close to a red point, which may have been a contributor to some losses in the silhouette score, but all in all, it looks good.

This is a comparison of the graphs before and after determining the optimal number of clusters:

We can see that the graph on the right looks much more equally divided.

Analyzing the Clusters

We now know that our data can be divided into 5 clusters, and we can start to interpret them and draw insights. We can sort them like this:

  • Orange corresponds with a low annual income and low spending score
  • Green corresponds with an average annual income and average spending score
  • Blue corresponds with a low annual income and high spending score
  • Red corresponds with a high annual income and low spending score
  • Purple corresponds with a high annual income and high spending score

Let’s go deeper. To compare the attributes of the different clusters, we can find the average of all variables across each cluster using the following code:

avg_data = clusters_new.groupby([‘cluster_pred’], as_index=False).mean()

To clarify some of the functions here; when we use as_index=False, we indicate to groupby() that we don’t want to set the column ID as the index. Groupby() is a function that splits the data into separate groups to perform computations. This is what the output looks like:

cluster_pred    Age        Annual Income (k$) Spending Score (1–100)
0 24.960000 28.040000 77.000000
1 43.727273 55.480519 49.324675
2 32.692308 86.538462 82.128205
3 40.666667 87.750000 17.583333
4 45.217391 26.304348 20.913043

This means that in the first cluster (cluster 0) the average person is around 25 years old, and has a low income and high spending score. If we look at our description above, this corresponds to the blue cluster! Let’s visualize this with some graphs.


We can run this simple code to come up with different visualizations for each variable:

sns.barplot(x='cluster_pred',y='Annual Income (k$)',palette="plasma",data=avg_data)
sns.barplot(x='cluster_pred',y='Spending Score (1-100)',palette="plasma",data=avg_data)

From this code, we get the following graphs:

And of course, we can’t forget about the gender variable, so we can input the following code to see the proportion of gender in each cluster:

data2 = pd.DataFrame(clusters_new.groupby(['cluster_pred','Gender'])['Gender'].count())

Here is our output:

cluster_pred Gender 
0 Female 47 Male 33
1 Female 21 Male 18
2 Female 17 Male 19
3 Female 14 Male 9
4 Female 13 Male 9

As our graph told us at the beginning of this project, there are more females than males. Clusters 0, 2, and 4 have a higher proportion of females than males and cluster 2 has nearly an equal proportion with slightly more men.

Attributes of Clusters

Cluster 0 (violet): intermediate annual income, intermediate spending score

  • Early 40s
  • 55k annual income
  • Intermediate spending score of 49
  • Predominantly female

Cluster 1 (magenta): High annual income, low spending score

  • Late 30s
  • 86k annual income
  • Low spending score of 17
  • More or less equal in gender

Cluster 2 (pink): High annual income, high spending score

  • Early 30s
  • 85k annual income
  • High spending score of 82
  • Predominantly female

Cluster 3 (orange): Low annual income, low spending score

  • Mid 40s
  • 26k annual income
  • Low spending score of 21
  • Predominantly female

Cluster 4 (yellow): Low annual income, high spending score

  • Mid 20s
  • 26k annual income
  • High spending score of 78
  • Predominantly female


Now we have reached the last step — my favourite part, building personas around each cluster. This is exactly as it sounds; taking the attributes from the clusters and applying them to people through a bit of storytelling. By doing so, we can also create recommendations for those in each cluster!

Cluster 0: Middle Income

This is a group of people who make and spend their money on an intermediate level. They are careful not to spend over their means as they don’t have an enormous stream of income. Being in their early 40s, it is likely that they may have children and have to be frugal at points to minimize financial worries.

Recommendations: Discount coupons, family packages

Cluster 1: Careful spenders

Next up we have our careful individuals, who have the highest annual income out of all the clusters and the lowest spending score. This group consists of middle-aged people potentially saving up for something like children, moving to another city, etc.

Recommendations: Special offers, promos, discount coupons

Cluster 2: Affluent individuals

Here we reach our affluent individuals. They have done well in their career so far and are still in their early 30s, and have built up wealth which they spend at fairly high rates. If they have a calling to something, they’ll buy it!

Recommendations: Investment properties, car deals

Cluster 3: Nearing retirement

This group is the oldest in age with both a low annual income and low spending score. They are likely saving up for retirement.

Recommendation: Healthcare products, promos

Cluster 4: Impulsive buyers

Lastly, we reach our impulsive buyers! They’re young, just starting out in their career, and don’t make much money. But that doesn’t mean they don’t spend! This group spends over their means and like living a good lifestyle.

Recommendation: Travel coupons, and coupons for clothing, makeup, and skincare


Creating a customer segmentation project was such a fun first project and although simple, it really helped me grasp the basics of ML. I learned a lot from building a K-means algorithm and am excited to keep working on projects!

All my code is available on Github here.

Check out my video on this project here!

Thank you so much for reading this! I’m a 15-year-old passionate about sustainability, and am the author of “Chronicles of Illusions: The Blue Wild”. If you want to see more of my work, connect with me on LinkedIn, Twitter, or subscribe to my monthly newsletter!



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store