This tutorial is based on the Urban Goggles project created at AstroHackWeek 2016. More information can be found ond on the AHW Wiki
Location: GitHub HQ (San Francisco) & Berkeley Institute for Data Science, 29th August – 2nd September, 2016.
This Notebook takes you through each step of the process, dissecting some of the functions along the way. The code has been adapted to suit the Apollo images and for the purposes of this tutorial. The images come from the Apollo Project Flickr galleries
Those who are new to Python and would like to start using the matplotlib and scikit-learn packages. Experience or familiarity with other programming languages (dare I say Fortran, IDL etc.) may be helpful. Comments and suggestions welcome.
Thanks to Adrian–Price Whelan (Princeton), Dan Foreman–Mackey (University of Washington), and Ben Nelson (CIERA - Northwestern University), for sharing their code.
http://matplotlib.org/api/pyplot_api.html
Also useful, is the markdown guide: http://markdown-guide.readthedocs.io/en/latest/
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import skimage.data as sd
import skimage.color as color
import skimage.transform as st
from skimage.io import imread
from sklearn.utils import shuffle
from sklearn.cluster import KMeans
from sklearn.mixture import GMM
from mpl_toolkits.mplot3d import Axes3D
from IPython.display import Image
image_filename = "apollo_images/AS07-4-1584"
# image_filename = "apollo_images/AS07-7-1776"
# image_filename = "apollo_images/AS11-40-5868"
# image_filename = "apollo_images/AS07-3-1524"
# image_filename = "apollo_images/AS07-11-2027"
# image_filename = "apollo_images/AS17-163-24122"
In this example, the image we're using is an Apollo 7 Hasselblad image from film magazine 4/N. Taken from Earth Orbit.
** A quick intro image arrays: **
First we need to create a data array from the image. The images use in this example are 612 x 640px. The image pixel arrays are defined by their [row, column] from the top left corner of the image. The final data array will be a set of RGB values corresponding to each pixel for each channel, ie. [rows, column, RBG channel]
rgb = np.array(imread(image_filename+".jpg")[...,:3], dtype=np.float64)
rgb.min(), rgb.max()
# print('min value')
print(rgb.min())
# print('max value')
print(rgb.max())
# print('image size')
print(rgb.shape)
print(rgb.size)
pixels_per_image = rgb.size
rgb = st.rescale(rgb, 1.0)
8.0 255.0 (640, 612, 3) 1175040
Just for fun we'll print the RBG values for the pixel in the first row and column. We can do this for individual channels by specifying 0, 1, or 2, or we can get the value for all three by doing this;
# Print the R, G, B values
print(rgb[0,0,:3])
[ 27. 50. 42.]
** Image manipulation:**
First we convert to floats instead of the default 8 bits integer coding. Dividing by 255 is important so that plt.imshow behaves and works well on float data. Essentially we're normalising RGB values so that they go from [0, 1.0]. We also want to rescale the image size (e.g. 0.1 = 10%) to reduce the flood of pixels. This should also help with the clustering.
rgb = np.array(imread(image_filename+".jpg")[...,:3], dtype=np.float64) / 255
rgb = st.rescale(rgb, 0.3)
print(rgb.shape)
print(rgb.size)
print(rgb[0,0,:3])
(192, 184, 3) 105984 [ 0.1092711 0.19141091 0.16537369]
# Transform to a 2D numpy array.
# For a description of tuples, see here: http://openbookproject.net/thinkcs/python/english3e/tuples.html
w, h, d = original_shape = tuple(rgb.shape)
assert d == 3
rgb_data = np.reshape(rgb, (w * h, d))
print(rgb_data.shape)
print(rgb_data)
(35328, 3) [[ 0.1092711 0.19141091 0.16537369] [ 0.10004263 0.18623188 0.15750213] [ 0.09658994 0.18823529 0.16508951] ..., [ 0.08627451 0.18823529 0.15294118] [ 0.09092072 0.18431373 0.16078431] [ 0.09477124 0.18366013 0.16078431]]
plt.figure(figsize=(12,10))
#figsize = (width, height) in inches
plt.imshow(rgb)
<matplotlib.image.AxesImage at 0x14af108d0>
Some useful references:
A few notes about K-means clustering and similar algorithms:
Wikipedia has a nice introduction to K-means clustering. In astronomy, bimodal mixture-modelling (KMM) algorithms are widely used to separate populations. I've used KMM to determine whether clusters of galaxies exhibit sub-structure in terms of their galaxy position and radial velocity distributions. Expectation-Maximisation is used to determine the parameters. K-means clustering is a little different.
K-means defines 'hard clusters', i.e. the sub-populations are assumed to fit into N definded clusters, whereas mixture-models allow you to determine the number of sub-populations without necessarily assigning each data point to a sub-cluster. In KMM the sub-populations are essentially probablilty distributions. With a K-means algorithm, you pre-define the number of clusters (k), then seed k random points throughout the parameter space. Each data point is then assigned to the cluster with the nearest centroid, the new centroid is computed, and the data points added and centroids computed iteratively until the algorithm converges on a solution. K-means is a special case of Mixture of Gaussian, which is a special case of Expectation-Maximisation... I think. I need to go back and look at each of these in more detail.
A really nice description of K-means clustering using numpy can be found on the Data Science Lab blog, Clustering With K-Means in Python. It's one of several blog posts on the topic.
What this is means is that you have the freedom to decide how many clusters you want: clf = KMeans(n_clusters=6). For these images I settled for 6-8, because that was close to what I would expect, based on a quick visual inspection of each image. I recommend playing around with this parameter so you can get a bettter feeling for what's going on.
hsv_data = color.rgb2hsv(rgb)
hsv = hsv_data.reshape(-1, 3).T
plt.figure(figsize=(10,6))
plt.plot(hsv[0], hsv[1], linestyle='none', alpha=1., marker=',');
plt.xlabel('Hue')
plt.ylabel('Saturation')
<matplotlib.text.Text at 0x1492d94a8>
phi = 2*np.pi*hsv[0]
x = hsv[1]*np.cos(phi)
y = hsv[1]*np.sin(phi)
z = hsv[2]
plt.figure(figsize=(10,10))
ax = plt.subplot(1,1,1,projection='3d')
plt.plot(x, y, z, linestyle='none', c='teal', alpha=0.3, marker=',');
ax.set_xlabel(r'$S\,\cos (2\pi H)$')
ax.set_ylabel(r'$S\,\sin (2\pi H)$')
ax.set_zlabel(r'$VALUE$')
<matplotlib.text.Text at 0x143b36f60>
X = np.vstack((x,y,z)).T
subset = shuffle(X)
clf = KMeans(n_clusters=6)
clf.fit(subset)
centroids = clf.cluster_centers_
# clf = GMM(n_components=16, )
# clf.fit(subset)
# centroids = clf.means_
# centroids = centroids[np.argsort(clf.weights_)]
plt.figure(figsize=(10,10))
ax = plt.subplot(1,1,1,projection='3d')
ax.plot(subset[:,0], subset[:,1], subset[:,2], c='teal', linestyle='none', alpha=0.3, marker=',')
ax.scatter(centroids[:,0], centroids[:,1], centroids[:,2], c='red', marker='x', s=75, linewidths=3, zorder=10)
ax.set_xlabel(r'red')
ax.set_ylabel(r'green')
ax.set_zlabel(r'blue')
<matplotlib.text.Text at 0x14c42da90>
rgb_clusters = color.hsv2rgb(centroids[None])[0]
# Displaying the image again...
plt.figure(figsize=(12,10))
plt.imshow(rgb)
<matplotlib.image.AxesImage at 0x14c461ba8>
Now we generate the color map from the HSV clusters or centroids. Dissection: the len() function is a built-in python function that just returns the length or number of things in an object. In this case eight clusters. np.sqrt() returns the positive square root of each element. int() just returns the integer. The syntax for these functions are fairly standard across programming languages.
_n = int(np.sqrt(len(centroids)))
fig,ax = plt.subplots(1,1,figsize=(8,2))
#figure(num=None, figsize=(8, 6), dpi=80, facecolor='w', edgecolor='k')
#http://matplotlib.org/users/image_tutorial.html
ax.imshow(rgb_clusters.reshape(1,len(centroids),3))
#print(len(centroids))
ax.xaxis.set_visible(True)
ax.yaxis.set_visible(False)
_n = int(np.sqrt(len(centroids)))
fig,ax = plt.subplots(1,1,figsize=(8,2))
ax.imshow(rgb_clusters.reshape(1,len(centroids),3), interpolation='nearest')
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
fig.savefig(image_filename+"_hsv_palette.jpg", dpi=100)
bins = np.linspace(0,1,16)
H,edges = np.histogramdd(X, bins=(bins,bins,bins))
from scipy.ndimage import gaussian_filter
# plt.imshow(np.sum(H, axis=1))
plt.imshow(gaussian_filter(H, 0.5)[8], interpolation='nearest', cmap='plasma')
<matplotlib.image.AxesImage at 0x15acf28d0>
x = rgb_data[...,0].ravel()
y = rgb_data[...,1].ravel()
z = rgb_data[...,2].ravel()
plt.figure(figsize=(10,10))
ax = plt.subplot(1,1,1,projection='3d')
plt.plot(x, y, z, linestyle='none', c='teal', alpha=0.3, marker=',');
ax.set_xlabel(r'red')
ax.set_ylabel(r'green')
ax.set_zlabel(r'blue')
<matplotlib.text.Text at 0x15ad37240>
X = np.vstack((x,y,z)).T
subset = shuffle(X)
clf = KMeans(n_clusters=6)
clf.fit(subset)
centroids = clf.cluster_centers_
# clf = GMM(n_components=16, )
# clf.fit(subset)
# centroids = clf.means_
# centroids = centroids[np.argsort(clf.weights_)]
plt.figure(figsize=(10,10))
ax = plt.subplot(1,1,1,projection='3d')
ax.plot(subset[:,0], subset[:,1], subset[:,2], c='teal', linestyle='none', alpha=0.9, marker=',', zorder=0)
ax.scatter(centroids[:,0], centroids[:,1], centroids[:,2], marker='x', alpha=1.,
c='red', s=100, linewidths=3, zorder=10);
ax.set_xlabel(r'red')
ax.set_ylabel(r'green')
ax.set_zlabel(r'blue')
<matplotlib.text.Text at 0x15b900080>
rgb_clusters = centroids
_n = int(np.sqrt(len(centroids)))
fig,ax = plt.subplots(1,1,figsize=(8,2))
#figure(num=None, figsize=(8, 6), dpi=80, facecolor='w', edgecolor='k')
#http://matplotlib.org/users/image_tutorial.html
ax.imshow(rgb_clusters.reshape(1,len(centroids),3))
#print(len(centroids))
ax.xaxis.set_visible(True)
ax.yaxis.set_visible(False)
_n = int(np.sqrt(len(centroids)))
fig,ax = plt.subplots(1,1,figsize=(8,2))
ax.imshow(rgb_clusters.reshape(1,len(centroids),3), interpolation='nearest')
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
fig.savefig(image_filename+"_rgb_palette.jpg", dpi=100)
bins = np.linspace(0,1,16)
H,edges = np.histogramdd(X, bins=(bins,bins,bins))
from scipy.ndimage import gaussian_filter
# plt.imshow(np.sum(H, axis=1))
plt.imshow(gaussian_filter(H, 0.5)[6], interpolation='nearest', cmap='plasma')
<matplotlib.image.AxesImage at 0x15c400c18>