by Giuseppe Insana, December 2022
# general
import os
import sys
from math import pi, ceil, sqrt
from tqdm.notebook import trange, tqdm
# sd
import torch
from torch import Tensor
import safetensors
import transformers
from diffusers import (
StableDiffusionPipeline,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DDIMScheduler,
DPMSolverMultistepScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
)
# image
from PIL import Image
import matplotlib.pyplot as plt
from IPython.display import Image as IImage # for gifs
from mpl_toolkits.axes_grid1 import ImageGrid # for image grid
# setup
%matplotlib inline
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "garbage_collection_threshold:0.6, max_split_size_mb:516"
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# torch.set_grad_enabled(False)
print(f"Using device: {device}")
# image display helper functions
def display_images(images, prompt="", subtitles=[]):
"""
simple image display via matplotlib
prompt can be specified, seed is taken from global variable
subtitles can be a list with as many elements as the images, to specify different labels for the images
"""
if len(images) > 1:
fig, axs = plt.subplots(1, max(2, len(images)), figsize=(12, 6))
fig.tight_layout()
plt.subplots_adjust(wspace=0, hspace=0)
plt.margins(x=0, y=0)
plt.axis("off")
for c, img in enumerate(images):
axs[c].tick_params(length=0, labelbottom=False, labelleft=False)
axs[c].imshow(img)
axs[c].set_title(subtitles[c] if len(subtitles) else "")
fig.suptitle("{}\n{}".format(prompt, seed))
else:
if prompt:
fig = plt.figure()
fig.suptitle("{}\n{}".format(prompt, seed))
plt.margins(x=0, y=0)
plt.axis("off")
plt.imshow(images[0])
else:
display(images[0])
def display_images_grid(images, prompt="", subtitles=[], grid_size=None, scale=2):
"""
simple image display in a grid via matplotlib
if grid_size is not specified, the nearest square grid of appropriate size will be used
"""
if not grid_size:
grid_size = ceil(sqrt(len(images)))
fig = plt.figure(figsize=(grid_size * scale, grid_size * scale))
grid = ImageGrid(
fig,
nrows_ncols=(grid_size, grid_size), # creates grid of axes
axes_pad=0, # 0.1 # pad between axes in inch.
)
for ax, im in zip(grid, images):
ax.tick_params(length=0, labelbottom=False, labelleft=False)
ax.imshow(im)
if prompt:
fig.suptitle("{}\n{}".format(prompt, seed))
plt.show()
def export_as_gif(filename, images, frames_per_second=10, rubber_band=False):
"""
export a list of images as a gif, optionally with rubber band repetition
the gif will be both saved to file and displayed in notebook
"""
my_images = images.copy()
if rubber_band:
my_images += images[2:-1][::-1]
my_images[0].save(
filename,
save_all=True,
append_images=images[1:],
duration=1000 // frames_per_second,
loop=0,
)
display(IImage(filename))
def torch_md_linspace(start: Tensor, stop: Tensor, num: int):
"""
Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
Replicates the multi-dimensional behaviour of numpy.linspace for PyTorch tensors.
e.g.:
start = torch.tensor([[0, 1], [2, 3]])
stop = torch.tensor([[10, 10], [10, 10]])
steps = 3
"""
# create a tensor of 'num' steps from 0 to 1
steps = torch.arange(num, dtype=torch.float16, device=start.device) / (num - 1)
for i in range(start.ndim):
steps = steps.unsqueeze(-1)
# the output starts at 'start' and increments until 'stop' in each dimension
out = start[None] + steps * (stop - start)[None]
return out
# test:
# start = torch.tensor([[0, 1], [2, 3]])
# stop = torch.tensor([[10, 10], [10, 10]])
# steps = 3
# np.isclose(torch_md_linspace(start, stop, num=steps), np.linspace(start, stop, num=steps)).all() # True
def eprint(*myargs, **kwargs):
"""
print to stderr, useful for error messages and to not clobber stdout
"""
print(*myargs, file=sys.stderr, **kwargs)
def text_enc(prompts, maxlen=None, device="cuda"):
"""
A function to take a textual prompt and convert it into embeddings
example: text_enc(["A dog wearing a white hat"])
"""
if maxlen is None:
maxlen = pipe.tokenizer.model_max_length
inp = pipe.tokenizer(
prompts,
padding="max_length",
max_length=maxlen,
truncation=True,
return_tensors="pt",
).input_ids.to(device)
return pipe.text_encoder(inp)[0].half()
def latents_to_pil(latents):
"""
Function to convert latents to images
"""
latents = (1 / 0.18215) * latents
with torch.no_grad():
image = pipe.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
images = (image * 255).round().astype("uint8")
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def sample_space(
latents, emb, g, steps, save_int=False, return_int=False, device="cuda"
):
"""
return latent representation
optionally save or return intermediate states
"""
if save_int and not os.path.exists(f"./steps"):
os.mkdir(f"./steps")
intermediates = []
# Setting number of steps in scheduler
scheduler.set_timesteps(steps)
# Adding noise to the latents
latents = latents.to(device).half() * scheduler.init_noise_sigma
# Iterating through defined steps
for i, ts in enumerate(tqdm(scheduler.timesteps, desc="iterations", leave=False)):
# We need to scale the i/p latents to match the variance
inp = scheduler.scale_model_input(torch.cat([latents] * 2), ts)
# Predicting noise residual using U-Net
with torch.no_grad():
u, t = pipe.unet(inp, ts, encoder_hidden_states=emb).sample.chunk(2)
# Performing Guidance
pred = u + g * (t - u)
# Conditioning the latents
latents = scheduler.step(pred, ts, latents).prev_sample
# Saving intermediate images
if save_int or return_int:
intermediate = latents_to_pil(latents)[0]
if save_int:
intermediate.save(f"steps/{i:04}.jpeg")
if return_int:
intermediates.append(intermediate)
if return_int:
return intermediates[0:-1]
else:
return latents_to_pil(latents)
def save_images(images, path="images"):
"""
save a list of images to a path, creating the directory if it does not exist
"""
if not os.path.exists(f"./{path}"):
os.mkdir(f"./{path}")
for index, image in enumerate(images):
image.save(f"./{path}/{index:04}.jpg")
def slerp(val, low, high):
"""
spherical interpolation
from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
compatible with image variation used by stable-diffusion-webui
"""
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
dot = (low_norm * high_norm).sum(1)
if dot.mean() > 0.9995:
return low * val + high * (1 - val)
omega = torch.acos(dot)
so = torch.sin(omega)
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (
torch.sin(val * omega) / so
).unsqueeze(1) * high
return res
def _initial_latents(seed=None, width=768, height=768, skip_random=0, verbose=False):
"""
return initial latents given an optional seed, optionally skipping a series of them
"""
# Setting the seed
if seed is not None:
if verbose:
eprint(f"using random seed {seed}")
torch.manual_seed(seed)
if skip_random:
if verbose:
eprint(f"skipping {skip_random} random")
# skip a series of random for latents (we want Nth image in a series generated from an initial seed)
_ = torch.randn(
(skip_random, pipe.unet.config.in_channels, height // 8, width // 8)
)
initial_latents = torch.randn(
(pipe.unet.config.in_channels, height // 8, width // 8)
)
if verbose:
eprint("initial latent, {}".format(initial_latents.sum()))
return initial_latents
def _enclat2img(
encodings=[],
initial_latents=[],
multiple_latents=[],
g=7.5,
steps=10,
neg_prompt=None,
device=device,
verbose=False,
):
images = []
# adding an unconditional prompt helps in the generation process
if neg_prompt is None:
uncond = text_enc([""] * 1, encodings.shape[1], device=device)
elif type(neg_prompt) != str:
eprint(f"ERROR: neg_prompt must be a string, not '{type(neg_prompt)}'")
return
else:
if verbose:
eprint(f"using negative prompt '{neg_prompt}'")
uncond = text_enc([neg_prompt] * 1, encodings.shape[1], device=device)
for _, encoding in enumerate(tqdm(encodings, desc="prompt", leave=False)):
emb = torch.cat([uncond, encoding.reshape_as(uncond)])
if len(multiple_latents):
for _, latents in enumerate(
tqdm(multiple_latents, desc="latent", leave=False)
):
images += sample_space(
torch.unsqueeze(latents, dim=0), emb, g, steps, device=device
)
else:
images += sample_space(
torch.unsqueeze(initial_latents, dim=0), emb, g, steps, device=device
)
return images
def prompt2img(
prompts=[],
neg_prompt=None,
n_samples=1,
g=10,
steps=30,
width=768,
height=768,
seed=None,
skip_random=0,
device=device,
verbose=False,
):
"""
Return a list of images equal to the number of prompts (optionally multiplied by n_samples)
prompt: list of strings or a single string
n_samples: how many samples to produce for each given prompt
neg_prompt: negative conditioning string
g: classifier free guidance
steps: iteration steps for the diffusion process
width, height: dimensions for the resulting image
seed: initialize random generator; use None to get next available random
skip_random: how many random latents to discard before producing image
"""
if type(prompts) == str:
prompts = [prompts]
if n_samples < 1:
eprint("ERROR: n_samples must be positive!")
return
initial_latents = _initial_latents(
seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose
)
multiple_latents = []
if n_samples > 1:
sample_latents = [
torch.unsqueeze(initial_latents, dim=0)
] # the first one generated from the seed
for _ in range(n_samples - 1): # add as many more as requested
sample_latent = torch.randn(
(pipe.unet.config.in_channels, height // 8, width // 8)
)
if verbose:
eprint("adding latent, {}".format(sample_latent.sum()))
sample_latents.append(torch.unsqueeze(sample_latent, dim=0))
multiple_latents = torch.cat(sample_latents)
encodings = text_enc(prompts, device=device)
return _enclat2img(
encodings=encodings,
initial_latents=initial_latents,
multiple_latents=multiple_latents,
g=g,
steps=steps,
neg_prompt=neg_prompt,
device=device,
verbose=verbose,
)
def beyond_prompt(
prompt="",
neg_prompt=None,
walk_steps=1,
walk_stepsize=0.02,
g=10,
steps=30,
width=768,
height=768,
seed=None,
skip_random=0,
device=device,
verbose=False,
):
"""
Walk forward in prompt embedding space by walk_stepsize to produce walk_steps + 1 images,
each a step (of optionally specified size) forward from the previous
prompt: string
neg_prompt: negative conditioning string
walk_steps: number of steps to walk forward
walk_stepsize: how much to walk forward in prompt embedding space at each step
g: classifier free guidance
steps: iteration steps for the diffusion process
width, height: dimensions for the resulting image
seed: initialize random generator; use None to get next available random
skip_random: how many random latents to discard before producing image
"""
if type(prompt) != str:
eprint(f"ERROR: prompt must be a string, not '{type(prompt)}'")
return
initial_latents = _initial_latents(
seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose
)
multiple_latents = []
encoding = text_enc([prompt], device=device)
delta = torch.ones_like(encoding) * walk_stepsize
new_encodings = []
for step_index in range(0, walk_steps + 1):
new_encodings.append(encoding)
encoding = encoding + delta # nudge prompt embedding
if verbose:
print("nudged prompt by {}".format(walk_stepsize * step_index))
encodings = torch.cat(new_encodings)
return _enclat2img(
encodings=encodings,
initial_latents=initial_latents,
multiple_latents=multiple_latents,
g=g,
steps=steps,
neg_prompt=neg_prompt,
device=device,
verbose=verbose,
)
def interpolate_prompts(
prompts=[],
neg_prompt=None,
interpolate_steps=1,
g=10,
steps=30,
width=768,
height=768,
seed=None,
skip_random=0,
device=device,
verbose=False,
):
"""
Given two prompts, interpolate among the two embeddings and produce a number of images equal to interpolate_steps
Return a list of images exploring the embedding space between first and second prompt
prompt: list of strings or a single string
interpolate_steps: number of images to produce between the one from the first and the one from the second prompt
neg_prompt: negative conditioning string
g: classifier free guidance
steps: iteration steps for the diffusion process
width, height: dimensions for the resulting image
seed: initialize random generator; use None to get next available random
skip_random: how many random latents to discard before producing image
Extension: interpolate four prompts and create square grid
"""
if type(prompts) != list or len(prompts) != 2:
eprint("ERROR: you need to pass a list of 2 prompts!")
return
if interpolate_steps < 1:
eprint("ERROR: interpolate steps must be positive!")
return
initial_latents = _initial_latents(
seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose
)
multiple_latents = []
encodings = text_enc(prompts, device=device)
encodings = torch_md_linspace(encodings[0], encodings[1], interpolate_steps + 2)
return _enclat2img(
encodings=encodings,
initial_latents=initial_latents,
multiple_latents=multiple_latents,
g=g,
steps=steps,
neg_prompt=neg_prompt,
device=device,
verbose=verbose,
)
def revolve_prompt(
prompt="",
neg_prompt=None,
walk_type="circle",
walk_steps=1,
g=10,
steps=30,
width=768,
height=768,
seed=None,
seed2=None,
seed3=None,
skip_random=0,
device=device,
verbose=False,
):
"""
Walk around a prompt in prompt in latent space to produce walk_steps images
For a "circle" walk, it uses two latents which can be determined specifying seed and seed2
In case of "spiral", the walk will be along a spherical spiral using three latents, optionally determined by seed, seed2 and seed3
Note: the image that would singularly generated from seed would appear as the first one in the set
and the one from seed2 would be found at around 1/4th of the total image count
in case of spiral walk, the image from seed1 is the first one,
the image from seed2 is found at around 1/4th of the total image count
and the image from seed3 (approximate) would be found at around 1/3rd of the total image count
prompt: string
neg_prompt: negative conditioning string
walk_type: "circle" (default) or "spiral"
walk_steps: how many steps to take in total (equals number of returned images)
g: classifier free guidance
steps: iteration steps for the diffusion process
width, height: dimensions for the resulting image
seed: initialize random generator; use None to get next available random
seed2: optional seed to determine circular walk
seed3: optional seed to further determine spiral walk
skip_random: how many random latents to discard before producing image
"""
if type(prompt) != str:
eprint(f"ERROR: prompt must be a string, not '{type(prompt)}'")
return
# initialize alternate latents
if seed2 is not None:
torch.manual_seed(seed2)
variation_latents = torch.randn(
(pipe.unet.config.in_channels, height // 8, width // 8)
)
if walk_type == "spiral" and seed3 is not None:
torch.manual_seed(seed3)
alt_variation_latents = torch.randn(
(pipe.unet.config.in_channels, height // 8, width // 8)
)
if walk_type == "circle" and seed3 is not None:
eprint(f"NOTICE: seed3 not used for 'circle' walk_type'")
initial_latents = _initial_latents(
seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose
)
multiple_latents = []
# if no alternate seed were specified, we'll use the next random after the one used for initial latents
if seed2 is None:
variation_latents = torch.randn(
(pipe.unet.config.in_channels, height // 8, width // 8)
)
if walk_type == "spiral" and seed3 is None:
alt_variation_latents = torch.randn(
(pipe.unet.config.in_channels, height // 8, width // 8)
)
encodings = text_enc([prompt], device=device)
if walk_type == "circle": # around a prompt and two random latents in a circle
##circular roundwalk:
# stepspace = torch.linspace(0, 2, walk_steps + 1)[0:-1] * pi
# walk_scale_x = torch.cos(stepspace)
# walk_scale_y = torch.sin(stepspace)
# (accelerates and decelerates, not very smooth in interpolation)
# spreadout circular walk:
spread_factor = 30 # the lower this, the more the points will be pushed away from 0, 90, 180, 270 bearings and concentrated towards 45, 135..
stepspace = torch.linspace(0, 2, walk_steps + 1)[0:-1]
stepspace -= torch.sin((stepspace + 0.25) * pi * 4) / spread_factor
stepspace *= pi
walk_scale_x = torch.cos(stepspace)
walk_scale_y = torch.sin(stepspace)
noise_x = torch.tensordot(walk_scale_x, initial_latents, dims=0)
noise_y = torch.tensordot(walk_scale_y, variation_latents, dims=0)
multiple_latents = torch.add(noise_x, noise_y)
elif walk_type == "spiral": # spherical spiral walk with three random latents
c = 2 # use 4 for double the amount of turns in the spherical spiral walk
# circular spherical spiral walk:
# theta = torch.linspace(1, 2, walk_steps + 1)[0:-1] * pi
# theta2 = torch.linspace(1, 0, walk_steps + 1)[0:-1] * pi
# walk_scale_x1 = torch.sin(theta) * torch.cos(c * theta)
# walk_scale_x2 = torch.sin(theta2) * torch.cos(c * theta2)
# walk_scale_x = torch.cat([walk_scale_x1, walk_scale_x2])
# walk_scale_y1 = torch.sin(theta) * torch.sin(c * theta)
# walk_scale_y2 = torch.sin(theta2) * torch.sin(c * theta2)
# walk_scale_y = torch.cat([walk_scale_y1, walk_scale_y2])
# walk_scale_z = torch.cos(torch.linspace(0, 2 * pi, walk_steps * 2 + 1)[0:-1])
# spread spherical spiral walk:
spread_factor = 30
theta = torch.linspace(1, 2, walk_steps // 2 + 1)[0:-1]
theta -= torch.sin((theta + 0.25) * pi * 4) / spread_factor
theta *= pi
theta2 = torch.linspace(1, 0, walk_steps // 2 + 1)[0:-1]
theta2 -= torch.sin((theta2 + 0.25) * pi * 4) / spread_factor
theta2 *= pi
walk_scale_x1 = torch.sin(theta) * torch.cos(c * theta)
walk_scale_x2 = torch.sin(theta2) * torch.cos(c * theta2)
walk_scale_x = torch.cat([walk_scale_x1, walk_scale_x2])
walk_scale_y1 = torch.sin(theta) * torch.sin(c * theta)
walk_scale_y2 = torch.sin(theta2) * torch.sin(c * theta2)
walk_scale_y = torch.cat([walk_scale_y1, walk_scale_y2])
stepspace = torch.linspace(0, 2, walk_steps + 1)[0:-1]
stepspace -= torch.sin((stepspace + 0.25) * pi * 4) / spread_factor
stepspace *= pi
walk_scale_z = torch.cos(stepspace)
noise_z = torch.tensordot(walk_scale_z, initial_latents, dims=0)
noise_x = torch.tensordot(walk_scale_x, variation_latents, dims=0)
noise_y = torch.tensordot(walk_scale_y, alt_variation_latents, dims=0)
multiple_latents = torch.add(torch.add(noise_x, noise_y), noise_z)
else:
eprint(f"ERROR: unknown walk_type '{walk_type}'")
return
return _enclat2img(
encodings=encodings,
initial_latents=initial_latents,
multiple_latents=multiple_latents,
g=g,
steps=steps,
neg_prompt=neg_prompt,
device=device,
verbose=verbose,
)
def prompt_variations(
prompt="",
neg_prompt=None,
variations=1,
variation_strength=0.1,
g=10,
steps=30,
width=768,
height=768,
seed=None,
skip_random=0,
device=device,
verbose=False,
):
"""
Return an image followed by a list of variations, each one of specified variation_strength from the first
prompt: list of strings or a single string
variations: how many variations to return after the normal image
variation_strength: how much to mix the initial latent and the variant ones (hence how different from initial image)
neg_prompt: negative conditioning string
g: classifier free guidance
steps: iteration steps for the diffusion process
width, height: dimensions for the resulting image
seed: initialize random generator; use None to get next available random
skip_random: how many random latents to discard before producing image
"""
if type(prompt) != str:
eprint(f"ERROR: prompt must be a string, not '{type(prompt)}'")
return
if variations < 1:
eprint("ERROR: number of variations must be positive!")
return
initial_latents = _initial_latents(
seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose
)
multiple_latents = []
sample_latents = [
torch.unsqueeze(initial_latents, dim=0)
] # the first one generated from the seed
for _ in range(variations): # add as many as requested
sample_latent = torch.randn(
(pipe.unet.config.in_channels, height // 8, width // 8)
)
sample_latent = slerp(variation_strength, initial_latents, sample_latent)
if verbose:
eprint("adding variation latent: {}".format(sample_latent.sum()))
sample_latents.append(torch.unsqueeze(sample_latent, dim=0))
multiple_latents = torch.cat(sample_latents)
encodings = text_enc([prompt], device=device)
return _enclat2img(
encodings=encodings,
initial_latents=initial_latents,
multiple_latents=multiple_latents,
g=g,
steps=steps,
neg_prompt=neg_prompt,
device=device,
verbose=verbose,
)
def variate_prompt(
prompt="",
neg_prompt=None,
variation_strength=0,
var_steps=0,
g=10,
steps=30,
width=768,
height=768,
seed=None,
seed2=None,
skip_random=0,
device=device,
verbose=False,
):
"""
create an image from the mix (in desired amount) of two random initial latents (optionally specified by seed and seed2)
alternatively, if var_steps is specified, it will interpolate between the two random latents, returning var_steps images
(i.e. like trying a linearly increasing set of values of variation_strength, from 0 to 1)
prompt: string
neg_prompt: negative conditioning string
variation_strength: how much to mix the initial latent and the variant one
var_steps: how many steps to go from an initial latent and a variation latent
g: classifier free guidance
steps: iteration steps for the diffusion process
width, height: dimensions for the resulting image
seed: initialize random generator; use None to get next available random
seed2: optional seed to determine circular walk
skip_random: how many random latents to discard before producing image
"""
if type(prompt) != str:
eprint(f"ERROR: prompt must be a string, not '{type(prompt)}'")
return
if var_steps > 0 and variation_strength > 0:
eprint("NOTICE: variation_strength will be ignored when var_steps specified")
if var_steps < 0 or variation_strength < 0:
eprint("ERROR: do not use negative values")
return
if var_steps <= 0 and variation_strength <= 0:
eprint("ERROR: nothing to do. specify either var_steps or variation_strength")
return
initial_latents = _initial_latents(
seed=seed, skip_random=skip_random, width=width, height=height, verbose=verbose
)
# initialize variation latent
if seed2 is not None:
torch.manual_seed(seed2)
variation_latents = torch.randn(
(pipe.unet.config.in_channels, height // 8, width // 8)
)
if var_steps > 0: # gradually interpolate between two random latents
var_latents = []
stepspace = torch.linspace(0, 1, var_steps) # include last point
for stepvalue in stepspace:
if verbose:
eprint("generating variation at {}".format(stepvalue)) # debug
var_latents.append(
torch.unsqueeze(
slerp(stepvalue, initial_latents, variation_latents), dim=0
)
)
multiple_latents = torch.cat(var_latents)
else:
multiple_latents = []
# if no alternate seed were specified, we'll use the next random after the one used for initial latents
if seed2 is None:
variation_latents = torch.randn(
(pipe.unet.config.in_channels, height // 8, width // 8)
)
initial_latents = slerp(variation_strength, initial_latents, variation_latents)
encodings = text_enc([prompt], device=device)
return _enclat2img(
encodings=encodings,
initial_latents=initial_latents,
multiple_latents=multiple_latents,
g=g,
steps=steps,
neg_prompt=neg_prompt,
device=device,
verbose=verbose,
)
# available:
# DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler, DDPMScheduler
model_id = "stabilityai/stable-diffusion-2-1-base"
scheduler = DPMSolverMultistepScheduler.from_pretrained(model_id, subfolder="scheduler")
pipe = StableDiffusionPipeline.from_pretrained(
model_id, scheduler=scheduler, torch_dtype=torch.float16
)
pipe.safety_checker = None
pipe.requires_safety_checker = False
pipe = pipe.to(device)
prompt = "hyper detailed digital painting of scenery, shibuya tokyo, post-apocalypse, ruins, rust, sky, skyscraper, abandoned, blue sky, broken window, building, cloud, crane machine, outdoors, overgrown, pillar, sunset"
seed = 24509723452345
prompt2img(prompt, seed=seed)[0]
seed = 64
prompt = "photo of a tiger demon"
images = prompt2img(prompt, width=512, height=768, g=7.5, steps=10, seed=seed)
display_images(images)
# Testing negative prompts
images = [None, None]
seed = 66
prompt1 = "photo of a castle in the middle of a forest with trees and bushes, detailed vegetation"
prompt2 = "green, leaves, summer, spring"
images[0] = prompt2img(prompt1, neg_prompt=None, seed=seed)[0]
images[1] = prompt2img(prompt1, neg_prompt=prompt2, seed=seed)[0]
# side by side comparison
display_images(images, prompt=prompt1, subtitles=["", "- " + prompt2])
# remove the prompt from itself for unexpected results:
prompt = "watercolour of a tiger"
prompt2img(prompt, neg_prompt=prompt, seed=1023458422345243)[0]
images = prompt2img(["cute puppy", "cute kitten"], width=512, height=512, seed=123)
display_images(images)
prompt1 = "a photo of a boy running on the beach"
prompt2 = "a photo of a cadillac on a highway"
seed = 749109862
images = interpolate_prompts(
[prompt1, prompt2], width=512, height=512, seed=seed, interpolate_steps=5
)
display_images(images, prompt1 + "<->" + prompt2)
prompt = "girl with green hair eating rice noodles"
images = beyond_prompt(
prompt,
neg_prompt="malformed",
width=512,
height=512,
seed=4093245,
walk_steps=8,
walk_stepsize=0.015,
)
display_images_grid(images, prompt)
prompt = "An oil masterpiece painting of horses in a field next to a farm in Normandy"
seed = 132432456352
seed2 = 42
walk_steps = 48
images = revolve_prompt(prompt, walk_steps=walk_steps, seed=seed, seed2=seed2)
display_images_grid(images, prompt)
# save_images(images, path=f"horses_r{walk_steps}")
export_as_gif(
f"horses_r{walk_steps}.gif", images, frames_per_second=2, rubber_band=False
)
prompt = "hires photo of shark, underwater scenery with tropical fishes and coral sea floor, caustics"
seed = 100
seed = 132432456352
seed2 = 42
seed3 = 897234234
walk_steps = 24
images = revolve_prompt(
prompt,
walk_type="spiral",
walk_steps=walk_steps,
width=512,
height=512,
seed=seed,
seed2=seed2,
seed3=seed3,
)
display_images_grid(images, prompt)
seed = 10234584620131114
prompt = "a watercolour painting of Cambridge Jesus Green"
images = prompt2img(
prompt, width=512, height=512, seed=seed, n_samples=9
)
display_images_grid(images)
# recreate the 5th variation:
images = prompt2img(
prompt, width=512, height=512, seed=seed, skip_random=4
)
images[0]
prompt = "An oil painting of horses in a field next to a farm in Normandy"
seed = 1022134
images = prompt_variations(prompt, variations=8, variation_strength=0.1, seed=seed)
display_images_grid(images)
seed = 1023458422345243
seed2 = 35634563
prompt = "movie cover of Schwarzenegger as the Terminator riding a Vespa"
images = prompt2img(prompts=prompt, neg_prompt=None, width=512, height=512, seed=seed)
images += variate_prompt(
prompt=prompt,
width=512,
height=512,
seed=seed,
seed2=seed2,
variation_strength=0.25,
)
display_images(images, prompt=prompt, subtitles=["", f"var{seed2} 25%"])
seed = 1023458422345243
seed2 = 35634563
prompt = "movie cover of Schwarzenegger as the Terminator riding a Vespa"
images = variate_prompt(
prompt=prompt, width=512, height=512, seed=seed, seed2=seed2, var_steps=6
)
display_images(images)
seed = 12345
images = interpolate_prompts(
["cow cat pawlephant", "muleskin beetledog"],
seed=seed,
interpolate_steps=1,
g=7.5,
steps=10,
)
display_images(images)
prompt2img("a cow covered in oreo cookies", seed=3534534, g=7.5, steps=10)[0]
seed = 3534534
seed2 = 3534533
walk_steps = 360
images = variate_prompt(
"a cow covered in oreo cookies", seed=seed, seed2=seed2, var_steps=walk_steps
)
save_images(images, path="cow_covered_in_oreos")