Saturday, May 31, 2008

Modeling an Ontology in memory with JGraphT

Last week, I blogged about a custom StAX parser that parsed an OWL XML file representing an ontology of wine. The parser parsed out the information into a MySQL database. The database structure has changed slightly, with a couple of tables being renamed to make the design a bit more obvious. The new schema looks like this:

What this database is trying to model is a bunch of facts modeled as semantic triples. A semantic triple consists of two entities connected by a relationship. The entity object only has an id and name and a list of Attribute objects. This is so we can beef up our Entity over time, as we discover more properties for these objects, without having to change any code. Attributes are modeled as name-value tuples, and the AttributeType normalizes the attribute names, which are likely to be repeated across Entities.

One thing we did before we go forward is to add reverse relationships. After the OWL file was parsed and loaded into the database, we ended up with about 9 relationship types. These represent one way relationships (such as subClassOf). Usually relationships are two way, so we manually set the reverse relationships with the id as the negative of the original relationId. The complete list of relationships is shown below. Of course, we need not put in relationships that don't make sense or that we don't want to expose.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
+----+---------------------+
| id | name                |
+----+---------------------+
| -9 | colorProperty       | 
| -8 | vintageYearProperty | 
| -7 | mainIngredient      | 
| -6 | bodyProperty        | 
| -5 | flavorProperty      | 
| -4 | sugarProperty       | 
| -3 | makes               | 
| -2 | contains            | 
| -1 | superclassOf        | 
|  1 | subclassOf          | 
|  2 | locatedIn           | 
|  3 | hasMaker            | 
|  4 | hasSugar            | 
|  5 | hasFlavor           | 
|  6 | hasBody             | 
|  7 | madeFromGrape       | 
|  8 | hasVintageYear      | 
|  9 | hasColor            | 
+----+---------------------+

An ontology can be visualized as a forest of taxonomy trees, where the nodes of the trees are connected to nodes of other trees - in other words, a graph. So my next step is to convert this structure into an in-memory graph object so it can be navigated without having to resort to complex SQL.

Searching for decent Java based graph data structures I could use, I came upon JGraphT, which not only provides standard graph data structures that can be used, but also has a large number of graph algorithms built into the package. I guess I could have cooked one up myself, since all I wanted to do was to model a graph and navigate it, but the advantage of using a standard data structure from a decent library is that the library author has already worked out the kinks in the data structure so it is likely to be more extensible. Moreover, while I don't need any of the graph algorithms built into JGraphT right now, it is conceivable that I will at some point down the road.

So anyway, this post describes the code that I wrote to load a JGraphT Graph object from my database, and then hitting the graph with a few basic queries to make sure everything works.

First the beans. I define an Entity bean, an Attribute bean, a Relation bean, and a Fact bean which models a semantic triple. These are simple holder classes.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// Entity.java
package com.mycompany.myapp.ontology;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.ReflectionToStringBuilder;
import org.apache.commons.lang.builder.ToStringStyle;

public class Entity implements Serializable {
  
  private static final long serialVersionUID = 54272228896206677L;

  private long id;
  private String name;
  private List<Attribute> attributes = new ArrayList<Attribute>();
  
  public Entity() {
    super();
  }
  
  public Entity(long id) {
    this();
    setId(id);
  }
  
  public long getId() {
    return id;
  }

  public void setId(long id) {
    this.id = id;
  }
  
  public String getName() {
    return name;
  }
  
  public void setName(String name) {
    this.name = name;
  }
  
  public List<Attribute> getAttributes() {
    return attributes;
  }

  public void setAttributes(List<Attribute> attributes) {
    this.attributes = attributes;
  }

  public void addAttribute(Attribute attribute) {
    this.attributes.add(attribute);
  }

  @Override
  public int hashCode() {
    return (int) id;
  }
  
  @Override
  public boolean equals(Object obj) {
    if (!(obj instanceof Entity)) {
      return false;
    }
    Entity that = (Entity) obj;
    return EqualsBuilder.reflectionEquals(this, that);
  }
  
  @Override
  public String toString() {
    return ReflectionToStringBuilder.reflectionToString(this, ToStringStyle.NO_FIELD_NAMES_STYLE);
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
// Attribute.java
package com.mycompany.myapp.ontology;

public class Attribute {
  
  private String name;
  private String value;

  public Attribute() {
    super();
  }
  
  public Attribute(String name, String value) {
    this();
    setName(name);
    setValue(value);
  }
  
  public String getName() {
    return name;
  }
  
  public void setName(String name) {
    this.name = name;
  }
  
  public String getValue() {
    return value;
  }
  
  public void setValue(String value) {
    this.value = value;
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
// Relation.java
package com.mycompany.myapp.ontology;

import java.io.Serializable;

import org.apache.commons.lang.builder.ReflectionToStringBuilder;
import org.apache.commons.lang.builder.ToStringStyle;

public class Relation implements Serializable {

  private static final long serialVersionUID = 8521110824988338681L;
  
  private long relationId;
  private String relationName;
  
  public Relation() {
    super();
  }

  public long getId() {
    return relationId;
  }

  public void setRelationId(long relationId) {
    this.relationId = relationId;
  }

  public String getName() {
    return relationName;
  }

  public void setRelationName(String relationName) {
    this.relationName = relationName;
  }

  @Override
  public String toString() {
    return ReflectionToStringBuilder.reflectionToString(this, ToStringStyle.NO_FIELD_NAMES_STYLE);
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
// Fact.java
package com.mycompany.myapp.ontology;

public class Fact {

  private long sourceEntityId;
  private long targetEntityId;
  private long relationId;
  
  public Fact() {
    super();
  }
  
  public Fact(long sourceEntityId, long targetEntityId, long relationId) {
    this();
    setSourceEntityId(sourceEntityId);
    setTargetEntityId(targetEntityId);
    setRelationId(relationId);
  }
  
  public long getSourceEntityId() {
    return sourceEntityId;
  }
  
  public void setSourceEntityId(long sourceEntityId) {
    this.sourceEntityId = sourceEntityId;
  }
  
  public long getTargetEntityId() {
    return targetEntityId;
  }
  
  public void setTargetEntityId(long targetEntityId) {
    this.targetEntityId = targetEntityId;
  }
  
  public long getRelationId() {
    return relationId;
  }

  public void setRelationId(long relationId) {
    this.relationId = relationId;
  }
}

The overriden equals() and hashCode() methods on the Entity bean (above) are necessary - this is to enable JGraphT to locate it in the graph when we try to look for it with a reference to an Entity object.

To connect Entities, we need an Edge object that can be labelled, so we subclass JGraphT's DefaultEdge and add in an additional property relationId. Here is the code for RelationEdge.java.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
// RelationEdge.java
package com.mycompany.myapp.ontology;

import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.ReflectionToStringBuilder;
import org.apache.commons.lang.builder.ToStringStyle;
import org.jgrapht.graph.DefaultEdge;

/**
 * Extends DefaultEdge to add a label to the graph. The label is the
 * relationId that relates the entities the edge connects.
 */
public class RelationEdge extends DefaultEdge {
  
  private static final long serialVersionUID = 1994877217677659613L;

  private long relationId;

  public RelationEdge() {
    super();
  }
  
  public RelationEdge(long relationId) {
    this();
    setRelationId(relationId);
  }
  
  public long getRelationId() {
    return relationId;
  }

  public void setRelationId(long relationId) {
    this.relationId = relationId;
  }
}

Finally, we define the container class which ties this all together. The client will call methods on this class.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
// Ontology.java
package com.mycompany.myapp.ontology;

import java.io.Serializable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jgrapht.Graph;
import org.jgrapht.graph.ClassBasedEdgeFactory;
import org.jgrapht.graph.SimpleDirectedGraph;

public class Ontology implements Serializable {

  private static final long serialVersionUID = 8903265933795172508L;
  
  private final Log log = LogFactory.getLog(getClass());
  
  protected Map<Long,Entity> entityMap;
  protected Map<Long,Relation> relationMap;
  protected SimpleDirectedGraph<Entity,RelationEdge> ontology;

  public Ontology() {
    entityMap = new HashMap<Long,Entity>();
    relationMap = new HashMap<Long,Relation>();
    ontology = new SimpleDirectedGraph<Entity,RelationEdge>(
      new ClassBasedEdgeFactory<Entity,RelationEdge>(RelationEdge.class));
  }

  public Entity getEntityById(long entityId) {
    return entityMap.get(entityId);
  }

  public Relation getRelationById(long relationId) {
    return relationMap.get(relationId);
  }
  
  public Set<Long> getAvailableRelationIds(Entity entity) {
    Set<Long> relationIds = new HashSet<Long>();
    Set<RelationEdge> relationEdges = ontology.edgesOf(entity);
    for (RelationEdge relationEdge : relationEdges) {
      relationIds.add(relationEdge.getRelationId());
    }
    return relationIds;
  }
  
  public Set<Entity> getEntitiesRelatedById(Entity entity, long relationId) {
    Set<RelationEdge> relationEdges = ontology.outgoingEdgesOf(entity);
    Set<Entity> relatedEntities = new HashSet<Entity>();
    for (RelationEdge relationEdge : relationEdges) {
      if (relationEdge.getRelationId() == relationId) {
        Entity relatedEntity = ontology.getEdgeTarget(relationEdge);
        relatedEntities.add(relatedEntity);
      }
    }
    return relatedEntities;
  }
  
  public void addEntity(Entity entity) {
    entityMap.put(entity.getId(), entity);
    ontology.addVertex(entity);
  }
  
  public void addRelation(Relation relation) throws Exception {
    relationMap.put(relation.getId(), relation);
  }
  
  public void addFact(Fact fact) throws Exception {
    Entity sourceEntity = getEntityById(fact.getSourceEntityId());
    if (sourceEntity == null) {
      log.error("No entity found for source entityId:" + fact.getSourceEntityId());
      return;
    }
    Entity targetEntity = getEntityById(fact.getTargetEntityId());
    if (targetEntity == null) {
      log.error("No entity found for target entityId: " + fact.getTargetEntityId());
      return;
    }
    long relationId = fact.getRelationId();
    Relation relation = getRelationById(relationId);
    if (relation == null) {
      log.error("No relation found for relationId: " + relationId);
      return;
    }
    // does fact exist? If so, dont do anything, just return
    Set<Long> relationIds = getAvailableRelationIds(sourceEntity);
    if (relationIds.contains(relationId)) {
      log.info("Fact: " + relation.getName() + "(" + 
        sourceEntity.getName() + "," + targetEntity.getName() + 
        ") already added to ontology");
      return;
    }
    RelationEdge relationEdge = new RelationEdge();
    relationEdge.setRelationId(relationId);
    ontology.addEdge(sourceEntity, targetEntity, relationEdge);
    if (relationMap.get(-1L * relationId) != null) {
      RelationEdge reverseRelationEdge = new RelationEdge();
      reverseRelationEdge.setRelationId(-1L * relationId);
      ontology.addEdge(targetEntity, sourceEntity, reverseRelationEdge);
    }
  }
}

To load this object, we use the DbOntologyLoader class, which calls methods on the DAO classes to retrieve data from the database. Here is the code for the loader.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// DbOntologyLoader.java
package com.mycompany.myapp.ontology.loaders;

import com.mycompany.myapp.ontology.daos.EntityDao;
import com.mycompany.myapp.ontology.daos.FactDao;
import com.mycompany.myapp.ontology.daos.RelationDao;
import com.mycompany.myapp.ontology.Fact;
import com.mycompany.myapp.ontology.Ontology;
import com.mycompany.myapp.ontology.Entity;
import com.mycompany.myapp.ontology.Relation;

import java.util.List;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

public class DbOntologyLoader {
  
  private final Log log = LogFactory.getLog(getClass());
  
  private EntityDao entityDao;
  private RelationDao relationDao;
  private FactDao factDao;
  
  public void setEntityDao(EntityDao entityDao) {
    this.entityDao = entityDao;
  }

  public void setRelationDao(RelationDao relationDao) {
    this.relationDao = relationDao;
  }

  public void setFactDao(FactDao factDao) {
    this.factDao = factDao;
  }

  public Ontology load() throws Exception {
    Ontology ontology = new Ontology();
    log.debug("Loading entities");
    List<Entity> entities = entityDao.getAllEntities();
    for (Entity entity : entities) {
      ontology.addEntity(entity);
    }
    log.debug("Loading relations");
    List<Relation> relations = relationDao.getAllRelations();
    for (Relation relation : relations) {
      ontology.addRelation(relation);
      if (relationDao.isBidirectional(relation.getId())) {
        Relation reverseRelation = relationDao.getById(-1L * relation.getId());
        ontology.addRelation(reverseRelation);
      }
    }
    log.debug("Loading facts");
    List<Fact> facts = factDao.getAllFacts();
    for (Fact fact : facts) {
      ontology.addFact(fact);
    }
    log.debug("Ontology load complete");
    return ontology;
  }
}

The loader depends on three DAOs for Entity, Relation and Fact. The code for these is shown below for completeness.

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
// EntityDao.java
package com.mycompany.myapp.ontology.daos;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.support.JdbcDaoSupport;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;

import com.mycompany.myapp.ontology.Attribute;
import com.mycompany.myapp.ontology.Entity;

public class EntityDao extends JdbcDaoSupport {

  private final Log log = LogFactory.getLog(getClass());

  @SuppressWarnings("unchecked")
  public List<Entity> getAllEntities() {
    List<Entity> entities = new ArrayList<Entity>();
    List<Map<String,Object>> rows = getJdbcTemplate().queryForList(
      "select id, name from entities");
    for (Map<String,Object> row : rows) {
      Entity entity = new Entity();
      entity.setId((Integer) row.get("ID"));
      entity.setName((String) row.get("NAME"));
      entities.add(entity);
    }
    return entities;
  }

  @SuppressWarnings("unchecked")
  public Entity getById(long id) {
    try {
      Entity entity = new Entity();
      Map<String,Object> row = getJdbcTemplate().queryForMap(
        "select id, name from entities where id = ?", 
        new Long[] {id});
      entity.setId((Integer) row.get("ID"));
      entity.setName((String) row.get("NAME"));
      return entity;
    } catch (IncorrectResultSizeDataAccessException e) {
      return null;
    }
  }
  
  @SuppressWarnings("unchecked")
  public Entity getByName(String name) {
    try {
      Entity entity = new Entity();
      Map<String,Object> row = getJdbcTemplate().queryForMap(
        "select id, name from entities where name = ?", 
        new String[] {name});
      entity.setId((Integer) row.get("ID"));
      entity.setName((String) row.get("NAME"));
      entity.setAttributes(getAttributes(entity.getId()));
      return entity;
    } catch (IncorrectResultSizeDataAccessException e) {
      return null;
    }
  }
  
  @SuppressWarnings("unchecked")
  public List<Attribute> getAttributes(long entityId) {
    List<Attribute> attributes = new ArrayList<Attribute>();
    List<Map<String,String>> rows = getJdbcTemplate().queryForList(
      "select at.attr_name, a.value " +
      "from attributes a, attribute_types at " +
      "where a.attr_id = at.id " +
      "and a.entity_id = ?", new Long[] {entityId});
    for (Map<String,String> row : rows) {
      String name = row.get("ATTR_NAME");
      String value = row.get("VALUE");
      Attribute attribute = new Attribute(name, value);
      attributes.add(attribute);
    }
    return attributes;
  }

  @SuppressWarnings("unchecked")
  public Attribute getAttributeByName(long entityId, String attributeName) {
    try {
      Attribute attribute = new Attribute();
      Map<String,String> row = getJdbcTemplate().queryForMap(
        "select at.attr_name, a.value " +
        "from attributes a, attribute_types at " +
        "where a.attr_id = at.id " +
        "and a.entity_id = ? " +
        "and at.attr_name = ?", new Object[] {entityId, attributeName});
      attribute.setName(row.get("NAME"));
      attribute.setValue(row.get("VALUE"));
      return attribute;
    } catch (IncorrectResultSizeDataAccessException e) {
      return null;
    }
  }

  public long getAttributeTypeId(final String attributeName) {
    try {
      long attributeTypeId = getJdbcTemplate().queryForLong(
        "select id from attribute_types where attr_name = ?", 
        new String[] {attributeName});
      return attributeTypeId;
    } catch (IncorrectResultSizeDataAccessException e) {
      return 0L;
    }
  }
  
  public long save(final Entity entity) {
    Entity dbEntity = getByName(entity.getName());
    if (dbEntity == null) {
      log.debug("Saving entity:" + entity.getName());
      // insert the entity
      KeyHolder entityKeyHolder = new GeneratedKeyHolder();
      getJdbcTemplate().update(new PreparedStatementCreator() {
        public PreparedStatement createPreparedStatement(Connection conn)
        throws SQLException {
          PreparedStatement ps = conn.prepareStatement(
            "insert into entities(name) values (?)", 
            Statement.RETURN_GENERATED_KEYS);
          ps.setString(1, entity.getName());
          return ps;
        }
      }, entityKeyHolder);
      long entityId = entityKeyHolder.getKey().longValue();
      List<Attribute> attributes = entity.getAttributes();
      for (Attribute attribute : attributes) {
        saveAttribute(entityId, attribute);
      }
      // finally, always save the "english name" of the entity as an attribute
      saveAttribute(entityId, new Attribute("EnglishName", getEnglishName(entity)));
      return entityId;
    } else {
      getJdbcTemplate().update("update entities set name = ? where id = ?", 
        new Object[] {entity.getName(), entity.getId()});
      return entity.getId();
    }
  }

  public long saveAttribute(final long entityId, final Attribute attribute) {
    // check to see if attribute exists in attribute_types
    long attributeTypeId = getAttributeTypeId(attribute.getName());
    if (attributeTypeId == 0L) {
      attributeTypeId = saveAttributeType(attribute.getName());
    }
    Attribute dbAttribute = getAttributeByName(entityId, attribute.getName());
    final long attrId = attributeTypeId;
    if (dbAttribute == null) {
      KeyHolder keyholder = new GeneratedKeyHolder();
      final String attributeName = attribute.getName();
      getJdbcTemplate().update(new PreparedStatementCreator() {
        public PreparedStatement createPreparedStatement(Connection conn)
        throws SQLException {
          PreparedStatement ps = conn.prepareStatement(
            "insert into attributes(entity_id, attr_id, value) values (?, ?, ?)");
          ps.setLong(1, entityId);
          ps.setLong(2, attrId);
          ps.setString(3, attribute.getValue());
          return ps;
        }
      }, keyholder);
      long attributeId = keyholder.getKey().longValue();
      return attributeId;
    } else {
      getJdbcTemplate().update(
        "update attributes set value = ? where entity_id = ? and attr_id = ?", 
        new Long[] {entityId, attrId});
      return attrId;
    }
  }

  public long saveAttributeType(final String attributeName) {
    long attributeTypeId = getAttributeTypeId(attributeName);
    if (attributeTypeId == 0L) {
      KeyHolder keyholder = new GeneratedKeyHolder();
      getJdbcTemplate().update(new PreparedStatementCreator() {
        public PreparedStatement createPreparedStatement(Connection conn)
        throws SQLException {
          PreparedStatement ps = conn.prepareStatement(
            "insert into attribute_types(attr_name) values (?)");
          ps.setString(1, attributeName);
          return ps;
        }
      }, keyholder);
      attributeTypeId = keyholder.getKey().longValue();
    }
    return attributeTypeId;
  }
    
  /**
   * Split up Uppercase Camelcased names (like Java classnames or C++ variable
   * names) into English phrases by splitting wherever there is a transition 
   * from lowercase to uppercase.
   * @param name the input camel cased name.
   * @return the "english" name.
   */
  public String getEnglishName(Entity entity) {
    if (entity == null) {
      return null;
    }
    StringBuilder englishNameBuilder = new StringBuilder();
    char[] namechars = entity.getName().toCharArray();
    for (int i = 0; i < namechars.length; i++) {
      if (i > 0 && Character.isUpperCase(namechars[i]) && 
          Character.isLowerCase(namechars[i-1])) {
        englishNameBuilder.append(' ');
      }
      englishNameBuilder.append(namechars[i]);
    }
    return englishNameBuilder.toString();
  }
}
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
// RelationDao.java
package com.mycompany.myapp.ontology.daos;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.support.JdbcDaoSupport;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;

import com.mycompany.myapp.ontology.Relation;

public class RelationDao extends JdbcDaoSupport {

  private final Log log = LogFactory.getLog(getClass());
  
  @SuppressWarnings("unchecked")
  public List<Relation> getAllRelations() {
    List<Relation> relations = new ArrayList<Relation>();
    List<Map<String,Object>> rows = getJdbcTemplate().queryForList(
      "select id, name from relations where id > 0");
    for (Map<String,Object> row : rows) {
      Relation relation = new Relation();
      relation.setRelationId((Integer) row.get("ID"));
      relation.setRelationName((String) row.get("NAME"));
      relations.add(relation);
    }
    return relations;
  }

  @SuppressWarnings("unchecked")
  public Relation getById(long relationId) {
    Relation relation = new Relation();
    try {
      Map<String,Object> row = getJdbcTemplate().queryForMap(
        "select id, name from relations where id = ?", 
        new Long[] {relationId});
      relation.setRelationId((Integer) row.get("ID"));
      relation.setRelationName((String) row.get("NAME"));
      return relation;
    } catch (IncorrectResultSizeDataAccessException e) {
      return null;
    }
  }

  @SuppressWarnings("unchecked")
  public Relation getByName(String name) {
    Relation relation = new Relation();
    try {
      Map<String,Object> row = getJdbcTemplate().queryForMap(
        "select id, name from relations where id = ?", 
        new String[] {name});
      relation.setRelationId((Integer) row.get("ID"));
      relation.setRelationName((String) row.get("NAME"));
      return relation;
    } catch (IncorrectResultSizeDataAccessException e) {
      return null;
    }
  }
  
  public boolean isBidirectional(long relationId) {
    int count = getJdbcTemplate().queryForInt(
      "select count(*) from relations where id = ?", 
      new Long[] {-1L * relationId});
    return count > 0;
  }
  
  public long save(final Relation relation) {
    Relation dbRelation = getByName(relation.getName());
    if (dbRelation == null) {
      KeyHolder keyholder = new GeneratedKeyHolder();
      getJdbcTemplate().update(new PreparedStatementCreator() {
        public PreparedStatement createPreparedStatement(Connection conn) 
            throws SQLException {
          PreparedStatement ps = conn.prepareStatement(
            "insert into relations(name) values (?)", 
            Statement.RETURN_GENERATED_KEYS);
          ps.setString(1, relation.getName());
          return ps;
        }
      }, keyholder);
      return keyholder.getKey().longValue();
    } else {
      getJdbcTemplate().update("update relations set name = ? where id = ?",
        new Object[] {relation.getName(), relation.getId()});
      return relation.getId();
    }
  }
}
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
// FactDao.java
package com.mycompany.myapp.ontology.daos;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.dao.IncorrectResultSizeDataAccessException;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.support.JdbcDaoSupport;
import org.springframework.jdbc.support.GeneratedKeyHolder;
import org.springframework.jdbc.support.KeyHolder;

import com.mycompany.myapp.ontology.Entity;
import com.mycompany.myapp.ontology.Fact;
import com.mycompany.myapp.ontology.Relation;

public class FactDao extends JdbcDaoSupport {

  private final Log log = LogFactory.getLog(getClass());
  
  private EntityDao entityDao;
  private RelationDao relationDao;
  
  public void setEntityDao(EntityDao entityDao) {
    this.entityDao = entityDao;
  }
  
  public void setRelationDao(RelationDao relationDao) {
    this.relationDao = relationDao;
  }

  @SuppressWarnings("unchecked")
  public List<Fact> getAllFacts() {
    List<Fact> facts = new ArrayList<Fact>();
    List<Map<String,Integer>> rows = getJdbcTemplate().queryForList(
      "select f.src_entity_id, f.trg_entity_id, f.relation_id " +
      "from facts f, relations r " +
      "where f.relation_id = r.id");
    for (Map<String,Integer> row : rows) {
      Fact fact = new Fact();
      fact.setSourceEntityId(row.get("SRC_ENTITY_ID"));
      fact.setTargetEntityId(row.get("TRG_ENTITY_ID"));
      fact.setRelationId(row.get("RELATION_ID"));
      facts.add(fact);
    }
    return facts;
  }

  public void save(Fact fact) {
    Entity sourceEntity = entityDao.getById(fact.getSourceEntityId());
    Entity targetEntity = entityDao.getById(fact.getTargetEntityId());
    if (sourceEntity == null || targetEntity == null) {
      log.error("Cannot relate null entities");
      return;
    }
    Relation relation = relationDao.getById(fact.getRelationId());
    if (relation == null) {
      log.error("Unknown relation, cannot save fact");
      return;
    }
    save(sourceEntity.getName(), targetEntity.getName(), relation.getName());
  }
  
  public void save(final String sourceEntityName, final String targetEntityName, 
      final String relationName) {
    // get the entity ids for source and target
    Entity sourceEntity = entityDao.getByName(sourceEntityName);
    Entity targetEntity = entityDao.getByName(targetEntityName);
    if (sourceEntity == null || targetEntity == null) {
      log.error("Cannot save relation: " + relationName + "(" + 
        sourceEntityName + "," + targetEntityName + ")"); 
      return;
    }
    log.debug("Saving relation: " + relationName + "(" + 
      sourceEntityName + "," + targetEntityName + ")");
    // get the relation id
    long relationTypeId = 0L;
    try {
      relationTypeId = getJdbcTemplate().queryForInt(
        "select id from relations where name = ?", 
        new String[] {relationName});
    } catch (IncorrectResultSizeDataAccessException e) {
      KeyHolder keyholder = new GeneratedKeyHolder();
      getJdbcTemplate().update(new PreparedStatementCreator() {
        public PreparedStatement createPreparedStatement(Connection conn) 
            throws SQLException {
          PreparedStatement ps = conn.prepareStatement(
            "insert into relations(name) values (?)", 
            Statement.RETURN_GENERATED_KEYS);
          ps.setString(1, relationName);
          return ps;
        }
      }, keyholder);
      relationTypeId = keyholder.getKey().longValue();
    }
    // save it
    getJdbcTemplate().update(
      "insert into facts(src_entity_id, trg_entity_id, relation_id) values (?, ?, ?)", 
      new Long[] {sourceEntity.getId(), targetEntity.getId(), relationTypeId});
  }
}

Finally, to test this thing out, we start up the loader and populate our graph, then issue queries against it (in the form of JUnit test methods) and see what we get. Here is the JUnit test:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// OntologyTest.java
package com.mycompany.myapp.ontology;

import java.util.Set;

import javax.sql.DataSource;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jgrapht.Graph;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.springframework.jdbc.datasource.DriverManagerDataSource;

import com.mycompany.myapp.ontology.daos.EntityDao;
import com.mycompany.myapp.ontology.daos.FactDao;
import com.mycompany.myapp.ontology.daos.RelationDao;
import com.mycompany.myapp.ontology.loaders.DbOntologyLoader;

public class OntologyTest {

  private final Log log = LogFactory.getLog(getClass());
  
  private static Ontology ontology;
  
  @BeforeClass
  public static void setUpBeforeClass() throws Exception {
    
    DataSource dataSource = new DriverManagerDataSource(
        "com.mysql.jdbc.Driver", "jdbc:mysql://localhost:3306/ontodb", "ontodev", "****");
    
    EntityDao entityDao = new EntityDao();
    entityDao.setDataSource(dataSource);
    
    RelationDao relationDao = new RelationDao();
    relationDao.setDataSource(dataSource);
    
    FactDao factDao = new FactDao();
    factDao.setDataSource(dataSource);
    factDao.setEntityDao(entityDao);
    factDao.setRelationDao(relationDao);
    
    DbOntologyLoader loader = new DbOntologyLoader();
    loader.setEntityDao(entityDao);
    loader.setRelationDao(relationDao);
    loader.setFactDao(factDao);
    
    ontology = loader.load();
  }
  
  @Test
  public void testLoad() throws Exception {
    Graph<Entity,RelationEdge> ontologyGraph = ontology.ontology;
    // We should have 237 vertices and about 500 edges
    log.debug("# vertices =" + ontologyGraph.vertexSet().size());
    Assert.assertTrue("#-vertices test failed", ontologyGraph.vertexSet().size() == 237);
    log.debug("# edges = " + ontologyGraph.edgeSet().size());
    Assert.assertTrue("#-edges test failed", ontologyGraph.edgeSet().size() == 500);
  }
  
  @Test
  public void testWhereIsLoireRegion() throws Exception {
    Entity loireRegion = ontology.getEntityById(26);
    long locatedInRelationId = 2L;
    Set<Entity> entities = ontology.getEntitiesRelatedById(loireRegion, locatedInRelationId);
    log.debug("query> where is Loire Region?");
    for (Entity entity : entities) {
      log.debug("..." + entity.getName());
    }
  }
  
  @Test
  public void testWhatRegionsAreInUSRegion() throws Exception {
    Entity usRegion = ontology.getEntityById(23);
    long reverseLocatedInRelationId = -2L;
    Set<Entity> entities = ontology.getEntitiesRelatedById(usRegion, reverseLocatedInRelationId);
    log.debug("query> what regions are in US Region?");
    for (Entity entity : entities) {
      log.debug("..." + entity.getName());
    }
  }
  
  @Test
  public void testWhatAreSweetWines() throws Exception {
    Entity sweetWinesEntity = ontology.getEntityById(125);
    long reverseOfHasSugarRelationId = -4L;
    Set<Entity> entities = ontology.getEntitiesRelatedById(
      sweetWinesEntity, reverseOfHasSugarRelationId);
    log.debug("query> what are sweet wines?");
    for (Entity entity : entities) {
      log.debug("..." + entity.getName());
    }
  }
}

And here are the actual outputs (formatted for clarity) for the questions we (effectively) asked in our unit tests above, and the answers we got back from our ontology.

1
2
3
4
5
6
7
8
9
 query> where is Loire Region?
 ...FrenchRegion
 query> what regions are in US Region?
 ...CaliforniaRegion
 ...TexasRegion
 query> what are sweet wines?
 ...WhitehallLanePrimavera
 ...SchlossVolradTrochenbierenausleseRiesling
 ...SchlossRothermelTrochenbierenausleseRiesling

As my son would say -- "Cool, huh?"

Update 2009-04-26: In recent posts, I have been building on code written and described in previous posts, so there were (and rightly so) quite a few requests for the code. So I've created a project on Sourceforge to host the code. You will find the complete source code built so far in the project's SVN repository.