Monday, April 24, 2017

Predicting Image Similarity using Siamese Networks


In my previous post, I mentioned that I want to use Siamese Networks to predict image similarity from the INRIA Holidays Dataset. The Keras project on Github has an example Siamese network that can recognize MNIST handwritten digits that represent the same number as similar and different numbers as different. This got me all excited and eager to try this out on the Holidays dataset, which contains 1491 photos from 500 different vacations.

My Siamese network is somewhat loosely based on the architecture in the Keras example. The main idea behind a Siamese network is that it takes two inputs which need to be compared to each other, so we reduce it to a denser and hopefully more "semantic" vector representation and compare it using some standard vector arithmetic. Each input undergoes a dimensionality reduction transformation implemented as a neural network. Since we want the two images to be transformed in the same way, we train the two networks using shared weights. The output of the dimensionality reduction is a pair of vectors, which are compared in some way to yield a metric that can be used to predict similarity between the inputs.

The Siamese network I built is shown in the diagram below. It differs from the Keras example in two major ways. First, the Keras example uses Fully Connected Networks (FCNs) as the dimensionality reduction transformation component, whereas I use a Convolutional Neural Network (CNN). Second, the example computes the Euclidean distance between the two output vectors, and attempts to minimize the contrastive loss between them to produce a number in the [0,1] range that is thresholded to return a binary similar/dissimilar prediction. In my case, I use a FCN that combines the output vectors using element-wise dot product, use cross-entropy as my loss function, and predict a 0/1 to indicate similar/dissimilar.


For the CNN, I tried various different configurations. Unfortunately, I started running out of memory on the g2.2xlarge instance when I started trying large CNNs, and ended up migrating to a p2.xlarge. Even then, I had to either cut down the size of the input image or the network complexity, and eventually settled on a LeNet configuration for my CNN, which seemed a bit underpowered for the data. For the current configuration, shown in 02-holidays-siamese-network notebook, the network pretty much refused to learn anything. In other tries, the best test set accuracy I was able to get was about 60%, but all of them involved compromising on the input size or the complexity of the CNN, so I gave up and started looking at other approaches.

I have had success with transfer learning in the past, where you take large networks pre-trained on some external corpus such as ImageNet, chop off the classification head, and expose the vector from the layer prior to the head layer(s). So the pre-trained network acts as the vectorizer or dimension reducer component. I used the following pre-trained networks that are available in Keras applications, to generate vectors from. The code to do this can be found in the 03-pretrained-nets-vectorizers notebook.

  • VGG-16
  • VGG-19
  • ResNet
  • InceptionV3
  • xCeption


The diagram above shows the general setup of this approach. The first step is to just run the predict method on the pre-trained models to generate the vectors for each image. These vectors then need to be combined and fed to another classifier component. Some strategies I tried were element-wise dot product, absolute difference and squared (Euclidean) distance. In case of dot product, corresponding elements of the two vectors that are both high end up becoming higher, and elements that differ end up getting smaller. In case of absolute and squared differences, elements that are different tend to become larger. In case of squared difference, large differences are highlighted better than small differences.

The classifier component (shown as FCN in my previous diagram) can be any kind of classifier, including non neural network based ones. As a baseline, I tried several common classifiers from the Scikit-Learn and XGBoost packages. You can see the code in the 04-pretrained-vec-dot-classifier, 05-pretrained-vec-l1-classifier, and 06-pretrained-vec-l2-classifier notebooks. The resulting accuracies for each (vectorizer, merge strategy, classifier) combination on the held out test set are summarized below.








Generally speaking, XGBoost seems to do the best across all merge strategies and vectorization schemes. Among these, Inception and ResNet vectors seem to be the best overall. We also now have a pretty high baseline for accuracy, about 96.5% for Inception vectors merged using dot product and classified with XGBoost. The code for this can be found in the 07-pretrained-vec-nn-classifier notebook. The figure below shows the accuracies for different merge strategies for ResNet and Inception.


The next step was to see if I could get even better performance by replacing the classifier head with a neural network. I ended up using a simple 3 layer FCN that gave a 95.7% accuracy with Inception vectors and using dot product for a merge strategy. Not quite as good as the XGBoost classifier, but quite close.

Finally, I decided to merge the two approaches. For the vectorization, I chose a pre-trained Inception network with its classification head removed. Input to this network would be images, and I would use the Keras ImageDataGenerator to augment my dataset, using the mechanism I described in my previous post. I decided to keep all the pre-trained weights fixed. For the classification head, I decided to start with the FCN I trained in the previous step and fine tune its weights during training. The code for that is in the 08-holidays-siamese-finetune notebook.


Unfortunately, this did not give me the stellar results I was hoping for, my best result was about 88% accuracy in similarity prediction. In retrospect, it may make sense to experiment with a simpler pre-trained model such as VGG and fine tune some of the later layer weights instead of keeping them all frozen. There is also a possibility that my final network is not getting the benefits of a fine tuned model from the previous steps. One symptom is that the accuracy after the first epoch is only around 0.6 - I would have expected it to be higher with a well trained model. In another project where a similar thing happened, a colleague discovered that I was doing extra normalization with ImageDataGenerator that I hadn't been doing with the vectorization step - this doesn't seem to be the case here though.

Overall, I got the best results from the transfer learning approach, with Inception vectors, dot product merge strategy and XGBoost classifier. Nice thing about transfer learning is that it is relatively cheap in terms of resources compared to the fine tuning or even the from-scratch training approach. While XGBoost does take some time to train, you can do the whole thing on your laptop. This is also true if you replace the XGBoost classifier with an FCN. You can also do inline Image Augmentation (i.e, without augmenting and saving) using the Keras ImageDataGenerator if you use the random_transform call.


Saturday, February 18, 2017

Using the Keras ImageDataGenerator with a Siamese Network


I have been looking at training a Siamese network to predict if two images are similar or different. Siamese networks are a type of Neural network that contain a pair of identical sub-networks that share the same parameters and weights. During training, the parameters are updated identically across both subnetworks. Siamese networks were first proposed in 1993 by Bromley, et al in their paper Signature Verification using a Siamese Time Delay Neural Network. Keras provides an example of a Siamese network as part of the distribution.

My dataset is the INRIA Holidays Dataset, a set of 1491 photos from 500 different vacations. The photos have a naming convenition from which the groups can be derived. Each photo is numbered with six digits - the first 4 refer to the vacation and the last two is a unique sequence number within the vacation. For example, a photo named 100301.jpg is from vacation 1003 and is the first photo in that group.

The input to my network consist of image pairs and the output is either 1 (similar) or 0 (different). Similar image pairs are from the same vacation group. For example, the code snippet displays three photos - the first two are from the same group and the last one is different.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from __future__ import division, print_function
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import np_utils
from scipy.misc import imresize
import itertools
import matplotlib.pyplot as plt
import numpy as np
import random
import os

DATA_DIR = "../data"
IMAGE_DIR = os.path.join(DATA_DIR, "holiday-photos")

ref_image = plt.imread(os.path.join(IMAGE_DIR, "100301.jpg"))
sim_image = plt.imread(os.path.join(IMAGE_DIR, "100302.jpg"))
dif_image = plt.imread(os.path.join(IMAGE_DIR, "127202.jpg"))

def draw_image(subplot, image, title):
    plt.subplot(subplot)
    plt.imshow(image)
    plt.title(title)
    plt.xticks([])
    plt.yticks([])
    
draw_image(131, ref_image, "reference")
draw_image(132, sim_image, "similar")
draw_image(133, dif_image, "different")
plt.tight_layout()
plt.show()


The following code snippet loops through the image directory and uses the file naming convention to create all pairs of similar images and a corresponding pair of different images. Similar image pairs are generated by considering all combination of image pairs within a group. Dissimilar image pairs are generated by pairing the left hand image of the similar pair with a random image from some other group. This gives us 2072 similar image pairs and 2072 different image pairs, ie, a total of 4144 image pairs for our training data.

Fearing that this might not be nearly enough to train my network adequately, I decided to use the Keras ImageDataGenerator to augment the dataset. Before Keras, when I was working with Caffe, I would manually augment my input with a fixed number of standard transformations, such as rotation, flipping, zooming and affine transforms (these are all just matrix transforms). The Keras ImageDataGenerator is much more sophisticated, you instantiate it with the range of transformations you will allow on your dataset, and it returns you a generator containing transformations on your input images images from a directory.

I have used the ImageDataGenerator previously to augment my dataset to train a simple classification CNN, where the input was an image and the output was a label. This is the default case the component is built to handle, so its actually very simple to use this. My problem this time was a litle different - my input is a pair of image names from a triple, and I wanted that the identical transformation be applied to both imaages. (This is not strictly necessary in my case, but can't hurt, and in any case I wanted to learn how to do this for another upcoming project).

It seems to be something that others have been looking for as well, and there is some discussion in Keras Issue 3059. In addition, the ImageDataGenerator documentation covers some cases where this can be done, using a pair of ImageDataGenerator instances that are instantiated with the same parameters. However, all these seem to require that you either enumerate the LHS and RHS images in the pair as 4-dimensional tensors (using flow()) or store them in two parallel directories with identical names (using flow_from_directory()). The first seems a bit wasteful, and the second seems incredibly complicated for my use case.

So I went digging into the code and found a private (in the sense of undocumented) method called random_transform(). It applies a random sequence of the transformations you have specified in the ImageDataGenerator constructor to your input image. In this post, I will describe an image generator that I built for my Siamese network using the random_transform() method.

We start with a basic generator that returns a batch of image triples per invocation. The generator is instantiated at each epoch, and the next() method is called to get the next batch of triples.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
def image_triple_generator(image_triples, batch_size):
    while True:
        # loop once per epoch
        num_recs = len(image_triples)
        indices = np.random.permutation(np.arange(num_recs))
        num_batches = num_recs // batch_size
        for bid in range(num_batches):
            # loop once per batch
            batch_indices = indices[bid * batch_size : (bid + 1) * batch_size]
            yield [image_triples[i] for i in batch_indices]
            
triples_batch_gen = image_triple_generator(image_triples, 4)
triples_batch_gen.next()

This gives us a batch of 4 triples as shown:

[('149601.jpg', '149604.jpg', 1),
 ('144700.jpg', '106201.jpg', 0),
 ('103304.jpg', '111701.jpg', 0),
 ('133200.jpg', '128100.jpg', 0)]

Calling next() returns the next 4 triples. This is what happens after each batch.

1
triples_batch_gen.next()

[('135104.jpg', '122601.jpg', 0),
 ('137700.jpg', '137701.jpg', 1),
 ('136005.jpg', '105501.jpg', 0),
 ('132500.jpg', '132511.jpg', 1)]

Next, we apply the ImageDataGenerator.random_transform() to a single image to see if it does indeed do what I think it does. My fear was that there needs to e some upstream initialization before I could call the random_transform() method. As you can see from the output, the random_transform() augments the original image into variants that are quite close and could legitimately have been real photos.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
datagen_args = dict(rotation_range=10,
                    width_shift_range=0.2,
                    height_shift_range=0.2,
                    shear_range=0.2,
                    zoom_range=0.2,
                    horizontal_flip=True)
datagen = ImageDataGenerator(**datagen_args)

sid = 150
np.random.seed(42)
image = plt.imread(os.path.join(IMAGE_DIR, "115201.jpg"))
sid += 1
draw_image(sid, image, "orig")
for j in range(4):
    augmented = datagen.random_transform(image)
    sid += 1
    draw_image(sid, augmented, "aug#{:d}".format(j + 1))

plt.tight_layout()
plt.show()


Next I wanted to see if I could take two images and apply the same transformation to both the images. I now take a pair of ImageDataGenerators configured the same way. The individual transformations that are applied to the image in the random_transform() method are all driven using numpy random number generators, so one way to make them do the same thing was to initialize the random number seed to the same random value for each ImageGenerator at the start of each batch. As you can see from the photos below, this strategy seems to be working.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
image_pair = ["108103.jpg", "112003.jpg"]

datagens = [ImageDataGenerator(**datagen_args),
            ImageDataGenerator(**datagen_args)]

sid = 240
for i, image in enumerate(image_pair):
    image = plt.imread(os.path.join(IMAGE_DIR, image_pair[i]))
    sid += 1
    draw_image(sid, image, "orig")
    # make sure the two image data generators generate same transformations
    np.random.seed(42)
    for j in range(3):
        augmented = datagens[i].random_transform(image)
        sid += 1
        draw_image(sid, augmented, "aug#{:d}".format(j + 1))

plt.tight_layout()
plt.show()


Finally, we are ready to build our final generator that can be plugged in to the Siamese network. I haven't built that yet, so there might be some changes once I try to integrate it in, but here is the first cut. The caching is because I noticed that it takes a while to generate the batches, so caching is hopefully going to spped it up.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
RESIZE_WIDTH = 300
RESIZE_HEIGHT = 300

def cached_imread(image_path, image_cache):
    if not image_cache.has_key(image_path):
        image = plt.imread(image_path)
        image = imresize(image, (RESIZE_WIDTH, RESIZE_HEIGHT))
        image_cache[image_path] = image
    return image_cache[image_path]

def preprocess_images(image_names, seed, datagen, image_cache):
    np.random.seed(seed)
    X = np.zeros((len(image_names), RESIZE_WIDTH, RESIZE_HEIGHT, 3))
    for i, image_name in enumerate(image_names):
        image = cached_imread(os.path.join(IMAGE_DIR, image_name), image_cache)
        X[i] = datagen.random_transform(image)
    return X

def image_triple_generator(image_triples, batch_size):
    datagen_args = dict(rotation_range=10,
                        width_shift_range=0.2,
                        height_shift_range=0.2,
                        shear_range=0.2,
                        zoom_range=0.2,
                        horizontal_flip=True)
    datagen_left = ImageDataGenerator(**datagen_args)
    datagen_right = ImageDataGenerator(**datagen_args)
    image_cache = {}
    
    while True:
        # loop once per epoch
        num_recs = len(image_triples)
        indices = np.random.permutation(np.arange(num_recs))
        num_batches = num_recs // batch_size
        for bid in range(num_batches):
            # loop once per batch
            batch_indices = indices[bid * batch_size : (bid + 1) * batch_size]
            batch = [image_triples[i] for i in batch_indices]
            # make sure image data generators generate same transformations
            seed = np.random.randint(low=0, high=1000, size=1)[0]
            Xleft = preprocess_images([b[0] for b in batch], seed, 
                                      datagen_left, image_cache)
            Xright = preprocess_images([b[1] for b in batch], seed,
                                       datagen_right, image_cache)
            Y = np_utils.to_categorical(np.array([b[2] for b in batch]))
            yield Xleft, Xright, Y

Here is a little snippet to call my data generator and verify that it returns the right shaped data.

1
2
3
triples_batch_gen = image_triple_generator(image_triples, 32)
Xleft, Xright, Y = triples_batch_gen.next()
print(Xleft.shape, Xright.shape, Y.shape)

which returns the expected shapes.

(32, 300, 300, 3) (32, 300, 300, 3) (32, 2)

So anyway, this is all I have so far. Once I have my Siamese network coded up and running, I will talk about it in a subsequent post. I haven't heard about anyone using the ImageDataGenerator.random_transform() directly before, so I thought that it might be interesting to describe my experience. Currently the enhancements seem to be aimed at trying to continue to allow folks to use the flow() and flow_from_directory() methods. I am not sure if more specialized requirements will come up in the future, but I think using the random_transform() method instead might a good choice for many situations. Of course, it is quite likely that I may be missing something, so in case you know of problems with this approach, please let me know.


Monday, January 30, 2017

Migrating VGG-CNN from Caffe to Keras


I attended (and presented at) the Demystifying Deep Learning and Artificial Intelligence Workshop at Oakland last November. One of the talks I attended was Introduction to Deep Learning for Images in Keras presented by Stephane Egly and Malaikannan (Malai) Sankarasubbu. Since my talk was on Transfer Learning and Fine Tuning CNNs for Image Classification (also using Keras), there was quite a bit of overlap between our interests, and Malai and I got to talking after his presentation. During that time I mentioned that I was using Keras now, but the project upon which my talk was based used a pre-trained VGG-CNN model from the Caffe model zoo.

We went on for a bit about how Keras is so much nicer compared to Caffe. I am a big fan of Keras, but as a committer to the Recurrent Shop project, he is probably a bigger fan (and very likely much more knowledgable about it) than I am. So anyway, after some more Caffe bashing from my end, he asked me why I didn't migrate that project to Keras, and I mentioned that the model I was using was not yet available in Keras Applications (the Keras model zoo). At that point, Malai made what to me at the time was a blindingly insightful comment - that migration is basically just transferring the weights over.

So, after thinking about it for a while, I finally decided to try to migrate the pre-trained Caffe VGG-CNN model to Keras, and here is the result. It turned out to be more than just transferring the weights over, but I learned a few things during the process, which I would like to share with you in this post.

Caffe End: Dumping the Model


A Caffe model is packaged as a protobuf (.prototxt) file that specifies the network structure and a binary model (.caffemodel) file that contains the weights. In addition, image inputs to the VGG-CNN need to be normalized by subtracting the mean value across each of the RGB channels. The mean values are provided via mean image, which is provided as a binary weights (.binaryproto) file. The first step is to dump out the network structure and weights in a portable format that can be ingested by Keras.

The code below uses the PyCaffe API to dump out the layers of the VGG-CNN and the shape of their outputs. It also dumps out the weights into Numpy native file format (.npy). We could just have easily dumped it into text format, but .npy is faster to load. Finally, it reads the mean image file and writes it out as another .npy file.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from __future__ import division, print_function
import caffe
import numpy as np
import os

DATA_DIR = "/path/to/my/data"
OUTPUT_DIR = os.path.join(DATA_DIR, "vgg-cnn-weights")

CAFFE_HOME="/path/to/my/caffe/installation"

MODEL_DIR = os.path.join(CAFFE_HOME, "models", "vgg_cnn_s")
MODEL_PROTO = os.path.join(MODEL_DIR, "deploy.prototxt")
MODEL_WEIGHTS = os.path.join(MODEL_DIR, "VGG_CNN_S.caffemodel")
MEAN_IMAGE = os.path.join(MODEL_DIR, "VGG_mean.binaryproto")

caffe.set_mode_cpu()
net = caffe.Net(MODEL_PROTO, MODEL_WEIGHTS, caffe.TEST)

# layer names and output shapes
for layer_name, blob in net.blobs.iteritems():
    print(layer_name, blob.data.shape)

# write out weight matrices and bias vectors
for k, v in net.params.items():
    print(k, v[0].data.shape, v[1].data.shape)
    np.save(os.path.join(OUTPUT_DIR, "W_{:s}.npy".format(k)), v[0].data)
    np.save(os.path.join(OUTPUT_DIR, "b_{:s}.npy".format(k)), v[1].data)

# write out mean image
blob = caffe.proto.caffe_pb2.BlobProto()
with open(MEAN_IMAGE, 'rb') as fmean:
    mean_data = fmean.read()
blob.ParseFromString(mean_data)
mu = np.array(caffe.io.blobproto_to_array(blob))
print("Mean image:", mu.shape)
np.save(os.path.join(OUTPUT_DIR, "mean_image.npy"), mu)

The layer names and output shapes are shown below. The first column is the name of the layer and the second is the shape of the output from that layer. The shapes for the convolutional layers are specified as 4-dimensional tensors, where the dimensions refer to batch size, number of channels, rows and columns. If you have worked with images, you are probably more familiar specifying image dimensions by (row, column, channel) rather than (channel, row, column). In Keras, the latter is called the Theano dimension ordering, because Theano uses it.

The last 3 layers are fully connected layers (called Dense in Keras) whose dimensions are specified as (batch size, output size). Notice that the network can predict one of 1000 categories, hence the last layer has a size (10, 1000).

The layers with names starting with "conv" are convolutional layers and those with names starting with "pool" are pooling layers. The single "norm" layer is a normalization layer. For a fuller description, take a look at the protobuf file.

layer.namelayer.output_shape
data(10, 3, 224, 224)
conv1(10, 96, 109, 109)
norm1(10, 96, 109, 109)
pool1(10, 96, 37, 37)
conv2(10, 256, 33, 33)
pool2(10, 256, 17, 17)
conv3(10, 512, 17, 17)
conv4(10, 512, 17, 17)
conv5(10, 512, 17, 17)
pool5(10, 512, 6, 6)
fc6(10, 4096)
fc7(10, 4096)
prob(10, 1000)

Each of the convolutional and fully connected layers are associated with a weight matrix and a bias vector. The code above downloads it to .npy files named W_{layer.name} for the weight matrix and b_{layer.name} for the bias vector. The table below shows each of the layers and the shapes of the associated weight matrix and bias vector.

layerW.shapeb.shape
conv1(96, 3, 7, 7)(96,)
conv2(256, 96, 5, 5)(256,)
conv3(512, 256, 3, 3)(512,)
conv4(512, 512, 3, 3)(512,)
conv5(512, 512, 3, 3)(512,)
fc6(4096, 18432)(4096,)
fc7(4096, 4096)(4096,)
fc8(1000, 4096)(1000,)

Finally, the mean image is written out to a numpy file as a (1, 3, 224, 224) tensor, the same size as a batch of 1 image.

Keras End: Rebuilding the Model


I then use the information about the layers and their output shapes to build a skeleton Keras network that looks kind of like the Caffe original. Here I specify the imports so you know what libraries the subsequent calls come from.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
from __future__ import division, print_function
from keras import backend as K
from keras.layers import Input
from keras.layers.core import Activation, Dense, Flatten
from keras.layers.convolutional import Convolution2D, ZeroPadding2D
from keras.layers.normalization import BatchNormalization
from keras.layers.pooling import MaxPooling2D
from keras.models import Model
from scipy.misc import imresize
import matplotlib.pyplot as plt
import numpy as np
import os
import re

Of late, I have been using Keras exclusively with the Tensorflow backend, but because the dimension ordering of the layer outputs and the weights of the Caffe network are all in Theano (channel, row, col) format, I initially set the dimension ordering to Theano in Keras. Later, I switched back to using the Theano backend as well, but more on that later.

1
K.set_image_dim_ordering("th")

Getting the shapes right


Here is my first attempt at a Keras equivalent for the Caffe VGG-CNN network. Almost all the Caffe layers have analogs in Keras that behave similarly, except for the NORM1 layer. The protobuf file mentions a Local Response Normalization (LRN) layer that is not available in Keras. For the moment I replace this layer with a Keras BatchNormalization layer, since both types of normalization layers leave the shape unchanged.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
data = Input(shape=(3, 224, 224), name="DATA")

conv1 = Convolution2D(96, 7, 7, subsample=(2, 2))(data)
conv1 = Activation("relu", name="CONV1")(conv1)

norm1 = BatchNormalization(name="NORM1")(conv1)

pool1 = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name="POOL1")(norm1)

conv2 = Convolution2D(256, 5, 5)(pool1)
conv2 = Activation("relu", name="CONV2")(conv2)

pool2 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name="POOL2")(conv2)

conv3 = Convolution2D(512, 3, 3)(pool2)
conv3 = Activation("relu", name="CONV3")(conv3)

conv4 = Convolution2D(512, 3, 3)(conv3)
conv4 = Activation("relu", name="CONV4")(conv4)

conv5 = Convolution2D(512, 3, 3)(conv4)
conv5 = Activation("relu", name="CONV5")(conv5)

pool5 = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name="POOL5")(conv5)

fc6 = Flatten()(pool5)
fc6 = Dense(4096)(fc6)
fc6 = Activation("relu", name="FC6")(fc6)

fc7 = Dense(4096)(fc6)
fc7 = Activation("relu", name="FC7")(fc7)

fc8 = Dense(1000, name="FC8")(fc7)
prob = Activation("softmax", name="PROB")(fc8)

model = Model(input=[data], output=[prob])

model.compile(optimizer="adam", loss="categorical_crossentropy")

As you can see below, there are differences between the shapes of the layer outputs between the original Caffe VGG-CNN network and my Keras copy. The columns in the table below correspond to the layer name, the shape of the output from that layer in the Caffe model and the corresponding shape in the Keras model. There are more Keras layers than Caffe layers here - I have named the Keras layers that correspond to the Caffe layers the same name as the Caffe layer so we can compare their output shapes. The ones that are highlighted green are layers whose output shapes match, and the ones highlighted in red are layers whose output shapes don't match.

layer.namecaffe.output_shapekeras.output_shape
DATA(10, 3, 224, 224)(10, 3, 224, 224)
convolution2d_1-(10, 96, 109, 109)
CONV1(10, 96, 109, 109)(10, 96, 109, 109)
NORM1(10, 96, 109, 109)(10, 96, 109, 109)
POOL1(10, 96, 37, 37)(10, 96, 36, 36)
convolution2d_2-(10, 256, 32, 32)
CONV2(10, 256, 33, 33)(10, 256, 32, 32)
POOL2(10, 256, 17, 17)(10, 256, 16, 16)
convolution2d_3-(10, 512, 14, 14)
CONV3(10, 512, 17, 17)(10, 512, 14, 14)
convolution2d_4-(10, 512, 12, 12)
CONV4(10, 512, 17, 17)(10, 512, 12, 12)
convolution2d_5-(10, 512, 10, 10)
CONV5(10, 512, 17, 17)(10, 512, 10, 10)
POOL5(10, 512, 6, 6)(10, 512, 3, 3)
flatten_1-(10, 4608)
dense_1-(10, 4096)
FC6(10, 4096)(10, 4096)
dense_2-(10, 4096)
FC7(10, 4096)(10, 4096)
FC8(10, 1000)(10, 1000)
PROB(10, 1000)(10, 1000)

The first difference is in POOL1, where the Keras output is two rows and two columns less than the Caffe version. The incoming tensor to POOL1 has the shape (10, 96, 109, 109), the size of the pooling filter is of size (3, 3) and its stride size is also (3, 3). Then, using the formula from the CS231N course page on Convolutional Networks (which incidentally is also a great resource if you want to do a quick refresher on CNNs), the expected output shape should be (10, 96, 36, 36), which is the same as the output shape from the Keras layer. The difference in Caffe is because Caffe does pooling by implicitly applying a padding that is VALID at the beginning and SAME at the end, i.e, a zero padding layer applied along the right and bottom edges. This can be simulated in Keras by applying a ZeroPadding2D layer with padding of (0, 2, 0, 2) just before the POOL1 layer.

The other shape differences are because of the same reason, and can be similarly fixed by applying the appropriate ZeroPadding2D layer in front of it. The full list of these additions are listed below.

  • Add (0, 2, 0, 2) zero padding before POOL1
  • Add (0, 1, 0, 1) zero padding before POOL2
  • Add (0, 2, 0, 2) zero padding before CONV3
  • Add (0, 2, 0, 2) zero padding before CONV4
  • Add (0, 2, 0, 2) zero padding before CONV5
  • Add (0, 1, 0, 1) zero padding before POOL5

With these changes, my Keras network now looks like this:

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
data = Input(shape=(3, 224, 224), name="DATA")

conv1 = Convolution2D(96, 7, 7, subsample=(2, 2))(data)
conv1 = Activation("relu", name="CONV1")(conv1)

norm1 = BatchNormalization(name="NORM1")(conv1)

pool1 = ZeroPadding2D(padding=(0, 2, 0, 2))(norm1)
pool1 = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name="POOL1")(pool1)

conv2 = Convolution2D(256, 5, 5)(pool1)
conv2 = Activation("relu", name="CONV2")(conv2)

pool2 = ZeroPadding2D(padding=(0, 1, 0, 1))(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name="POOL2")(pool2)

conv3 = ZeroPadding2D(padding=(0, 2, 0, 2))(pool2)
conv3 = Convolution2D(512, 3, 3)(conv3)
conv3 = Activation("relu", name="CONV3")(conv3)

conv4 = ZeroPadding2D(padding=(0, 2, 0, 2))(conv3)
conv4 = Convolution2D(512, 3, 3)(conv4)
conv4 = Activation("relu", name="CONV4")(conv4)

conv5 = ZeroPadding2D(padding=(0, 2, 0, 2))(conv4)
conv5 = Convolution2D(512, 3, 3)(conv5)
conv5 = Activation("relu", name="CONV5")(conv5)

pool5 = ZeroPadding2D(padding=(0, 1, 0, 1))(conv5)
pool5 = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name="POOL5")(pool5)

fc6 = Flatten()(pool5)
fc6 = Dense(4096)(fc6)
fc6 = Activation("relu", name="FC6")(fc6)

fc7 = Dense(4096)(fc6)
fc7 = Activation("relu", name="FC7")(fc7)

fc8 = Dense(1000, name="FC8")(fc7)
prob = Activation("softmax", name="PROB")(fc8)

model = Model(input=[data], output=[prob])

model.compile(optimizer="adam", loss="categorical_crossentropy")

After these changes, there are no more shape differences between the Caffe layer outputs and their corresponding Keras layer outputs.

layer.namecaffe.output_shapekeras.output_shape
DATA(10, 3, 224, 224)(10, 3, 224, 224)
convolution2d_6-(10, 96, 109, 109)
CONV1(10, 96, 109, 109)(10, 96, 109, 109)
NORM1(10, 96, 109, 109)(10, 96, 109, 109)
zeropadding2d_1-(10, 96, 111, 111)
POOL1(10, 96, 37, 37)(10, 96, 37, 37)
convolution2d_7-(10, 256, 33, 33)
CONV2(10, 256, 33, 33)(10, 256, 33, 33)
zeropadding2d_2-(10, 256, 34, 34)
POOL2(10, 256, 17, 17)(10, 256, 17, 17)
zeropadding2d_3-(10, 256, 19, 19)
convolution2d_8-(10, 512, 17, 17)
CONV3(10, 512, 17, 17)(10, 512, 17, 17)
zeropadding2d_4-(10, 512, 19, 19)
convolution2d_9-(10, 512, 17, 17)
CONV4(10, 512, 17, 17)(10, 512, 17, 17)
zeropadding2d_5-(10, 512, 19, 19)
convolution2d_10-(10, 512, 17, 17)
CONV5(10, 512, 17, 17)(10, 512, 17, 17)
zeropadding2d_6-(10, 512, 18, 18)
POOL5(10, 512, 6, 6)(10, 512, 6, 6)
flatten_2-(10, 18432)
dense_3-(10, 4096)
FC6(10, 4096)(10, 4096)
dense_4-(10, 4096)
FC7(10, 4096)(10, 4096)
FC8(10, 1000)(10, 1000)
PROB(10, 1000)(10, 1000)

Defining the Local Response Normalization Layer


Now that the output shapes are lined up, I needed to do something about the BatchNormalization placeholder layer. The NORM1 layer is defined in the protobuf file as a Local Response Normalization (LRN) layer that performs lateral inhibition by normalizing over local input regions. It can work in two modes - ACROSS_CHANNEL and WITHIN_CHANNEL. In the first, the local regions extend across nearby channels but have no spatial extent and in the second, they extend spatially but are in separate channels. The equation for LRN is shown below:


Since Keras has no built-in LRN layer, I built a custom layer following the instructions on the Writing your own Keras Layers page. My custom LRN layer implements the WITHIN_CHANNEL approach, the code for it is shown below.

Note that the LRN is nowadays generally regarded as obsolete, having been replaced by better methods of regularization such as Dropout, Batch Normalization and better initialization. So the LRN is probably not something you want to use in your new networks. However, because we are trying to implement an existing network, we need to replicate the LRN functionality in Keras as well.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class LRN(Layer):
    
    def __init__(self, n=5, alpha=0.0005, beta=0.75, k=2, **kwargs):
        self.n = n
        self.alpha = alpha
        self.beta = beta
        self.k = k
        super(LRN, self).__init__(**kwargs)

    def build(self, input_shape):
        self.shape = input_shape
        super(LRN, self).build(input_shape)

    def call(self, x, mask=None):
        if K.image_dim_ordering == "th":
            _, f, r, c = self.shape
        else:
            _, r, c, f = self.shape
        half_n = self.n // 2
        squared = K.square(x)
        pooled = K.pool2d(squared, (half_n, half_n), strides=(1, 1),
                         border_mode="same", pool_mode="avg")
        if K.image_dim_ordering == "th":
            summed = K.sum(pooled, axis=1, keepdims=True)
            averaged = (self.alpha / self.n) * K.repeat_elements(summed, f, axis=1)
        else:
            summed = K.sum(pooled, axis=3, keepdims=True)
            averaged = (self.alpha / self.n) * K.repeat_elements(summed, f, axis=3)
        denom = K.pow(self.k + averaged, self.beta)
        return x / denom
    
    def get_output_shape_for(self, input_shape):
        return input_shape

The constructor allows us to set the hyperparameters for the layer. I have set default values for these hyperparameters from the values in the Caffe protobuf file. The call method implements the equation using the Keras backend API. The summation divided by n is equal to the sum of 2D average pooling operations with a pool size of (n/2, n/2) and a stride size of (1, 1). The rest of the code is mostly self-explanatory. The get_output_shape_for method returns the output shape of the layer, which in this case is the same as the input shape.

Loading the weights


The next step is to load up the weights that I dumped out of the Caffe binary .caffemodel file to .npy files. The convolutional weight matrices are 4 dimensional and the fully connected weight matrices are 2 dimensional. However, I noticed that their dimensions are reversed. For example, if the Keras layer needs a matrix of size (7, 7, 3, 96), the shape of the matrix I get from Caffe is (96, 3, 7, 7). Same with the fully connected weights. Turns out that the difference in the convolutional weights is based on a difference between the way Caffe and Keras do convolutions. In order to account for this difference, these Caffe weights need to be rotated 180 degrees. Similarly, the fully connected weights need to be transposed to be usable in Keras. The transform_conv_weights and transform_fc_weights functions implement the transformations necessary to make this happen.

You can also find code to do this in convert.py in the MarcBS Keras fork, which contains code specifically for converting Caffe networks to Keras. The MarcBS fork is very likely not up-to-date with the main Keras branch, the README.md says that it is compatible with Theano only.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def transform_conv_weight(W):
    # for non FC layers, do this because Keras does convolution vs Caffe correlation
    for i in range(W.shape[0]):
        for j in range(W.shape[1]):
            W[i, j] = np.rot90(W[i, j], 2)
    return W

def transform_fc_weight(W):
    return W.T

# load weights
CAFFE_WEIGHTS_DIR = "/Users/palsujit/Projects/fttl-with-keras/data/vgg-cnn/saved-weights"

W_conv1 = transform_conv_weight(np.load(os.path.join(CAFFE_WEIGHTS_DIR, "W_conv1.npy")))
b_conv1 = np.load(os.path.join(CAFFE_WEIGHTS_DIR, "b_conv1.npy"))

W_conv2 = transform_conv_weight(np.load(os.path.join(CAFFE_WEIGHTS_DIR, "W_conv2.npy")))
b_conv2 = np.load(os.path.join(CAFFE_WEIGHTS_DIR, "b_conv2.npy"))

W_conv3 = transform_conv_weight(np.load(os.path.join(CAFFE_WEIGHTS_DIR, "W_conv3.npy")))
b_conv3 = np.load(os.path.join(CAFFE_WEIGHTS_DIR, "b_conv3.npy"))

W_conv4 = transform_conv_weight(np.load(os.path.join(CAFFE_WEIGHTS_DIR, "W_conv4.npy")))
b_conv4 = np.load(os.path.join(CAFFE_WEIGHTS_DIR, "b_conv4.npy"))

W_conv5 = transform_conv_weight(np.load(os.path.join(CAFFE_WEIGHTS_DIR, "W_conv5.npy")))
b_conv5 = np.load(os.path.join(CAFFE_WEIGHTS_DIR, "b_conv5.npy"))

W_fc6 = transform_fc_weight(np.load(os.path.join(CAFFE_WEIGHTS_DIR, "W_fc6.npy")))
b_fc6 = np.load(os.path.join(CAFFE_WEIGHTS_DIR, "b_fc6.npy"))

W_fc7 = transform_fc_weight(np.load(os.path.join(CAFFE_WEIGHTS_DIR, "W_fc7.npy")))
b_fc7 = np.load(os.path.join(CAFFE_WEIGHTS_DIR, "b_fc7.npy"))

W_fc8 = transform_fc_weight(np.load(os.path.join(CAFFE_WEIGHTS_DIR, "W_fc8.npy")))
b_fc8 = np.load(os.path.join(CAFFE_WEIGHTS_DIR, "b_fc8.npy"))

I now set the weight matrices and bias vectors into their corresponding layers, and replace the BatchNormalization layer with my custom LRN layer. The code to do that is shown below.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# define network
data = Input(shape=(3, 224, 224), name="DATA")

conv1 = Convolution2D(96, 7, 7, subsample=(2, 2),
                     weights=(W_conv1, b_conv1))(data)
conv1 = Activation("relu", name="CONV1")(conv1)

norm1 = LRN(name="NORM1")(conv1)

pool1 = ZeroPadding2D(padding=(0, 2, 0, 2))(norm1)
pool1 = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name="POOL1")(pool1)

conv2 = Convolution2D(256, 5, 5, weights=(W_conv2, b_conv2))(pool1)
conv2 = Activation("relu", name="CONV2")(conv2)

pool2 = ZeroPadding2D(padding=(0, 1, 0, 1))(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name="POOL2")(pool2)

conv3 = ZeroPadding2D(padding=(0, 2, 0, 2))(pool2)
conv3 = Convolution2D(512, 3, 3, weights=(W_conv3, b_conv3))(conv3)
conv3 = Activation("relu", name="CONV3")(conv3)

conv4 = ZeroPadding2D(padding=(0, 2, 0, 2))(conv3)
conv4 = Convolution2D(512, 3, 3, weights=(W_conv4, b_conv4))(conv4)
conv4 = Activation("relu", name="CONV4")(conv4)

conv5 = ZeroPadding2D(padding=(0, 2, 0, 2))(conv4)
conv5 = Convolution2D(512, 3, 3, weights=(W_conv5, b_conv5))(conv5)
conv5 = Activation("relu", name="CONV5")(conv5)

pool5 = ZeroPadding2D(padding=(0, 1, 0, 1))(conv5)
pool5 = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name="POOL5")(pool5)

fc6 = Flatten()(pool5)
fc6 = Dense(4096, weights=(W_fc6, b_fc6))(fc6)
fc6 = Activation("relu", name="FC6")(fc6)

fc7 = Dense(4096, weights=(W_fc7, b_fc7))(fc6)
fc7 = Activation("relu", name="FC7")(fc7)

fc8 = Dense(1000, weights=(W_fc8, b_fc8), name="FC8")(fc7)
prob = Activation("softmax", name="PROB")(fc8)

model = Model(input=[data], output=[prob])

model.compile(optimizer="adam", loss="categorical_crossentropy")

At this point, we have a pre-trained VGG-CNN network in Keras. Lets check to see if its any good.

Predictions using Keras VGG-CNN


I wanted to see how this pretrained network would perform on some test images. The Caffe documentation has a Classification Tutorial where they demonstrate classification using an image of a cat, so I decided to use the same image.

The network is designed to take a batch of images, and for each image, return a number between 0 and 999 indicating one of 1000 ImageNet classes. The mapping of these numbers to actual labels is provided by this Github gist file. I build a little dictionary to look up a label given the class ID below.

1
2
3
4
5
6
id2label = {}
flabel = open("../data/caffe2keras-labels.txt", "rb")
for line in flabel:
    lid, lname = line.strip().split("\t")
    id2label[int(lid)] = lname
flabel.close()

The network is designed to take a 4 dimensional tensor representing a batch of images. The images need to be of size (224, 224, 3), ie, 224 pixels in height, 224 pixels in width and 3 (RGB) channels deep. The preprocess_image function defined below makes the conversion from the RGB image to the 4 dimensional tensor representing a batch of one image suitable for feeding into the network. The comments in the function show the transformations that need to be made.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def preprocess_image(img, resize_wh, mean_image):
    # resize
    img4d = imresize(img, (resize_wh, resize_wh))
    img4d = img4d.astype("float32")
    # BGR -> RGB
    img4d = img4d[:, :, ::-1]
    # swap axes to theano mode
    img4d = np.transpose(img4d, (2, 0, 1))
    # add batch dimension
    img4d = np.expand_dims(img4d, axis=0)
    # subtract mean image
    img4d -= mean_image
    # clip to uint
    img4d = np.clip(img4d, 0, 255).astype("uint8")
    return img4d
    
CAT_IMAGE = "/path/to/cat.jpg"
MEAN_IMAGE = "/path/to/mean_image.npy"
RESIZE_WH = 224

mean_image = np.load(MEAN_IMAGE)
image = plt.imread(CAT_IMAGE)
img4d = preprocess_image(image, RESIZE_WH, mean_image)

print(image.shape, mean_image.shape, img4d.shape)
plt.imshow(image)


To make the model predict the class of the cat, I just call the predict method on the model.

1
2
preds = model.predict(img4d)[0]
print(np.argmax(preds))

281

1
2
top_preds = np.argsort(preds)[::-1][0:10]
print(top_preds)

array([281, 285, 282, 277, 287, 284, 283, 263, 387, 892])

1
2
pred_probas = [(x, id2label[x], preds[x]) for x in top_preds]
print(pred_probas)

[(281, 'tabby, tabby cat', 0.091636732),
 (285, 'Egyptian cat', 0.060721621),
 (282, 'tiger cat', 0.038346913),
 (277, 'red fox, Vulpes vulpes', 0.028440412),
 (287, 'lynx, catamount', 0.020138463),
 (284, 'Siamese cat, Siamese', 0.015518123),
 (283, 'Persian cat', 0.012380695),
 (263, 'Pembroke, Pembroke Welsh corgi', 0.012343671),
 (387, 'lesser panda, red panda, panda, bear cat, ...', 0.01201651),
 (892, 'wall clock', 0.012008683)]

So it looks like our model is able to predict the tabby cat correctly. Even the other predictions in the top 10 look pretty good. Strangely enough, until I switched the backend from Tensorflow to Theano, the model was predicting a Lemur cat as the most likely class for this image. Once I switched the backend to Theano, I started getting the above results. It is possible that there is some other difference between Caffe and Tensorflow that is not accounted for by the 180 degree rotation of the weights (probably the reason why the MarcBS fork works with Theano only). Anyway, since it looks like the Theano backend can be used for migrated models from Caffe, I did not investigate further.

I found this blog post by Joe Marino immensely helpful during this migration. The post describes his experience of migrating GoogleNet from Caffe to Keras. I found it while looking for an implementation of LRN that I could either copy or adapt. The post does describe a Theano specific implementation of an ACROSS_CHANNEL LRN, but I ended up writing my own backend agnostic version of a WITHIN_CHANNEL LRN using the Keras backend API. However, the post also describes Caffe's weird padding style and the transformations necessary to convert the weight matrices from Caffe to Keras, both of which saved me potentially hours of googling and code reading. The post also described a custom layer for the zero padding, which looks to be no longer necessary since the Keras ZeroPadding2D layer already has support for this type of padding.

The code for this post is available in my fttl-with-keras project on Github (look for the caffe2keras-* files in src). You will need a working Caffe installation to dump out the weights from the pretrained Caffe VGG-CNN model part and a working Keras installation with Theano backend for the second part.