Monday, January 30, 2017

Migrating VGG-CNN from Caffe to Keras


I attended (and presented at) the Demystifying Deep Learning and Artificial Intelligence Workshop at Oakland last November. One of the talks I attended was Introduction to Deep Learning for Images in Keras presented by Stephane Egly and Malaikannan (Malai) Sankarasubbu. Since my talk was on Transfer Learning and Fine Tuning CNNs for Image Classification (also using Keras), there was quite a bit of overlap between our interests, and Malai and I got to talking after his presentation. During that time I mentioned that I was using Keras now, but the project upon which my talk was based used a pre-trained VGG-CNN model from the Caffe model zoo.

We went on for a bit about how Keras is so much nicer compared to Caffe. I am a big fan of Keras, but as a committer to the Recurrent Shop project, he is probably a bigger fan (and very likely much more knowledgable about it) than I am. So anyway, after some more Caffe bashing from my end, he asked me why I didn't migrate that project to Keras, and I mentioned that the model I was using was not yet available in Keras Applications (the Keras model zoo). At that point, Malai made what to me at the time was a blindingly insightful comment - that migration is basically just transferring the weights over.

So, after thinking about it for a while, I finally decided to try to migrate the pre-trained Caffe VGG-CNN model to Keras, and here is the result. It turned out to be more than just transferring the weights over, but I learned a few things during the process, which I would like to share with you in this post.

Caffe End: Dumping the Model


A Caffe model is packaged as a protobuf (.prototxt) file that specifies the network structure and a binary model (.caffemodel) file that contains the weights. In addition, image inputs to the VGG-CNN need to be normalized by subtracting the mean value across each of the RGB channels. The mean values are provided via mean image, which is provided as a binary weights (.binaryproto) file. The first step is to dump out the network structure and weights in a portable format that can be ingested by Keras.

The code below uses the PyCaffe API to dump out the layers of the VGG-CNN and the shape of their outputs. It also dumps out the weights into Numpy native file format (.npy). We could just have easily dumped it into text format, but .npy is faster to load. Finally, it reads the mean image file and writes it out as another .npy file.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from __future__ import division, print_function
import caffe
import numpy as np
import os

DATA_DIR = "/path/to/my/data"
OUTPUT_DIR = os.path.join(DATA_DIR, "vgg-cnn-weights")

CAFFE_HOME="/path/to/my/caffe/installation"

MODEL_DIR = os.path.join(CAFFE_HOME, "models", "vgg_cnn_s")
MODEL_PROTO = os.path.join(MODEL_DIR, "deploy.prototxt")
MODEL_WEIGHTS = os.path.join(MODEL_DIR, "VGG_CNN_S.caffemodel")
MEAN_IMAGE = os.path.join(MODEL_DIR, "VGG_mean.binaryproto")

caffe.set_mode_cpu()
net = caffe.Net(MODEL_PROTO, MODEL_WEIGHTS, caffe.TEST)

# layer names and output shapes
for layer_name, blob in net.blobs.iteritems():
    print(layer_name, blob.data.shape)

# write out weight matrices and bias vectors
for k, v in net.params.items():
    print(k, v[0].data.shape, v[1].data.shape)
    np.save(os.path.join(OUTPUT_DIR, "W_{:s}.npy".format(k)), v[0].data)
    np.save(os.path.join(OUTPUT_DIR, "b_{:s}.npy".format(k)), v[1].data)

# write out mean image
blob = caffe.proto.caffe_pb2.BlobProto()
with open(MEAN_IMAGE, 'rb') as fmean:
    mean_data = fmean.read()
blob.ParseFromString(mean_data)
mu = np.array(caffe.io.blobproto_to_array(blob))
print("Mean image:", mu.shape)
np.save(os.path.join(OUTPUT_DIR, "mean_image.npy"), mu)

The layer names and output shapes are shown below. The first column is the name of the layer and the second is the shape of the output from that layer. The shapes for the convolutional layers are specified as 4-dimensional tensors, where the dimensions refer to batch size, number of channels, rows and columns. If you have worked with images, you are probably more familiar specifying image dimensions by (row, column, channel) rather than (channel, row, column). In Keras, the latter is called the Theano dimension ordering, because Theano uses it.

The last 3 layers are fully connected layers (called Dense in Keras) whose dimensions are specified as (batch size, output size). Notice that the network can predict one of 1000 categories, hence the last layer has a size (10, 1000).

The layers with names starting with "conv" are convolutional layers and those with names starting with "pool" are pooling layers. The single "norm" layer is a normalization layer. For a fuller description, take a look at the protobuf file.

layer.namelayer.output_shape
data(10, 3, 224, 224)
conv1(10, 96, 109, 109)
norm1(10, 96, 109, 109)
pool1(10, 96, 37, 37)
conv2(10, 256, 33, 33)
pool2(10, 256, 17, 17)
conv3(10, 512, 17, 17)
conv4(10, 512, 17, 17)
conv5(10, 512, 17, 17)
pool5(10, 512, 6, 6)
fc6(10, 4096)
fc7(10, 4096)
prob(10, 1000)

Each of the convolutional and fully connected layers are associated with a weight matrix and a bias vector. The code above downloads it to .npy files named W_{layer.name} for the weight matrix and b_{layer.name} for the bias vector. The table below shows each of the layers and the shapes of the associated weight matrix and bias vector.

layerW.shapeb.shape
conv1(96, 3, 7, 7)(96,)
conv2(256, 96, 5, 5)(256,)
conv3(512, 256, 3, 3)(512,)
conv4(512, 512, 3, 3)(512,)
conv5(512, 512, 3, 3)(512,)
fc6(4096, 18432)(4096,)
fc7(4096, 4096)(4096,)
fc8(1000, 4096)(1000,)

Finally, the mean image is written out to a numpy file as a (1, 3, 224, 224) tensor, the same size as a batch of 1 image.

Keras End: Rebuilding the Model


I then use the information about the layers and their output shapes to build a skeleton Keras network that looks kind of like the Caffe original. Here I specify the imports so you know what libraries the subsequent calls come from.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
from __future__ import division, print_function
from keras import backend as K
from keras.layers import Input
from keras.layers.core import Activation, Dense, Flatten
from keras.layers.convolutional import Convolution2D, ZeroPadding2D
from keras.layers.normalization import BatchNormalization
from keras.layers.pooling import MaxPooling2D
from keras.models import Model
from scipy.misc import imresize
import matplotlib.pyplot as plt
import numpy as np
import os
import re

Of late, I have been using Keras exclusively with the Tensorflow backend, but because the dimension ordering of the layer outputs and the weights of the Caffe network are all in Theano (channel, row, col) format, I initially set the dimension ordering to Theano in Keras. Later, I switched back to using the Theano backend as well, but more on that later.

1
K.set_image_dim_ordering("th")

Getting the shapes right


Here is my first attempt at a Keras equivalent for the Caffe VGG-CNN network. Almost all the Caffe layers have analogs in Keras that behave similarly, except for the NORM1 layer. The protobuf file mentions a Local Response Normalization (LRN) layer that is not available in Keras. For the moment I replace this layer with a Keras BatchNormalization layer, since both types of normalization layers leave the shape unchanged.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
data = Input(shape=(3, 224, 224), name="DATA")

conv1 = Convolution2D(96, 7, 7, subsample=(2, 2))(data)
conv1 = Activation("relu", name="CONV1")(conv1)

norm1 = BatchNormalization(name="NORM1")(conv1)

pool1 = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name="POOL1")(norm1)

conv2 = Convolution2D(256, 5, 5)(pool1)
conv2 = Activation("relu", name="CONV2")(conv2)

pool2 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name="POOL2")(conv2)

conv3 = Convolution2D(512, 3, 3)(pool2)
conv3 = Activation("relu", name="CONV3")(conv3)

conv4 = Convolution2D(512, 3, 3)(conv3)
conv4 = Activation("relu", name="CONV4")(conv4)

conv5 = Convolution2D(512, 3, 3)(conv4)
conv5 = Activation("relu", name="CONV5")(conv5)

pool5 = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name="POOL5")(conv5)

fc6 = Flatten()(pool5)
fc6 = Dense(4096)(fc6)
fc6 = Activation("relu", name="FC6")(fc6)

fc7 = Dense(4096)(fc6)
fc7 = Activation("relu", name="FC7")(fc7)

fc8 = Dense(1000, name="FC8")(fc7)
prob = Activation("softmax", name="PROB")(fc8)

model = Model(input=[data], output=[prob])

model.compile(optimizer="adam", loss="categorical_crossentropy")

As you can see below, there are differences between the shapes of the layer outputs between the original Caffe VGG-CNN network and my Keras copy. The columns in the table below correspond to the layer name, the shape of the output from that layer in the Caffe model and the corresponding shape in the Keras model. There are more Keras layers than Caffe layers here - I have named the Keras layers that correspond to the Caffe layers the same name as the Caffe layer so we can compare their output shapes. The ones that are highlighted green are layers whose output shapes match, and the ones highlighted in red are layers whose output shapes don't match.

layer.namecaffe.output_shapekeras.output_shape
DATA(10, 3, 224, 224)(10, 3, 224, 224)
convolution2d_1-(10, 96, 109, 109)
CONV1(10, 96, 109, 109)(10, 96, 109, 109)
NORM1(10, 96, 109, 109)(10, 96, 109, 109)
POOL1(10, 96, 37, 37)(10, 96, 36, 36)
convolution2d_2-(10, 256, 32, 32)
CONV2(10, 256, 33, 33)(10, 256, 32, 32)
POOL2(10, 256, 17, 17)(10, 256, 16, 16)
convolution2d_3-(10, 512, 14, 14)
CONV3(10, 512, 17, 17)(10, 512, 14, 14)
convolution2d_4-(10, 512, 12, 12)
CONV4(10, 512, 17, 17)(10, 512, 12, 12)
convolution2d_5-(10, 512, 10, 10)
CONV5(10, 512, 17, 17)(10, 512, 10, 10)
POOL5(10, 512, 6, 6)(10, 512, 3, 3)
flatten_1-(10, 4608)
dense_1-(10, 4096)
FC6(10, 4096)(10, 4096)
dense_2-(10, 4096)
FC7(10, 4096)(10, 4096)
FC8(10, 1000)(10, 1000)
PROB(10, 1000)(10, 1000)

The first difference is in POOL1, where the Keras output is two rows and two columns less than the Caffe version. The incoming tensor to POOL1 has the shape (10, 96, 109, 109), the size of the pooling filter is of size (3, 3) and its stride size is also (3, 3). Then, using the formula from the CS231N course page on Convolutional Networks (which incidentally is also a great resource if you want to do a quick refresher on CNNs), the expected output shape should be (10, 96, 36, 36), which is the same as the output shape from the Keras layer. The difference in Caffe is because Caffe does pooling by implicitly applying a padding that is VALID at the beginning and SAME at the end, i.e, a zero padding layer applied along the right and bottom edges. This can be simulated in Keras by applying a ZeroPadding2D layer with padding of (0, 2, 0, 2) just before the POOL1 layer.

The other shape differences are because of the same reason, and can be similarly fixed by applying the appropriate ZeroPadding2D layer in front of it. The full list of these additions are listed below.

  • Add (0, 2, 0, 2) zero padding before POOL1
  • Add (0, 1, 0, 1) zero padding before POOL2
  • Add (0, 2, 0, 2) zero padding before CONV3
  • Add (0, 2, 0, 2) zero padding before CONV4
  • Add (0, 2, 0, 2) zero padding before CONV5
  • Add (0, 1, 0, 1) zero padding before POOL5

With these changes, my Keras network now looks like this:

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
data = Input(shape=(3, 224, 224), name="DATA")

conv1 = Convolution2D(96, 7, 7, subsample=(2, 2))(data)
conv1 = Activation("relu", name="CONV1")(conv1)

norm1 = BatchNormalization(name="NORM1")(conv1)

pool1 = ZeroPadding2D(padding=(0, 2, 0, 2))(norm1)
pool1 = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name="POOL1")(pool1)

conv2 = Convolution2D(256, 5, 5)(pool1)
conv2 = Activation("relu", name="CONV2")(conv2)

pool2 = ZeroPadding2D(padding=(0, 1, 0, 1))(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name="POOL2")(pool2)

conv3 = ZeroPadding2D(padding=(0, 2, 0, 2))(pool2)
conv3 = Convolution2D(512, 3, 3)(conv3)
conv3 = Activation("relu", name="CONV3")(conv3)

conv4 = ZeroPadding2D(padding=(0, 2, 0, 2))(conv3)
conv4 = Convolution2D(512, 3, 3)(conv4)
conv4 = Activation("relu", name="CONV4")(conv4)

conv5 = ZeroPadding2D(padding=(0, 2, 0, 2))(conv4)
conv5 = Convolution2D(512, 3, 3)(conv5)
conv5 = Activation("relu", name="CONV5")(conv5)

pool5 = ZeroPadding2D(padding=(0, 1, 0, 1))(conv5)
pool5 = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name="POOL5")(pool5)

fc6 = Flatten()(pool5)
fc6 = Dense(4096)(fc6)
fc6 = Activation("relu", name="FC6")(fc6)

fc7 = Dense(4096)(fc6)
fc7 = Activation("relu", name="FC7")(fc7)

fc8 = Dense(1000, name="FC8")(fc7)
prob = Activation("softmax", name="PROB")(fc8)

model = Model(input=[data], output=[prob])

model.compile(optimizer="adam", loss="categorical_crossentropy")

After these changes, there are no more shape differences between the Caffe layer outputs and their corresponding Keras layer outputs.

layer.namecaffe.output_shapekeras.output_shape
DATA(10, 3, 224, 224)(10, 3, 224, 224)
convolution2d_6-(10, 96, 109, 109)
CONV1(10, 96, 109, 109)(10, 96, 109, 109)
NORM1(10, 96, 109, 109)(10, 96, 109, 109)
zeropadding2d_1-(10, 96, 111, 111)
POOL1(10, 96, 37, 37)(10, 96, 37, 37)
convolution2d_7-(10, 256, 33, 33)
CONV2(10, 256, 33, 33)(10, 256, 33, 33)
zeropadding2d_2-(10, 256, 34, 34)
POOL2(10, 256, 17, 17)(10, 256, 17, 17)
zeropadding2d_3-(10, 256, 19, 19)
convolution2d_8-(10, 512, 17, 17)
CONV3(10, 512, 17, 17)(10, 512, 17, 17)
zeropadding2d_4-(10, 512, 19, 19)
convolution2d_9-(10, 512, 17, 17)
CONV4(10, 512, 17, 17)(10, 512, 17, 17)
zeropadding2d_5-(10, 512, 19, 19)
convolution2d_10-(10, 512, 17, 17)
CONV5(10, 512, 17, 17)(10, 512, 17, 17)
zeropadding2d_6-(10, 512, 18, 18)
POOL5(10, 512, 6, 6)(10, 512, 6, 6)
flatten_2-(10, 18432)
dense_3-(10, 4096)
FC6(10, 4096)(10, 4096)
dense_4-(10, 4096)
FC7(10, 4096)(10, 4096)
FC8(10, 1000)(10, 1000)
PROB(10, 1000)(10, 1000)

Defining the Local Response Normalization Layer


Now that the output shapes are lined up, I needed to do something about the BatchNormalization placeholder layer. The NORM1 layer is defined in the protobuf file as a Local Response Normalization (LRN) layer that performs lateral inhibition by normalizing over local input regions. It can work in two modes - ACROSS_CHANNEL and WITHIN_CHANNEL. In the first, the local regions extend across nearby channels but have no spatial extent and in the second, they extend spatially but are in separate channels. The equation for LRN is shown below:


Since Keras has no built-in LRN layer, I built a custom layer following the instructions on the Writing your own Keras Layers page. My custom LRN layer implements the WITHIN_CHANNEL approach, the code for it is shown below.

Note that the LRN is nowadays generally regarded as obsolete, having been replaced by better methods of regularization such as Dropout, Batch Normalization and better initialization. So the LRN is probably not something you want to use in your new networks. However, because we are trying to implement an existing network, we need to replicate the LRN functionality in Keras as well.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class LRN(Layer):
    
    def __init__(self, n=5, alpha=0.0005, beta=0.75, k=2, **kwargs):
        self.n = n
        self.alpha = alpha
        self.beta = beta
        self.k = k
        super(LRN, self).__init__(**kwargs)

    def build(self, input_shape):
        self.shape = input_shape
        super(LRN, self).build(input_shape)

    def call(self, x, mask=None):
        if K.image_dim_ordering == "th":
            _, f, r, c = self.shape
        else:
            _, r, c, f = self.shape
        half_n = self.n // 2
        squared = K.square(x)
        pooled = K.pool2d(squared, (half_n, half_n), strides=(1, 1),
                         border_mode="same", pool_mode="avg")
        if K.image_dim_ordering == "th":
            summed = K.sum(pooled, axis=1, keepdims=True)
            averaged = (self.alpha / self.n) * K.repeat_elements(summed, f, axis=1)
        else:
            summed = K.sum(pooled, axis=3, keepdims=True)
            averaged = (self.alpha / self.n) * K.repeat_elements(summed, f, axis=3)
        denom = K.pow(self.k + averaged, self.beta)
        return x / denom
    
    def get_output_shape_for(self, input_shape):
        return input_shape

The constructor allows us to set the hyperparameters for the layer. I have set default values for these hyperparameters from the values in the Caffe protobuf file. The call method implements the equation using the Keras backend API. The summation divided by n is equal to the sum of 2D average pooling operations with a pool size of (n/2, n/2) and a stride size of (1, 1). The rest of the code is mostly self-explanatory. The get_output_shape_for method returns the output shape of the layer, which in this case is the same as the input shape.

Loading the weights


The next step is to load up the weights that I dumped out of the Caffe binary .caffemodel file to .npy files. The convolutional weight matrices are 4 dimensional and the fully connected weight matrices are 2 dimensional. However, I noticed that the dimensions of the fully connected weights are reversed. For example, if the Keras layer needs a matrix of size (18432, 4096), the shape of the matrix I get from Caffe is (4096, 18432). Also there is a difference between the way Caffe and Keras do convolutions. To account for this difference, these Caffe weights need to be rotated 180 degrees. Similarly, the fully connected weights need to be transposed to be usable in Keras. The transform_conv_weights and transform_fc_weights functions implement the transformations necessary to make this happen.

You can also find code to do this in convert.py in the MarcBS Keras fork, which contains code specifically for converting Caffe networks to Keras. The MarcBS fork is very likely not up-to-date with the main Keras branch, the README.md says that it is compatible with Theano only.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def transform_conv_weight(W):
    # for non FC layers, do this because Keras does convolution vs Caffe correlation
    for i in range(W.shape[0]):
        for j in range(W.shape[1]):
            W[i, j] = np.rot90(W[i, j], 2)
    return W

def transform_fc_weight(W):
    return W.T

# load weights
CAFFE_WEIGHTS_DIR = "/Users/palsujit/Projects/fttl-with-keras/data/vgg-cnn/saved-weights"

W_conv1 = transform_conv_weight(np.load(os.path.join(CAFFE_WEIGHTS_DIR, "W_conv1.npy")))
b_conv1 = np.load(os.path.join(CAFFE_WEIGHTS_DIR, "b_conv1.npy"))

W_conv2 = transform_conv_weight(np.load(os.path.join(CAFFE_WEIGHTS_DIR, "W_conv2.npy")))
b_conv2 = np.load(os.path.join(CAFFE_WEIGHTS_DIR, "b_conv2.npy"))

W_conv3 = transform_conv_weight(np.load(os.path.join(CAFFE_WEIGHTS_DIR, "W_conv3.npy")))
b_conv3 = np.load(os.path.join(CAFFE_WEIGHTS_DIR, "b_conv3.npy"))

W_conv4 = transform_conv_weight(np.load(os.path.join(CAFFE_WEIGHTS_DIR, "W_conv4.npy")))
b_conv4 = np.load(os.path.join(CAFFE_WEIGHTS_DIR, "b_conv4.npy"))

W_conv5 = transform_conv_weight(np.load(os.path.join(CAFFE_WEIGHTS_DIR, "W_conv5.npy")))
b_conv5 = np.load(os.path.join(CAFFE_WEIGHTS_DIR, "b_conv5.npy"))

W_fc6 = transform_fc_weight(np.load(os.path.join(CAFFE_WEIGHTS_DIR, "W_fc6.npy")))
b_fc6 = np.load(os.path.join(CAFFE_WEIGHTS_DIR, "b_fc6.npy"))

W_fc7 = transform_fc_weight(np.load(os.path.join(CAFFE_WEIGHTS_DIR, "W_fc7.npy")))
b_fc7 = np.load(os.path.join(CAFFE_WEIGHTS_DIR, "b_fc7.npy"))

W_fc8 = transform_fc_weight(np.load(os.path.join(CAFFE_WEIGHTS_DIR, "W_fc8.npy")))
b_fc8 = np.load(os.path.join(CAFFE_WEIGHTS_DIR, "b_fc8.npy"))

I now set the weight matrices and bias vectors into their corresponding layers, and replace the BatchNormalization layer with my custom LRN layer. The code to do that is shown below.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# define network
data = Input(shape=(3, 224, 224), name="DATA")

conv1 = Convolution2D(96, 7, 7, subsample=(2, 2),
                     weights=(W_conv1, b_conv1))(data)
conv1 = Activation("relu", name="CONV1")(conv1)

norm1 = LRN(name="NORM1")(conv1)

pool1 = ZeroPadding2D(padding=(0, 2, 0, 2))(norm1)
pool1 = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name="POOL1")(pool1)

conv2 = Convolution2D(256, 5, 5, weights=(W_conv2, b_conv2))(pool1)
conv2 = Activation("relu", name="CONV2")(conv2)

pool2 = ZeroPadding2D(padding=(0, 1, 0, 1))(conv2)
pool2 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name="POOL2")(pool2)

conv3 = ZeroPadding2D(padding=(0, 2, 0, 2))(pool2)
conv3 = Convolution2D(512, 3, 3, weights=(W_conv3, b_conv3))(conv3)
conv3 = Activation("relu", name="CONV3")(conv3)

conv4 = ZeroPadding2D(padding=(0, 2, 0, 2))(conv3)
conv4 = Convolution2D(512, 3, 3, weights=(W_conv4, b_conv4))(conv4)
conv4 = Activation("relu", name="CONV4")(conv4)

conv5 = ZeroPadding2D(padding=(0, 2, 0, 2))(conv4)
conv5 = Convolution2D(512, 3, 3, weights=(W_conv5, b_conv5))(conv5)
conv5 = Activation("relu", name="CONV5")(conv5)

pool5 = ZeroPadding2D(padding=(0, 1, 0, 1))(conv5)
pool5 = MaxPooling2D(pool_size=(3, 3), strides=(3, 3), name="POOL5")(pool5)

fc6 = Flatten()(pool5)
fc6 = Dense(4096, weights=(W_fc6, b_fc6))(fc6)
fc6 = Activation("relu", name="FC6")(fc6)

fc7 = Dense(4096, weights=(W_fc7, b_fc7))(fc6)
fc7 = Activation("relu", name="FC7")(fc7)

fc8 = Dense(1000, weights=(W_fc8, b_fc8), name="FC8")(fc7)
prob = Activation("softmax", name="PROB")(fc8)

model = Model(input=[data], output=[prob])

model.compile(optimizer="adam", loss="categorical_crossentropy")

At this point, we have a pre-trained VGG-CNN network in Keras. Lets check to see if its any good.

Predictions using Keras VGG-CNN


I wanted to see how this pretrained network would perform on some test images. The Caffe documentation has a Classification Tutorial where they demonstrate classification using an image of a cat, so I decided to use the same image.

The network is designed to take a batch of images, and for each image, return a number between 0 and 999 indicating one of 1000 ImageNet classes. The mapping of these numbers to actual labels is provided by this Github gist file. I build a little dictionary to look up a label given the class ID below.

1
2
3
4
5
6
id2label = {}
flabel = open("../data/caffe2keras-labels.txt", "rb")
for line in flabel:
    lid, lname = line.strip().split("\t")
    id2label[int(lid)] = lname
flabel.close()

The network is designed to take a 4 dimensional tensor representing a batch of images. The images need to be of size (224, 224, 3), ie, 224 pixels in height, 224 pixels in width and 3 (RGB) channels deep. The preprocess_image function defined below makes the conversion from the RGB image to the 4 dimensional tensor representing a batch of one image suitable for feeding into the network. The comments in the function show the transformations that need to be made.

1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def preprocess_image(img, resize_wh, mean_image):
    # resize
    img4d = imresize(img, (resize_wh, resize_wh))
    img4d = img4d.astype("float32")
    # BGR -> RGB
    img4d = img4d[:, :, ::-1]
    # swap axes to theano mode
    img4d = np.transpose(img4d, (2, 0, 1))
    # add batch dimension
    img4d = np.expand_dims(img4d, axis=0)
    # subtract mean image
    img4d -= mean_image
    # clip to uint
    img4d = np.clip(img4d, 0, 255).astype("uint8")
    return img4d
    
CAT_IMAGE = "/path/to/cat.jpg"
MEAN_IMAGE = "/path/to/mean_image.npy"
RESIZE_WH = 224

mean_image = np.load(MEAN_IMAGE)
image = plt.imread(CAT_IMAGE)
img4d = preprocess_image(image, RESIZE_WH, mean_image)

print(image.shape, mean_image.shape, img4d.shape)
plt.imshow(image)


To make the model predict the class of the cat, I just call the predict method on the model.

1
2
preds = model.predict(img4d)[0]
print(np.argmax(preds))

281

1
2
top_preds = np.argsort(preds)[::-1][0:10]
print(top_preds)

array([281, 285, 282, 277, 287, 284, 283, 263, 387, 892])

1
2
pred_probas = [(x, id2label[x], preds[x]) for x in top_preds]
print(pred_probas)

[(281, 'tabby, tabby cat', 0.091636732),
 (285, 'Egyptian cat', 0.060721621),
 (282, 'tiger cat', 0.038346913),
 (277, 'red fox, Vulpes vulpes', 0.028440412),
 (287, 'lynx, catamount', 0.020138463),
 (284, 'Siamese cat, Siamese', 0.015518123),
 (283, 'Persian cat', 0.012380695),
 (263, 'Pembroke, Pembroke Welsh corgi', 0.012343671),
 (387, 'lesser panda, red panda, panda, bear cat, ...', 0.01201651),
 (892, 'wall clock', 0.012008683)]

So it looks like our model is able to predict the tabby cat correctly. Even the other predictions in the top 10 look pretty good. Strangely enough, until I switched the backend from Tensorflow to Theano, the model was predicting a Lemur cat as the most likely class for this image. Once I switched the backend to Theano, I started getting the above results. It is possible that there is some other difference between Caffe and Tensorflow that is not accounted for by the 180 degree rotation of the weights (probably the reason why the MarcBS fork works with Theano only). Anyway, since it looks like the Theano backend can be used for migrated models from Caffe, I did not investigate further.

I found this blog post by Joe Marino immensely helpful during this migration. The post describes his experience of migrating GoogleNet from Caffe to Keras. I found it while looking for an implementation of LRN that I could either copy or adapt. The post does describe a Theano specific implementation of an ACROSS_CHANNEL LRN, but I ended up writing my own backend agnostic version of a WITHIN_CHANNEL LRN using the Keras backend API. However, the post also describes Caffe's weird padding style and the transformations necessary to convert the weight matrices from Caffe to Keras, both of which saved me potentially hours of googling and code reading. The post also described a custom layer for the zero padding, which looks to be no longer necessary since the Keras ZeroPadding2D layer already has support for this type of padding.

The code for this post is available in my fttl-with-keras project on Github (look for the caffe2keras-* files in src). You will need a working Caffe installation to dump out the weights from the pretrained Caffe VGG-CNN model part and a working Keras installation with Theano backend for the second part.