CosineSimilarity.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.commons.text.similarity;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
/**
* Measures the Cosine similarity of two vectors of an inner product space and compares the angle between them.
* <p>
* For further explanation about the Cosine Similarity, refer to https://en.wikipedia.org/wiki/Cosine_similarity.
* </p>
* <p>
* Instances of this class are immutable and are safe for use by multiple concurrent threads.
* </p>
*
* @since 1.0
*/
public class CosineSimilarity {
/**
* Singleton instance.
*/
static final CosineSimilarity INSTANCE = new CosineSimilarity();
/**
* Calculates the cosine similarity for two given vectors.
*
* @param leftVector left vector
* @param rightVector right vector
* @return cosine similarity between the two vectors
*/
public Double cosineSimilarity(final Map<CharSequence, Integer> leftVector,
final Map<CharSequence, Integer> rightVector) {
if (leftVector == null || rightVector == null) {
throw new IllegalArgumentException("Vectors must not be null");
}
final Set<CharSequence> intersection = getIntersection(leftVector, rightVector);
final double dotProduct = dot(leftVector, rightVector, intersection);
double d1 = 0.0d;
for (final Integer value : leftVector.values()) {
d1 += Math.pow(value, 2);
}
double d2 = 0.0d;
for (final Integer value : rightVector.values()) {
d2 += Math.pow(value, 2);
}
final double cosineSimilarity;
if (d1 <= 0.0 || d2 <= 0.0) {
cosineSimilarity = 0.0;
} else {
cosineSimilarity = dotProduct / (Math.sqrt(d1) * Math.sqrt(d2));
}
return cosineSimilarity;
}
/**
* Computes the dot product of two vectors. It ignores remaining elements. It means
* that if a vector is longer than other, then a smaller part of it will be used to compute
* the dot product.
*
* @param leftVector left vector
* @param rightVector right vector
* @param intersection common elements
* @return The dot product
*/
private double dot(final Map<CharSequence, Integer> leftVector, final Map<CharSequence, Integer> rightVector,
final Set<CharSequence> intersection) {
long dotProduct = 0;
for (final CharSequence key : intersection) {
dotProduct += leftVector.get(key) * (long) rightVector.get(key);
}
return dotProduct;
}
/**
* Returns a set with strings common to the two given maps.
*
* @param leftVector left vector map
* @param rightVector right vector map
* @return common strings
*/
private Set<CharSequence> getIntersection(final Map<CharSequence, Integer> leftVector,
final Map<CharSequence, Integer> rightVector) {
final Set<CharSequence> intersection = new HashSet<>(leftVector.keySet());
intersection.retainAll(rightVector.keySet());
return intersection;
}
}