This notebook walks you through how to download Google's Conceptual Captions Dataset, and then clean and curate the data. Once you have a refined dataset, you can use this to train your own state-of-the-art ControlNet model, or to train a model for image captioning tasks!
!pip install pandas fiftyone
import hashlib
import pandas as pd
from tqdm.notebook import tqdm
import fiftyone as fo
import fiftyone.zoo as foz
import fiftyone.brain as fob
from fiftyone import ViewField as F
Download the tab-separated variable (.tsv
) file by clicking the “Download” button at the bottom of Google’s Conceptual Captions webpage, or by clicking on this link.
We can load the tsv file as a pandas DataFrame in similar fashion to a csv, by passing in sep=\t
to specify that the separator is a tab.
df = pd.read_csv("Train_GCC-training.tsv", sep='\t')
Give the columns of the DataFrame
descriptive names:
df.columns =['caption', 'url']
And then hash the url for each entry to generate a unique ID:
def hash_url(url):
return hashlib.md5(url.encode()).hexdigest()[:12]
df['url_hash'] = df['url'].apply(hash_url)
The DataFrame
looks like this:
caption url url_hash 0 sierra looked stunning in this top and this sk... http://78.media.tumblr.com/3b133294bdc7c7784b7... e7023a8dfcd2 1 young confused girl standing in front of a war... https://media.gettyimages.com/photos/young-con... 92679c323fc6 2 interior design of modern living room with fir... https://thumb1.shutterstock.com/display_pic_wi... 74c4fa5539f4 3 cybernetic scene isolated on white background . https://thumb1.shutterstock.com/display_pic_wi... f1ea388e05e1 4 gangsta rap artist attends sports team vs play... https://media.gettyimages.com/photos/jayz-atte... 9a6f8026f593 ... ... ... ... 3318327 the teams line up for a photo after kick - off https://i0.wp.com/i.dailymail.co.uk/i/pix/2015... 6aec77a477f9 3318328 stickers given to delegates at the convention . http://cdn.radioiowa.com/wp-content/uploads/20... 7d42aea90652 3318329 this is my very favourite design that i recent... https://i.pinimg.com/736x/96/f0/77/96f07728efe... f6dd151121c0 3318330 man driving a car through the mountains https://www.quickenloans.com/blog/wp-content/u... ee4244df5c55 3318331 a longtail boat with a flag goes by spectacula... http://l7.alamy.com/zooms/338c4740f7b2480dbb72... 7625946297b7 3318332 rows × 3 columns
We will use these IDs to specify the download locations (filepaths) of images, so that we can associate captions to the corresponding images.
If we want to download the images in batches, we can do so as follows:
def download_batch(df, batch_size=10000, start_index=0):
batch = df.iloc[start_index:start_index+batch_size]
for j in tqdm(range(batch_size)):
url, uh = batch.iloc[j][['url', 'url_hash']]
!curl -s --connect-timeout 3 --max-time 3 "{url}" -o images/{uh}.jpg
Here we download batch_size
images starting from start_index
into the folder images
, with filename specified by the url hash we generated above. We use curl
to execute the download operation, and set limits for the time spent attempting to download each image, because some of the links are no longer valid.
To download a total of num_images
images, run the following:
def download_images(df, batch_size=10000, num_images = 100000):
for i in range(num_images//batch_size):
download_batch(df, batch_size=batch_size, start_index=i*batch_size)
## download all images
num_images = len(df)
download_images(df, batch_size=10000, num_images=num_images)
Once we have the images downloaded into a images
folder, we can load the images and their captions as a Dataset
in FiftyOne:
dataset = fo.Dataset(name="gcc", persistent=True)
dataset.add_sample_field("caption", fo.StringField)
samples = []
for i in tqdm(range(num_images)):
caption, uh = df.iloc[i]['caption'], df.iloc[i]['url_hash']
filepath = f"images/{uh}.jpg"
sample = fo.Sample(
filepath=filepath,
caption=caption
)
samples.append(sample)
dataset.add_samples(samples)
This code creates a Dataset
named “gcc”, which is persisted to the underlying database, and then iterates through the first num_images
rows of the pandas DataFrame
, creating a Sample
with the appropriate filepath and caption.
For this walkthrough, I downloaded the first roughly 310,000 images.
The first step we should take when inspecting a new computer vision dataset is to visualize it! We can do this by launching the FiftyOne App:
session = fo.launch_app(dataset)
When we look at the data, we can immediately see that some of the images are not valid. This may be due to links which are no longer working, interruptions during downloading, or some other issue entirely.
Fortunately, we can filter out these invalid images easily. In FiftyOne, the compute_metadata()
method computes media-type-specific metadata for each sample. For image-based samples, this includes image width, height, and size in bytes.
When the media file is nonexistent or corrupted, the metadata will be left as null. We can thus filter out the corrupted images by running compute_metadata()
and matching for samples where the metadata exists:
dataset.compute_metadata()
## view containing only valid images
view = dataset.exists("metadata")
session = fo.launch_app(view)
A next step we may want to take is filtering out samples with unusual aspect ratios. If our goal is to control the outputs of a diffusion model, we will likely only be working with images within a certain range of reasonable aspect ratios.
We can do this using FiftyOne’s ViewField
, which allows us to apply arbitrary expressions to attributes of our samples, and then filter based on these. For instance, if we want to discard all images that are more than twice as large in either dimension as they are in the other dimension, we can do so with the following code:
from fiftyone import ViewField as F
long_filter = F("metadata.width") > 2*F("metadata.height")
tall_filter = F("metadata.height") > 2*F("metadata.width")
aspect_ratio_filter = (~long_filter) & (~tall_filter)
view = valid_image_view.match(aspect_ratio_filter)
For the sake of clarity, this is what the discarded samples look like:
bad_aspect_view = valid_image_view.match(~aspect_ratio_filter)
session = fo.launch_app(bad_aspect_view)
If you so choose, you can use a more or less stringent aspect ratio filter!
In a similar vein, we might want to remove the low resolution images. We want to generate stunning, photorealistic images, so there is no sense including low resolution images in the training data.
This filter is similar to the aspect ratio filter. If we select 300 pixels as our lowest allowed width and height, the filter takes the form:
hires_filter = (F("metadata.width") > 300) & (F("metadata.height") > 300)
view = good_aspect_view.match(hires_filter)
Once again, you can choose whatever thresholds you like. For clarity, here is a representative view of the discarded images:
lowres_view = good_aspect_view.match(~hires_filter)
session = fo.launch_app(lowres_view)
Looking at the low resolution images, we also might be reminded that some of the images in our dataset are greyscale. We likely want to generate images that are as vibrant as possible, so we should discard the black-and-white images.
In FiftyOne, one of the attributes logged in image metadata is the number of channels: color images have three channels (RGB), whereas grayscale images only have one channel. Removing grayscale images is as simple as matching for images with three channels!
## color images to keep
view = view.match(F("metadata.num_channels") == 3)
## gray images to discard
gray_view = view.match(F("metadata.num_channels") == 1)
session = fo.launch_app(gray_view)
Our next task in our data curation quest is to remove duplicate images. When an image is exactly or approximately duplicated in a training dataset, the resulting model may be biased by this small set of overrepresented samples - not to mention the added training costs.
We can find approximate duplicates in our dataset by using a model to generate embeddings for our images (we will use a CLIP model for illustration):
## load CLIP model from the FiftyOne Model Zoo
model = foz.load_zoo_model("clip-vit-base32-torch")
## Compute embeddings and store them in embeddings_field
view.compute_embeddings(
model,
embeddings_field = "image_clip_embedding"
)
Then we create a similarity index based on these embeddings:
results = fob.compute_similarity(view, embeddings="image_clip_embedding")
Finally, we can set a numerical threshold at which point we will consider images approximate duplicates (here we choose 0.3), and only retain one representative from each group of approximate duplicates:
results.find_duplicates(thresh=0.3)
# view the duplicates, paired up
dup_view = results.duplicates_view()
session = fo.launch_app(dup_view, auto = False)
# get one image from each group of duplicates
dup_rep_ids = list(results.neighbors_map.keys())
# get ids of non-duplicates
non_dup_ids = view.exclude(
dup_view.values("id")
).values("id")
# ids to keep
ids = dup_rep_ids + non_dup_ids
# create view from ids
view = view[ids]
Okay, now you’re in luck, because we saved the coolest step for last!
Google’s Conceptual Captions Dataset consists of image-caption pairs from the internet. More precisely, “the raw descriptions are harvested from the Alt-text HTML attribute associated with web images”. This is great as an initial pass, but there are bound to be some low-quality captions in there.
We may not be able to ensure that all of our captions perfectly describe their images, but we can certainly filter out some poorly aligned image-captions pairs!
We will do so using CLIPScore, which is a “reference-free evaluation metric for image captioning”. In other words, you just need the image and the caption. CLIPScore is easy to implement. First, we use Scipy’s cosine distance method to define a cosine similarity function:
from scipy.spatial.distance import cosine as cosine_distance
def cosine(vector1, vector2):
return 1. - cosine_distance(vector1, vector2)
Then we define a function which takes in a Sample
, and computes the CLIPScore between image embedding and caption embedding, stored on the samples:
def compute_clip_score(sample):
image_embedding = sample["image_clip_embedding"]
caption_embedding = sample["caption_clip_embedding"]
return max(100.*cosine(image_embedding, caption_embedding), 0.)
Essentially, this expression just lower bounds the score at zero. The scaling factor 100 is the same as used by PyTorch.
We can then compute the CLIPScore - our measure of alignment between images and captions - by adding the fields to our dataset and iterating over our samples:
dataset.add_sample_field("caption_clip_embedding", fo.VectorField)
dataset.add_sample_field("clip_score", fo.FloatField)
for sample in view.iter_samples(autosave=True, progress=True):
sample["caption_clip_embedding"] = model.embed_prompt(sample["caption"])
sample["clip_score"] = compute_clip_score(sample)
view.save()
If we want to see the “least aligned” samples, we can sort by “clip_score”.
## 100 least aligned samples
least_aligned_view = view.sort_by("clip_score")[:100]
To see the most aligned samples, we can do the same, but passing in reverse=True
:
## 100 most aligned samples
most_aligned_view = view.sort_by("clip_score", reverse=True)[:100]
We can then set a CLIPScore threshold depending on how aligned we demand the image-caption pairs are. To my taste, a threshold of 21.8 seemed good enough:
view = view.match(F("clip_score") > 21.8)
gcc_clean = view.clone(name = "gcc_clean", persistent=True)
The second line clones the view into a new persistent Dataset
named “gcc_clean”.
After our data cleaning and curation, we have turned a relatively mediocre initial dataset into a high-quality dataset that is ready for training a ControlNet model! We surely haven’t created a perfect dataset — a perfect dataset does not exist. What we have done is addressed all of the data quality issues that plagued ControlNet 1.0, plus a few more, just for good measure :)