Jack manages two locations for a nationwide car rental company. Each day, some number of customers arrive at each location to rent cars. If Jack has a car available, he rents it out and is credited 10 dollars by the national company. If he is out of cars at that location, then the business is lost. Cars become available for renting the day after they are returned. To help ensure that cars are available where they are needed, Jack can move them between the two locations overnight, at a cost of 2 dollars per car moved. We assume that the number of cars requested and returned at each location are Poisson random variables, meaning that the probability that the number is n is λnn!e−λ, where λ is the expected number. Suppose λ is 3 and 4 for rental requests at the first and second locations and 3 and 2 for returns. To simplify the problem slightly, we assume that there can be no more than 20 cars at each location (any additional cars are returned to the nationwide company, and thus disappear from the problem) and a maximum of five cars can be moved from one location to the other in one night. We take the discount rate to be # = 0.9 and formulate this as a continuing finite MDP, where the time steps are days, the state is the number of cars at each location at the end of the day, and the actions are the net numbers of cars moved between the two locations overnight.
#######################################################################
# Copyright (C) #
# 2016 Shangtong Zhang(zhangshangtong.cpp@gmail.com) #
# 2016 Kenta Shimada(hyperkentakun@gmail.com) #
# 2017 Aja Rangaswamy (aja004@gmail.com) #
# Permission given to modify the code as long as you keep this #
# declaration at the top #
#######################################################################
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from scipy.stats import poisson
%matplotlib inline
# maximum # of cars in each location
MAX_CARS = 20
# maximum # of cars to move during night
MAX_MOVE_OF_CARS = 5
# expectation for rental requests in first location
RENTAL_REQUEST_FIRST_LOC = 3
# expectation for rental requests in second location
RENTAL_REQUEST_SECOND_LOC = 4
# expectation for # of cars returned in first location
RETURNS_FIRST_LOC = 3
# expectation for # of cars returned in second location
RETURNS_SECOND_LOC = 2
DISCOUNT = 0.9
# credit earned by a car
RENTAL_CREDIT = 10
# cost of moving a car
MOVE_CAR_COST = 2
# all possible actions
actions = np.arange(-MAX_MOVE_OF_CARS, MAX_MOVE_OF_CARS + 1)
# An up bound for poisson distribution
# If n is greater than this value, then the probability of getting n is truncated to 0
POISSON_UPPER_BOUND = 11
# Probability for poisson distribution
# @lam: lambda should be less than 10 for this function
poisson_cache = dict()
def poisson_probability(n, lam):
global poisson_cache
key = n * 10 + lam
if key not in poisson_cache:
# 设置字典,防止重复求泊松分布
# key 相当于一种哈希编码 Hash(n)
poisson_cache[key] = poisson.pmf(n, lam)
return poisson_cache[key]
def expected_return(state, action, state_value, constant_returned_cars):
"""
@state: [# of cars in first location, # of cars in second location]
@action: positive if moving cars from first location to second location,
negative if moving cars from second location to first location
@stateValue: state value matrix
@constant_returned_cars: if set True, model is simplified such that
the # of cars returned in daytime becomes constant
rather than a random value from poisson distribution, which will reduce calculation time
and leave the optimal policy/value state matrix almost the same
"""
# initailize total return
returns = 0.0
# cost for moving cars
returns -= MOVE_CAR_COST * abs(action)
# moving cars
NUM_OF_CARS_FIRST_LOC = min(state[0] - action, MAX_CARS)
NUM_OF_CARS_SECOND_LOC = min(state[1] + action, MAX_CARS)
# go through all possible rental requests
for rental_request_first_loc in range(POISSON_UPPER_BOUND):
for rental_request_second_loc in range(POISSON_UPPER_BOUND):
# probability for current combination of rental requests
prob = poisson_probability(rental_request_first_loc, RENTAL_REQUEST_FIRST_LOC) * \
poisson_probability(rental_request_second_loc, RENTAL_REQUEST_SECOND_LOC)
num_of_cars_first_loc = NUM_OF_CARS_FIRST_LOC
num_of_cars_second_loc = NUM_OF_CARS_SECOND_LOC
'''
注意作者的编程习惯,全部大写表示当前客观变量
(当前2个 LOC 各客观存在这么多车)
小写表示临时变量,用于计算
'''
# valid rental requests should be less than actual # of cars
valid_rental_first_loc = min(num_of_cars_first_loc, rental_request_first_loc)
valid_rental_second_loc = min(num_of_cars_second_loc, rental_request_second_loc)
# get credits for renting
reward = (valid_rental_first_loc + valid_rental_second_loc) * RENTAL_CREDIT
num_of_cars_first_loc -= valid_rental_first_loc
num_of_cars_second_loc -= valid_rental_second_loc
'''
这里之所以要用2层 for
是为了遍历状态,求期望
'''
if constant_returned_cars:
# get returned cars, those cars can be used for renting tomorrow
returned_cars_first_loc = RETURNS_FIRST_LOC
returned_cars_second_loc = RETURNS_SECOND_LOC
num_of_cars_first_loc = min(num_of_cars_first_loc + returned_cars_first_loc, MAX_CARS)
num_of_cars_second_loc = min(num_of_cars_second_loc + returned_cars_second_loc, MAX_CARS)
returns += prob * (reward + DISCOUNT * state_value[num_of_cars_first_loc, num_of_cars_second_loc])
'''
这里很重要
作者在这个函数开头便说明了:
还车数当成常数,对结果没有影响
但若是还使用泊松分布生成返回值,则 O(n^4)
极大地影响了运行时间
'''
else:
for returned_cars_first_loc in range(POISSON_UPPER_BOUND):
for returned_cars_second_loc in range(POISSON_UPPER_BOUND):
prob_return = poisson_probability(
returned_cars_first_loc, RETURNS_FIRST_LOC) * poisson_probability(returned_cars_second_loc, RETURNS_SECOND_LOC)
num_of_cars_first_loc_ = min(num_of_cars_first_loc + returned_cars_first_loc, MAX_CARS)
num_of_cars_second_loc_ = min(num_of_cars_second_loc + returned_cars_second_loc, MAX_CARS)
prob_ = prob_return * prob
returns += prob_ * (reward + DISCOUNT *
state_value[num_of_cars_first_loc_, num_of_cars_second_loc_])
return returns
def figure_4_2(constant_returned_cars=True):
value = np.zeros((MAX_CARS + 1, MAX_CARS + 1))
policy = np.zeros(value.shape, dtype=np.int)
iterations = 0
_, axes = plt.subplots(2, 3, figsize=(40, 20)) # 注意这里
plt.subplots_adjust(wspace=0.1, hspace=0.2) # 这里的 subplot 技巧
axes = axes.flatten()
while True:
fig = sns.heatmap(np.flipud(policy), cmap="YlGnBu", ax=axes[iterations])
fig.set_ylabel('# cars at first location', fontsize=30)
fig.set_yticks(list(reversed(range(MAX_CARS + 1))))
fig.set_xlabel('# cars at second location', fontsize=30)
fig.set_title('policy {}'.format(iterations), fontsize=30)
# policy evaluation (in-place)
while True:
old_value = value.copy()
for i in range(MAX_CARS + 1):
for j in range(MAX_CARS + 1):
new_state_value = expected_return([i, j], policy[i, j], value, constant_returned_cars)
value[i, j] = new_state_value
'''
注意这里的编程习惯
使用 new_state_value 过渡
增强易读性
'''
max_value_change = abs(old_value - value).max()
print('max value change {}'.format(max_value_change))
'''
直到评估出当前策略对应的 最优价值 才结束
'''
if max_value_change < 1e-4:
break
# policy improvement
policy_stable = True
for i in range(MAX_CARS + 1):
for j in range(MAX_CARS + 1):
old_action = policy[i, j]
action_returns = []
for action in actions:
if (0 <= action <= i) or (-j <= action <= 0):
action_returns.append(expected_return([i, j], action, value, constant_returned_cars))
else:
action_returns.append(-np.inf)
new_action = actions[np.argmax(action_returns)] # 注意这句
policy[i, j] = new_action
if policy_stable and old_action != new_action:
# 确认 policy_stable == True
# 以保证程序安全性,如果为 False 就没必要考虑置为 False 了
policy_stable = False
print('policy stable {}'.format(policy_stable))
if policy_stable:
fig = sns.heatmap(np.flipud(value), cmap="YlGnBu", ax=axes[-1])
fig.set_ylabel('# cars at first location', fontsize=30)
fig.set_yticks(list(reversed(range(MAX_CARS + 1))))
fig.set_xlabel('# cars at second location', fontsize=30)
fig.set_title('optimal value', fontsize=30)
break
iterations += 1
plt.show()
figure_4_2()
max value change 196.62783361783852 max value change 134.98823859766583 max value change 91.41415360228919 max value change 67.17097732555729 max value change 51.29055484635097 max value change 38.49091000659837 max value change 29.406139835126424 max value change 25.7210573245398 max value change 22.381602293031023 max value change 19.40385808254939 max value change 16.77577350573091 max value change 14.47251552455765 max value change 12.464101852186843 max value change 10.719367983418692 max value change 9.20806226246873 max value change 7.9019189666795455 max value change 6.775146571130392 max value change 5.8045764710083745 max value change 4.969618520007145 max value change 4.252112693842776 max value change 3.6361309524054946 max value change 3.107761240497666 max value change 2.654891834022692 max value change 2.26700589940549 max value change 1.9349911763441128 max value change 1.650966802154585 max value change 1.4081276418079938 max value change 1.2006055672075036 max value change 1.02334664187498 max value change 0.8720029351049448 max value change 0.7428376083516355 max value change 0.6326419232035505 max value change 0.5386628774742235 max value change 0.45854026040933604 max value change 0.3902520158000584 max value change 0.33206690395809346 max value change 0.28250355471067223 max value change 0.2402951004837064 max value change 0.20435866938208846 max value change 0.1737691018435612 max value change 0.14773633074884174 max value change 0.12558593365213255 max value change 0.10674242749371388 max value change 0.09071493100810812 max value change 0.07708486873008269 max value change 0.06549543334426744 max value change 0.05564256088280217 max value change 0.047267206266042194 max value change 0.040148735572074656 max value change 0.03409927655258116 max value change 0.028958890796900505 max value change 0.02459144993093787 max value change 0.02088111467702447 max value change 0.01772932984539466 max value change 0.0150522606048753 max value change 0.012778605996800252 max value change 0.010847734796413988 max value change 0.00920809667559297 max value change 0.007815868406908066 max value change 0.006633800647819044 max value change 0.005630235829983121 max value change 0.0047782719802853535 max value change 0.004055050928627679 max value change 0.003441152547054571 max value change 0.002920079317334512 max value change 0.0024778178350288727 max value change 0.0021024658457236 max value change 0.0017839150607983356 max value change 0.0015135814573454809 max value change 0.0012841759864272717 max value change 0.0010895096651211134 max value change 0.0009243279118891223 max value change 0.0007841697620847299 max value change 0.000665248240920846 max value change 0.0005643487139082026 max value change 0.00047874254232738167 max value change 0.00040611372588728045 max value change 0.0003444966009737982 max value change 0.000292222921245866 max value change 0.0002478769196159192 max value change 0.00021025715051337102 max value change 0.00017834408060934948 max value change 0.00015127258103575514 max value change 0.00012830856951495662 max value change 0.00010882917825938421 max value change 9.230592559106299e-05 policy stable False max value change 72.93565506480746 max value change 5.771584637253568 max value change 2.1472970104344995 max value change 1.070365975080108 max value change 0.8619106467957636 max value change 0.7181428891676092 max value change 0.611364010490604 max value change 0.5169059906119742 max value change 0.4358272831748309 max value change 0.3670218562992318 max value change 0.30890785349942007 max value change 0.259927010978231 max value change 0.21868429274547907 max value change 0.18397356667821896 max value change 0.15476712387498992 max value change 0.1301950284682789 max value change 0.10952318723241206 max value change 0.09213308201026393 max value change 0.07750397308279844 max value change 0.065197614125168 max value change 0.054845259444334715 max value change 0.046136675532011395 max value change 0.038810871455780216 max value change 0.03264828952961807 max value change 0.02746423075132043 max value change 0.02310332179803254 max value change 0.019434859485443212 max value change 0.016348893962117472 max value change 0.013752933586260951 max value change 0.011569172929284832 max value change 0.009732160858675343 max value change 0.008186838889457704 max value change 0.006886890992063854 max value change 0.005793355427499591 max value change 0.004873456997415815 max value change 0.004099624717810002 max value change 0.003448665460666689 max value change 0.0029010688248831684 max value change 0.0024404223665897007 max value change 0.002052919695017863 max value change 0.0017269466670768452 max value change 0.0014527332954799022 max value change 0.0012220609166320173 max value change 0.001028015870474519 max value change 0.0008647822832017482 max value change 0.000727467756007627 max value change 0.0006119567266296144 max value change 0.0005147871261783621 max value change 0.0004330466088617868 max value change 0.00036428526584586507 max value change 0.0003064422000988998 max value change 0.0002577837491344326 max value change 0.00021685153535599966 max value change 0.00018241874482782805 max value change 0.00015345336976224644 max value change 0.00012908726307614415 max value change 0.00010859013127628714 max value change 9.134763803331225e-05 policy stable False max value change 4.7865793901779625 max value change 3.2947349349497017 max value change 2.2411823866665372 max value change 1.616931343950455 max value change 1.1197864003121367 max value change 0.7204544260453076 max value change 0.443826224180043 max value change 0.270089591177225 max value change 0.16639579119885184 max value change 0.1097569388878128 max value change 0.09306955083684443 max value change 0.07883243113371918 max value change 0.06673516197616891 max value change 0.05647744756430484 max value change 0.04778890580797679 max value change 0.04043363544485601 max value change 0.03420889623009771 max value change 0.02894175601244342 max value change 0.0244852782967655 max value change 0.020714866829109724 max value change 0.01752498189301832 max value change 0.014826276628639334 max value change 0.01254313570910881 max value change 0.010611575530163009 max value change 0.008977459770164842 max value change 0.00759498609147613 max value change 0.006425404193464601 max value change 0.005435930515432119 max value change 0.004598829696703888 max value change 0.00389063733302919 max value change 0.003291502340118768 max value change 0.0027846305547996053 max value change 0.0023558140037494013 max value change 0.0019930326484427496 max value change 0.0016861174636346732 max value change 0.0014264653930240456 max value change 0.001206798198722936 max value change 0.0010209584465883381 max value change 0.0008637369118673632 max value change 0.0007307265586291578 max value change 0.0006181990110007973 max value change 0.0005230000359119913 max value change 0.00044246113805002096 max value change 0.00037432475153309497 max value change 0.00031668096374914967 max value change 0.0002679139766428307 max value change 0.00022665681694888917 max value change 0.00019175301355289776 max value change 0.00016222418821598694 max value change 0.00013724262737468962 max value change 0.00011610807808892787 max value change 9.822812404536307e-05 policy stable False max value change 0.5643315459673204 max value change 0.19760617142037518 max value change 0.10013580858492332 max value change 0.06076229858263105 max value change 0.04080851176706801 max value change 0.02724975517776329 max value change 0.01637959485265128 max value change 0.00917172069227945 max value change 0.0049277609952014245 max value change 0.0025834353657501197 max value change 0.0013420404746966597 max value change 0.0007016294298409775 max value change 0.00037558255417025066 max value change 0.00020989058543818828 max value change 0.00013043237390775175 max value change 0.00011051700198549952 max value change 9.361574132071837e-05 policy stable False max value change 0.04079438312567163 max value change 0.010408227162770345 max value change 0.005110707129347247 max value change 0.0032318390198042835 max value change 0.0021719229242762594 max value change 0.0013911772695109903 max value change 0.0008154469392138708 max value change 0.0004459807777266178 max value change 0.0002340408432246477 max value change 0.00012037610895276885 max value change 6.173777182993945e-05 policy stable True
jupyter 上迟迟不执行,怀疑 jupyter 无法处理多线程。
# This file is contributed by Tahsincan Kรถse which implements a synchronous policy evaluation, while the car_rental.py
# implements an asynchronous policy evaluation. This file also utilizes multi-processing for acceleration and contains
# an answer to Exercise 4.5
import numpy as np
import matplotlib.pyplot as plt
import math
import tqdm
import multiprocessing as mp
from functools import partial
import time
import itertools
%matplotlib inline
############# PROBLEM SPECIFIC CONSTANTS #######################
MAX_CARS = 20
MAX_MOVE = 5
MOVE_COST = -2
ADDITIONAL_PARK_COST = -4
RENT_REWARD = 10
# expectation for rental requests in first location
RENTAL_REQUEST_FIRST_LOC = 3
# expectation for rental requests in second location
RENTAL_REQUEST_SECOND_LOC = 4
# expectation for # of cars returned in first location
RETURNS_FIRST_LOC = 3
# expectation for # of cars returned in second location
RETURNS_SECOND_LOC = 2
################################################################
poisson_cache = dict()
def poisson(n, lam):
global poisson_cache
key = n * 10 + lam
if key not in poisson_cache.keys():
poisson_cache[key] = math.exp(-lam) * math.pow(lam, n) / math.factorial(n)
return poisson_cache[key]
class PolicyIteration:
def __init__(self, truncate, parallel_processes, delta=1e-2, gamma=0.9, solve_4_5=False):
self.TRUNCATE = truncate
self.NR_PARALLEL_PROCESSES = parallel_processes
self.actions = np.arange(-MAX_MOVE, MAX_MOVE + 1)
self.inverse_actions = {el: ind[0] for ind, el in np.ndenumerate(self.actions)}
self.values = np.zeros((MAX_CARS + 1, MAX_CARS + 1))
self.policy = np.zeros(self.values.shape, dtype=np.int)
self.delta = delta
self.gamma = gamma
self.solve_extension = solve_4_5
def solve(self):
iterations = 0
total_start_time = time.time()
while True:
start_time = time.time()
self.values = self.policy_evaluation(self.values, self.policy)
elapsed_time = time.time() - start_time
print(f'PE => Elapsed time {elapsed_time} seconds')
start_time = time.time()
policy_change, self.policy = self.policy_improvement(self.actions, self.values, self.policy)
elapsed_time = time.time() - start_time
print(f'PI => Elapsed time {elapsed_time} seconds')
if policy_change == 0:
break
iterations += 1
total_elapsed_time = time.time() - total_start_time
print(f'Optimal policy is reached after {iterations} iterations in {total_elapsed_time} seconds')
# out-place
def policy_evaluation(self, values, policy):
global MAX_CARS
while True:
new_values = np.copy(values)
k = np.arange(MAX_CARS + 1)
# cartesian product
all_states = ((i, j) for i, j in itertools.product(k, k))
results = []
with mp.Pool(processes=self.NR_PARALLEL_PROCESSES) as p:
'''
多线程,传入 all_states 参数
固定 func, policy, values
在临界区(不知理解的对不对)
'''
cook = partial(self.expected_return_pe, policy, values)
results = p.map(cook, all_states)
for v, i, j in results:
new_values[i, j] = v
difference = np.abs(new_values - values).sum()
print(f'Difference: {difference}')
values = new_values
if difference < self.delta:
print(f'Values are converged!')
return values
def policy_improvement(self, actions, values, policy):
new_policy = np.copy(policy)
expected_action_returns = np.zeros((MAX_CARS + 1, MAX_CARS + 1, np.size(actions)))
cooks = dict()
with mp.Pool(processes=8) as p:
for action in actions:
k = np.arange(MAX_CARS + 1)
all_states = ((i, j) for i, j in itertools.product(k, k))
cooks[action] = partial(self.expected_return_pi, values, action)
results = p.map(cooks[action], all_states)
for v, i, j, a in results:
expected_action_returns[i, j, self.inverse_actions[a]] = v
for i in range(expected_action_returns.shape[0]):
for j in range(expected_action_returns.shape[1]):
new_policy[i, j] = actions[np.argmax(expected_action_returns[i, j])]
policy_change = (new_policy != policy).sum()
print(f'Policy changed in {policy_change} states')
return policy_change, new_policy
# O(n^4) computation for all possible requests and returns
def bellman(self, values, action, state):
expected_return = 0
if self.solve_extension:
if action > 0:
# Free shuttle to the second location
expected_return += MOVE_COST * (action - 1)
else:
expected_return += MOVE_COST * abs(action)
else:
expected_return += MOVE_COST * abs(action)
for req1 in range(0, self.TRUNCATE):
for req2 in range(0, self.TRUNCATE):
# moving cars
num_of_cars_first_loc = int(min(state[0] - action, MAX_CARS))
num_of_cars_second_loc = int(min(state[1] + action, MAX_CARS))
# valid rental requests should be less than actual # of cars
real_rental_first_loc = min(num_of_cars_first_loc, req1)
real_rental_second_loc = min(num_of_cars_second_loc, req2)
# get credits for renting
reward = (real_rental_first_loc + real_rental_second_loc) * RENT_REWARD
if self.solve_extension:
if num_of_cars_first_loc >= 10:
reward += ADDITIONAL_PARK_COST
if num_of_cars_second_loc >= 10:
reward += ADDITIONAL_PARK_COST
num_of_cars_first_loc -= real_rental_first_loc
num_of_cars_second_loc -= real_rental_second_loc
# probability for current combination of rental requests
prob = poisson(req1, RENTAL_REQUEST_FIRST_LOC) * \
poisson(req2, RENTAL_REQUEST_SECOND_LOC)
for ret1 in range(0, self.TRUNCATE):
for ret2 in range(0, self.TRUNCATE):
num_of_cars_first_loc_ = min(num_of_cars_first_loc + ret1, MAX_CARS)
num_of_cars_second_loc_ = min(num_of_cars_second_loc + ret2, MAX_CARS)
prob_ = poisson(ret1, RETURNS_FIRST_LOC) * \
poisson(ret2, RETURNS_SECOND_LOC) * prob
# Classic Bellman equation for state-value
# prob_ corresponds to p(s'|s,a) for each possible s' -> (num_of_cars_first_loc_,num_of_cars_second_loc_)
expected_return += prob_ * (
reward + self.gamma * values[num_of_cars_first_loc_, num_of_cars_second_loc_])
return expected_return
# Parallelization enforced different helper functions
# Expected return calculator for Policy Evaluation
def expected_return_pe(self, policy, values, state):
action = policy[state[0], state[1]]
expected_return = self.bellman(values, action, state)
return expected_return, state[0], state[1]
# Expected return calculator for Policy Improvement
def expected_return_pi(self, values, action, state):
if ((action >= 0 and state[0] >= action) or (action < 0 and state[1] >= abs(action))) == False:
return -float('inf'), state[0], state[1], action
expected_return = self.bellman(values, action, state)
return expected_return, state[0], state[1], action
def plot(self):
print(self.policy)
plt.figure()
plt.xlim(0, MAX_CARS + 1)
plt.ylim(0, MAX_CARS + 1)
plt.table(cellText=self.policy, loc=(0, 0), cellLoc='center')
plt.show()
TRUNCATE = 9
solver = PolicyIteration(TRUNCATE, parallel_processes=4, delta=1e-1, gamma=0.9, solve_4_5=True)
solver.solve()
solver.plot()