MiniBatchKMeansClusterer.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.legacy.ml.clustering;
import org.apache.commons.math4.legacy.exception.NullArgumentException;
import org.apache.commons.math4.legacy.exception.NumberIsTooSmallException;
import org.apache.commons.math4.legacy.ml.distance.DistanceMeasure;
import org.apache.commons.math4.legacy.core.Pair;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.ListSampler;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
/**
* Clustering algorithm <a href="https://www.eecs.tufts.edu/~dsculley/papers/fastkmeans.pdf">
* based on KMeans</a>.
*
* @param <T> Type of the points to cluster.
*/
public class MiniBatchKMeansClusterer<T extends Clusterable>
extends KMeansPlusPlusClusterer<T> {
/** Batch data size in iteration. */
private final int batchSize;
/** Iteration count of initialize the centers. */
private final int initIterations;
/** Data size of batch to initialize the centers. */
private final int initBatchSize;
/** Maximum number of iterations during which no improvement is occuring. */
private final int maxNoImprovementTimes;
/**
* Build a clusterer.
*
* @param k Number of clusters to split the data into.
* @param maxIterations Maximum number of iterations to run the algorithm for all the points,
* The actual number of iterationswill be smaller than {@code maxIterations * size / batchSize},
* where {@code size} is the number of points to cluster.
* Disabled if negative.
* @param batchSize Batch size for training iterations.
* @param initIterations Number of iterations allowed in order to find out the best initial centers.
* @param initBatchSize Batch size for initializing the clusters centers.
* A value of {@code 3 * batchSize} should be suitable in most cases.
* @param maxNoImprovementTimes Maximum number of iterations during which no improvement is occuring.
* A value of 10 is suitable in most cases.
* @param measure Distance measure.
* @param random Random generator.
* @param emptyStrategy Strategy for handling empty clusters that may appear during algorithm iterations.
*/
public MiniBatchKMeansClusterer(final int k,
final int maxIterations,
final int batchSize,
final int initIterations,
final int initBatchSize,
final int maxNoImprovementTimes,
final DistanceMeasure measure,
final UniformRandomProvider random,
final EmptyClusterStrategy emptyStrategy) {
super(k, maxIterations, measure, random, emptyStrategy);
if (batchSize < 1) {
throw new NumberIsTooSmallException(batchSize, 1, true);
}
if (initIterations < 1) {
throw new NumberIsTooSmallException(initIterations, 1, true);
}
if (initBatchSize < 1) {
throw new NumberIsTooSmallException(initBatchSize, 1, true);
}
if (maxNoImprovementTimes < 1) {
throw new NumberIsTooSmallException(maxNoImprovementTimes, 1, true);
}
this.batchSize = batchSize;
this.initIterations = initIterations;
this.initBatchSize = initBatchSize;
this.maxNoImprovementTimes = maxNoImprovementTimes;
}
/**
* Runs the MiniBatch K-means clustering algorithm.
*
* @param points Points to cluster (cannot be {@code null}).
* @return the clusters.
* @throws org.apache.commons.math4.legacy.exception.MathIllegalArgumentException
* if the number of points is smaller than the number of clusters.
*/
@Override
public List<CentroidCluster<T>> cluster(final Collection<T> points) {
// Sanity check.
NullArgumentException.check(points);
if (points.size() < getNumberOfClusters()) {
throw new NumberIsTooSmallException(points.size(), getNumberOfClusters(), false);
}
final int pointSize = points.size();
final int batchCount = pointSize / batchSize + (pointSize % batchSize > 0 ? 1 : 0);
final int max = getMaxIterations() < 0 ?
Integer.MAX_VALUE :
getMaxIterations() * batchCount;
final List<T> pointList = new ArrayList<>(points);
List<CentroidCluster<T>> clusters = initialCenters(pointList);
final ImprovementEvaluator evaluator = new ImprovementEvaluator(batchSize,
maxNoImprovementTimes);
for (int i = 0; i < max; i++) {
clearClustersPoints(clusters);
final List<T> batchPoints = ListSampler.sample(getRandomGenerator(), pointList, batchSize);
// Training step.
final Pair<Double, List<CentroidCluster<T>>> pair = step(batchPoints, clusters);
final double squareDistance = pair.getFirst();
clusters = pair.getSecond();
// Check whether the training can finished early.
if (evaluator.converge(squareDistance, pointSize)) {
break;
}
}
// Add every mini batch points to their nearest cluster.
clearClustersPoints(clusters);
for (final T point : points) {
addToNearestCentroidCluster(point, clusters);
}
return clusters;
}
/**
* Helper method.
*
* @param clusters Clusters to clear.
*/
private void clearClustersPoints(final List<CentroidCluster<T>> clusters) {
for (CentroidCluster<T> cluster : clusters) {
cluster.getPoints().clear();
}
}
/**
* Mini batch iteration step.
*
* @param batchPoints Points selected for this batch.
* @param clusters Centers of the clusters.
* @return the squared distance of all the batch points to the nearest center.
*/
private Pair<Double, List<CentroidCluster<T>>> step(final List<T> batchPoints,
final List<CentroidCluster<T>> clusters) {
// Add every mini batch points to their nearest cluster.
for (final T point : batchPoints) {
addToNearestCentroidCluster(point, clusters);
}
final List<CentroidCluster<T>> newClusters = adjustClustersCenters(clusters);
// Add every mini batch points to their nearest cluster again.
double squareDistance = 0.0;
for (T point : batchPoints) {
final double d = addToNearestCentroidCluster(point, newClusters);
squareDistance += d * d;
}
return new Pair<>(squareDistance, newClusters);
}
/**
* Initializes the clusters centers.
*
* @param points Points used to initialize the centers.
* @return clusters with their center initialized.
*/
private List<CentroidCluster<T>> initialCenters(final List<T> points) {
final List<T> validPoints = initBatchSize < points.size() ?
ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
new ArrayList<>(points);
double nearestSquareDistance = Double.POSITIVE_INFINITY;
List<CentroidCluster<T>> bestCenters = null;
for (int i = 0; i < initIterations; i++) {
final List<T> initialPoints = (initBatchSize < points.size()) ?
ListSampler.sample(getRandomGenerator(), points, initBatchSize) :
new ArrayList<>(points);
final List<CentroidCluster<T>> clusters = chooseInitialCenters(initialPoints);
final Pair<Double, List<CentroidCluster<T>>> pair = step(validPoints, clusters);
final double squareDistance = pair.getFirst();
final List<CentroidCluster<T>> newClusters = pair.getSecond();
//Find out a best centers that has the nearest total square distance.
if (squareDistance < nearestSquareDistance) {
nearestSquareDistance = squareDistance;
bestCenters = newClusters;
}
}
return bestCenters;
}
/**
* Adds a point to the cluster whose center is closest.
*
* @param point Point to add.
* @param clusters Clusters.
* @return the distance between point and the closest center.
*/
private double addToNearestCentroidCluster(final T point,
final List<CentroidCluster<T>> clusters) {
double minDistance = Double.POSITIVE_INFINITY;
CentroidCluster<T> closestCentroidCluster = null;
// Find cluster closest to the point.
for (CentroidCluster<T> centroidCluster : clusters) {
final double distance = distance(point, centroidCluster.getCenter());
if (distance < minDistance) {
minDistance = distance;
closestCentroidCluster = centroidCluster;
}
}
NullArgumentException.check(closestCentroidCluster);
closestCentroidCluster.addPoint(point);
return minDistance;
}
/**
* Stopping criterion.
* The evaluator checks whether improvement occurred during the
* {@link #maxNoImprovementTimes allowed number of successive iterations}.
*/
private static final class ImprovementEvaluator {
/** Batch size. */
private final int batchSize;
/** Maximum number of iterations during which no improvement is occuring. */
private final int maxNoImprovementTimes;
/**
* <a href="https://en.wikipedia.org/wiki/Moving_average">
* Exponentially Weighted Average</a> of the squared
* diff to monitor the convergence while discarding
* minibatch-local stochastic variability.
*/
private double ewaInertia = Double.NaN;
/** Minimum value of {@link #ewaInertia} during iteration. */
private double ewaInertiaMin = Double.POSITIVE_INFINITY;
/** Number of iteration during which {@link #ewaInertia} did not improve. */
private int noImprovementTimes;
/**
* @param batchSize Number of elements for each batch iteration.
* @param maxNoImprovementTimes Maximum number of iterations during
* which no improvement is occuring.
*/
private ImprovementEvaluator(int batchSize,
int maxNoImprovementTimes) {
this.batchSize = batchSize;
this.maxNoImprovementTimes = maxNoImprovementTimes;
}
/**
* Stopping criterion.
*
* @param squareDistance Total square distance from the batch points
* to their nearest center.
* @param pointSize Number of data points.
* @return {@code true} if no improvement was made after the allowed
* number of iterations, {@code false} otherwise.
*/
public boolean converge(final double squareDistance,
final int pointSize) {
final double batchInertia = squareDistance / batchSize;
if (Double.isNaN(ewaInertia)) {
ewaInertia = batchInertia;
} else {
final double alpha = Math.min(batchSize * 2 / (pointSize + 1), 1);
ewaInertia = ewaInertia * (1 - alpha) + batchInertia * alpha;
}
if (ewaInertia < ewaInertiaMin) {
// Improved.
noImprovementTimes = 0;
ewaInertiaMin = ewaInertia;
} else {
// No improvement.
++noImprovementTimes;
}
return noImprovementTimes >= maxNoImprovementTimes;
}
}
}