Sunday, May 24, 2015

Computing Comorbidity Network with Spark on EC2


Some time back (while I was at my previous job), I was experimenting with Apache Spark, and I used some claims data to construct a disease similarity network using anonymized claim data released by the Centers for Medicare and Medicaid Services (CMS). The data consisted of 1.3M inpatient and 15.8M outpatient records for 2.3M unique members. Unfortunately, I was only able to run the code locally against a small amount of dev data. I wanted to run it against the full dataset using Amazon Elastic Map Reduce (EMR) but couldn't figure out how to. The standard way to run Spark jobs on the Amazon Web Service (AWS) platform is to spin up your own clusters on Amazon Elastic Compute Cloud (EC2), as described on this Running Spark on EC2 page. However, at that point I used EMR quite heavily, and there appeared to be several people who had succeeded in running Spark jobs on EMR, so I was hoping to as well. However, I gave up on it after a few failures.

Fast forward a few months to my current job, where we use Spark quite heavily. On one of the projects I am working on, we are using the Databricks Cloud Platform, where you write Python or Scala notebooks on top of a managed Spark cluster. On another, I am using PySpark and spinning up my own clusters on EC2. So I have gotten over the aversion (or learning curve) of managing my own clusters, and I am not quite as paranoid (about being hit with a hefty EC2 charge because I forgot to turn off the cluster after use) as I used to be.

So, long story short, I decided to spin up an EC2 cluster and run the job with the full dataset, and finish what I started. The final output is shown below. The nodes represent the diseases that the members had, and the edges represent the similarity between the diseases in terms of the treatment procedures (for inpatient and outpatient settings) that are used to treat them. The edge weights are scaled to 1 (highest). From the graph, it looks like Cancer (CNCR) and Stroke/Transient Ischemic Attack (TIA) don't share procedures with the others or with each other, and that Diabetes (DIA) and Ischemic Heart Disease (IHD) have a lot of procedures in common, while Osteoporosis (OSTR) and Chronic Kidney Disease (CKD) have much lower overlap, etc.


As mentioned above, I initially wanted to run this on EMR, so I had used a specialized project structure from Snowplow Analytics that creates a fat JAR for deployment to EMR. Now that I was following the recommended EC2 approach, project setup is much simpler, with a skeletal build.sbt file and single Scala object (separate Job class recommended by the Snowplow Analytics project is no longer required). Here is the build.sbt file.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
// Source: build.sbt
name := "cms-disease-graph"

version := "0.1"

scalaVersion := "2.10.4"

libraryDependencies ++= Seq(
    "org.apache.spark" %% "spark-core" % "1.3.1"
)

I had decided initially to run the code as is, or with minimal changes. However, I had forgotten to handle headers lines in the data, so I had to go back to dev-testing it locally. When I did that, I started getting mysterious Stage failures on my Mac, which turned out to be some funky network issue corrected by explicitly setting SPARK_LOCAL_IP to 127.0.0.1. In any case, by that time I was also trying to figure out if there was anything wrong with my code, and given that I had been using the Spark API (admittedly with PySpark, but the Spark API carries over very well across languages) pretty heavily over the last couple of months, I was having trouble with readability so I decided to rewrite it the way I now think of as "correct". So here's the code, its very different from what I have in my previous post, but hopefully more readable.

1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
// Source: src/main/scala/com/mycompany/diseasegraph/GraphDataGenerator.scala
package com.mycompany.diseasegraph

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkConf

import SparkContext._

object GraphDataGenerator {

    // column number => comorbidity in benefits_summary
    val ColumnDiseaseMap = Map(
        12 -> "ALZM",
        13 -> "CHF",
        14 -> "CKD",
        15 -> "CNCR",
        16 -> "COPD",
        17 -> "DEP",
        18 -> "DIAB",
        19 -> "IHD",
        20 -> "OSTR",
        21 -> "ARTH",
        22 -> "TIA"
    )

    def main(args: Array[String]): Unit = {
        if (args.size != 5) {
            Console.println("""
Usage: GraphDataGeneratorJob \
    s3://path/to/benefit_summary.csv \
    s3://path/to/inpatient_claims.csv \
    s3://path/to/outpatient_claims.csv \
    s3://path/to/disease_code_output \
    s3://path/to/disease_pairs_output""")
        } else {
            val conf = new SparkConf().setAppName("GraphDataGenerator")
            val sc = new SparkContext(conf)

            // permissions to read and write data on S3
            sc.hadoopConfiguration.set("fs.s3n.awsAccessKeyId", 
                sys.env("AWS_ACCESS_KEY"))
            sc.hadoopConfiguration.set("fs.s3n.awsSecretAccessKey", 
                sys.env("AWS_SECRET_KEY"))

            // dedupe member so only record with highest number of
            // comorbidities is retained
            val membersDeduped = sc.textFile(args(0))
                // remove heading line (repeats)
                .filter(line => ! line.startsWith("\""))
                // extract (memberId, [comorbidity_indicators])
                .map(line => {
                    val cols = line.split(",")
                    val memberId = cols(0)
                    val comorbs = ColumnDiseaseMap.keys.toList.sorted
                        .map(e => cols(e).toInt)
                    (memberId, comorbs)
                })
                .reduceByKey((v1, v2) => {
                    // 1 == Yes, 2 == No
                    val v1size = v1.filter(_ == 1).size
                    val v2size = v2.filter(_ == 1).size
                    if (v1size > v2size) v1 else v2
                })
                // normalize to (member_id, (disease, weight)) 
                .flatMap(x => {
                    val diseases = x._2.zipWithIndex
                      .filter(di => di._1 == 1) // retain only Yes
                      .map(di => ColumnDiseaseMap(di._2 + 12)) 
                    val weight = 1.0 / diseases.size
                    diseases.map(disease => (x._1, (disease, weight)))
                })

            // normalize inpatient and outpatient claims to 
            // (member_id, (code, weight))
            val inpatientClaims = sc.textFile(args(1))
                // remove heading line (repeats)
                  .filter(line => ! line.startsWith("\""))
                  // extracts (member_id:claims_id, proc_codes)
                  .map(line => {
                      val cols = line.split(",")
                      val memberId = cols(0)
                      val claimsId = cols(1)
                      val procCodes = cols.slice(30, 35)
                        .filter(pc => ! pc.isEmpty())
                      val memberClaimId = Array(memberId, claimsId).mkString(":")
                      (memberClaimId, procCodes)
                  })
                  // remove encounters with no procedure codes (they
                  // may have other codes we are not interested in)
                  .filter(claimProcs => claimProcs._2.size > 0)
                  // find list of procedures done per encounter
                  .groupByKey()
                  // reweight codes per encounter. If many codes were
                  // administered, then weigh them lower (assuming doctors
                  // have limited time per patient, so more codes mean
                  // that each code is less important - this assumption is
                  // not necessarily correct, but lets go with it).
                  .flatMap(grouped => {
                      val memberId = grouped._1.split(":")(0)
                      val codes = grouped._2.flatMap(x => x).toList
                      val weight = 1.0 / codes.size
                      codes.map(code => (memberId, (code, weight)))
                  })
                  
            val outpatientClaims = sc.textFile(args(1))
                  .filter(line => ! line.startsWith("\""))
                  .map(line => {
                      val cols = line.split(",")
                      val memberId = cols(0)
                      val claimsId = cols(1)
                      val procCodes = cols.slice(31, 75)
                        .filter(pc => ! pc.isEmpty())
                      val memberClaimId = Array(memberId, claimsId).mkString(":")
                      (memberClaimId, procCodes)
                  })
                  .filter(claimProcs => claimProcs._2.size > 0)
                  .groupByKey()
                  .flatMap(grouped => {
                      val memberId = grouped._1.split(":")(0)
                      val codes = grouped._2.flatMap(x => x).toList
                      val weight = 1.0 / codes.size
                      codes.map(code => (memberId, (code, weight)))
                  })
                  
            // combine the two RDDs into one
            inpatientClaims.union(outpatientClaims)

            // join membersDeduped and inpatientClaims on member_id
            // to get (code, (disease, weight))
            val codeDisease = membersDeduped.join(inpatientClaims)
                .map(joined => {
                    val disease = joined._2._1._1
                    val procCode = joined._2._2._1
                    val weight = joined._2._1._2 * joined._2._2._2
                    (Array(disease, procCode).mkString(":"), weight)
                })
                // combine weights for same disease + procedure code
                .reduceByKey(_ + _)
                // key by procedure code for self join
                .map(reduced => {
                    val Array(disease, procCode) = reduced._1.split(":")
                    (procCode, (disease, reduced._2))
                })
                .cache()
                
            // save a copy for future analysis
            codeDisease.map(x => "%s\t%s\t%.5f".format(x._2._1, x._1, x._2._2))
                .saveAsTextFile(args(3))
            
            // finally self join on procedure code. The idea is to 
            // compute the relationship between diseases by the weighted
            // sum of procedures that have been observed to have been 
            // done for people with the disease
            val diseaseDisease = codeDisease.join(codeDisease)
                // eliminate cases where LHS == RHS
                .filter(dd => ! dd._2._1._1.equals(dd._2._2._1))
                // compute disease-disease relationship weights
                .map(dd => {
                    val lhsDisease = dd._2._1._1
                    val rhsDisease = dd._2._2._1
                    val diseases = Array(lhsDisease, rhsDisease).sorted
                    val weight = dd._2._1._2 * dd._2._2._2
                    (diseases.mkString(":"), weight)
                })
                // combine the disease pair weights
                .reduceByKey(_ + _)
                // bring it into a single file for convenience
                .coalesce(1, false)
                // sort them (purely cosmetic reasons, and they should
                // be small enough at this point to make this feasible
                .sortByKey()
                // split it back out for rendering
                .map(x => {
                    val diseases = x._1.split(":")
                    "%s\t%s\t%.5f".format(diseases(0), diseases(1), x._2)
                })
                .saveAsTextFile(args(4))

            sc.stop()
        }
    }
}    

The code above extracts (member_id, (disease, weight)) data from the member data, and (member_id, (procedure, weight)) data from the claims data, joins on member_id to produce (procedure, (disease, weight)) data which is then self joined with itself to produce ((disease, disease), weight) data. The code is heavily commented so it should be easy to figure out how its doing what its doing. Our final output looks something like this:

1
2
3
4
5
6
ALZM ARTH 53129371.82649
ALZM CHF 97277186.14343
ALZM CKD 76223548.83137
ALZM CNCR 24270358.94585
ALZM COPD 59983199.31700
...

For visualization purposes, we use the following Python program (again described in my previous post), with the one important difference is that it no longer reads a directory of part files. In my Spark code above, I force the final output into a single file using a coalesce() call so just read a single file. The only other change is to increase the size of the output so edge weights don't overlap as much.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Source: src/main/python/disease_graph.py
# Draws a disease interaction chart (based on common procedures for treatment)
# Adapted from:
# https://www.udacity.com/wiki/creating-network-graphs-with-python

import networkx as nx
import matplotlib.pyplot as plt
import os

def draw_graph(G, labels=None, graph_layout='shell',
               node_size=1600, node_color='blue', node_alpha=0.3,
               node_text_size=12,
               edge_color='blue', edge_alpha=0.3, edge_tickness=1,
               edge_text_pos=0.3,
               text_font='sans-serif'):

    # these are different layouts for the network you may try
    # shell seems to work best
    if graph_layout == 'spring':
        graph_pos=nx.spring_layout(G)
    elif graph_layout == 'spectral':
        graph_pos=nx.spectral_layout(G)
    elif graph_layout == 'random':
        graph_pos=nx.random_layout(G)
    else:
        graph_pos=nx.shell_layout(G)
    # draw graph
    nx.draw_networkx_nodes(G,graph_pos,node_size=node_size, 
                           alpha=node_alpha, node_color=node_color)
    nx.draw_networkx_edges(G,graph_pos,width=edge_tickness,
                           alpha=edge_alpha,edge_color=edge_color)
    nx.draw_networkx_labels(G, graph_pos,font_size=node_text_size,
                            font_family=text_font)
    nx.draw_networkx_edge_labels(G, graph_pos, edge_labels=labels, 
                                 label_pos=edge_text_pos)
    # show graph
    frame = plt.gca()
    plt.gcf().set_size_inches(10, 10)
    frame.axes.get_xaxis().set_visible(False)
    frame.axes.get_yaxis().set_visible(False)

    plt.show()

def add_node_to_graph(G, node, node_labels):
    if node not in node_labels:
        G.add_node(node)
    node_labels.add(node)

datafile = "data/disease_disease_part-000000"
lines = []
fin = open(os.path.join(datafile), 'rb')
for line in fin:
    disease_1, disease_2, weight = line.strip().split("\t")
    lines.append((disease_1, disease_2, float(weight)))
fin.close()

max_weight = max([x[2] for x in lines])
norm_lines = map(lambda x: (x[0], x[1], x[2] / max_weight), lines)

G = nx.Graph()
edge_labels = dict()
node_labels = set()
for line in norm_lines:
    add_node_to_graph(G, line[0], node_labels)
    add_node_to_graph(G, line[1], node_labels)
    if line[2] > 0.3:
        G.add_edge(line[0], line[1], weight=line[2])
        edge_labels[(line[0], line[1])] = "%.2f" % (line[2])
draw_graph(G, labels=edge_labels, graph_layout="shell")

All the code described above is available here on GitHub.

To set up the cluster, you first need to download and install Spark locally. I chose version 1.3.1 pre-built for Hadoop 2.6 and later. Its just a tarball which expands into a folder on your local box.

You also need an AWS account and credentials. You probably have it already from previous work on AWS, but if you don't the process is described here. The credentials consist of the AWS_ACCESS_KEY and AWS_SECRET_KEY. You will also have to generate a key pair, which is a certificate (.pem) file which you download and store locally. Assuming you have them set up in your environment already, you will need to set up two additional environment variables for starting up your cluster, then invoke the spark-ec2 script to start .

1
2
3
4
5
6
7
8
laptop$ export AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY
laptop$ export AWS_SECRET_ACCESS_KEY=$AWS_SECRET_KEY
laptop$ cd $SPARK_INSTALL_ROOT/ec2
laptop$ ./spark-ec2 -k your_key_name \
                    -i /path/to/your/pem_file \
                    --region=your_region \
                    -s 4 \
                    launch your_cluster_name

The spark-ec2 command above will launch a 5-node cluster (1 master + 4 slaves) of m1.large instances. The -k specifies a name that you gave to your key pair and the -i specifies the path to the PEM file for that key pair. The region should be the same as your default region, which is where your key pair is valid at. The -s specifies the number of slaves and the last argument is the name of your cluster.

This process takes a while (around 10-20 mins), you will see messages on the console where you ran the spark-ec2 command. If you go over to the EC2 console on the AWS Web UI, you will see machines whose names are prefixed with your_cluster_name. The master will be called your_cluster_name-master. Wait for the spark-ec2 command to end, succesful completion would look something like this:

1
2
3
4
5
...
Connection to ec2-10-20-30-40.compute-1.amazonaws.com closed.
Spark standalone cluster started at http://ec2-10-20-30-40.compute-1.amazonaws.com:8080
Ganglia started at http://ec2-10-20-30-40.compute-1.amazonaws.com:5080/ganglia
Done!

You can bring up the Spark web console from the cluster URL provided in the "Spark standalone cluster started..." line. Now that the master is up, you can also build the JAR and copy it to the master. You can get the hostname for the cluster master from the web console URL, or from the EC2 console (click on the master and the public DNS is shown in the detail panel below).

1
 2
 3
 4
 5
 6
 7
 8
 9
10
laptop$ cd $PROJECT_ROOT
laptop$ sbt package
...
[info] Compiling 1 Scala source to target/scala-2.10/classes...
[info] Packaging target/scala-2.10/cms-disease-graph_2.10-0.1.jar ...
[info] Done packaging.
[success] Total time: 4 s, completed May 24, 2015 9:53:03 AM
laptop$ scp -i /path/to/pem/file \
           $PROJECT_ROOT/target/scala-2.10/cms-disease-graph_2.10-0.1.jar \
           root@ec2-10-20-30-40.compute-1.amazonaws.com:/tmp/

Then logon to the master, navigate to the spark subdirectory (its right under the home directory), and submit the job to the cluster using spark-submit.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
laptop$ ssh -i /path/to/pem/file root@ec2-10-20-30-40.compute-1.amazonaws.com
root@cluster-master$ export AWS_ACCESS_KEY=your_access_key
root@cluster-master$ export AWS_SECRET_KEY=your_secret_key
root@cluster-master$ bin/spark-submit \
    --master spark://ec2-10-20-30-40.compute-1.amazonaws.com:7077 \
    --class com.mycompany.diseasegraph.GraphDataGeneratorJob 
    /tmp/cms-disease-graph_2.10-0.1.jar \
    s3n://my_bucket/cms_disease_graph/inputs/benefit_summary.csv \
    s3n://my_bucket/cms_disease_graph/inputs/inpatient_claims.csv \
    s3n://my_bucket/cms_disease_graph/inputs/outpatient_claims.csv \
    s3n://my_bucket/cms_disease_graph/disease_codes \
    s3n://my_bucket/cms_disease_graph/disease_pairs

The code requires you to specify your AWS access and secret key credentials because it reads and writes information to S3. It reads them off environment variables. The --master specifies the URL of the Spark master - you can find this on the Spark Web Console - it starts with spark and ends with 7077. The --class specifies the full path of the Scala object representing the Spark job. The third argument is the full path to the JAR file, and all subsequent arguments are parameters to the job itself - in my case, I have 3 input CSV files corresponding to the member information, inpatient claim information and outpatient claim information, and 2 output directories. One thing to note is that the S3 paths are specified using the s3n protocol instead of the more modern s3 protocol - this is because of the Hadoop Configuration parameters set in the code for S3 access - even though S3 versions of the keys exist, they don't seem to work correctly. The entire process took a disappointing 3.2 minutes to complete on my 5-node cluster, so really it took me more time to set up the job than to run it :-).

Once the job is done, you should terminate the cluster, unless you plan on reusing the cluster (probably for different jobs) in the near future, in which case you should stop the cluster - this will keep any data you have (such as JAR files, scripts, etc) on your root partition (not /tmp though obviously) and startup is a little faster - but you do pay for it, although its much lower than if you kept it running). Since I don't have immediate plans for using Spark on my own projects, I terminated it (by destroy). For safety, you should confirm that the cluster is really shut down by looking at the EC2 console.

1
2
3
4
5
root@cluster-master$ logout
laptop$ ./spark-ec2 -k your_key_name \
                    -i /path/to/your/pem_file \
                    --region=your_region \
                    destroy your_cluster_name

And thats all I have for today. I still hear people talking about running Spark jobs on EMR, and I still think that is preferable, but the EC2 approach is good enough for me until Amazon makes it as easy to run Spark as just another Step Type in their Web UI.