001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.text.similarity;
018
019import java.util.HashSet;
020import java.util.Map;
021import java.util.Set;
022
023/**
024 * Measures the Cosine similarity of two vectors of an inner product space and compares the angle between them.
025 * <p>
026 * For further explanation about the Cosine Similarity, refer to https://en.wikipedia.org/wiki/Cosine_similarity.
027 * </p>
028 * <p>
029 * Instances of this class are immutable and are safe for use by multiple concurrent threads.
030 * </p>
031 *
032 * @since 1.0
033 */
034public class CosineSimilarity {
035
036    /**
037     * Singleton instance.
038     */
039    static final CosineSimilarity INSTANCE = new CosineSimilarity();
040
041    /**
042     * Construct a new instance.
043     */
044    public CosineSimilarity() {
045        // empty
046    }
047
048    /**
049     * Calculates the cosine similarity for two given vectors.
050     *
051     * @param leftVector left vector
052     * @param rightVector right vector
053     * @return cosine similarity between the two vectors
054     */
055    public Double cosineSimilarity(final Map<CharSequence, Integer> leftVector,
056                                   final Map<CharSequence, Integer> rightVector) {
057        if (leftVector == null || rightVector == null) {
058            throw new IllegalArgumentException("Vectors must not be null");
059        }
060
061        final Set<CharSequence> intersection = getIntersection(leftVector, rightVector);
062
063        final double dotProduct = dot(leftVector, rightVector, intersection);
064        double d1 = 0.0d;
065        for (final Integer value : leftVector.values()) {
066            d1 += Math.pow(value, 2);
067        }
068        double d2 = 0.0d;
069        for (final Integer value : rightVector.values()) {
070            d2 += Math.pow(value, 2);
071        }
072        final double cosineSimilarity;
073        if (d1 <= 0.0 || d2 <= 0.0) {
074            cosineSimilarity = 0.0;
075        } else {
076            cosineSimilarity = dotProduct / (Math.sqrt(d1) * Math.sqrt(d2));
077        }
078        return cosineSimilarity;
079    }
080
081    /**
082     * Computes the dot product of two vectors. It ignores remaining elements. It means
083     * that if a vector is longer than other, then a smaller part of it will be used to compute
084     * the dot product.
085     *
086     * @param leftVector left vector
087     * @param rightVector right vector
088     * @param intersection common elements
089     * @return The dot product
090     */
091    private double dot(final Map<CharSequence, Integer> leftVector, final Map<CharSequence, Integer> rightVector,
092            final Set<CharSequence> intersection) {
093        long dotProduct = 0;
094        for (final CharSequence key : intersection) {
095            dotProduct += leftVector.get(key) * (long) rightVector.get(key);
096        }
097        return dotProduct;
098    }
099
100    /**
101     * Returns a set with keys common to the two given maps.
102     *
103     * @param leftVector left vector map
104     * @param rightVector right vector map
105     * @return common strings
106     */
107    private Set<CharSequence> getIntersection(final Map<CharSequence, Integer> leftVector,
108            final Map<CharSequence, Integer> rightVector) {
109        final Set<CharSequence> intersection = new HashSet<>(leftVector.keySet());
110        intersection.retainAll(rightVector.keySet());
111        return intersection;
112    }
113
114}