Original code taken from https://gist.github.com/EderSantana/c7222daa328f0e885093
python3 -m pip install matplotlib --upgrade
%matplotlib inline
import json
import numpy as np
import random
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import matplotlib.animation
import IPython.display
The idea in this FallingFruit game is that the user has to move to try to dodge the fruit. If they dodge it, they win and the game is over. If they get hit by it, they lose and the game is over. We are trying to teach the computer to play this game.
class FallingFruit(object):
def __init__(self, grid_size=10):
'''
Input: grid_size (length of the side of the canvas)
Initializes internal state.
'''
self.grid_size = grid_size
self.min_player_center = 1
self.max_player_center = self.grid_size-2
self.reset()
def _update_state(self, action):
'''
Input: action (0 for left, 1 for stay, 2 for right)
Moves player according to action. Moves fruit down. Updates state to reflect these movements
'''
if action == 0: # left
movement = -1
elif action == 1: # stay
movement = 0
elif action == 2: # right
movement = 1
else:
raise Exception('Invalid action {}'.format(action))
fruit_x, fruit_y, player_center = self.state
# move the player unless this would move it off the edge of the grid
new_player_center = min(max(self.min_player_center, player_center + movement), self.max_player_center)
# move fruit down
fruit_y += 1
out = np.asarray([fruit_x, fruit_y, new_player_center])
self.state = out
def _draw_state(self):
'''
Returns a 2D numpy array with 1s (white squares) at the locations of the fruit and player and
0s (black squares) everywhere else.
'''
im_size = (self.grid_size, self.grid_size)
canvas = np.zeros(im_size)
fruit_x, fruit_y, player_center = self.state
canvas[fruit_y, fruit_x] = 1 # draw fruit
canvas[-1, player_center-1:player_center + 2] = 1 # draw 3-pixel player
return canvas
def _get_reward(self):
'''
Returns 1 if the fruit was dodged, -1 if it hit, and 0 if it is still in the air.
'''
fruit_x, fruit_y, player_center = self.state
if fruit_y == self.grid_size-1:
if abs(fruit_x - player_center) <= 1:
return -1 # it was hit by fruit
else:
return 1 # it dodged the fruit
else:
return 0 # the fruit is still in the air
def observe(self):
'''
Returns the current canvas, as a 1D array.
'''
canvas = self._draw_state()
return canvas.reshape((1, -1))
def act(self, action):
'''
Input: action (0 for left, 1 for stay, 2 for right)
Returns:
current canvas (as a 1D array)
reward received after this action
True if episode is over and False otherwise
'''
self._update_state(action)
observation = self.observe()
reward = self._get_reward()
episode_over = (reward != 0) # if the reward is zero, the fruit is still in the air
return observation, reward, episode_over
def reset(self):
'''
Updates internal state
fruit in a random column in the top row
player center in a random column
'''
fruit_x = random.randint(0, self.grid_size-1)
fruit_y = 0
player_center = random.randint(self.min_player_center, self.max_player_center)
self.state = np.asarray([fruit_x, fruit_y, player_center])
class ExperienceReplay(object):
def __init__(self, max_memory=100, discount=.9):
self.max_memory = max_memory
self.memory = list()
self.discount = discount
def remember(self, states, episode_over):
'''
Input:
states: [starting_observation, action_taken, reward_received, new_observation]
episode_over: boolean
Add the states and episode over to the internal memory array. If the array is longer than
self.max_memory, drop the oldest memory
'''
self.memory.append([states, episode_over])
if len(self.memory) > self.max_memory:
del self.memory[0]
def get_batch(self, model, batch_size=10):
'''
Randomly chooses batch_size memories, possibly repeating.
For each of these memories, updates the models current best guesses about the value of taking a
certain action from the starting state, based on the reward received and the model's current
estimate of how valuable the new state is.
'''
len_memory = len(self.memory)
num_actions = model.output_shape[-1] # the number of possible actions
env_dim = self.memory[0][0][0].shape[1] # the number of pixels in the image
input_size = min(len_memory, batch_size)
inputs = np.zeros((input_size, env_dim))
targets = np.zeros((input_size, num_actions))
for i, idx in enumerate(np.random.randint(0, len_memory, size=input_size)):
starting_observation, action_taken, reward_received, new_observation = self.memory[idx][0]
episode_over = self.memory[idx][1]
# Set the input to the state that was observed in the game before an action was taken
inputs[i] = starting_observation[0]
# Start with the model's current best guesses about the value of taking each action from this state
targets[i] = model.predict(starting_observation)[0]
# Now we need to update the value of the action that was taken
if episode_over:
# if the episode is over, give the actual reward received
targets[i, action_taken] = reward_received
else:
# if the episode is not over, give the reward received (always zero in this particular game)
# plus the maximum reward predicted for state we got to by taking this action (with a discount)
Q_sa = np.max(model.predict(new_observation)[0])
targets[i, action_taken] = reward_received + self.discount * Q_sa
return inputs, targets
# parameters
epsilon = .1 # probability of exploration (choosing a random action instead of the current best one)
num_actions = 3 # [move_left, stay, move_right]
max_memory = 500
hidden_size = 100
batch_size = 50
grid_size = 10
print_freq = 10
def build_model():
'''
Returns three initialized objects: the model, the environment, and the replay.
'''
model = Sequential()
model.add(Dense(hidden_size, input_shape=(grid_size**2,), activation='relu'))
model.add(Dense(hidden_size, activation='relu'))
model.add(Dense(num_actions))
model.compile(Adam(), "mse")
# Define environment/game
env = FallingFruit(grid_size)
# Initialize experience replay object
exp_replay = ExperienceReplay(max_memory=max_memory)
return model, env, exp_replay
def take_step(exp_replay, model, starting_observation):
# get next action
if np.random.rand() <= epsilon:
# epsilon of the time, we just choose randomly
action = np.random.randint(0, num_actions, size=1)
else:
# find which action the model currently thinks is best from this state
q = model.predict([starting_observation])
action = np.argmax(q[0])
# apply action, get rewards and new state
new_observation, reward, episode_over = env.act(action)
# store experience
exp_replay.remember([starting_observation, action, reward, new_observation], episode_over)
return new_observation, reward, episode_over
def train_model(model, env, exp_replay, num_episodes, pretrain_episodes=100):
'''
Inputs:
model, env, and exp_replay objects as returned by build_model
num_episodes: integer, the number of episodes that should be rolled out for training
'''
for episode in range(pretrain_episodes):
env.reset()
episode_over = False
# get initial input
starting_observation = env.observe()
while not episode_over:
starting_observation, reward, episode_over = \
take_step(exp_replay, model, starting_observation)
rewards = []
for episode in range(1, num_episodes+1):
episode_reward = 0
loss = 0.
env.reset()
episode_over = False
# get initial input
starting_observation = env.observe()
while not episode_over:
starting_observation, reward, episode_over = \
take_step(exp_replay, model, starting_observation)
episode_reward += reward
rewards.append(episode_reward)
# get data updated based on the stored experiences
inputs, targets = exp_replay.get_batch(model, batch_size=batch_size)
# train model on the updated data
loss += model.train_on_batch(inputs, targets)
# Print update from this episode
if episode % print_freq == 0:
print("Episodes {:04d}-{:04d}/{:04d} | Loss {:.4f} | Rewards {}".format(
episode - print_freq + 1, episode, num_episodes, loss, rewards))
rewards = []
def create_animation(model, env, num_episodes):
'''
Inputs:
model and env objects as returned from build_model
num_episodes: integer, the number of episodes to be included in the animation
Returns: a matplotlib animation object
'''
# Animation code from
# https://matplotlib.org/examples/animation/dynamic_image.html
# https://stackoverflow.com/questions/35532498/animation-in-ipython-notebook/46878531#46878531
# First, play the episodes and collect all of the images for each observed state
observations = []
rewards = []
for _ in range(num_episodes):
episode_reward = 0
env.reset()
observation = env.observe()
observations.append(observation)
episode_over = False
while episode_over == False:
q = model.predict([observation])
action = np.argmax(q[0])
# apply action, get rewards and new state
observation, reward, episode_over = env.act(action)
observations.append(observation)
episode_reward += reward
rewards.append(episode_reward)
fig = plt.figure()
image = plt.imshow(np.zeros((grid_size, grid_size)),interpolation='none', cmap='gray', animated=True, vmin=0, vmax=1)
print('Rewards in Animation: {}'.format(rewards))
def animate(observation):
image.set_array(observation.reshape((grid_size, grid_size)))
return [image]
animation = matplotlib.animation.FuncAnimation(fig, animate, frames=observations, blit=True, )
return animation
model, env, exp_replay = build_model()
animation = create_animation(model, env, num_episodes=10)
IPython.display.HTML(animation.to_jshtml())
Rewards in Animation: [1, 1, -1, 1, 1, -1, -1, 1, -1, -1]
See how much better it is at dodging the fruit
train_model(model, env, exp_replay, num_episodes=400)
animation = create_animation(model, env, num_episodes=10)
IPython.display.HTML(animation.to_jshtml())
Episodes 0001-0010/0400 | Loss 0.0786 | Rewards [1, 1, -1, 1, 1, 1, 1, -1, 1, -1] Episodes 0011-0020/0400 | Loss 0.0162 | Rewards [1, 1, -1, 1, 1, 1, -1, -1, 1, 1] Episodes 0021-0030/0400 | Loss 0.0224 | Rewards [-1, -1, 1, 1, 1, 1, 1, -1, 1, 1] Episodes 0031-0040/0400 | Loss 0.0298 | Rewards [1, 1, 1, 1, 1, 1, 1, 1, -1, 1] Episodes 0041-0050/0400 | Loss 0.0144 | Rewards [1, -1, 1, 1, 1, -1, 1, 1, 1, -1] Episodes 0051-0060/0400 | Loss 0.0317 | Rewards [1, 1, 1, 1, 1, 1, 1, -1, 1, 1] Episodes 0061-0070/0400 | Loss 0.0202 | Rewards [1, 1, 1, -1, 1, 1, 1, 1, 1, 1] Episodes 0071-0080/0400 | Loss 0.0092 | Rewards [1, 1, 1, 1, 1, 1, -1, 1, 1, 1] Episodes 0081-0090/0400 | Loss 0.0071 | Rewards [1, 1, 1, 1, 1, -1, 1, 1, 1, 1] Episodes 0091-0100/0400 | Loss 0.0031 | Rewards [1, 1, 1, 1, -1, -1, 1, 1, -1, -1] Episodes 0101-0110/0400 | Loss 0.0231 | Rewards [1, 1, -1, 1, 1, 1, 1, 1, 1, 1] Episodes 0111-0120/0400 | Loss 0.0127 | Rewards [1, -1, 1, 1, -1, 1, 1, 1, 1, 1] Episodes 0121-0130/0400 | Loss 0.0379 | Rewards [-1, 1, 1, -1, 1, 1, 1, 1, 1, 1] Episodes 0131-0140/0400 | Loss 0.0088 | Rewards [-1, 1, 1, 1, 1, 1, 1, 1, 1, 1] Episodes 0141-0150/0400 | Loss 0.0030 | Rewards [1, 1, 1, -1, 1, 1, 1, 1, 1, -1] Episodes 0151-0160/0400 | Loss 0.0038 | Rewards [1, -1, 1, -1, 1, -1, 1, 1, 1, 1] Episodes 0161-0170/0400 | Loss 0.0328 | Rewards [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] Episodes 0171-0180/0400 | Loss 0.0036 | Rewards [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] Episodes 0181-0190/0400 | Loss 0.0019 | Rewards [1, 1, 1, 1, -1, 1, 1, -1, -1, 1] Episodes 0191-0200/0400 | Loss 0.0142 | Rewards [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] Episodes 0201-0210/0400 | Loss 0.0215 | Rewards [1, 1, -1, -1, 1, -1, 1, 1, 1, -1] Episodes 0211-0220/0400 | Loss 0.0065 | Rewards [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] Episodes 0221-0230/0400 | Loss 0.0226 | Rewards [-1, 1, 1, 1, 1, 1, -1, 1, 1, 1] Episodes 0231-0240/0400 | Loss 0.0055 | Rewards [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] Episodes 0241-0250/0400 | Loss 0.0076 | Rewards [1, 1, 1, 1, 1, 1, 1, 1, -1, 1] Episodes 0251-0260/0400 | Loss 0.0032 | Rewards [1, 1, 1, 1, -1, 1, 1, 1, 1, -1] Episodes 0261-0270/0400 | Loss 0.0032 | Rewards [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] Episodes 0271-0280/0400 | Loss 0.0026 | Rewards [-1, 1, 1, 1, 1, -1, 1, 1, 1, 1] Episodes 0281-0290/0400 | Loss 0.0023 | Rewards [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] Episodes 0291-0300/0400 | Loss 0.0068 | Rewards [1, 1, 1, 1, 1, -1, 1, -1, 1, 1] Episodes 0301-0310/0400 | Loss 0.0052 | Rewards [1, 1, 1, 1, 1, -1, 1, 1, 1, 1] Episodes 0311-0320/0400 | Loss 0.0027 | Rewards [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] Episodes 0321-0330/0400 | Loss 0.0046 | Rewards [1, 1, 1, 1, 1, -1, 1, 1, 1, 1] Episodes 0331-0340/0400 | Loss 0.0090 | Rewards [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] Episodes 0341-0350/0400 | Loss 0.0027 | Rewards [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] Episodes 0351-0360/0400 | Loss 0.0022 | Rewards [-1, 1, 1, 1, 1, 1, 1, 1, -1, 1] Episodes 0361-0370/0400 | Loss 0.0009 | Rewards [1, 1, 1, 1, 1, 1, -1, 1, 1, 1] Episodes 0371-0380/0400 | Loss 0.0026 | Rewards [1, -1, 1, 1, 1, 1, 1, 1, 1, 1] Episodes 0381-0390/0400 | Loss 0.0057 | Rewards [-1, 1, 1, 1, 1, 1, -1, 1, 1, 1] Episodes 0391-0400/0400 | Loss 0.0025 | Rewards [1, 1, 1, 1, -1, 1, 1, 1, -1, 1] Rewards in Animation: [1, 1, 1, 1, -1, 1, 1, 1, 1, 1]