import subprocess
CUDA_version = [s for s in subprocess.check_output(["nvcc", "--version"]).decode("UTF-8").split(", ") if s.startswith("release")][0].split(" ")[-1]
print("CUDA version:", CUDA_version)
if CUDA_version == "10.0":
torch_version_suffix = "+cu100"
elif CUDA_version == "10.1":
torch_version_suffix = "+cu101"
elif CUDA_version == "10.2":
torch_version_suffix = ""
else:
torch_version_suffix = "+cu110"
! pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex
CUDA version: 10.1 Looking in links: https://download.pytorch.org/whl/torch_stable.html Requirement already satisfied: torch==1.7.1+cu101 in /usr/local/lib/python3.6/dist-packages (1.7.1+cu101) Requirement already satisfied: torchvision==0.8.2+cu101 in /usr/local/lib/python3.6/dist-packages (0.8.2+cu101) Requirement already satisfied: ftfy in /usr/local/lib/python3.6/dist-packages (5.8) Requirement already satisfied: regex in /usr/local/lib/python3.6/dist-packages (2019.12.20) Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch==1.7.1+cu101) (1.19.5) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch==1.7.1+cu101) (3.7.4.3) Requirement already satisfied: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from torch==1.7.1+cu101) (0.8) Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision==0.8.2+cu101) (7.0.0) Requirement already satisfied: wcwidth in /usr/local/lib/python3.6/dist-packages (from ftfy) (0.2.5)
!pip install big-sleep --upgrade
Requirement already up-to-date: big-sleep in /usr/local/lib/python3.6/dist-packages (0.4.6) Requirement already satisfied, skipping upgrade: pytorch-pretrained-biggan in /usr/local/lib/python3.6/dist-packages (from big-sleep) (0.1.1) Requirement already satisfied, skipping upgrade: torch>=1.7.1 in /usr/local/lib/python3.6/dist-packages (from big-sleep) (1.7.1+cu101) Requirement already satisfied, skipping upgrade: einops>=0.3 in /usr/local/lib/python3.6/dist-packages (from big-sleep) (0.3.0) Requirement already satisfied, skipping upgrade: tqdm in /usr/local/lib/python3.6/dist-packages (from big-sleep) (4.41.1) Requirement already satisfied, skipping upgrade: torchvision>=0.8.2 in /usr/local/lib/python3.6/dist-packages (from big-sleep) (0.8.2+cu101) Requirement already satisfied, skipping upgrade: fire in /usr/local/lib/python3.6/dist-packages (from big-sleep) (0.4.0) Requirement already satisfied, skipping upgrade: ftfy in /usr/local/lib/python3.6/dist-packages (from big-sleep) (5.8) Requirement already satisfied, skipping upgrade: regex in /usr/local/lib/python3.6/dist-packages (from big-sleep) (2019.12.20) Requirement already satisfied, skipping upgrade: requests in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-biggan->big-sleep) (2.23.0) Requirement already satisfied, skipping upgrade: numpy in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-biggan->big-sleep) (1.19.5) Requirement already satisfied, skipping upgrade: boto3 in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-biggan->big-sleep) (1.16.63) Requirement already satisfied, skipping upgrade: dataclasses; python_version < "3.7" in /usr/local/lib/python3.6/dist-packages (from torch>=1.7.1->big-sleep) (0.8) Requirement already satisfied, skipping upgrade: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch>=1.7.1->big-sleep) (3.7.4.3) Requirement already satisfied, skipping upgrade: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision>=0.8.2->big-sleep) (7.0.0) Requirement already satisfied, skipping upgrade: six in /usr/local/lib/python3.6/dist-packages (from fire->big-sleep) (1.15.0) Requirement already satisfied, skipping upgrade: termcolor in /usr/local/lib/python3.6/dist-packages (from fire->big-sleep) (1.1.0) Requirement already satisfied, skipping upgrade: wcwidth in /usr/local/lib/python3.6/dist-packages (from ftfy->big-sleep) (0.2.5) Requirement already satisfied, skipping upgrade: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-pretrained-biggan->big-sleep) (2.10) Requirement already satisfied, skipping upgrade: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-pretrained-biggan->big-sleep) (1.24.3) Requirement already satisfied, skipping upgrade: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-pretrained-biggan->big-sleep) (3.0.4) Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-pretrained-biggan->big-sleep) (2020.12.5) Requirement already satisfied, skipping upgrade: botocore<1.20.0,>=1.19.63 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch-pretrained-biggan->big-sleep) (1.19.63) Requirement already satisfied, skipping upgrade: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch-pretrained-biggan->big-sleep) (0.10.0) Requirement already satisfied, skipping upgrade: s3transfer<0.4.0,>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch-pretrained-biggan->big-sleep) (0.3.4) Requirement already satisfied, skipping upgrade: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.6/dist-packages (from botocore<1.20.0,>=1.19.63->boto3->pytorch-pretrained-biggan->big-sleep) (2.8.1)
from IPython.display import Image, display
import string
import torch
from torchvision.utils import save_image
from torchvision import transforms
import numpy as np
from big_sleep import Imagine
from big_sleep.clip import tokenize
from nltk.corpus import stopwords
from skimage.measure import compare_ssim
import cv2
from pathlib import Path
import ipywidgets
import PIL
from PIL import ImageFont, ImageDraw
TEXT = 'story_hallucinator'
SAVE_EVERY = 1
SAVE_PROGRESS = True
LEARNING_RATE = 0.1
ITERATIONS = 1
REPEATS = 5
SPAN = 6
def train_step(self, epoch, i, rand=0):
total_loss = 0
for _ in range(self.gradient_accumulate_every):
losses = self.model(self.encoded_text)
loss = (sum(losses) / self.gradient_accumulate_every) + rand*np.random.randn()
total_loss += loss
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
mres = self.model.model()
return transforms.ToPILImage()(mres[len(mres)-1].cpu()).convert("RGB")
filename = TEXT.replace(' ', '_')
def burnin(words):
# burn in first image
for i in range(10):
phrase = " ".join(words[:SPAN])
im_model.text = phrase.translate(str.maketrans('', '', string.punctuation))
im_model.encoded_text = tokenize(im_model.text).cuda()
train_step(im_model, 0, i)
def add_text_to_im(img,msg_orig):
W, H = img.size
draw = ImageDraw.Draw(img)
font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 18)
msgs = [msg_orig]
w, h = draw.textsize(msg_orig, font=font)
if w>W:
split = span // 2
msgs = [" ".join(words[epoch:epoch+split]), " ".join(words[epoch+split:epoch+span])]
for shift, msg in enumerate(msgs):
w, h = draw.textsize(msg, font=font)
x, y = (W-w)/2, 7*(H-h)/8 + shift*h
adj = 1
#move right
shadowColor = "black"
draw.text((x-adj, y), msg, fill=shadowColor, font=font)
#move left
draw.text((x+adj, y), msg, fill=shadowColor, font=font)
#move up
draw.text((x, y+adj), msg, fill=shadowColor, font=font)
#move down
draw.text((x, y-adj), msg, fill=shadowColor, font=font)
#diagnal left up
draw.text((x-adj, y+adj), msg, fill=shadowColor, font=font)
#diagnal right up
draw.text((x+adj, y+adj), msg, fill=shadowColor, font=font)
#diagnal left down
draw.text((x-adj, y-adj), msg, fill=shadowColor, font=font)
#diagnal right down
draw.text((x+adj, y-adj), msg, fill=shadowColor, font=font)
draw.text((x, y), msg, fill="white", font=font)
def vizualize_words(words,segment_num):
# viz
image_list = []
L = len(words)
step_size = SPAN
iterations = 15
for epoch in range(0,L,step_size):
for i in range(iterations):
phrase = " ".join(words[epoch+int(i*step_size/iterations):epoch+SPAN+int(i*step_size/iterations)])
im_model.text = phrase.translate(str.maketrans('', '', string.punctuation))
im_model.encoded_text = tokenize(im_model.text).cuda()
image_cur = train_step(im_model, epoch, i)
add_text_to_im(image_cur,phrase)
image_list.append(image_cur)
image_list[0].save(fp=f'aidungeon{segment_num}.gif', format='GIF', append_images=image_list[1:], save_all=True, duration=50, loop=0)
iter_num=0
result = "This is the start of the AI Dungeon game"
####
## REFRESH IMAGE MODEL
####
im_model = Imagine(
text = TEXT,
save_every = SAVE_EVERY,
lr = LEARNING_RATE,
iterations = ITERATIONS,
save_progress = SAVE_PROGRESS
)
# burn in first image
for i in range(20):
phrase = " ".join(result.split())
im_model.text = phrase.translate(str.maketrans('', '', string.punctuation))
im_model.encoded_text = tokenize(im_model.text).cuda()
train_step(im_model, 0, i)
result = """
You are Marley, a mutant trying to survive after the deadly plague. You have
scales on your chest and feather on your left arm.
In the town you were born in, your strange condition was considered a sin, and
you have been banished since you were eleven. After a long journey, you find
a ravaged cottage. You look inside and see that it is not inhabited anymore.
The place has become a graveyard. A few of the bodies are still moving, but
they are all dead now
"""
vizualize_words(result.split(),iter_num)
iter_num += 1
with open(f'aidungeon1.gif', 'rb') as f_temp:
progress= ipywidgets.Image(
value=f_temp.read(),
format='gif',
width=512,
height=512)
display(progress)
Image(value=b'GIF89a\x00\x02\x00\x02\x87\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x…
!ls
aidungeon0.gif aidungeon1.gif aidungeon2.gif sample_data