{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# K-Means Algorithm Demo\n", "\n", "\n", "\n", "> ☝Before moving on with this demo you might want to take a look at:\n", "> - 📗[Math behind the K-Means Algorithm](../../k-means.md)\n", "\n", "\n", "**K-means clustering** aims to partition _n_ observations into _K_ clusters in which each observation belongs to the cluster with the nearest mean, serving as a prototype of the cluster.\n", "\n", "> **Demo Project:** In this example we will try to cluster Iris flowers into tree categories that we don't know in advance based on `petal_length` and `petal_width` parameters using K-Means unsupervised learning algorithm." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# To make debugging of logistic_regression module easier we enable imported modules autoreloading feature.\n", "# By doing this you may change the code of logistic_regression library and all these changes will be available here.\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "# Add project root folder to module loading paths.\n", "import sys\n", "sys.path.append('..')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import Dependencies\n", "\n", "- [pandas](https://pandas.pydata.org/) - library that we will use for loading and displaying the data in a table\n", "- [numpy](http://www.numpy.org/) - library that we will use for linear algebra operations\n", "- [matplotlib](https://matplotlib.org/) - library that we will use for plotting the data\n", "- [k_means](https://github.com/trekhleb/homemade-machine-learning/blob/master/homemade/k_means/k_means.py) - custom implementation of K-Means algorithm" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Import 3rd party dependencies.\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "# Import custom k-means implementation.\n", "from utils.k_means import KMeans" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the Data\n", "\n", "In this demo we will use [Iris data set](http://archive.ics.uci.edu/ml/datasets/Iris).\n", "\n", "The data set consists of several samples from each of three species of Iris (`Iris setosa`, `Iris virginica` and `Iris versicolor`). Four features were measured from each sample: the length and the width of the sepals and petals, in centimeters. Based on the combination of these four features, [Ronald Fisher](https://en.wikipedia.org/wiki/Iris_flower_data_set) developed a linear discriminant model to distinguish the species from each other." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
sepal_lengthsepal_widthpetal_lengthpetal_widthclass
05.13.51.40.2SETOSA
14.93.01.40.2SETOSA
24.73.21.30.2SETOSA
34.63.11.50.2SETOSA
45.03.61.40.2SETOSA
55.43.91.70.4SETOSA
64.63.41.40.3SETOSA
75.03.41.50.2SETOSA
84.42.91.40.2SETOSA
94.93.11.50.1SETOSA
\n", "
" ], "text/plain": [ " sepal_length sepal_width petal_length petal_width class\n", "0 5.1 3.5 1.4 0.2 SETOSA\n", "1 4.9 3.0 1.4 0.2 SETOSA\n", "2 4.7 3.2 1.3 0.2 SETOSA\n", "3 4.6 3.1 1.5 0.2 SETOSA\n", "4 5.0 3.6 1.4 0.2 SETOSA\n", "5 5.4 3.9 1.7 0.4 SETOSA\n", "6 4.6 3.4 1.4 0.3 SETOSA\n", "7 5.0 3.4 1.5 0.2 SETOSA\n", "8 4.4 2.9 1.4 0.2 SETOSA\n", "9 4.9 3.1 1.5 0.1 SETOSA" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Load the data.\n", "data = pd.read_csv('../data/iris.csv')\n", "\n", "# Print the data table.\n", "data.head(10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot the Data\n", "\n", "Let's take two parameters `petal_length` and `petal_width` for each flower into consideration and plot the dependency of the Iris class on these two parameters.\n", "\n", "Since we have an advantage of knowing the actual flower labels (classes) let's illustrate the real-world classification on the plot. But K-Means algorithm is an example of unsuervised learning algorithm which means that this algorithm doesn't need to know about labels. Thus below in this demo we will try to split Iris flowers into unknown clusters and compare the result of such split with the actual flower classification." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# List of suppported Iris classes.\n", "iris_types = ['SETOSA', 'VERSICOLOR', 'VIRGINICA']\n", "\n", "# Pick the Iris parameters for consideration.\n", "x_axis = 'petal_length'\n", "y_axis = 'petal_width'\n", "\n", "# Make the plot a little bit bigger than default one.\n", "plt.figure(figsize=(12, 5))\n", "\n", "# Plot the scatter for every type of Iris.\n", "# This is the case when we know flower labels in advance.\n", "plt.subplot(1, 2, 1)\n", "for iris_type in iris_types:\n", " plt.scatter(\n", " data[x_axis][data['class'] == iris_type],\n", " data[y_axis][data['class'] == iris_type],\n", " label=iris_type\n", " )\n", " \n", "plt.xlabel(x_axis + ' (cm)')\n", "plt.ylabel(y_axis + ' (cm)')\n", "plt.title('Iris Types (labels are known)')\n", "plt.legend()\n", "\n", "# Plot non-classified scatter of Iris flowers.\n", "# This is the case when we don't know flower labels in advance.\n", "# This is how K-Means sees the dataset.\n", "plt.subplot(1, 2, 2)\n", "plt.scatter(\n", " data[x_axis][:],\n", " data[y_axis][:],\n", ")\n", "plt.xlabel(x_axis + ' (cm)')\n", "plt.ylabel(y_axis + ' (cm)')\n", "plt.title('Iris Types (labels are NOT known)')\n", "\n", "# Plot all subplots.\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Prepara the Data for Training\n", "\n", "Let's extract `petal_length` and `petal_width` data and form a training feature set." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Get total number of Iris examples.\n", "num_examples = data.shape[0]\n", "\n", "# Get features.\n", "x_train = data[[x_axis, y_axis]].values.reshape((num_examples, 2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Init and Train Logistic Regression Model\n", "\n", "> ☝🏻This is the place where you might want to play with model configuration.\n", "\n", "- `num_clusters` - number of clusters into which we want to split our training dataset.\n", "- `max_iterations` - maximum number of training iterations." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Set K-Means parameters.\n", "num_clusters = 3 # Number of clusters into which we want to split our training dataset.\n", "max_iterations = 50 # maximum number of training iterations.\n", "\n", "# Init K-Means instance.\n", "k_means = KMeans(x_train, num_clusters)\n", "\n", "# Train K-Means instance.\n", "(centroids, closest_centroids_ids) = k_means.train(max_iterations)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Plot the Clustering Results\n", "\n", "Now let's plot the original Iris flow classification along with our unsupervised K-Means clusters to see how the algorithm performed." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# List of suppported Iris classes.\n", "iris_types = ['SETOSA', 'VERSICOLOR', 'VIRGINICA']\n", "\n", "# Pick the Iris parameters for consideration.\n", "x_axis = 'petal_length'\n", "y_axis = 'petal_width'\n", "\n", "# Make the plot a little bit bigger than default one.\n", "plt.figure(figsize=(12, 5))\n", "\n", "# Plot ACTUAL Iris flower classification.\n", "plt.subplot(1, 2, 1)\n", "for iris_type in iris_types:\n", " plt.scatter(\n", " data[x_axis][data['class'] == iris_type],\n", " data[y_axis][data['class'] == iris_type],\n", " label=iris_type\n", " )\n", "\n", "plt.xlabel(x_axis + ' (cm)')\n", "plt.ylabel(y_axis + ' (cm)')\n", "plt.title('Iris Real-World Clusters')\n", "plt.legend()\n", "\n", "# Plot UNSUPERWISED Iris flower classification.\n", "plt.subplot(1, 2, 2)\n", "for centroid_id, centroid in enumerate(centroids):\n", " current_examples_indices = (closest_centroids_ids == centroid_id).flatten()\n", " plt.scatter(\n", " data[x_axis][current_examples_indices],\n", " data[y_axis][current_examples_indices],\n", " label='Cluster #' + str(centroid_id)\n", " )\n", "\n", "# Plot clusters centroids.\n", "for centroid_id, centroid in enumerate(centroids):\n", " plt.scatter(centroid[0], centroid[1], c='black', marker='x')\n", " \n", "plt.xlabel(x_axis + ' (cm)')\n", "plt.ylabel(y_axis + ' (cm)')\n", "plt.title('Iris K-Means Clusters')\n", "plt.legend()\n", "\n", "# Show all subplots.\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 2 }