Upsampling with DeconvolutionLayer in Caffe

This example shows you how to upsample your feature map using DeconvolutionLayer and how it behaves.

In [1]:
%pylab inline
import caffe, caffe_helper # git clone git://github.com/tnarihi/tnarihi-caffe-helper.git and set path to pyton dir.
from scipy.misc import imread
Populating the interactive namespace from numpy and matplotlib

creating synthetic image

In [2]:
imsize = 3
x, y = ogrid[:imsize, :imsize]
img = repeat((x + y)[..., newaxis], 3, 2) / float(imsize + imsize)
imshow(img, interpolation='none')
Out[2]:
<matplotlib.image.AxesImage at 0x7f8d80a606d0>

Creating bottom and top blobs

Note that this requires #2148 to be merged.

In [3]:
bottom = [caffe.Blob([1, 3, img.shape[0], img.shape[1]])]
top = [caffe.Blob([])]
bottom[0].data[...] = img.reshape(1, *img.shape).transpose(0, 3, 1, 2)

Craeting upsampling layer

Here you create an upsampling layer by DeconvolutionLayer and filler BilinearUpsamplingFiller. This requires #2213 to use BilinearUpsamplingFiller, and #2149 to create a layer from Python.

In [4]:
# helper functions
def get_kernel_size(factor):
    return 2 * factor - factor % 2
def get_pad(factor):
    return int(ceil((factor - 1) / 2.))

def upsample_creator(factor, num_in):
    kernel = get_kernel_size(factor)
    stride = factor
    lp = caffe.LayerParameter("""name: "upsample", type: "Deconvolution"
        convolution_param { kernel_size: %d stride: %d num_output: %d group: %d pad: %d
        weight_filler: { type: "bilinear_upsampling" } bias_term: false }""" % (
            kernel, stride, num_in, num_in, get_pad(factor)
        ))
    return caffe.create_layer(lp.to_python())
In [5]:
factor = 3
upl = upsample_creator(factor, 3)
upl.SetUp(bottom, top)
imshow(caffe_helper.visualize.blob_to_tile(upl.blobs[0].data), interpolation='none', cmap='gray')
Out[5]:
<matplotlib.image.AxesImage at 0x7f8d808f2410>

Upsampling by forward pass

In [6]:
upl.Reshape(bottom, top)
upl.Forward(bottom, top)
imgo = top[0].data.copy().transpose(0, 2, 3, 1)[0]
imshow(imgo, interpolation='none')
Out[6]:
<matplotlib.image.AxesImage at 0x7f8d76c1ef50>

Checking consistency with skimage.transform.rescale

In [7]:
import skimage.transform
imgo2 = skimage.transform.rescale(img, factor, mode='constant', cval=0)
subplot(121)
imshow(imgo, interpolation='none')
subplot(122)
imshow(imgo2, interpolation='none')
assert np.allclose(imgo, imgo2)
In [8]:
for factor in xrange(2, 10):
    upl = upsample_creator(factor, 3)
    upl.SetUp(bottom, top)
    upl.Reshape(bottom, top)
    upl.Forward(bottom, top)
    imgo = squeeze(top[0].data.copy().transpose(0, 2, 3, 1))
    imgo2 = skimage.transform.rescale(img, factor, mode='constant', cval=0)
    assert np.allclose(imgo, imgo2)
    print "factor %d: OK" % factor
factor 2: OK
factor 3: OK
factor 4: OK
factor 5: OK
factor 6: OK
factor 7: OK
factor 8: OK
factor 9: OK

BilinearUpsamplingFiller in Python

In [9]:
# helper functions
def get_center(factor):
    return (2. * factor - 1 - factor % 2) / (2. * factor)
def get_radius(factor):
    return factor

def filt_init(factor):
    rr = get_kernel_size(factor)
    c = get_center(factor)
    r = get_radius(factor)
    x, y = ogrid[:rr, :rr]
    f = float(factor)
    return (1 - abs(x / f - c)) * (1 - abs(y / f - c))
In [10]:
for factor in xrange(2, 10):
    upl = upsample_creator(factor, 3)
    upl.SetUp(bottom, top)
    upl.blobs[0].data[...] = filt_init(factor)
    upl.Reshape(bottom, top)
    upl.Forward(bottom, top)
    imgo = top[0].data.copy().transpose(0, 2, 3, 1)[0]
    imgo2 = skimage.transform.rescale(img, factor, mode='constant', cval=0)
    assert np.allclose(imgo, imgo2)
    print "factor %d: OK" % factor
factor 2: OK
factor 3: OK
factor 4: OK
factor 5: OK
factor 6: OK
factor 7: OK
factor 8: OK
factor 9: OK

Color image example

In [11]:
img = imread('/home/narihira/caffe/examples/images/cat.jpg')
imshow(img, interpolation='none')
Out[11]:
<matplotlib.image.AxesImage at 0x7f8d76784c50>
In [12]:
bottom = [caffe.Blob([1, 3, img.shape[0], img.shape[1]])]
bottom[0].data[...] = img.reshape(1, *img.shape).transpose(0, 3, 1, 2)
factor = 3
upl = upsample_creator(factor, 3)
upl.SetUp(bottom, top)
upl.Reshape(bottom, top)
upl.Forward(bottom, top)
imgo = top[0].data.copy().transpose(0, 2, 3, 1)[0]
imshow(imgo.astype(uint8), interpolation='none')
Out[12]:
<matplotlib.image.AxesImage at 0x7f8d7666b450>
In [13]:
for factor in xrange(2, 10):
    upl = upsample_creator(factor, 3)
    upl.SetUp(bottom, top)
    upl.Reshape(bottom, top)
    upl.Forward(bottom, top)
    imgo = top[0].data.copy().transpose(0, 2, 3, 1)[0]
    imgo2 = skimage.transform.rescale(img, factor, mode='constant', cval=0) * 255
    assert np.allclose(imgo, imgo2)
    print "factor %d: OK" % factor
factor 2: OK
factor 3: OK
factor 4: OK
factor 5: OK
factor 6: OK
factor 7: OK
factor 8: OK
factor 9: OK