This notebook walks you through how to play a game of multimodal AI Telephone!
Here’s how the game of AI Telephone works:
n
(in our case n=10
).To run this code, you will need to install the FiftyOne open source library for dataset curation, the OpenAI Python Library, and the Replicate Python client.
!pip install fiftyone openai replicate
We will import all of the necessary modules:
import hashlib
import os
import requests
import openai
import replicate
import fiftyone as fo
from fiftyone import ViewField as F
First we define the base class:
class Text2Image(object):
"""Wrapper for a Text2Image model."""
def __init__(self):
self.name = None
self.model_name = None
def generate_image(self, text):
response = replicate.run(self.model_name, input={"prompt": text})
if type(response) == list:
response = response[0]
return response
Then we create a class for each model:
class StableDiffusion(Text2Image):
"""Wrapper for a StableDiffusion model."""
def __init__(self):
self.name = "stable-diffusion"
self.model_name = "stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478"
class VQGANCLIP(Text2Image):
"""Wrapper for a VQGAN-CLIP model."""
def __init__(self):
self.name = "vqgan-clip"
self.model_name = "mehdidc/feed_forward_vqgan_clip:28b5242dadb5503688e17738aaee48f5f7f5c0b6e56493d7cf55f74d02f144d8"
class DALLE2(Text2Image):
"""Wrapper for a DALL-E 2 model."""
def __init__(self):
self.name = "dalle-2"
def generate_image(self, text):
response = openai.Image.create(
prompt=text,
n=1,
size="512x512"
)
return response['data'][0]['url']
Once again, we define the base class:
class Image2Text(object):
"""Wrapper for an Image2Text model."""
def __init__(self):
self.name = None
self.model_name = None
self.task_description = "Write a detailed description of this image."
def _clean_response(self, response):
response = response.lower()
phrases = ["caption: ", "the image shows ", "the image features"]
for phrase in phrases:
if phrase in response:
response = response.split(phrase)[1].strip()
return response
def _generate_text(self, image_url):
response = replicate.run(
self.model_name,
input={
"image": image_url,
"prompt": self.task_description,
}
)
return response
def generate_text(self, image_url):
response = self._generate_text(image_url)
response = self._clean_response(response)
return response
Then we create a class for each model:
class BLIP(Image2Text):
"""Wrapper for a BLIP model."""
def __init__(self):
super().__init__()
self.name = "blip"
self.model_name = "salesforce/blip:2e1dddc8621f72155f24cf2e0adbde548458d3cab9f00c0139eea840d0ac4746"
class CLIPPrefix(Image2Text):
"""Wrapper for a CLIPPrefixCaptioning model."""
def __init__(self):
super().__init__()
self.name = "clip-prefix"
self.model_name = "rmokady/clip_prefix_caption:9a34a6339872a03f45236f114321fb51fc7aa8269d38ae0ce5334969981e4cd8"
class MiniGPT4(Image2Text):
"""Wrapper for a MiniGPT-4 model."""
def __init__(self):
super().__init__()
self.name = "minigpt-4"
self.model_name = "daanelson/minigpt-4:b96a2f33cc8e4b0aa23eacfce731b9c41a7d9466d9ed4e167375587b54db9423"
class MPLUGOwl(Image2Text):
"""Wrapper for a mPLUG Owl model."""
def __init__(self):
super().__init__()
self.name = "mplug-owl"
self.model_name = "joehoover/mplug-owl:51a43c9d00dfd92276b2511b509fcb3ad82e221f6a9e5806c54e69803e291d6b"
def _generate_text(self, image_url):
response = replicate.run(
self.model_name,
input={
"img": image_url,
"prompt": self.task_description,
}
)
output_str = ''
for item in response:
output_str += item
return output_str
These specific prompts were generated using GPT-4. Feel free to generate prompts however you'd like!
easy_texts = [
"A red apple sitting on a wooden table with sunlight streaming in from a window.",
"A small white dog is playing in a lush green park, chasing a yellow frisbee.",
"A bluebird is perched on a blooming cherry blossom branch on a clear spring day.",
]
medium_texts = [
"A busy city street with neon signs in the evening, people walking with umbrellas, a vendor selling hot dogs, and a red double-decker bus passing by.",
"A quaint cobblestone alleyway in a European town during a bright day. There are colorful flowers in the window boxes, a bicycle leaning against the wall, and a cat lounging near a doorway.",
"An astronaut floating in the International Space Station, looking out at Earth through the window, with a space capsule docked in the background.",
]
hard_texts = [
"A grand medieval banquet hall filled with elegantly dressed lords and ladies feasting on a spread of exotic dishes, a minstrel playing a lute, and a knight narrating his adventures.",
"A bustling marketplace in an ancient Middle Eastern city. Traders haggling over spices and silks, camels carrying goods, the sun setting behind a mosque with a crescent moon visible.",
"A complex network of futuristic machines in a high-tech lab. Scientists are observing data on holographic screens, while autonomous robots are assembling nanobots.",
]
impossible_texts = [
"A panoramic scene of an advanced alien civilization on a distant exoplanet. Interstellar vehicles flying in an indigo sky above towering crystalline structures. Aliens with varying physical features are interacting, engaging in activities like exchanging energy orbs, communicating through light patterns, and tending to exotic, bio-luminescent flora. The planet’s twin moons are visible in the horizon over a glistening alien ocean."
]
levels = ["easy", "medium", "hard", "impossible"]
level_prompts = [easy_texts, medium_texts, hard_texts, impossible_texts]
def get_prompts():
prompts = []
for level, texts in zip(levels, level_prompts):
for text in texts:
prompts.append(Prompt(text, level))
return prompts
def download_image(image_url, filename):
img_data = requests.get(image_url).content
with open(filename, 'wb') as handler:
handler.write(img_data)
class TelephoneLine(object):
"""Class for playing telephone with AI."""
def __init__(self, t2i, i2t):
self.t2i = t2i
self.i2t = i2t
self.name = f"{t2i.name}_{i2t.name}"
self.conversations = {}
def get_conversation_name(self, text):
full_name = f"{self.name}{text}"
hashed_name = hashlib.md5(full_name.encode())
return hashed_name.hexdigest()[:6]
def play(self, prompt, nturns = 10):
"""Play a game of telephone."""
print(f"Connecting {self.t2i.name} <-> {self.i2t.name} with prompt: {prompt.text[:20]}...")
texts = [prompt.text]
image_urls = []
for _ in range(nturns):
image_url = self.t2i.generate_image(texts[-1])
text = self.i2t.generate_text(image_url)
texts.append(text)
image_urls.append(image_url)
conversation_name = self.get_conversation_name(prompt.text)
self.conversations[conversation_name] = {
"texts": texts,
"image_urls": image_urls,
"level": prompt.level
}
def save_conversations_to_dataset(self, dataset):
"""Save conversations to a dataset."""
for conversation_name in self.conversations.keys():
conversation = self.conversations[conversation_name]
prompt = conversation["texts"][0]
level = conversation["level"]
image_urls = conversation["image_urls"]
texts = conversation["texts"]
for i in range(len(image_urls)):
filename = f"{conversation_name}_{i}.jpg"
filepath = os.path.join("telephone_images", filename)
download_image(image_urls[i], filepath)
sample = fo.Sample(
filepath = filepath,
conversation_name = conversation_name,
prompt = prompt,
level = level,
t2i_model = self.t2i.name,
i2t_model = self.i2t.name,
step_number = i,
text_before = texts[i],
text_after = texts[i+1]
)
dataset.add_sample(sample)
Set the directory where you'd like images downloaded/stored:
IMAGES_DIR = "telephone_images"
if not os.path.exists(IMAGES_DIR):
os.makedirs(IMAGES_DIR)
## Image2Text models
mplug_owl = MPLUGOwl()
blip = BLIP()
clip_prefix = CLIPPrefix()
mini_gpt4 = MiniGPT4()
image2text_models = [mplug_owl, blip, clip_prefix, mini_gpt4]
## Text2Image models
vqgan_clip = VQGANCLIP()
sd = StableDiffusion()
dalle2 = DALLE2()
text2image_models = [sd, dalle2, vqgan_clip]
combos = [(t2i, i2t) for t2i in text2image_models for i2t in image2text_models]
lines = [TelephoneLine(*combo) for combo in combos]
prompts = get_prompts()
Create the dataset where we will store the results:
dataset = fo.Dataset(name = 'telephone', persistent=True)
dataset.add_sample_field("conversation_name", fo.StringField)
dataset.add_sample_field("prompt", fo.StringField)
dataset.add_sample_field("level", fo.StringField)
dataset.add_sample_field("t2i_model", fo.StringField)
dataset.add_sample_field("i2t_model", fo.StringField)
dataset.add_sample_field("step_number", fo.IntField)
dataset.add_sample_field("text_before", fo.StringField)
dataset.add_sample_field("text_after", fo.StringField)
Play all of the games:
for line in tqdm(lines):
for prompt in prompts:
line.play(prompt, nturns = 10)
line.save_conversations_to_dataset(dataset)
Check out the results in the FiftyOne App:
## auto=False to prevent the app from opening. Open with new tab in browser: http://localhost:5151
session = fo.launch_app(dataset, auto = False)
Use the dynamic groups functionality in the FiftyOne App: click on the splitting icon in the menu bar to group images by conversation, select conversation_name
from the dropdown, then toggle the selector to ordered
and select step_number
.
import numpy as np
from scipy.spatial.distance import cosine as cosine_distance
Create a unique hash key for each prompt and store the embeddings in a dictionary:
def hash_prompt(prompt):
return hashlib.md5(prompt.encode()).hexdigest()[:6]
## Use ImageBind to embed text. You can use any text embedding model here.
## You can also embed the generated images if you appropriately modify the code below.
MODEL_NAME = "daanelson/imagebind:0383f62e173dc821ec52663ed22a076d9c970549c209666ac3db181618b7a304"
def embed_text(text):
response = replicate.run(
MODEL_NAME,
input={
"text_input": text,
"modality": "text"
}
)
return np.array(response)
prompts = dataset.distinct("prompt")
### Embed initial prompts
prompt_embeddings = {}
dataset.add_sample_field("prompt_hash", fo.StringField)
prompt_groups = dataset.group_by("prompt")
for pg in prompt_groups.iter_dynamic_groups():
prompt = pg.first().prompt
hash = hash_prompt(prompt)
prompt_embeddings[hash] = embed_text(prompt)
view = pg.set_field("prompt_hash", hash)
view.save("prompt_hash")
Compute a distance between the text description and the prompt:
dataset.add_sample_field("text_after_dist", fo.FloatField)
prompt_groups = dataset.group_by("conversation_name")
for cg in conversation_groups.iter_dynamic_groups(progress=True):
hash = cg.first().prompt_hash
prompt_embedding = prompt_embeddings[hash]
ordered_samples = cg.sort_by("step_number")
for sample in ordered_samples.iter_samples(autosave=True):
text_embedding = embed_text(sample.text_after)
sample["text_embedding"] = text_embedding
sample.text_after_dist = cosine_distance(
prompt_embedding,
text_embedding
)
Aggregate the results by level of prompt difficulty, T2I model, and I2T model:
### Aggregate performance by level
levels = dataset.distinct("level")
t2i_models = dataset.distinct("t2i_model")
i2t_models = dataset.distinct("i2t_model")
pairs = [(t2i, i2t) for t2i in t2i_models for i2t in i2t_models]
steps = sorted(dataset.distinct("step_number"))
pair_dists = {}
for level in levels:
pair_level_dists = {}
level_view = dataset.match(F("level") == level)
for pair in pairs:
t2i_model, i2t_model = pair
model_view = level_view.match(F("t2i_model") == t2i_model).match(F("i2t_model") == i2t_model)
step_dists = [0.]
for step in steps:
step_view = model_view.match(F("step_number") == step)
step_dists.append(step_view.mean("image_dist"))
step_dists.append(step_view.mean("text_after_dist"))
pair_level_dists[pair] = step_dists
pair_dists[level] = pair_level_dists
import matplotlib.pyplot as plt
import numpy as np
Set the style for each curve in the plot based on the T2I model and I2T model, so that we can easily distinguish between them:
## color by t2i model
t2i_colors = {
"dalle-2": "r",
"stable-diffusion": 'b',
"vqgan-clip": 'y'
}
## marker by i2t model
i2t_markers = {
"clip-prefix": '+',
'minigpt-4': 'o',
'blip': 'v',
'mplug-owl': 's'
}
def get_style(pair):
t2i, i2t = pair
return f"-{t2i_colors[t2i]}{i2t_markers[i2t]}"
def format_pair(pair):
t2i, i2t = pair
arrow = r'$\leftrightarrow$'
return f"{t2i}{arrow}{i2t}"
Function that plots results for each difficulty level:
def plot_level_results(level):
plt.figure(figsize=(20,10))
for pair in pairs:
dists = pair_dists[level][pair]
steps = np.arange(len(dists)) + 1
plt.plot(steps, dists, get_style(pair), label=format_pair(pair))
plt.xlabel("Step Number")
plt.ylabel("Cosine Distance")
plt.title(f"AI Telephone Results: {level.capitalize()} Prompts", fontsize=20)
plt.legend(frameon=False)
plt.savefig(f"{level}.png")
Iterate over each difficulty level and plot the results:
for level in levels:
plot_level_results(level)