Sunday, September 27, 2015

Sentence Similarity using Word2Vec and Word Movers Distance


Sometime back, I read about the Word Mover's Distance (WMD) in the paper From Word Embeddings to Document Distances by Kusner, Sun, Kolkin and Weinberger. The WMD is a distance function that measures the distance between two texts as the cumulative sum of minimum distance each word in one text must move in vector space to the closest word in the other text. In the paper, the authors provide some examples where WMD is calculated against a Word2Vec vector space. Since Word2Vec word embeddings preserve aspects of the word's context, its a good way to capture semantic meaning (or difference in meaning) when calculating WMD.

The paper reminded me of a similar (in intent) algorithm that I had implemented earlier and written about in my post Computing Semantic Similarity for Short Sentences. There, we captured the semantic meaning using an external semantic network (Wordnet).

Since the problems were so similar, I figured that it might be interesting to compute the WMD for the sentence pairs in this paper and see how they match up with intuition. I already had lying around a dump of the GoogleNews vectors (pretrained vectors over about 100B words of Google News) from a previous project. The paper described results over a dataset of just 16 short sentence pairs, so I decided to do this interactively on Spark using a Databricks notebook. We use Databricks at work and its ideal for this kind of quick and dirty ad-hoc work.

First we load up our 16 sentence pairs. The input is 3 columns - sentence#1, sentence#2 and the original score, tab separated. Since we don't care about the original score, we discard it and convert the input to a pair.

Since we want to compare words across sentences in the same pair, it makes sense to have these words in the same worker when they are compared, so we add an index key to each sentence pair. The output of this cell is an RDD that looks like ((sentence1: String, sentence2: String), index: Long).

1
 2
 3
 4
 5
 6
 7
 8
 9
10
import org.apache.spark.storage.StorageLevel

val sentencePairs = sc.textFile("sentence_pairs.txt")
    .map(line => {
        val Array(s1, s2, _) = line.split('\t')
        (s1, s2)
    })
    .zipWithIndex
    .persist(StorageLevel.MEMORY_AND_DISK)
sentencePairs.count()

WMD between two sentences (or between any two blobs of text) is computed as the sum of the distances between closest pairs of words in the texts. The words are pre-processed to remove stop words, so the next cell pulls in a list of English stopwords which I convert to a Set and broadcast to the Worker boxes.

1
2
val stopwords = sc.textFile("stopwords.txt").collect.toSet
val bStopwords = sc.broadcast(stopwords)

We now split up both sentences into words (removing punctuation and splitting on whitespace), removing stopwords from each, then flatMap-ing them to the format (index: Long, (word1: String, word2: String)). This gives us a list of 71 word pairs.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
def getWordPairs(id: Long, s1: String, s2: String, stopwords: Set[String]): 
        List[(Long, (String, String))] = {
    val w1s = s1.toLowerCase
          .replaceAll("\\p{Punct}", "")
          .split(" ")
          .filter(w => !stopwords.contains(w))
    val w2s = s2.toLowerCase
          .replaceAll("\\p{Punct}", "")
          .split(" ")
          .filter(w => !stopwords.contains(w))
    val wpairs = for (w1 <- w1s; w2 <- w2s) yield (id, (w1, w2))
    wpairs.toList
}

val wordPairs = sentencePairs.flatMap(ssi => 
    getWordPairs(ssi._2, ssi._1._1, ssi._1._2, bStopwords.value))
wordPairs.count()

Next we ingest the Word2Vec vectors. I've used Gensim's Word2Vec module to convert the the original Word2Vec binary format to TSV. The format of this dataset is (word: String, comma-separated list of 300 vector elements).

1
2
3
4
5
val w2vs = sc.textFile("GoogleNews-vectors-negative300.tsv")
    .map(line => {
        val Array(word, vector) = line.split('\t')
        (word, vector)
    })

Next, we join the wordPairs against the w2vs RDD on the RHS and the LHS words to get the 300 dimensional word2vec vector for the RHS and LHS word respectively. We do a lot of moving things around so I have used case matching instead of the less intuitive underscore syntax to represent tuple elements and subelements. Note that we need to hang on to the left word because we want to find the word that is closest to each left word.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
import breeze.linalg._

def dist(lvec: String, rvec: String): Double = {
    val lv = DenseVector(lvec.split(',').map(_.toDouble))
    val rv = DenseVector(rvec.split(',').map(_.toDouble))
    math.sqrt(sum((lv - rv) :* (lv - rv)))
}

val wordVectors = wordPairs.map({case (idx, (lword, rword)) => 
        (rword, (idx, lword))})
    .join(w2vs)    // (rword, ((idx, lword), rvec))
    .map({case (rword, ((idx, lword), rvec)) => (lword, (idx, rvec))})
    .join(w2vs)    // (lword, ((idx, rvec), lvec))
    .map({case (lword, ((idx, rvec), lvec)) => ((idx, lword), (lvec, rvec))})
    .map({case ((idx, lword), (lvec, rvec)) => 
        ((idx, lword), List(dist(lvec, rvec)))}) 
    .persist(StorageLevel.MEMORY_AND_DISK)

I used Euclidean Distance in Word2Vec space for distance between words. I also tried using Cosine Distance (1 - Cosine Similarity) with similar results. We then sum all the shortest distances across all LHS words to get the WMD for the sentence pair.

1
2
3
4
val bestWMDs = wordVectors.reduceByKey((a, b) => a ++ b)
    .mapValues(dists => dists.sortWith(_ < _).head)  // dist to closest word
    .map({case ((idx, lword), wmd) => (idx, wmd)})
    .reduceByKey((a, b) => a + b)                    // sum all wmds for sent

Finally, we join these WMD scores back into the original dataset using the pair index that we originally generated using zipWithIndex.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.implicits._

case class SentencePair(s1: String, s2: String, wmd: Double)
val results = sentencePairs.map(_.swap)
    .join(bestWMDs)
    .map({case (id, ((s1, s2), wmd)) => SentencePair(s1, s2, wmd)})
val resultsDF = sqlContext.createDataFrame(results)
    .orderBy($"s1".asc, $"wmd".asc)
display(resultsDF)

The results are shown below. The sentences are sorted by the LHS sentence first, then by WMD (lowest WMD first so we can easily see the closest sentence pairs first and compare them to pairs that are not as close).

LHS SentenceRHS SentenceWMD
A glass of cider.A full cup of apple juice.2.2169259719396095
Canis familiaris are animals.Dogs are common pets.1.859694788966317
Dogs are animals.They are common pets.1.4537090848972198
I have a hammer.Take some nails.1.1578027104196844
I have a hammer.Take some apples.1.3028564676146912
I have a pen.Where is ink?1.020277185488236
I have a pen.Where do you live?1.3924941078355293
I like that bachelor.I like that unmarried man.1.176742725809037
It is a dog.That must be your dog.0
It is a dog.It is a pig.1.04864558369858
It is a dog.It is a log.1.3798001799052624
John is very nice.Is John very nice?0
Red alcoholic drink.Fresh orange juice.3.1161814560971166
Red alcoholic drink.A bottle of wine.3.386809492524872
Red alcoholic drink.Fresh apple juice.3.505168296314785
Red alcoholic drink.An English dictionary.4.106139922327307

As you can see, the scoring seems correct. For example, it finds that a "glass of cider" and a "cup of apple juice" are quite similar, even though there are no shared words (except for stopwords). Similarly "I have a hammer" is more similar to "Take some nails" than "Take some apples". The only intuitively incorrect result in this set is that "Red alcoholic drink" is more similar to "Fresh orange juice" than a "A bottle of wine". However, "A bottle of wine" is more similar to "Red Alcoholic drink" than "Fresh apple juice" and "An English dictionary" respectively. So overall, it seems to work on my limited dataset.

In my case, I already have two sentences and I just have to find the distance between them. In cases where you have to find the closest sentence, the complexity of the algorithm is O(p3 log p). One suggestion is to prune the number of possible RHS sentences by thresholding on the centroid distance (WCD) or relaxed WMD (see the paper for details) between the two sentences, and only running the full WMD on the pruned set of sentence pairs.