View more, visit my tutorial page: https://morvanzhou.github.io/tutorials/ My Youtube Channel: https://www.youtube.com/user/MorvanZhou More about Reinforcement learning: https://morvanzhou.github.io/tutorials/machine-learning/reinforcement-learning/
Dependencies:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import gym
# Hyper Parameters
BATCH_SIZE = 32
LR = 0.01 # learning rate
EPSILON = 0.9 # greedy policy
GAMMA = 0.9 # reward discount
TARGET_REPLACE_ITER = 100 # target update frequency
MEMORY_CAPACITY = 2000
env = gym.make('CartPole-v0')
env = env.unwrapped
N_ACTIONS = env.action_space.n
N_STATES = env.observation_space.shape[0]
[2017-06-20 22:23:40,418] Making new env: CartPole-v0
class Net(nn.Module):
def __init__(self, ):
super(Net, self).__init__()
self.fc1 = nn.Linear(N_STATES, 10)
self.fc1.weight.data.normal_(0, 0.1) # initialization
self.out = nn.Linear(10, N_ACTIONS)
self.out.weight.data.normal_(0, 0.1) # initialization
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
actions_value = self.out(x)
return actions_value
class DQN(object):
def __init__(self):
self.eval_net, self.target_net = Net(), Net()
self.learn_step_counter = 0 # for target updating
self.memory_counter = 0 # for storing memory
self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2)) # initialize memory
self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)
self.loss_func = nn.MSELoss()
def choose_action(self, x):
x = Variable(torch.unsqueeze(torch.FloatTensor(x), 0))
# input only one sample
if np.random.uniform() < EPSILON: # greedy
actions_value = self.eval_net.forward(x)
action = torch.max(actions_value, 1)[1].data.numpy()[0, 0] # return the argmax
else: # random
action = np.random.randint(0, N_ACTIONS)
return action
def store_transition(self, s, a, r, s_):
transition = np.hstack((s, [a, r], s_))
# replace the old memory with new memory
index = self.memory_counter % MEMORY_CAPACITY
self.memory[index, :] = transition
self.memory_counter += 1
def learn(self):
# target parameter update
if self.learn_step_counter % TARGET_REPLACE_ITER == 0:
self.target_net.load_state_dict(self.eval_net.state_dict())
self.learn_step_counter += 1
# sample batch transitions
sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
b_memory = self.memory[sample_index, :]
b_s = Variable(torch.FloatTensor(b_memory[:, :N_STATES]))
b_a = Variable(torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int)))
b_r = Variable(torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2]))
b_s_ = Variable(torch.FloatTensor(b_memory[:, -N_STATES:]))
# q_eval w.r.t the action in experience
q_eval = self.eval_net(b_s).gather(1, b_a) # shape (batch, 1)
q_next = self.target_net(b_s_).detach() # detach from graph, don't backpropagate
q_target = b_r + GAMMA * q_next.max(1)[0] # shape (batch, 1)
loss = self.loss_func(q_eval, q_target)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
dqn = DQN()
print('\nCollecting experience...')
for i_episode in range(400):
s = env.reset()
ep_r = 0
while True:
env.render()
a = dqn.choose_action(s)
# take action
s_, r, done, info = env.step(a)
# modify the reward
x, x_dot, theta, theta_dot = s_
r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8
r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5
r = r1 + r2
dqn.store_transition(s, a, r, s_)
ep_r += r
if dqn.memory_counter > MEMORY_CAPACITY:
dqn.learn()
if done:
print('Ep: ', i_episode,
'| Ep_r: ', round(ep_r, 2))
if done:
break
s = s_
Collecting experience... Ep: 201 | Ep_r: 1.59 Ep: 202 | Ep_r: 4.18 Ep: 203 | Ep_r: 2.73 Ep: 204 | Ep_r: 1.97 Ep: 205 | Ep_r: 1.18 Ep: 206 | Ep_r: 0.86 Ep: 207 | Ep_r: 2.88 Ep: 208 | Ep_r: 1.63 Ep: 209 | Ep_r: 3.91 Ep: 210 | Ep_r: 3.6 Ep: 211 | Ep_r: 0.98 Ep: 212 | Ep_r: 3.85 Ep: 213 | Ep_r: 1.81 Ep: 214 | Ep_r: 2.32 Ep: 215 | Ep_r: 3.75 Ep: 216 | Ep_r: 3.53 Ep: 217 | Ep_r: 4.75 Ep: 218 | Ep_r: 2.4 Ep: 219 | Ep_r: 0.64 Ep: 220 | Ep_r: 1.15 Ep: 221 | Ep_r: 2.3 Ep: 222 | Ep_r: 7.37 Ep: 223 | Ep_r: 1.25 Ep: 224 | Ep_r: 5.02 Ep: 225 | Ep_r: 10.29 Ep: 226 | Ep_r: 17.54 Ep: 227 | Ep_r: 36.2 Ep: 228 | Ep_r: 6.61 Ep: 229 | Ep_r: 10.04 Ep: 230 | Ep_r: 55.19 Ep: 231 | Ep_r: 10.03 Ep: 232 | Ep_r: 13.25 Ep: 233 | Ep_r: 8.75 Ep: 234 | Ep_r: 3.83 Ep: 235 | Ep_r: -0.92 Ep: 236 | Ep_r: 5.12 Ep: 237 | Ep_r: 3.56 Ep: 238 | Ep_r: 5.69 Ep: 239 | Ep_r: 8.43 Ep: 240 | Ep_r: 29.27 Ep: 241 | Ep_r: 17.95 Ep: 242 | Ep_r: 44.77 Ep: 243 | Ep_r: 98.0 Ep: 244 | Ep_r: 38.78 Ep: 245 | Ep_r: 45.02 Ep: 246 | Ep_r: 27.73 Ep: 247 | Ep_r: 36.96 Ep: 248 | Ep_r: 48.98 Ep: 249 | Ep_r: 111.36 Ep: 250 | Ep_r: 95.61 Ep: 251 | Ep_r: 149.77 Ep: 252 | Ep_r: 29.96 Ep: 253 | Ep_r: 2.79 Ep: 254 | Ep_r: 20.1 Ep: 255 | Ep_r: 24.25 Ep: 256 | Ep_r: 3074.75 Ep: 257 | Ep_r: 1258.49 Ep: 258 | Ep_r: 127.39 Ep: 259 | Ep_r: 283.46 Ep: 260 | Ep_r: 166.96 Ep: 261 | Ep_r: 101.71 Ep: 262 | Ep_r: 63.45 Ep: 263 | Ep_r: 288.94 Ep: 264 | Ep_r: 130.49 Ep: 265 | Ep_r: 207.05 Ep: 266 | Ep_r: 183.71 Ep: 267 | Ep_r: 142.75 Ep: 268 | Ep_r: 126.53 Ep: 269 | Ep_r: 310.79 Ep: 270 | Ep_r: 863.2 Ep: 271 | Ep_r: 365.12 Ep: 272 | Ep_r: 659.52 Ep: 273 | Ep_r: 103.98 Ep: 274 | Ep_r: 554.83 Ep: 275 | Ep_r: 246.01 Ep: 276 | Ep_r: 332.23 Ep: 277 | Ep_r: 323.35 Ep: 278 | Ep_r: 278.71 Ep: 279 | Ep_r: 613.6 Ep: 280 | Ep_r: 152.21 Ep: 281 | Ep_r: 402.02 Ep: 282 | Ep_r: 351.4 Ep: 283 | Ep_r: 115.87 Ep: 284 | Ep_r: 163.26 Ep: 285 | Ep_r: 631.0 Ep: 286 | Ep_r: 263.47 Ep: 287 | Ep_r: 511.21 Ep: 288 | Ep_r: 337.18 Ep: 289 | Ep_r: 819.76 Ep: 290 | Ep_r: 190.83 Ep: 291 | Ep_r: 442.98 Ep: 292 | Ep_r: 537.24 Ep: 293 | Ep_r: 1101.12 Ep: 294 | Ep_r: 178.42 Ep: 295 | Ep_r: 225.61 Ep: 296 | Ep_r: 252.62 Ep: 297 | Ep_r: 617.5 Ep: 298 | Ep_r: 617.8 Ep: 299 | Ep_r: 244.01 Ep: 300 | Ep_r: 687.91 Ep: 301 | Ep_r: 618.51 Ep: 302 | Ep_r: 1405.07 Ep: 303 | Ep_r: 456.95 Ep: 304 | Ep_r: 340.33 Ep: 305 | Ep_r: 502.91 Ep: 306 | Ep_r: 441.21 Ep: 307 | Ep_r: 255.81 Ep: 308 | Ep_r: 403.03 Ep: 309 | Ep_r: 229.1 Ep: 310 | Ep_r: 308.49 Ep: 311 | Ep_r: 165.37 Ep: 312 | Ep_r: 153.76 Ep: 313 | Ep_r: 442.05 Ep: 314 | Ep_r: 229.23 Ep: 315 | Ep_r: 128.52 Ep: 316 | Ep_r: 358.18 Ep: 317 | Ep_r: 319.03 Ep: 318 | Ep_r: 381.76 Ep: 319 | Ep_r: 199.19 Ep: 320 | Ep_r: 418.63 Ep: 321 | Ep_r: 223.95 Ep: 322 | Ep_r: 222.37 Ep: 323 | Ep_r: 405.4 Ep: 324 | Ep_r: 311.32 Ep: 325 | Ep_r: 184.85 Ep: 326 | Ep_r: 1026.71 Ep: 327 | Ep_r: 252.41 Ep: 328 | Ep_r: 224.93 Ep: 329 | Ep_r: 620.02 Ep: 330 | Ep_r: 174.54 Ep: 331 | Ep_r: 782.45 Ep: 332 | Ep_r: 263.79 Ep: 333 | Ep_r: 178.63 Ep: 334 | Ep_r: 242.84 Ep: 335 | Ep_r: 635.43 Ep: 336 | Ep_r: 668.89 Ep: 337 | Ep_r: 265.42 Ep: 338 | Ep_r: 207.81 Ep: 339 | Ep_r: 293.09 Ep: 340 | Ep_r: 530.23 Ep: 341 | Ep_r: 479.26 Ep: 342 | Ep_r: 559.77 Ep: 343 | Ep_r: 241.39 Ep: 344 | Ep_r: 158.83 Ep: 345 | Ep_r: 1510.69 Ep: 346 | Ep_r: 425.17 Ep: 347 | Ep_r: 266.94 Ep: 348 | Ep_r: 166.08 Ep: 349 | Ep_r: 630.52 Ep: 350 | Ep_r: 250.95 Ep: 351 | Ep_r: 625.88 Ep: 352 | Ep_r: 417.7 Ep: 353 | Ep_r: 867.81 Ep: 354 | Ep_r: 150.62 Ep: 355 | Ep_r: 230.89 Ep: 356 | Ep_r: 1017.52 Ep: 357 | Ep_r: 190.28 Ep: 358 | Ep_r: 396.91 Ep: 359 | Ep_r: 305.53 Ep: 360 | Ep_r: 131.61 Ep: 361 | Ep_r: 387.54 Ep: 362 | Ep_r: 298.82 Ep: 363 | Ep_r: 207.56 Ep: 364 | Ep_r: 248.56 Ep: 365 | Ep_r: 589.12 Ep: 366 | Ep_r: 179.52 Ep: 367 | Ep_r: 130.19 Ep: 368 | Ep_r: 1220.84 Ep: 369 | Ep_r: 126.35 Ep: 370 | Ep_r: 133.31 Ep: 371 | Ep_r: 485.81 Ep: 372 | Ep_r: 823.4 Ep: 373 | Ep_r: 253.26 Ep: 374 | Ep_r: 466.06 Ep: 375 | Ep_r: 203.27 Ep: 376 | Ep_r: 386.5 Ep: 377 | Ep_r: 491.02 Ep: 378 | Ep_r: 239.45 Ep: 379 | Ep_r: 276.93 Ep: 380 | Ep_r: 331.98 Ep: 381 | Ep_r: 764.79 Ep: 382 | Ep_r: 198.29 Ep: 383 | Ep_r: 717.18 Ep: 384 | Ep_r: 562.15 Ep: 385 | Ep_r: 29.44 Ep: 386 | Ep_r: 344.95 Ep: 387 | Ep_r: 671.87 Ep: 388 | Ep_r: 299.81 Ep: 389 | Ep_r: 899.76 Ep: 390 | Ep_r: 319.04 Ep: 391 | Ep_r: 252.11 Ep: 392 | Ep_r: 865.62 Ep: 393 | Ep_r: 255.64 Ep: 394 | Ep_r: 81.74 Ep: 395 | Ep_r: 213.13 Ep: 396 | Ep_r: 422.33 Ep: 397 | Ep_r: 167.47 Ep: 398 | Ep_r: 507.34 Ep: 399 | Ep_r: 614.0