by Austin Poor
Algorithmically finding color palettes from movie stills.
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image
from skimage import color
from sklearn.cluster import KMeans, AgglomerativeClustering
import color_palettes
print("KMeans + Agglomerative clustering for RGB + HSV with 8 centers...")
print("-" * 65 + "\n")
for img_path in Path("stills").glob("*.jpg"):
print("Image:", img_path)
color_palettes.make_all_palettes(img_path)
KMeans + Agglomerative clustering for RGB + HSV with 8 centers... ----------------------------------------------------------------- Image: stills/09 (776).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
==================================================================================================== Image: stills/13 (795).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
==================================================================================================== Image: stills/18 (795).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
==================================================================================================== Image: stills/03 (776).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
==================================================================================================== Image: stills/11 (795).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
====================================================================================================
image_path = Path("stills/18 (795).jpg")
image = Image.open(image_path.open("rb"))
arr = np.array(image)
img_shape = arr.shape
pixels = arr.reshape((-1,3))
pixels = pixels.astype("float32") / 255
What if we want to filter out any pixels that are too dark or too bright, so we don't get black and white in our color palettes.
Let's experiment with picking low/high points for filtering.
plt.figure(figsize=(12,8))
plt.imshow(pixels.reshape(img_shape))
plt.axis("off");
To help visualize what pixels are being filtered out, we can plot the original photo side by side with a filtered photo.
The filtered pixels (which the clustering algorithm won't see) will be hilighted in green so they stand out.
# Cut-offs in the range [0.0,1.0]
LOW_CUTOFF = 0.01
HIGH_CUTOFF = 0.9
avg_pix_val = pixels.mean(1)
mask = (LOW_CUTOFF <= avg_pix_val) & (avg_pix_val <= HIGH_CUTOFF)
mask = np.concatenate(
(mask.reshape((-1,1)),)*3,
1
).reshape(img_shape)
green_img = np.ones_like(img_shape) * np.array([0,255,0])
plt.figure(figsize=(18,16))
plt.subplot(121)
plt.imshow(arr)
plt.title("No Filtering")
plt.axis("off")
plt.subplot(122)
plt.imshow(np.where(
mask,
arr,
green_img
))
plt.title(f"Filtered Pixels in Green\nFiltered to the Range $[{LOW_CUTOFF},{HIGH_CUTOFF}]$")
plt.axis("off")
plt.tight_layout();
Great! Now what if we compare a few combinations of lower and upper cutoff values. The following plot shows a grid with a combination of cuttoff values.
Moving left-to-right, the high-cutoff is increased, and moving top-to-bottom, the low-cutoff is increased.
The top-left image isn't filtered and the bottom-right image is the most filtered.
# Cut-offs in the range [0.0,1.0]
LOW_CUTOFFS = [0, 0.001, 0.01]
HIGH_CUTOFFS = [1, 0.8, 0.5]
avg_pix_val = pixels.mean(1)
green_img = np.ones_like(img_shape) * np.array([0,255,0])
fig, ax = plt.subplots(
nrows=len(LOW_CUTOFFS),
ncols=len(HIGH_CUTOFFS),
figsize=(20,12)
)
for i, LC in enumerate(LOW_CUTOFFS):
for j, HC in enumerate(HIGH_CUTOFFS):
mask = (LC <= avg_pix_val) & (avg_pix_val <= HC)
mask = np.concatenate(
(mask.reshape((-1,1)),)*3,
1
).reshape(img_shape)
ax[i,j].imshow(np.where(
mask,
arr,
green_img
))
ax[i,j].axis("off")
fig.tight_layout();
Now we can turn that into a function to pass into our palette generation function.
from functools import partial
def filter_pixels(pixels: np.ndarray, low = 0.0, high = 1.0) -> np.ndarray:
pix_mean = pixels.mean(1)
mask = (low <= pix_mean) & (pix_mean <= high)
idx = np.arange(len(pixels))
return pixels[idx[mask]]
And run it for a couple of high/low cut-off levels...
low_cutoff = 0.01
high_cutoff = 0.8
print("KMeans + Agglomerative clustering for RGB + HSV with 8 centers...")
print(f"Filtering pixels outside the range [{low_cutoff},{high_cutoff}]...")
print("-" * 65 + "\n")
for img_path in Path("stills").glob("*.jpg"):
print("Image:", img_path)
color_palettes.make_all_palettes(
img_path,
filter_fn=partial(
filter_pixels,
low=low_cutoff,
high=high_cutoff
)
)
KMeans + Agglomerative clustering for RGB + HSV with 8 centers... Filtering pixels outside the range [0.01,0.8]... ----------------------------------------------------------------- Image: stills/09 (776).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
==================================================================================================== Image: stills/13 (795).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
==================================================================================================== Image: stills/18 (795).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
==================================================================================================== Image: stills/03 (776).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
==================================================================================================== Image: stills/11 (795).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
====================================================================================================
low_cutoff = 0.05
high_cutoff = 0.9
print("KMeans + Agglomerative clustering for RGB + HSV with 8 centers...")
print(f"Filtering pixels outside the range [{low_cutoff},{high_cutoff}]...")
print("-" * 65 + "\n")
for img_path in Path("stills").glob("*.jpg"):
print("Image:", img_path)
color_palettes.make_all_palettes(
img_path,
filter_fn=partial(
filter_pixels,
low=low_cutoff,
high=high_cutoff
)
)
KMeans + Agglomerative clustering for RGB + HSV with 8 centers... Filtering pixels outside the range [0.05,0.9]... ----------------------------------------------------------------- Image: stills/09 (776).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
==================================================================================================== Image: stills/13 (795).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
==================================================================================================== Image: stills/18 (795).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
==================================================================================================== Image: stills/03 (776).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
==================================================================================================== Image: stills/11 (795).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
====================================================================================================
low_cutoff = 0.1
high_cutoff = 0.99
print("KMeans + Agglomerative clustering for RGB + HSV with 8 centers...")
print(f"Filtering pixels outside the range [{low_cutoff},{high_cutoff}]...")
print("-" * 65 + "\n")
for img_path in Path("stills").glob("*.jpg"):
print("Image:", img_path)
color_palettes.make_all_palettes(
img_path,
filter_fn=partial(
filter_pixels,
low=low_cutoff,
high=high_cutoff
)
)
KMeans + Agglomerative clustering for RGB + HSV with 8 centers... Filtering pixels outside the range [0.1,0.99]... ----------------------------------------------------------------- Image: stills/09 (776).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
==================================================================================================== Image: stills/13 (795).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
==================================================================================================== Image: stills/18 (795).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
==================================================================================================== Image: stills/03 (776).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
==================================================================================================== Image: stills/11 (795).jpg
KMeans RGB clustering
KMeans HSV clustering
Agglomerative RGB clustering
Agglomerative HSV clustering
====================================================================================================
We still aren't getting some of the more rare but prominent colors -- for example the red in the boxing ring photo -- but adding pixel filtering definitely seems to help make the palettes more vibrant and gets rid of colors that approach pure white and pure black.