Saturday, October 25, 2014

Extracting SVO Triples from Wikipedia


I recently came across this discussion (login required) on LinkedIn about extracting (subject, verb, object) (SVO) triples from text. Jack Park, owner of the SolrSherlock project, suggested using ReVerb to do this. I remembered an entertaining Programming Assignment from when I did the Natural Language Processing Course on Coursera, that involved finding spouse names from a small subset of Wikipedia, so I figured I it would be interesting to try using ReVerb against this data.

This post describes that work. As before, given the difference between this and the "preferred" approach that the automatic grader expects, results are likely to be wildly off the mark. BTW, I highly recommend taking the course if you haven't already, there are lots of great ideas in there. One of the ideas deals with generating "raw" triples, then filtering them using known (subject, object) pairs to find candidate verbs, then turning around and using the verbs to find unknown (subject, object) pairs.

So in order to find the known (subject, object) pairs, I decided to parse the Infobox content (the "semi-structured" part of Wikipedia pages). Wikipedia markup is a mini programming language in itself, so I went looking for some pointers on how to parse it (third party parsers or just ideas) on StackOverflow. Someone suggested using DBPedia instead, since they have already done the Infobox extraction for you. I tried both, and somewhat surprisingly, manually parsing Infobox gave me better results in some cases, so I describe both approaches below.

Data and Initial Setup


The data is a small (tiny) XML dump of 24 Wikipedia pages about famous people, US Presidents, authors, actors and other historical figures. The markup is in MediaWiki format. I used Scala's native XML support and some character counting logic to extract and parse the Infoboxes from Wikipedia to a Map of name-value pairs.

To extract the text (ie the non-Infobox) portion from the Wikipedia data, I used the Bliki engine to convert the MediaWiki markup to HTML, then used the Jericho HTML parser to convert it to plain text. I needed to do this because of the richness of the Wiki format - direct conversion to text leaves some of the markup behind.

In order to access data on DBPedia, I used Apache Jena, the Java framework for all things Semantic. As I learnt at the Knowledge Engineering with Semantic Web Technologies course on OpenHPI, the Semantic Web landscape is full of great ideas and implementations which don't mesh together very well. Jena provides a common (albeit complex) API that attempts to unify all of these. It is composed of multiple sub-projects, but fortunately sbt (and friends) allow you to declare it with a single call to an uber-dependency (see below). In any case, all I used Jena for was to build a client to query DBPedia's SPARQL endpoint. As an aside, the call is an HTTP GET and the result can be streamed back as XML, so I could just have used plain Scala to do this, as this EBook by Mark Watson demonstrates.

For reference, here are the additions to my build.sbt that I had to do for this post.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
resolvers ++= Seq(
  ...
  "Bliki Repo" at "http://gwtwiki.googlecode.com/svn/maven-repository/",
  ...
)

libraryDependencies ++= Seq(
  ...
  "info.bliki.wiki" % "bliki-core" % "3.0.19",
  "net.htmlparser.jericho" % "jericho-html" % "3.3",
  "edu.washington.cs.knowitall" % "reverb-core" % "1.4.0",
  "org.apache.jena" % "apache-jena-libs" % "2.12.1",
  ...
)

Parsing Infoboxes


As mentioned above, to understand which triples are interesting in the text, we need some indication of what/who the subject and object are in the triples, so we can filter on them and find interesting verbs. So the first step was to parse out the Infobox. We do this by isolating, for each page in the dump, the section of text bounded by the string "{{Infobox" and the matching "}}". Here is the code to parse the various components we care about - the titles, Infoboxes and texts for each page. The Infoboxes are returned as a List of Map of name-value pairs. Also, as mentioned above, the text is extracted in two steps - first to HTML then to plain text.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
// Source: src/main/scala/com/mycompany/scalcium/utils/MediaWikiParser.scala
package com.mycompany.scalcium.utils

import java.io.File

import scala.collection.mutable.ArrayBuffer
import scala.util.control.Breaks._
import scala.xml.XML

import info.bliki.wiki.dump.WikiPatternMatcher
import info.bliki.wiki.model.WikiModel
import net.htmlparser.jericho.Source

class MediaWikiParser(xmlfile: File) {

  val InfoboxStartPattern = "{{Infobox"
    
  val titles = ArrayBuffer[String]()
  val texts = ArrayBuffer[String]()
  val infoboxes = ArrayBuffer[Map[String,String]]()
  
  def parse(): Unit = {
    val mediaWikiElement = XML.loadFile(xmlfile)
    (mediaWikiElement \ "page").map(pageElement => {
      val title = (pageElement \ "title").text
      val text = (pageElement \ "revision" \ "text").text
      // parse out Infobox
      val infoboxStart = text.indexOf(InfoboxStartPattern)
      var bcount = 2
      var infoboxEnd = infoboxStart + InfoboxStartPattern.length()
      breakable {
        (infoboxEnd until text.length()).foreach(i => { 
          val c = text.charAt(i)
          var binc = 0
          if (c == '}') binc = -1
          else if (c == '{') binc = 1
          else binc = 0
          bcount += binc
          infoboxEnd = i
          if (bcount == 0) break
        })
      }
      if (infoboxStart >= 0) {
        addTitle(title)
        addInfobox(text.substring(infoboxStart, infoboxEnd))
        addText(text.substring(infoboxEnd + 1))
      }
    })
  }
  
  def addTitle(title: String): Unit = titles += title

  def addInfobox(ibtext: String): Unit = {
    val infobox = ibtext.split("\n")
      .map(line => {
        val pipePos = line.indexOf('|')
        val nvp = line.substring(pipePos + 1).split("=")
        if (nvp.length == 2) {
          val wpm = new WikiPatternMatcher(nvp(1).trim())
          (nvp(0).trim(), wpm.getPlainText())
        } 
        else ("None", "")
      })
      .filter(nvp => ! "None".equals(nvp._1))
      .toMap
    infoboxes += infobox
  }
  
  def addText(text: String): Unit = {
    // convert wiki text to HTML, then to plain text
    val htmltext = WikiModel.toHtml(text)
    val plaintext = new Source(htmltext).getTextExtractor().toString()
    texts += plaintext
  }
  
  def getTitles(): List[String] = titles.toList
  
  def getInfoboxes(): List[Map[String,String]] = infoboxes.toList
  
  def getTexts(): List[String] = texts.toList
}

For the DBPedia client, we query the spouse relationship for each entry using the code below. Note that the actual code is mostly boilerplate, similar to JDBC code.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
// Source: src/main/scala/com/mycompany/scalcium/triples/DBPediaClient.scala
package com.mycompany.scalcium.triples

import scala.collection.mutable.ArrayBuffer

import com.hp.hpl.jena.query.Query
import com.hp.hpl.jena.query.QueryExecution
import com.hp.hpl.jena.query.QueryExecutionFactory
import com.hp.hpl.jena.query.QueryFactory
import com.hp.hpl.jena.rdf.model.Literal

class DBPediaClient(url: String = "http://dbpedia.org/sparql") {

  val sparqlQueryTemplate = """
    PREFIX dbpedia: <http://dbpedia.org/resource/>
    PREFIX onto: <http://dbpedia.org/ontology/>
    PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
    
    SELECT ?label  WHERE {
      dbpedia:%s onto:%s ?o .
      ?o rdfs:label ?label .
    } LIMIT 100
  """

  def getObject(subj: String, verb: String): String = {
    val sparqlQuery = sparqlQueryTemplate
      .format(subj.replace(" ", "_"), verb)
    val query: Query = QueryFactory.create(sparqlQuery)
    val qexec: QueryExecution = QueryExecutionFactory.sparqlService(url, query)
    val results = qexec.execSelect()
    val objs = ArrayBuffer[String]()
    while (results.hasNext()) {
      val qsol = results.next()
      val literal = qsol.get("label").asInstanceOf[Literal]
      if ("en".equals(literal.getLanguage())) 
        objs += literal.getLexicalForm()
    }
    if (objs.isEmpty) null else objs.head
  }
}

The JUnit test below demonstrates the use of my MediaWikiParser and DBPediaClient to extract information directly from Infoboxes and from DBPedia resource pages respectively. The people list is the list of titles that were extracted from the Wikipedia XML dump.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
// Source: src/test/scala/com/mycompany/scalcium/triples/DBPediaClientTest.scala
package com.mycompany.scalcium.triples

import java.io.File

import scala.collection.mutable.ArrayBuffer
import scala.util.matching.Regex

import org.junit.Test

import com.mycompany.scalcium.utils.MediaWikiParser

class DBPediaClientTest {

  val people = List("Andrew Johnson", "Edgar Allan Poe", 
      "Desi Arnaz", "Elvis Presley", "Albert Camus", 
      "Arthur Miller", "Boris Yeltsin", "Ernest Hemingway", 
      "Benjamin Franklin", "Bill Oddie", "Abraham Lincoln", 
      "Billy Crystal", "Bill Clinton", "Alfonso V of Aragon", 
      "Dwight D. Eisenhower", "Colin Powell", "Cary Elwes", 
      "Alexander II of Russia", "Arnold Schwarzenegger", 
      "Christopher Columbus", "Barry Bonds", "Bill Gates", 
      "Elizabeth Garrett Anderson")
  
  @Test
  def testExtractSpouseFromInfobox(): Unit = {
    val inParens = new Regex("""\(.*?\)""")
    val xmlfile = new File("src/main/resources/wiki/small_wiki.xml")
    val parser = new MediaWikiParser(xmlfile)
    parser.parse()
    val triples = ArrayBuffer[(String,String,String)]()
    parser.getTitles().zip(parser.getInfoboxes())
      .map(ti => {
        val spouse = if (ti._2.contains("spouse")) ti._2("spouse") else null
        // clean up data received (situation dependent)
        if (spouse != null) {
          val spouses = inParens.replaceAllIn(spouse, "")
            .split("\\s{2,}")
            .map(_.trim)
          spouses.foreach(spouse => 
            triples += ((ti._1, "spouse", spouse)))
        } else triples += ((ti._1, "spouse", "NOTFOUND"))
      })
    triples.foreach(Console.println(_))
  }
  
  @Test
  def testExtractSpouseFromDBPedia(): Unit = {
 val triples = ArrayBuffer[(String,String,String)]()
    val client = new DBPediaClient()
    people.map(person => {
      val spouse = client.getObject(person, "spouse")
      if (spouse != null) {
        // clean up data received (situation dependent)
        val spouses = spouse.replace(',', ' ')
          .replace(')', ' ')
          .split('(')
          .map(s => s.trim())
          .foreach(s => triples += ((person, "spouse", s)))
      } else triples += ((person, "spouse", "NOTFOUND"))
    })
    triples.foreach(Console.println(_))
  }
}

Parsing Text


Finally, I use the ReVerb project to parse triples out of the text portion of each Wikipedia entry. The classical approach to parsing out triples is to do deep parsing of sentences (ie convert it to a tree) and then find structures that are rooted at verbs. From what I could glean from the the paper backing ReVerb, the approach used here is to match regular expressions of POS tags of input sentences. The authors claim that this results in higher quality triples being extracted, since it misses out "incoherent extractions". It is also faster since (time-consuming) deep parsing is not involved, so is more suitable for large amounts of data. Here is the code for my ReVerbClient. It takes in a block of text and uses its own tokenizer to break it into sentences.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
// Source: src/main/scala/com/mycompany/scalcium/triples/ReVerbClient.scala
package com.mycompany.scalcium.triples

import java.io.StringReader

import scala.collection.JavaConversions._

import edu.washington.cs.knowitall.extractor.ReVerbExtractor
import edu.washington.cs.knowitall.normalization.BinaryExtractionNormalizer
import edu.washington.cs.knowitall.util.DefaultObjects

class ReVerbClient {

  def parse(text: String): List[(String,String,String)] = {
    DefaultObjects.initializeNlpTools()
    val reader = DefaultObjects.getDefaultSentenceReader(
      new StringReader(text))
    val extractor = new ReVerbExtractor()
    val normalizer = new BinaryExtractionNormalizer()
    reader.iterator()
      .flatMap(sent => extractor.extract(sent))
      .map(extract => {
        val normExtract = normalizer.normalize(extract)
        val subj = normExtract.getArgument1().toString()
        val verb = normExtract.getRelation().toString()
        val obj = normExtract.getArgument2().toString()
        (subj, verb, obj)
      })
      .toList
  }
}

This extracts a large number of raw triples. Since our objective is to return spouse triples, we do a considerable amount of filtering and post-processing on the raw triples to return triples that are easily consumable by downstream code. This custom logic is in the JUnit test below.

1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
// Source: src/test/scala/com/mycompany/scalcium/triples/ReVerbClientTest.scala
package com.mycompany.scalcium.triples

import java.io.File

import scala.collection.mutable.ArrayBuffer
import scala.util.matching.Regex

import org.junit.Test

import com.mycompany.scalcium.utils.MediaWikiParser

class ReVerbClientTest {

  val SpouseWords = Set("wife", "husband")
  val BeWords = Set("is", "was", "be")
  val StopWords = Set("of")

  @Test
  def testExtractTriplesFromXml(): Unit = {
    val reverb = new ReVerbClient()
    val infile = new File("src/main/resources/wiki/small_wiki.xml")
    val parser = new MediaWikiParser(infile)
    parser.parse()
    parser.getTitles.zip(parser.getTexts())
      .map(tt => {
        val title = tt._1
        val triples = reverb.parse(tt._2)
        Console.println(">>> " + title)
        // clean up triple
        val resolvedTriples = triples.map(triple => {
          // resolve pronouns in subj, replace with title
          if (isPronoun(triple._1)) (title, triple._2, triple._3)
          else triple
        })
        val tripleBuf = ArrayBuffer[(String,String,String)]()
        // filter out where verb is (married, divorced)
        tripleBuf ++= resolvedTriples.filter(triple => {
          (triple._2.indexOf("married") > -1 || 
            triple._2.indexOf("divorced") > -1)
        })
        // filter out where subj or obj has (wife, husband)
        // and the verb is (is, was, be)
        tripleBuf ++= resolvedTriples.filter(triple => {
          val wordsInSubj = triple._1.split("\\s+").map(_.toLowerCase).toSet
          val wordsInVerb = triple._2.split("\\s+").map(_.toLowerCase).toSet
          val wordsInObj = triple._3.split("\\s+").map(_.toLowerCase).toSet
          (wordsInSubj.intersect(SpouseWords).size > 0 &&
            wordsInVerb.intersect(BeWords).size > 0 &&
            isProperNoun(triple._3)) ||
          (isProperNoun(triple._1) &&
            wordsInVerb.intersect(BeWords).size > 0 &&
            wordsInObj.intersect(SpouseWords).size > 0)
        })
        // extract patterns like "Bill and Hillary Clinton" from either
        // subj or obj
        tripleBuf ++= resolvedTriples.map(triple => {
            val names = title.split("\\s+")
            val pattern = new Regex("""%s (and|&) (\w+) %s"""
              .format(names(0), names(names.size - 1)))
            val sfName = pattern.findAllIn(triple._1).matchData
                .map(m => m.group(2)).toList ++
              pattern.findAllIn(triple._3).matchData
                .map(m => m.group(2)).toList
            if (sfName.size == 1)
              (title, "spouse", List(sfName.head, 
                names(names.size - 1)).mkString(" "))
            else ("x", "x", "x")
          })
          .filter(triple => "spouse".equals(triple._2))
        // post-process the triples
        val spouseTriples = tripleBuf.map(triple => {
            // fix incomplete name in subj and obj from title
            val subjHasTitle = containsTitle(triple._1, title)
            val objHasTitle = containsTitle(triple._3, title)
            val subj = if (subjHasTitle) title else triple._1
            val obj = if (objHasTitle && !subjHasTitle) title else triple._3
            val verb = if (subjHasTitle || objHasTitle) "spouse" else triple._2
            (subj, verb, obj)
          })
          .filter(triple => 
            // make sure both subj and obj are proper nouns
            (isProperNoun(triple._1) && 
              "spouse".equals(triple._2) &&
              isProperNoun(triple._3)))
          
        spouseTriples.foreach(Console.println(_))
      })
  }
  
  def isPronoun(s: String): Boolean =
    "she".equalsIgnoreCase(s) || "he".equalsIgnoreCase(s)
    
  def isProperNoun(s: String): Boolean = {
    val words = s.split("\\s+").filter(w => !StopWords.contains(w))
    val initCapWords = words.filter(w => w.charAt(0).isUpper == true)
    words.length == initCapWords.length
  }
  
  def containsTitle(s: String, title: String): Boolean = {
    val words = Set() ++ s.split("\\s+") 
    val names = Set() ++ title.split("\\s+")
    words.intersect(names).size > 0
  }
}

Results


Each of the three approaches result in overlapping sets of spouse triples, as you can see from the summary below. Note that the "spouse" relationship is bidirectional, so the number of triples extracted are actually double what is shown. As you can see from the results below, none of the three sources seem to be authoritative in the sense that you can depend on them absolutely, and at least a few of them seem to be incorrect (although that could be very likely a bad text parse on my part).

Title Infobox DBPedia Text
Andrew Johnson
(Andrew Johnson, spouse, Eliza McCardle Johnson)

(Andrew Johnson, spouse, Eliza McCardle Johnson)

(Andrew Johnson, spouse, Eliza McCardle)
Edgar Allan Poe
(Edgar Allan Poe, spouse, Virginia Eliza Clemm Poe)

(Edgar Allan Poe, spouse, Virginia Eliza Clemm Poe)

(Edgar Allan Poe, spouse, Virginia Clemm)

(Edgar Allan Poe, spouse, Virginia)
Desi Arnaz

(Desi Arnaz, spouse, Lucille Ball)

(Desi Arnaz, spouse, Edith Mack Hirsch)

(Desi Arnaz, spouse, Lucille Ball)
Elvis Presley


Albert Camus


(Albert Camus, spouse, Simone Hie)

(Albert Camus, spouse, Francine Faure)
Arthur Miller

(Arthur Miller, spouse, Mary Slattery)

(Arthur Miller, spouse, Marilyn Monroe)

(Arthur Miller, spouse, Inge Morath)

(Arthur Miller, spouse, Mary Slattery)

(Arthur Miller, spouse, Marilyn Monroe)
Boris Yeltsin
(Boris Yeltsin, spouse, Naina Yeltsina)

(Boris Yeltsin, spouse, Naina Yeltsina)

Ernest Hemingway
(Ernest Hemingway, spouse, Pauline Pfeiffer)

(Ernest Hemingway, spouse, Elizabeth Hadley Richardson)

(Ernest Hemingway, spouse, Pauline Pfeiffer)

(Ernest Hemingway, spouse, Martha Gellhorn)

(Ernest Hemingway, spouse, Mary Welsh Hemingway)

(Ernest Hemingway, spouse, Hadley Richardson)
Benjamin Franklin
(Benjamin Franklin, spouse, Deborah Read)

(Benjamin Franklin, spouse, Deborah Read)

(Benjamin Franklin, spouse, Richard Bache)

(Benjamin Franklin, spouse, Deborah Franklin)
Bill Oddie

(Bill Oddie, spouse, Jean Hart)

(Bill Oddie, spouse, Laura Beaumont)

(Bill Oddie, spouse, Laura Beaumont)
Abraham Lincoln
(Abraham Lincoln, spouse, Mary Todd Lincoln)

(Abraham Lincoln, spouse, Mary Todd Lincoln)

Billy Crystal

(Billy Crystal, spouse, Janice Goldfinger)

(Billy Crystal, spouse, Janice Goldfinger)
Bill Clinton

(Bill Clinton, spouse, Hillary Rodham Clinton)

(Bill Clinton, spouse, Hillary Clinton)
Alfonso V of Aragon
(Alfonso V of Aragon, spouse, Maria of Castile Queen of Aragon)

(Alfonso V of Aragon, spouse, Maria of Castile)

Dwight D. Eisenhower
(Dwight D. Eisenhower, spouse, Mamie Eisenhower)

(Dwight D. Eisenhower, spouse, Mamie Mamie Doud Eisenhower)

(Dwight D. Eisenhower, spouse, Mamie Geneva Doud of Denver)
Colin Powell

(Colin Powell, spouse, Alma Vivian Johnson Powell)

(Colin Powell, spouse, Alma)

(Colin Powell, spouse, Alma Powell)
Cary Elwes

(Cary Elwes, spouse, Lisa Marie Kurbikoff)

(Cary Elwes, spouse, 1 child)

Alexander II of Russia
(Alexander II of Russia, spouse, Maria Alexandrovna)

(Alexander II of Russia, spouse, Marie of Hesse and by Rhine)

(Alexander II of Russia, spouse, Marie of Hesse and by Rhine)

(Alexander II of Russia, spouse, Princess Marie of Hesse)
Arnold Schwarzenegger


(Arnold Schwarzenegger, spouse, Maria Shriver)
Christopher Columbus
(Christopher Columbus, spouse, Filipa Moniz Perestrelo)

(Christopher Columbus, spouse, Filipa Moniz)

(Christopher Columbus, spouse, Filipa Moniz Perestrello)
Barry Bonds


Bill Gates

(Bill Gates, spouse, Melinda Gates)

(Bill Gates, spouse, Melinda Gates)
Elizabeth Garrett Anderson


(Elizabeth Garrett Anderson, spouse, Richard Garrett III)

This is what I have for this week. Hope you found it interesting. I am hoping to do something with SVO triples on a larger scale (not on Wikipedia) so this was a good way for me to check out the toolkit. It also introduced me to Apache Jena and SPARQL, something I've been meaning to do ever since I attended the OpenHPI course on Knowledge Engineering.

Sunday, October 19, 2014

Experiments in Tuning Neural Networks


One of the Programming Assignments (PA) for the Neural Networks for Machine Learning course on Coursera is to just investigate the effect of various parameters on a Neural Network's (NN) performance. The input is the MNIST handwritten digits dataset provided as part of the MATLAB starter code, for which I substituted the simplified version available at the UCI Machine Learning Repository. As in previous posts where I have used Coursera PAs as inspiration for my own learning, I will formally state the obvious - the data and approach are very different, and hence is likely to produce incorrect results for the PA.

The NN itself consists of an input layer of 64 neurons, corresponding to each of the pixels in the 8x8 handwritten digit, a hidden layer of sigmoid activation units, and an output of 10 softmax activation units corresponding to the digits 0-9. There are 3,823 records in the training set and 1,797 records in the testing set. I split up the training set 50/50 for training and cross-validation. I then varied various common NN tunable parameters and observed their effect on the error rate. The wikibooks link provides a good overview of these tunable parameters. Code for creating and evaluating the NN with various tunable parameters is shown below:

1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
// Source: src/main/scala/com/mycompany/scalcium/langmodel/EncogNNEval.scala
package com.mycompany.scalcium.langmodel

import java.io.File

import scala.collection.JavaConversions._
import scala.io.Source
import scala.util.Random

import org.encog.Encog
import org.encog.engine.network.activation.ActivationSigmoid
import org.encog.engine.network.activation.ActivationSoftMax
import org.encog.mathutil.randomize.RangeRandomizer
import org.encog.ml.data.MLDataSet
import org.encog.ml.data.basic.BasicMLData
import org.encog.ml.data.basic.BasicMLDataSet
import org.encog.neural.networks.BasicNetwork
import org.encog.neural.networks.layers.BasicLayer
import org.encog.neural.networks.training.propagation.back.Backpropagation

class EncogNNEval {
  
  val Debug = false
  val encoder = new OneHotEncoder(10)

  def evaluate(trainfile: File, decay: Float, hiddenLayerSize: Int, 
      numIters: Int, learningRate: Float, momentum: Float, 
      miniBatchSize: Int, earlyStopping: Boolean): 
      (Double, Double, BasicNetwork) = {
    // parse training file into a 50/50 training and validation set
    val datasets = parseFile(trainfile, 0.5F)
    val trainset = datasets._1; val valset = datasets._2
    // build network
    val network = new BasicNetwork()
    network.addLayer(new BasicLayer(null, true, 8 * 8))
    network.addLayer(new BasicLayer(new ActivationSigmoid(), true, hiddenLayerSize))
    network.addLayer(new BasicLayer(new ActivationSoftMax(), false, 10))
    network.getStructure().finalizeStructure()
    new RangeRandomizer(-1, 1).randomize(network)
    // set up trainer
    val trainer = new Backpropagation(network, trainset, learningRate, momentum)
    trainer.setBatchSize(miniBatchSize)
    var currIter = 0
    var trainError = 0.0D
    var valError = 0.0D
    var pValError = 0.0D
    var contLoop = false
    do {
      trainer.iteration()
      if (decay > 0.0F) trainer.setLearningRate(
        (1.0 - (decay * currIter / numIters) * learningRate))
      // calculate training and validation error
      trainError = error(network, trainset)
      valError = error(network, valset)
      if (Debug) {
        Console.println("Epoch: %d, Train error: %.3f, Validation Error: %.3f"
          .format(currIter, trainError, valError))
      }
      currIter += 1
      contLoop = shouldContinue(currIter, numIters, earlyStopping, 
        valError, pValError)
      pValError = valError
    } while (contLoop)
    trainer.finishTraining()
    Encog.getInstance().shutdown()
    (trainError, valError, network)
  }

  def parseFile(f: File, holdout: Float): (MLDataSet, MLDataSet) = {
    val trainset = new BasicMLDataSet()
    val valset = new BasicMLDataSet()
    Source.fromFile(f).getLines()
      .foreach(line => {
        val cols = line.split(",")
        val inputs = cols.slice(0, 64).map(_.toDouble / 64.0D)
        val output = encoder.encode(cols(64).toInt)
        if (Random.nextDouble < holdout)
          valset.add(new BasicMLData(inputs), new BasicMLData(output))
        else trainset.add(new BasicMLData(inputs), new BasicMLData(output))
      })
    (trainset, valset)
  } 
  
  def error(network: BasicNetwork, dataset: MLDataSet): Double = {
    var numCorrect = 0.0D
    var numTested = 0.0D
    val x = dataset.map(pair => {
      val predicted = network.compute(pair.getInput()).getData()
      val actual = encoder.decode(pair.getIdeal().getData())
      if (actual == predicted.indexOf(predicted.max)) numCorrect += 1.0D
      numTested += 1.0D
    })
    numCorrect / numTested
  }

  def shouldContinue(currIter: Int, numIters: Int, earlyStopping: Boolean,
      validationError: Double, prevValidationError: Double): Boolean = 
    if (earlyStopping) 
      (currIter < numIters && prevValidationError < validationError)
    else currIter < numIters  
}

The first experiment is to vary the learning rate with and without momentum. The NN is trained for 70 iterations for each case. The second experiment is a repeat of the first, but uses early stopping to stop training if the cross validation error starts to increase. The third experiment tests the effect of weight decay, ie, lowering the learning rate at each iteration by a fixed amount. I keep the momentum at 0 in this case. The fourth experiment tests the effect of varying the number of hidden units on error rate, keeping all other parameters constant. The final experiment is to run the test for many more iterations with optimal values for parameters discovered in the previous experiments. Here is the code for the unit test.

1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
// Source: src/test/scala/com/mycompany/scalcium/langmodel/EncogNNEvalTest.scala
package com.mycompany.scalcium.langmodel

import java.io.File
import java.io.FileWriter
import java.io.PrintWriter

import org.junit.Test

class EncogNNEvalTest {

  val trainfile = new File("src/main/resources/langmodel/optdigits_train.txt")
  val testfile = new File("src/main/resources/langmodel/optdigits_test.txt")
  
  @Test
  def testVaryLearningRateAndMomentum(): Unit = {
    val results = new PrintWriter(new FileWriter(
      new File("results1.csv")), true)
    val nneval = new EncogNNEval()
    val weightDecay = 0.0F
    val numHiddenUnit = 10
    val numIterations = 70
    val learningRates = Array[Float](0.002F, 0.01F, 0.05F, 0.2F, 1.0F, 
                                     5.0F, 20.0F)
    val momentums = Array[Float](0.0F, 0.9F)
    val miniBatchSize = 10
    val earlyStopping = false
    var lineNo = 0
    for (learningRate <- learningRates;
         momentum <- momentums) {
      runAndReport(nneval, results, trainfile, weightDecay, numHiddenUnit, 
        numIterations, learningRate, momentum, miniBatchSize, earlyStopping,
        lineNo == 0)
      lineNo += 1
    }
    results.flush()
    results.close()
  }
  
  @Test
  def testVaryLearningRateAndMomentumWithEarlyStopping(): Unit = {
    val results = new PrintWriter(new FileWriter(
      new File("results2.csv")), true)
    val nneval = new EncogNNEval()
    val weightDecay = 0.0F
    val numHiddenUnit = 10
    val numIterations = 70
    val learningRates = Array[Float](0.002F, 0.01F, 0.05F, 0.2F, 1.0F, 
                                     5.0F, 20.0F)
    val momentums = Array[Float](0.0F, 0.9F)
    val miniBatchSize = 10
    val earlyStopping = true
    var lineNo = 0
    for (learningRate <- learningRates;
         momentum <- momentums) {
      runAndReport(nneval, results, trainfile, weightDecay, numHiddenUnit, 
        numIterations, learningRate, momentum, miniBatchSize, earlyStopping,
        lineNo == 0)
      lineNo += 1
    }
    results.flush()
    results.close()
  }
  
  @Test
  def testVaryWeightDecay(): Unit = {
    val results = new PrintWriter(new FileWriter(
      new File("results3.csv")), true)
    val nneval = new EncogNNEval()
    val weightDecays = Array[Float](10.0F, 1.0F, 0.0F, 0.1F, 0.01F, 0.001F)
    val numHiddenUnit = 10
    val numIterations = 70
    val learningRate = 0.05F
    val momentum = 0.0F
    val miniBatchSize = 10
    val earlyStopping = true
    var lineNo = 0
    for (weightDecay <- weightDecays) {
      runAndReport(nneval, results, trainfile, weightDecay, numHiddenUnit, 
        numIterations, learningRate, momentum, miniBatchSize, earlyStopping,
        lineNo == 0)
      lineNo += 1
    }
    results.flush()
    results.close()
  }
  
  @Test
  def testVaryHiddenUnits(): Unit = {
    val results = new PrintWriter(new FileWriter(
      new File("results4.csv")), true)
    val nneval = new EncogNNEval()
    val weightDecay = 0.0F
    val numHiddenUnits = Array[Int](10, 50, 100, 150, 200, 250, 500)
    val numIterations = 70
    val learningRate = 0.05F
    val momentum = 0.0F
    val miniBatchSize = 10
    val earlyStopping = true
    var lineNo = 0
    for (numHiddenUnit <- numHiddenUnits) {
      runAndReport(nneval, results, trainfile, weightDecay, numHiddenUnit, 
        numIterations, learningRate, momentum, miniBatchSize, earlyStopping,
        lineNo == 0)
      lineNo += 1
    }
    results.flush()
    results.close()
  }
  
  @Test
  def testFinalRun(): Unit = {
    val nneval = new EncogNNEval()
    val weightDecay = 0.0F
    val numHiddenUnit = 200
    val numIterations = 1000
    val learningRate = 0.1F
    val momentum = 0.9F
    val miniBatchSize = 100
    val earlyStopping = true
    val scores = nneval.evaluate(trainfile, weightDecay, numHiddenUnit, 
      numIterations, learningRate, momentum, miniBatchSize, earlyStopping) 
    // verify on test set
    val testds = nneval.parseFile(testfile, 0.0F)
    val network = scores._3
    val testError = nneval.error(network, testds._1)
    Console.println("Train Error: %.3f, Validation Error: %.3f, Test Error: %.3f"
      .format(scores._1, scores._2, testError))
  }
  
  def runAndReport(nneval: EncogNNEval, results: PrintWriter, 
      trainfile: File, weightDecay: Float, numHiddenUnit: Int, 
      numIterations: Int, learningRate: Float, momentum: Float, 
      miniBatchSize: Int, earlyStopping: Boolean,
      writeHeader: Boolean): Unit = {
    val scores = nneval.evaluate(trainfile, weightDecay, numHiddenUnit, 
      numIterations, learningRate, momentum, miniBatchSize, earlyStopping) 
    if (writeHeader)
      results.println("DECAY\tHUNITS\tITERS\tLR\tMOM\tBS\tES\tTRNERR\tVALERR")
    results.println("%.3f\t%d\t%d\t%.3f\t%.3f\t%d\t%d\t%.3f\t%.3f"
      .format(weightDecay, numHiddenUnit, numIterations, learningRate, 
        momentum, miniBatchSize, if (earlyStopping) 1 else 0, 
        scores._1, scores._2))
  }
}

I used matplotlib to chart the results for each of the four experiments described above. Here is the code:

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Source: nneval_charts.py
import pandas as pd
import matplotlib.pyplot as plt
import os

DATA_DIR = "/path/to/data/files"

def draw_4chart(xs, ys1, ys2, ys3, ys4, title, xlabel, ylabel, legends):
    plt.plot(xs, ys1, label=legends[0])
    plt.plot(xs, ys2, label=legends[1])
    plt.plot(xs, ys3, label=legends[2])
    plt.plot(xs, ys4, label=legends[3])
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.legend()
    plt.show()
    
def chart1():
    rdf = pd.read_csv(os.path.join(DATA_DIR, "results1.csv"), 
                      sep="\t", header=False)
    # split into momentum groups
    rdf0 = rdf[rdf["MOM"] == 0.0]
    xvals = rdf0["LR"].values
    yvals0_tr = rdf0["TRNERR"].values
    yvals0_vl = rdf0["VALERR"].values
    rdf1 = rdf[rdf["MOM"] > 0.0]
    yvals1_tr = rdf1["TRNERR"].values
    yvals1_vl = rdf1["VALERR"].values
    draw_4chart(xvals, yvals0_tr, yvals0_vl, yvals1_tr, yvals1_vl, 
               "Error vs Learning Rate and Momentum", 
               "Learning Rate", "Error Rate", 
               ["Trn Err (Mom=0)", "CV Err (Mom=0)",
                "Trn Err (Mom=0.9)", "CV Err (Mom=0.9)"])

def chart2():
    rdf = pd.read_csv(os.path.join(DATA_DIR, "results2.csv"), 
                      sep="\t", header=False)
    # split into momentum groups
    rdf0 = rdf[rdf["MOM"] == 0.0]
    xvals = rdf0["LR"].values
    yvals0_tr = rdf0["TRNERR"].values
    yvals0_vl = rdf0["VALERR"].values
    rdf1 = rdf[rdf["MOM"] > 0.0]
    yvals1_tr = rdf1["TRNERR"].values
    yvals1_vl = rdf1["VALERR"].values
    draw_4chart(xvals, yvals0_tr, yvals0_vl, yvals1_tr, yvals1_vl, 
               "Error vs Learning Rate & Momentum (w/Early Stopping)", 
               "Learning Rate", "Error Rate",
               ["Trn Err (Mom=0)", "CV Err (Mom=0)",
                "Trn Err (Mom=0.9)", "CV Err (Mom=0.9)"])

def draw_2chart(xs, ys1, ys2, title, xlabel, ylabel, legends):
    plt.plot(xs, ys1, label=legends[0])
    plt.plot(xs, ys2, label=legends[1])
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.legend()
    plt.show()
    
def chart3():
    rdf = pd.read_csv(os.path.join(DATA_DIR, "results3.csv"), 
                      sep="\t", header=False)
    xvals = rdf["DECAY"].values
    yvals1 = rdf["TRNERR"].values
    yvals2 = rdf["VALERR"].values
    draw_2chart(xvals, yvals1, yvals2, 
                "Error vs Weight Decay (w/Early Stopping)", 
                "Weight Decay", "Error Rate", ["Trn Err", "CV Err"])    

def chart4():
    rdf = pd.read_csv(os.path.join(DATA_DIR, "results4.csv"), 
                      sep="\t", header=False)
    xvals = rdf["HUNITS"].values
    yvals1 = rdf["TRNERR"].values
    yvals2 = rdf["VALERR"].values
    draw_2chart(xvals, yvals1, yvals2, 
                "Error vs #-Hidden Units (w/Early Stopping)", 
                "#-Hidden Units", "Error Rate", ["Trn Err", "CV Err"])    

    
chart1()
chart2()
chart3()
chart4()

The first three results mostly coencide with our intuition that the graphs should look like a hockey stick. However, in case of the number of hidden units, it seems like the error rate is lowest with 10 hidden units.










Finally, the final run completed with a training error of 0.267, validation error of 0.269 and a test set error of 0.257.

For me, this exercise was a way to get to understand the various knobs you can turn to get a NN to perform better, as well as a way to familiarize myself more with the Encog library. The NN here still needs to be tuned quite a bit - although the results are not terrible, handwritten digit recognition is a well-studied problem in this area and accuracies are in the high 90% range (ie under 0.1 error rates).