確率ロボティクス2017第9回-2

上田隆一

2017年11月15日@千葉工業大学

Q学習

  • 方策オフ型TD学習
  • 次の式を使う
    • $Q(s,a) \longleftarrow (1-\alpha )Q(s,a) + \alpha [r + \gamma \max_{a'} Q(s',a')]$
  • $\epsilon$-グリーディ方策を使っても非グリーディな行動が価値関数に影響を与えない

Q学習の実装

09.ipynbをコピーしてsarsaの関数のところから書き換えます。

エージェントの定義

In [1]:
class Agent:
    def __init__(self):
        self.actions = ["up","down","left","right"]
        self.pos = (0,0)
        
agent = Agent()

状態の定義

In [2]:
size = 3

class State:
    def __init__(self,actions):
        self.Q = {}
        for a in actions:
            self.Q[a] = 0.0
        self.best_action = "up"
        self.goal = False
        
    def set_goal(self,actions):
        for a in actions:
            self.Q[a] =0.0
        self.goal = True
        
states = [[State(agent.actions) for i in range(size)] for j in range(size)]
states[2][2].set_goal(agent.actions)

描画

In [3]:
import matplotlib.pyplot as plt  
import matplotlib.patches as patches

def draw(mark_pos):
    fig, ax = plt.subplots()
    values = [[states[i][j].Q[states[i][j].best_action] for j in range(size)] for i in range(size)]
    mp = ax.pcolor(values, cmap=plt.cm.YlOrRd,vmin=0,vmax=8)
    ax.set_aspect(1)
    ax.set_xticks(range(size), minor=False)
    ax.set_yticks(range(size), minor=False)
    
    for x in range(len(values)):
        for y in range(len(values[0])):
            s = states[x][y]
            plt.text(x+0.5,y+0.5,int(1000*s.Q[s.best_action])/1000,ha = 'center', va = 'center', size=20)
            if states[x][y].goal:
                plt.text(x+0.75,y+0.75,"G",ha = 'center', va = 'center', size=20)
                
    plt.text(agent.pos[0]+0.5,agent.pos[1]+0.25,"agent",ha = 'center', va = 'center', size=20)
    
    if mark_pos == "all":   # 指定した位置にactionの文字列を書くという処理
        for x in range(size):
            for y in range(size):
                if states[x][y].goal: continue
                plt.text(x+0.5,y+0.25,states[x][y].best_action,ha = 'center', va = 'center', size=20)
    elif mark_pos != None: 
        s = states[mark_pos[0]][mark_pos[1]]
        plt.text(mark_pos[0]+0.5,mark_pos[1]+0.25,s.best_action,ha = 'center', va = 'center', size=20)
            
    plt.show()
    fig.clear()
    
draw(None)

状態遷移の実装

In [4]:
import random

def state_transition(s_pos,a):
    ###確率10%で元のまま ###
    if random.uniform(0,1) < 0.1:
        return s_pos
    
    x,y = s_pos
    if   a == "up": y += 1
    elif a == "down": y -= 1
    elif a == "right": x += 1
    elif a == "left": x -= 1
        
    if x < 0:       x = 0
    elif x >= size: x = size-1
    if y < 0:       y = 0
    elif y >= size: y = size-1
        
    return (x,y)

方策($\epsilon$-greedy)

In [5]:
def e_greedy(s):
    best_a = None
    best_q = 1000000000
    for a in s.Q:
        if best_q > s.Q[a]:
            best_q = s.Q[a]
            best_a = a
    s.best_action = best_a
        
    if random.uniform(0,1) < 0.1: #10%でランダムに
        return random.choice(agent.actions)
    else:
        return best_a

Q学習の1ステップの処理

ここからsarsaと少し違います。

In [6]:
alpha = 0.5
gamma = 1.0

def q_proc(s_pos,a):
    s = states[s_pos[0]][s_pos[1]]
    s_next_pos = state_transition(s_pos,a)
    s_next = states[s_next_pos[0]][s_next_pos[1]]
    a_next = e_greedy(s_next)
    
    q = (1.0-alpha)*s.Q[a] + alpha * (1.0 + gamma * s_next.Q[s_next.best_action])
    print("s:" + str(s_pos)+ " a:" + a + " s':" + str(s_next_pos))
    print("----")
    return s_next_pos, a_next, q

def one_trial():
    agent.pos = (random.randrange(size),random.randrange(size))
    a = e_greedy(states[agent.pos[0]][agent.pos[1]])
    if states[agent.pos[0]][agent.pos[1]].goal:
        return
          
    while True:
        #draw(None)
        s_next, a_next, q = q_proc(agent.pos,a)
        states[agent.pos[0]][agent.pos[1]].Q[a] = q
        agent.pos = s_next
        a = a_next
        if states[agent.pos[0]][agent.pos[1]].goal:
            break
            
            
for i in range(100):
    one_trial()
    draw("all")
s:(1, 2) a:up s':(1, 2)
----
s:(1, 2) a:up s':(1, 2)
----
s:(1, 2) a:down s':(1, 1)
----
s:(1, 1) a:up s':(1, 2)
----
s:(1, 2) a:left s':(0, 2)
----
s:(0, 2) a:up s':(0, 2)
----
s:(0, 2) a:up s':(0, 2)
----
s:(0, 2) a:down s':(0, 1)
----
s:(0, 1) a:up s':(0, 2)
----
s:(0, 2) a:left s':(0, 2)
----
s:(0, 2) a:left s':(0, 2)
----
s:(0, 2) a:right s':(1, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(0, 0) a:up s':(0, 1)
----
s:(0, 1) a:down s':(0, 0)
----
s:(0, 0) a:down s':(0, 0)
----
s:(0, 0) a:down s':(0, 0)
----
s:(0, 0) a:left s':(0, 0)
----
s:(0, 0) a:left s':(0, 0)
----
s:(0, 0) a:right s':(0, 0)
----
s:(0, 0) a:right s':(1, 0)
----
s:(1, 0) a:up s':(1, 1)
----
s:(1, 1) a:down s':(1, 0)
----
s:(1, 0) a:down s':(1, 0)
----
s:(1, 0) a:down s':(1, 0)
----
s:(1, 0) a:left s':(0, 0)
----
s:(0, 0) a:up s':(0, 1)
----
s:(0, 1) a:left s':(0, 1)
----
s:(0, 1) a:left s':(0, 1)
----
s:(0, 1) a:down s':(0, 1)
----
s:(0, 1) a:right s':(1, 1)
----
s:(1, 1) a:left s':(0, 1)
----
s:(0, 1) a:up s':(0, 1)
----
s:(0, 1) a:up s':(0, 2)
----
s:(0, 2) a:left s':(0, 2)
----
s:(0, 2) a:down s':(0, 1)
----
s:(0, 1) a:right s':(1, 1)
----
s:(1, 1) a:right s':(2, 1)
----
s:(2, 1) a:up s':(2, 1)
----
s:(2, 1) a:up s':(2, 2)
----
s:(1, 0) a:right s':(2, 0)
----
s:(2, 0) a:left s':(2, 0)
----
s:(2, 0) a:up s':(2, 1)
----
s:(2, 1) a:down s':(2, 1)
----
s:(2, 1) a:down s':(2, 0)
----
s:(2, 0) a:down s':(2, 0)
----
s:(2, 0) a:down s':(2, 0)
----
s:(2, 0) a:right s':(2, 0)
----
s:(2, 0) a:right s':(2, 0)
----
s:(2, 0) a:up s':(2, 1)
----
s:(2, 1) a:up s':(2, 2)
----
s:(0, 2) a:right s':(1, 2)
----
s:(1, 2) a:down s':(1, 1)
----
s:(1, 1) a:left s':(0, 1)
----
s:(0, 1) a:down s':(0, 0)
----
s:(0, 0) a:up s':(0, 1)
----
s:(0, 1) a:left s':(0, 1)
----
s:(0, 1) a:left s':(0, 1)
----
s:(0, 1) a:right s':(1, 1)
----
s:(1, 1) a:up s':(1, 2)
----
s:(1, 2) a:left s':(0, 2)
----
s:(0, 2) a:up s':(0, 2)
----
s:(0, 2) a:up s':(0, 2)
----
s:(0, 2) a:down s':(0, 1)
----
s:(0, 1) a:right s':(1, 1)
----
s:(1, 1) a:down s':(1, 0)
----
s:(1, 0) a:right s':(2, 0)
----
s:(2, 0) a:left s':(1, 0)
----
s:(1, 0) a:up s':(1, 0)
----
s:(1, 0) a:up s':(1, 1)
----
s:(1, 1) a:right s':(2, 1)
----
s:(2, 1) a:left s':(1, 1)
----
s:(1, 1) a:right s':(2, 1)
----
s:(2, 1) a:right s':(2, 1)
----
s:(2, 1) a:right s':(2, 1)
----
s:(2, 1) a:right s':(2, 1)
----
s:(2, 1) a:down s':(2, 0)
----
s:(2, 0) a:up s':(2, 0)
----
s:(2, 0) a:up s':(2, 1)
----
s:(2, 1) a:up s':(2, 2)
----
s:(0, 2) a:right s':(1, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(2, 0) a:down s':(2, 0)
----
s:(2, 0) a:down s':(2, 0)
----
s:(2, 0) a:left s':(1, 0)
----
s:(1, 0) a:down s':(1, 0)
----
s:(1, 0) a:down s':(1, 0)
----
s:(1, 0) a:left s':(0, 0)
----
s:(0, 0) a:down s':(0, 0)
----
s:(0, 0) a:down s':(0, 0)
----
s:(0, 0) a:left s':(0, 0)
----
s:(0, 0) a:left s':(0, 0)
----
s:(0, 0) a:right s':(0, 0)
----
s:(0, 0) a:right s':(0, 0)
----
s:(0, 0) a:down s':(0, 0)
----
s:(0, 0) a:up s':(0, 1)
----
s:(0, 1) a:up s':(0, 1)
----
s:(0, 1) a:up s':(0, 2)
----
s:(0, 2) a:left s':(0, 2)
----
s:(0, 2) a:right s':(1, 2)
----
s:(1, 2) a:up s':(1, 2)
----
s:(1, 2) a:up s':(1, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(0, 2) a:right s':(1, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(2, 1) a:left s':(1, 1)
----
s:(1, 1) a:right s':(2, 1)
----
s:(2, 1) a:up s':(2, 2)
----
s:(0, 0) a:left s':(0, 0)
----
s:(0, 0) a:left s':(0, 0)
----
s:(0, 0) a:up s':(0, 1)
----
s:(0, 1) a:down s':(0, 0)
----
s:(0, 0) a:right s':(1, 0)
----
s:(1, 0) a:left s':(0, 0)
----
s:(0, 0) a:down s':(0, 0)
----
s:(0, 0) a:down s':(0, 0)
----
s:(0, 0) a:right s':(1, 0)
----
s:(1, 0) a:up s':(1, 0)
----
s:(1, 0) a:right s':(2, 0)
----
s:(2, 0) a:right s':(2, 0)
----
s:(2, 0) a:right s':(2, 0)
----
s:(2, 0) a:left s':(1, 0)
----
s:(1, 0) a:down s':(1, 0)
----
s:(1, 0) a:down s':(1, 0)
----
s:(1, 0) a:right s':(2, 0)
----
s:(2, 0) a:up s':(2, 1)
----
s:(2, 1) a:left s':(1, 1)
----
s:(1, 1) a:up s':(1, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(2, 0) a:down s':(2, 0)
----
s:(2, 0) a:right s':(2, 0)
----
s:(2, 0) a:up s':(2, 1)
----
s:(2, 1) a:right s':(2, 1)
----
s:(2, 1) a:up s':(2, 2)
----
s:(0, 0) a:right s':(1, 0)
----
s:(1, 0) a:up s':(1, 1)
----
s:(1, 1) a:down s':(1, 0)
----
s:(1, 0) a:up s':(1, 1)
----
s:(1, 1) a:left s':(0, 1)
----
s:(0, 1) a:right s':(1, 1)
----
s:(1, 1) a:right s':(2, 1)
----
s:(2, 1) a:up s':(2, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(0, 0) a:up s':(0, 0)
----
s:(0, 0) a:up s':(0, 1)
----
s:(0, 1) a:left s':(0, 1)
----
s:(0, 1) a:left s':(0, 1)
----
s:(0, 1) a:right s':(1, 1)
----
s:(1, 1) a:up s':(1, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(0, 0) a:right s':(0, 0)
----
s:(0, 0) a:right s':(1, 0)
----
s:(1, 0) a:up s':(1, 1)
----
s:(1, 1) a:right s':(2, 1)
----
s:(2, 1) a:up s':(2, 2)
----
s:(1, 0) a:right s':(2, 0)
----
s:(2, 0) a:up s':(2, 1)
----
s:(2, 1) a:up s':(2, 2)
----
s:(0, 1) a:up s':(0, 2)
----
s:(0, 2) a:right s':(1, 2)
----
s:(1, 2) a:up s':(1, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(0, 1) a:down s':(0, 0)
----
s:(0, 0) a:left s':(0, 0)
----
s:(0, 0) a:left s':(0, 0)
----
s:(0, 0) a:up s':(0, 1)
----
s:(0, 1) a:right s':(1, 1)
----
s:(1, 1) a:up s':(1, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(0, 1) a:up s':(0, 2)
----
s:(0, 2) a:down s':(0, 1)
----
s:(0, 1) a:up s':(0, 2)
----
s:(0, 2) a:up s':(0, 2)
----
s:(0, 2) a:up s':(0, 2)
----
s:(0, 2) a:left s':(0, 2)
----
s:(0, 2) a:left s':(0, 2)
----
s:(0, 2) a:right s':(1, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(2, 1) a:up s':(2, 2)
----
s:(0, 1) a:left s':(0, 1)
----
s:(0, 1) a:left s':(0, 1)
----
s:(0, 1) a:right s':(1, 1)
----
s:(1, 1) a:down s':(1, 0)
----
s:(1, 0) a:left s':(0, 0)
----
s:(0, 0) a:down s':(0, 0)
----
s:(0, 0) a:down s':(0, 0)
----
s:(0, 0) a:up s':(0, 1)
----
s:(0, 1) a:up s':(0, 2)
----
s:(0, 2) a:right s':(1, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(2, 0) a:up s':(2, 1)
----
s:(2, 1) a:up s':(2, 2)
----
s:(0, 2) a:right s':(1, 2)
----
s:(1, 2) a:right s':(1, 2)
----
s:(1, 2) a:right s':(1, 2)
----
s:(1, 2) a:down s':(1, 1)
----
s:(1, 1) a:left s':(0, 1)
----
s:(0, 1) a:right s':(1, 1)
----
s:(1, 1) a:right s':(2, 1)
----
s:(2, 1) a:up s':(2, 2)
----
s:(2, 1) a:up s':(2, 2)
----
s:(1, 0) a:down s':(1, 0)
----
s:(1, 0) a:down s':(1, 0)
----
s:(1, 0) a:up s':(1, 1)
----
s:(1, 1) a:up s':(1, 2)
----
s:(1, 2) a:left s':(0, 2)
----
s:(0, 2) a:right s':(1, 2)
----
s:(1, 2) a:up s':(1, 2)
----
s:(1, 2) a:up s':(1, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(1, 0) a:right s':(2, 0)
----
s:(2, 0) a:left s':(1, 0)
----
s:(1, 0) a:up s':(1, 1)
----
s:(1, 1) a:right s':(2, 1)
----
s:(2, 1) a:up s':(2, 2)
----
s:(0, 0) a:right s':(1, 0)
----
s:(1, 0) a:right s':(1, 0)
----
s:(1, 0) a:right s':(2, 0)
----
s:(2, 0) a:up s':(2, 1)
----
s:(2, 1) a:up s':(2, 2)
----
s:(1, 0) a:up s':(1, 1)
----
s:(1, 1) a:right s':(2, 1)
----
s:(2, 1) a:up s':(2, 2)
----
s:(2, 1) a:up s':(2, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(0, 1) a:down s':(0, 0)
----
s:(0, 0) a:up s':(0, 1)
----
s:(0, 1) a:up s':(0, 2)
----
s:(0, 2) a:right s':(1, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(2, 0) a:up s':(2, 0)
----
s:(2, 0) a:up s':(2, 1)
----
s:(2, 1) a:right s':(2, 1)
----
s:(2, 1) a:right s':(2, 1)
----
s:(2, 1) a:up s':(2, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(1, 0) a:left s':(0, 0)
----
s:(0, 0) a:left s':(0, 0)
----
s:(0, 0) a:left s':(0, 0)
----
s:(0, 0) a:right s':(1, 0)
----
s:(1, 0) a:up s':(1, 1)
----
s:(1, 1) a:right s':(2, 1)
----
s:(2, 1) a:up s':(2, 2)
----
s:(1, 1) a:right s':(2, 1)
----
s:(2, 1) a:up s':(2, 2)
----
s:(1, 2) a:right s':(2, 2)
----
s:(1, 2) a:right s':(2, 2)
----