CalinskiHarabasz.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.math4.legacy.ml.clustering.evaluation;
import org.apache.commons.math4.legacy.exception.InsufficientDataException;
import org.apache.commons.math4.legacy.ml.clustering.Cluster;
import org.apache.commons.math4.legacy.ml.clustering.ClusterEvaluator;
import org.apache.commons.math4.legacy.ml.clustering.Clusterable;
import org.apache.commons.math4.legacy.core.MathArrays;
import java.util.Collection;
import java.util.List;
/**
* Compute the Calinski and Harabasz score.
* <p>
* It is also known as the Variance Ratio Criterion.
* <p>
* The score is defined as ratio between the within-cluster dispersion and
* the between-cluster dispersion.
*
* @see <a href="https://www.tandfonline.com/doi/abs/10.1080/03610927408827101">A dendrite method for cluster
* analysis</a>
*/
public class CalinskiHarabasz implements ClusterEvaluator {
/** {@inheritDoc} */
@Override
public double score(List<? extends Cluster<? extends Clusterable>> clusters) {
final int dimension = dimensionOfClusters(clusters);
final double[] centroid = meanOfClusters(clusters, dimension);
double intraDistanceProduct = 0.0;
double extraDistanceProduct = 0.0;
for (Cluster<? extends Clusterable> cluster : clusters) {
// Calculate the center of the cluster.
double[] clusterCentroid = mean(cluster.getPoints(), dimension);
for (Clusterable p : cluster.getPoints()) {
// Increase the intra distance sum
intraDistanceProduct += covariance(clusterCentroid, p.getPoint());
}
// Increase the extra distance sum
extraDistanceProduct += cluster.getPoints().size() * covariance(centroid, clusterCentroid);
}
final int pointCount = countAllPoints(clusters);
final int clusterCount = clusters.size();
// Return the ratio of the intraDistranceProduct to extraDistanceProduct
return intraDistanceProduct == 0.0 ? 1.0 :
(extraDistanceProduct * (pointCount - clusterCount) /
(intraDistanceProduct * (clusterCount - 1)));
}
/** {@inheritDoc} */
@Override
public boolean isBetterScore(double a,
double b) {
return a > b;
}
/**
* Calculate covariance of two double array.
* <pre>
* covariance = sum((p1[i]-p2[i])^2)
* </pre>
*
* @param p1 Double array
* @param p2 Double array
* @return covariance of two double array
*/
private double covariance(double[] p1, double[] p2) {
MathArrays.checkEqualLength(p1, p2);
double sum = 0;
for (int i = 0; i < p1.length; i++) {
final double dp = p1[i] - p2[i];
sum += dp * dp;
}
return sum;
}
/**
* Calculate the mean of all the points.
*
* @param points A collection of points
* @param dimension The dimension of each point
* @return The mean value.
*/
private double[] mean(final Collection<? extends Clusterable> points, final int dimension) {
final double[] centroid = new double[dimension];
for (final Clusterable p : points) {
final double[] point = p.getPoint();
for (int i = 0; i < centroid.length; i++) {
centroid[i] += point[i];
}
}
for (int i = 0; i < centroid.length; i++) {
centroid[i] /= points.size();
}
return centroid;
}
/**
* Calculate the mean of all the points in the clusters.
*
* @param clusters A collection of clusters.
* @param dimension The dimension of each point.
* @return The mean value.
*/
private double[] meanOfClusters(final Collection<? extends Cluster<? extends Clusterable>> clusters, final int dimension) {
final double[] centroid = new double[dimension];
int allPointsCount = 0;
for (Cluster<? extends Clusterable> cluster : clusters) {
for (Clusterable p : cluster.getPoints()) {
double[] point = p.getPoint();
for (int i = 0; i < centroid.length; i++) {
centroid[i] += point[i];
}
allPointsCount++;
}
}
for (int i = 0; i < centroid.length; i++) {
centroid[i] /= allPointsCount;
}
return centroid;
}
/**
* Count all the points in collection of cluster.
*
* @param clusters collection of cluster
* @return points count
*/
private int countAllPoints(final Collection<? extends Cluster<? extends Clusterable>> clusters) {
int pointCount = 0;
for (Cluster<? extends Clusterable> cluster : clusters) {
pointCount += cluster.getPoints().size();
}
return pointCount;
}
/**
* Detect the dimension of points in the clusters.
*
* @param clusters collection of cluster
* @return The dimension of the first point in clusters
*/
private int dimensionOfClusters(final Collection<? extends Cluster<? extends Clusterable>> clusters) {
// Iteration and find out the first point.
for (Cluster<? extends Clusterable> cluster : clusters) {
for (Clusterable p : cluster.getPoints()) {
return p.getPoint().length;
}
}
// Throw exception if there is no point.
throw new InsufficientDataException();
}
}