KMeansPlusPlusClusterer.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.legacy.ml.clustering;
import org.apache.commons.math4.legacy.exception.NullArgumentException;
import org.apache.commons.math4.legacy.exception.ConvergenceException;
import org.apache.commons.math4.legacy.exception.NotStrictlyPositiveException;
import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
import org.apache.commons.math4.legacy.exception.util.LocalizedFormats;
import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
import org.apache.commons.math4.legacy.ml.distance.EuclideanDistance;
import org.apache.commons.math4.legacy.stat.descriptive.moment.Variance;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.simple.RandomSource;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
/**
* Clustering algorithm based on David Arthur and Sergei Vassilvitski k-means++ algorithm.
* @param <T> type of the points to cluster
* @see <a href="http://en.wikipedia.org/wiki/K-means%2B%2B">K-means++ (wikipedia)</a>
* @since 3.2
*/
public class KMeansPlusPlusClusterer<T extends Clusterable> extends Clusterer<T> {
/** Strategies to use for replacing an empty cluster. */
public enum EmptyClusterStrategy {
/** Split the cluster with largest distance variance. */
LARGEST_VARIANCE,
/** Split the cluster with largest number of points. */
LARGEST_POINTS_NUMBER,
/** Create a cluster around the point farthest from its centroid. */
FARTHEST_POINT,
/** Generate an error. */
ERROR
}
/** The number of clusters. */
private final int numberOfClusters;
/** The maximum number of iterations. */
private final int maxIterations;
/** Random generator for choosing initial centers. */
private final UniformRandomProvider random;
/** Selected strategy for empty clusters. */
private final EmptyClusterStrategy emptyStrategy;
/** Build a clusterer.
* <p>
* The default strategy for handling empty clusters that may appear during
* algorithm iterations is to split the cluster with largest distance variance.
* <p>
* The euclidean distance will be used as default distance measure.
*
* @param k the number of clusters to split the data into
*/
public KMeansPlusPlusClusterer(final int k) {
this(k, Integer.MAX_VALUE);
}
/** Build a clusterer.
* <p>
* The default strategy for handling empty clusters that may appear during
* algorithm iterations is to split the cluster with largest distance variance.
* <p>
* The euclidean distance will be used as default distance measure.
*
* @param k the number of clusters to split the data into
* @param maxIterations the maximum number of iterations to run the algorithm for.
* If negative, no maximum will be used.
*/
public KMeansPlusPlusClusterer(final int k, final int maxIterations) {
this(k, maxIterations, new EuclideanDistance());
}
/** Build a clusterer.
* <p>
* The default strategy for handling empty clusters that may appear during
* algorithm iterations is to split the cluster with largest distance variance.
*
* @param k the number of clusters to split the data into
* @param maxIterations the maximum number of iterations to run the algorithm for.
* @param measure the distance measure to use
* @throws NotStrictlyPositiveException if {@code k <= 0}.
*/
public KMeansPlusPlusClusterer(final int k, final int maxIterations, final DistanceMeasure measure) {
this(k, maxIterations, measure, RandomSource.MT_64.create());
}
/** Build a clusterer.
* <p>
* The default strategy for handling empty clusters that may appear during
* algorithm iterations is to split the cluster with largest distance variance.
*
* @param k the number of clusters to split the data into
* @param maxIterations the maximum number of iterations to run the algorithm for.
* If negative, no maximum will be used.
* @param measure the distance measure to use
* @param random random generator to use for choosing initial centers
*/
public KMeansPlusPlusClusterer(final int k, final int maxIterations,
final DistanceMeasure measure,
final UniformRandomProvider random) {
this(k, maxIterations, measure, random, EmptyClusterStrategy.LARGEST_VARIANCE);
}
/** Build a clusterer.
*
* @param k the number of clusters to split the data into
* @param maxIterations the maximum number of iterations to run the algorithm for.
* @param measure the distance measure to use
* @param random random generator to use for choosing initial centers
* @param emptyStrategy strategy to use for handling empty clusters that
* may appear during algorithm iterations
* @throws NotStrictlyPositiveException if {@code k <= 0} or
* {@code maxIterations <= 0}.
*/
public KMeansPlusPlusClusterer(final int k,
final int maxIterations,
final DistanceMeasure measure,
final UniformRandomProvider random,
final EmptyClusterStrategy emptyStrategy) {
super(measure);
if (k <= 0) {
throw new NotStrictlyPositiveException(k);
}
if (maxIterations <= 0) {
throw new NotStrictlyPositiveException(maxIterations);
}
this.numberOfClusters = k;
this.maxIterations = maxIterations;
this.random = random;
this.emptyStrategy = emptyStrategy;
}
/**
* Return the number of clusters this instance will use.
* @return the number of clusters
*/
public int getNumberOfClusters() {
return numberOfClusters;
}
/**
* Returns the maximum number of iterations this instance will use.
* @return the maximum number of iterations, or -1 if no maximum is set
*/
public int getMaxIterations() {
return maxIterations;
}
/**
* Runs the K-means++ clustering algorithm.
*
* @param points the points to cluster
* @return a list of clusters containing the points
* @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException
* if the data points are null or the number of clusters is larger than the
* number of data points
* @throws ConvergenceException if an empty cluster is encountered and the
* empty cluster strategy is set to {@link EmptyClusterStrategy#ERROR}
*/
@Override
public List<CentroidCluster<T>> cluster(final Collection<T> points) {
// sanity checks
NullArgumentException.check(points);
// number of clusters has to be smaller or equal the number of data points
if (points.size() < numberOfClusters) {
throw new NumberIsTooSmallException(points.size(), numberOfClusters, false);
}
// create the initial clusters
List<CentroidCluster<T>> clusters = chooseInitialCenters(points);
// create an array containing the latest assignment of a point to a cluster
// no need to initialize the array, as it will be filled with the first assignment
int[] assignments = new int[points.size()];
assignPointsToClusters(clusters, points, assignments);
// iterate through updating the centers until we're done
for (int count = 0; count < maxIterations; count++) {
boolean hasEmptyCluster = clusters.stream().anyMatch(cluster->cluster.getPoints().isEmpty());
List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
int changes = assignPointsToClusters(newClusters, points, assignments);
clusters = newClusters;
// if there were no more changes in the point-to-cluster assignment
// and there are no empty clusters left, return the current clusters
if (changes == 0 && !hasEmptyCluster) {
return clusters;
}
}
return clusters;
}
/**
* @return the random generator
*/
UniformRandomProvider getRandomGenerator() {
return random;
}
/**
* @return the {@link EmptyClusterStrategy}
*/
EmptyClusterStrategy getEmptyClusterStrategy() {
return emptyStrategy;
}
/**
* Adjust the clusters's centers with means of points.
* @param clusters the origin clusters
* @return adjusted clusters with center points
*/
List<CentroidCluster<T>> adjustClustersCenters(List<CentroidCluster<T>> clusters) {
List<CentroidCluster<T>> newClusters = new ArrayList<>();
for (final CentroidCluster<T> cluster : clusters) {
final Clusterable newCenter;
if (cluster.getPoints().isEmpty()) {
switch (emptyStrategy) {
case LARGEST_VARIANCE :
newCenter = getPointFromLargestVarianceCluster(clusters);
break;
case LARGEST_POINTS_NUMBER :
newCenter = getPointFromLargestNumberCluster(clusters);
break;
case FARTHEST_POINT :
newCenter = getFarthestPoint(clusters);
break;
default :
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
}
} else {
newCenter = cluster.centroid();
}
newClusters.add(new CentroidCluster<>(newCenter));
}
return newClusters;
}
/**
* Adds the given points to the closest {@link Cluster}.
*
* @param clusters the {@link Cluster}s to add the points to
* @param points the points to add to the given {@link Cluster}s
* @param assignments points assignments to clusters
* @return the number of points assigned to different clusters as the iteration before
*/
private int assignPointsToClusters(final List<CentroidCluster<T>> clusters,
final Collection<T> points,
final int[] assignments) {
int assignedDifferently = 0;
int pointIndex = 0;
for (final T p : points) {
int clusterIndex = getNearestCluster(clusters, p);
if (clusterIndex != assignments[pointIndex]) {
assignedDifferently++;
}
CentroidCluster<T> cluster = clusters.get(clusterIndex);
cluster.addPoint(p);
assignments[pointIndex++] = clusterIndex;
}
return assignedDifferently;
}
/**
* Use K-means++ to choose the initial centers.
*
* @param points the points to choose the initial centers from
* @return the initial centers
*/
List<CentroidCluster<T>> chooseInitialCenters(final Collection<T> points) {
// Convert to list for indexed access. Make it unmodifiable, since removal of items
// would screw up the logic of this method.
final List<T> pointList = Collections.unmodifiableList(new ArrayList<> (points));
// The number of points in the list.
final int numPoints = pointList.size();
// Set the corresponding element in this array to indicate when
// elements of pointList are no longer available.
final boolean[] taken = new boolean[numPoints];
// The resulting list of initial centers.
final List<CentroidCluster<T>> resultSet = new ArrayList<>();
// Choose one center uniformly at random from among the data points.
final int firstPointIndex = random.nextInt(numPoints);
final T firstPoint = pointList.get(firstPointIndex);
resultSet.add(new CentroidCluster<>(firstPoint));
// Must mark it as taken
taken[firstPointIndex] = true;
// To keep track of the minimum distance squared of elements of
// pointList to elements of resultSet.
final double[] minDistSquared = new double[numPoints];
// Initialize the elements. Since the only point in resultSet is firstPoint,
// this is very easy.
for (int i = 0; i < numPoints; i++) {
if (i != firstPointIndex) { // That point isn't considered
double d = distance(firstPoint, pointList.get(i));
minDistSquared[i] = d*d;
}
}
while (resultSet.size() < numberOfClusters) {
// Sum up the squared distances for the points in pointList not
// already taken.
double distSqSum = 0.0;
for (int i = 0; i < numPoints; i++) {
if (!taken[i]) {
distSqSum += minDistSquared[i];
}
}
// Add one new data point as a center. Each point x is chosen with
// probability proportional to D(x)2
final double r = random.nextDouble() * distSqSum;
// The index of the next point to be added to the resultSet.
int nextPointIndex = -1;
// Sum through the squared min distances again, stopping when
// sum >= r.
double sum = 0.0;
for (int i = 0; i < numPoints; i++) {
if (!taken[i]) {
sum += minDistSquared[i];
if (sum >= r) {
nextPointIndex = i;
break;
}
}
}
// If it's not set to >= 0, the point wasn't found in the previous
// for loop, probably because distances are extremely small. Just pick
// the last available point.
if (nextPointIndex == -1) {
for (int i = numPoints - 1; i >= 0; i--) {
if (!taken[i]) {
nextPointIndex = i;
break;
}
}
}
// We found one.
if (nextPointIndex >= 0) {
final T p = pointList.get(nextPointIndex);
resultSet.add(new CentroidCluster<T> (p));
// Mark it as taken.
taken[nextPointIndex] = true;
if (resultSet.size() < numberOfClusters) {
// Now update elements of minDistSquared. We only have to compute
// the distance to the new center to do this.
for (int j = 0; j < numPoints; j++) {
// Only have to worry about the points still not taken.
if (!taken[j]) {
double d = distance(p, pointList.get(j));
double d2 = d * d;
if (d2 < minDistSquared[j]) {
minDistSquared[j] = d2;
}
}
}
}
} else {
// None found --
// Break from the while loop to prevent
// an infinite loop.
break;
}
}
return resultSet;
}
/**
* Get a random point from the {@link Cluster} with the largest distance variance.
*
* @param clusters the {@link Cluster}s to search
* @return a random point from the selected cluster
* @throws ConvergenceException if clusters are all empty
*/
private T getPointFromLargestVarianceCluster(final Collection<CentroidCluster<T>> clusters) {
double maxVariance = Double.NEGATIVE_INFINITY;
Cluster<T> selected = null;
for (final CentroidCluster<T> cluster : clusters) {
if (!cluster.getPoints().isEmpty()) {
// compute the distance variance of the current cluster
final Clusterable center = cluster.getCenter();
final Variance stat = new Variance();
for (final T point : cluster.getPoints()) {
stat.increment(distance(point, center));
}
final double variance = stat.getResult();
// select the cluster with the largest variance
if (variance > maxVariance) {
maxVariance = variance;
selected = cluster;
}
}
}
// did we find at least one non-empty cluster ?
if (selected == null) {
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
}
// extract a random point from the cluster
final List<T> selectedPoints = selected.getPoints();
return selectedPoints.remove(random.nextInt(selectedPoints.size()));
}
/**
* Get a random point from the {@link Cluster} with the largest number of points.
*
* @param clusters the {@link Cluster}s to search
* @return a random point from the selected cluster
* @throws ConvergenceException if clusters are all empty
*/
private T getPointFromLargestNumberCluster(final Collection<? extends Cluster<T>> clusters) {
int maxNumber = 0;
Cluster<T> selected = null;
for (final Cluster<T> cluster : clusters) {
// get the number of points of the current cluster
final int number = cluster.getPoints().size();
// select the cluster with the largest number of points
if (number > maxNumber) {
maxNumber = number;
selected = cluster;
}
}
// did we find at least one non-empty cluster ?
if (selected == null) {
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
}
// extract a random point from the cluster
final List<T> selectedPoints = selected.getPoints();
return selectedPoints.remove(random.nextInt(selectedPoints.size()));
}
/**
* Get the point farthest to its cluster center.
*
* @param clusters the {@link Cluster}s to search
* @return point farthest to its cluster center
* @throws ConvergenceException if clusters are all empty
*/
private T getFarthestPoint(final Collection<CentroidCluster<T>> clusters) {
double maxDistance = Double.NEGATIVE_INFINITY;
Cluster<T> selectedCluster = null;
int selectedPoint = -1;
for (final CentroidCluster<T> cluster : clusters) {
// get the farthest point
final Clusterable center = cluster.getCenter();
final List<T> points = cluster.getPoints();
for (int i = 0; i < points.size(); ++i) {
final double distance = distance(points.get(i), center);
if (distance > maxDistance) {
maxDistance = distance;
selectedCluster = cluster;
selectedPoint = i;
}
}
}
// did we find at least one non-empty cluster ?
if (selectedCluster == null) {
throw new ConvergenceException(LocalizedFormats.EMPTY_CLUSTER_IN_K_MEANS);
}
return selectedCluster.getPoints().remove(selectedPoint);
}
/**
* Returns the nearest {@link Cluster} to the given point.
*
* @param clusters the {@link Cluster}s to search
* @param point the point to find the nearest {@link Cluster} for
* @return the index of the nearest {@link Cluster} to the given point
*/
private int getNearestCluster(final Collection<CentroidCluster<T>> clusters, final T point) {
double minDistance = Double.MAX_VALUE;
int clusterIndex = 0;
int minCluster = 0;
for (final CentroidCluster<T> c : clusters) {
final double distance = distance(point, c.getCenter());
if (distance < minDistance) {
minDistance = distance;
minCluster = clusterIndex;
}
clusterIndex++;
}
return minCluster;
}
}