Grid World Tutorial: POMDPs.jl for Complete Beginners

In this tutorial, we try to provide a simple example of how to define a Markov decision process (MDP) problem using the POMDPs.jl interface. After defining the problem in this way, you will be able to use the solvers that the interface supports. In this tutorial, we will show you how to use the value iteration and the Monte Carlo Tree Search solvers that the POMDPs.jl interface supports. We assume that you have some knowledge of basic programming, but are not necessarily familiar with all the features that exist in Julia. We try to cover the many language specific features used in POMDPs.jl in this tutorial. We do assume that you know the grid world problem, and are familiar with the formal defintion of the MDP. Let's get started!


You need to install a few modules in order to use this notebook. If you have all the modules below installed, great! If not run the following commands:

# install the POMDPs.jl interface

using POMDPs # we'll use the POMDPs.add function to install packages that are part of JuliaPOMDP

# install the Value Iteration solver

# install the MCTS solver

# install support tools we'll use for simulation
In [1]:
# first import the POMDPs.jl interface
using POMDPs

# import our helper Distributions.jl module
using Distributions

# POMDPToolbox has some glue code to help us use Distributions.jl
using POMDPToolbox

Problem Overview

In Grid World, we are trying to control an agent who has trouble moving in the desired direction. In our problem, we have a four reward states on a $10\times 10$ grid. Each position on the grid represents a state, and the positive reward states are terminal (the agent stops recieveing reward after reaching them). The agent has four actions to choose from: up, down, left, right. The agent moves in the desired direction with a probability of 0.7, and with a probability of 0.1 in each of the remaining three directions. The problem has the following form: example

MDP Type

In POMDPs.jl, an MDP is defined by creating a subtype of the MDP abstract type. The types of the states and actions for the MDP are declared as parameters of the MDP type. For example, if our states and actions are both represented by integers we can define our MDP type in the following way:

type MyMDP <: MDP{Int64, Int64} # MDP{StateType, ActionType}


MyMDP is a subtype from an abstract MDP type defined in POMDPs.jl. Let's first define types to represent grid worls states and actions, and then we'll go through defining our Grid World MDP type.


The data container below represents the state of the agent in the grid world.

In [2]:
struct GridWorldState 
    x::Int64 # x position
    y::Int64 # y position
    done::Bool # are we in a terminal state?

Below are some convenience functions for working with the GridWorldState.

In [3]:
# initial state constructor
GridWorldState(x::Int64, y::Int64) = GridWorldState(x,y,false)
# checks if the position of two states are the same
posequal(s1::GridWorldState, s2::GridWorldState) = s1.x == s2.x && s1.y == s2.y
posequal (generic function with 1 method)


Since our action is simply the direction the agent chooses to go (i.e. up, down, left, right), we can use a Symbol to represent it. Symbols are essentially the same as strings, but they typically consist of only one word and literals begin with ":". See this page for a techincal discussion of what they are. Note that in this case, we will not define a custom type for our action, instead we represent it directly with a symbol. So that our action looks like:

action = :up # can also be :down, :left, :right


The GridWorld data container is defined below. It holds all the information we need to define the MDP tuple $$(\mathcal{S}, \mathcal{A}, T, R).$$

In [4]:
# the grid world mdp type
type GridWorld <: MDP{GridWorldState, Symbol} # Note that our MDP is parametarized by the state and the action
    size_x::Int64 # x size of the grid
    size_y::Int64 # y size of the grid
    reward_states::Vector{GridWorldState} # the states in which agent recieves reward
    reward_values::Vector{Float64} # reward values for those states
    tprob::Float64 # probability of transitioning to the desired state
    discount_factor::Float64 # disocunt factor

Before moving on, I want to create a constructor for GridWorld for convenience. Currently, if I want to create an instance of GridWorld, I have to pass in all of fields inside the GridWorld container (size_x, size_y, etc). The function below will return a GridWorld type with all the fields filled with some default values.

In [5]:
# we use key worded arguments so we can change any of the values we pass in 
function GridWorld(;sx::Int64=10, # size_x
                    sy::Int64=10, # size_y
                    rs::Vector{GridWorldState}=[GridWorldState(4,3), GridWorldState(4,6), GridWorldState(9,3), GridWorldState(8,8)], # reward states
                    rv::Vector{Float64}=rv = [-10.,-5,10,3], # reward values
                    tp::Float64=0.7, # tprob
    return GridWorld(sx, sy, rs, rv, tp, discount_factor)

# we can now create a GridWorld mdp instance like this:
mdp = GridWorld()
mdp.reward_states # mdp contains all the defualt values from the constructor
4-element Array{GridWorldState,1}:
 GridWorldState(4, 3, false)
 GridWorldState(4, 6, false)
 GridWorldState(9, 3, false)
 GridWorldState(8, 8, false)


Let's look at how we can define the state and action spaces for our problem.

State Space

The state space in an MDP represents all the states in the problem. There are two primary functionalities that we want our spaces to support. We want to be able to iterate over the state space (for Value Iteration for example), and sometimes we want to be able to sample form the state space (used in some POMDP solvers). In this notebook, we will only look at iterable state spaces.

Since we can iterate over elements of an array, and our problem is small, we can store all of our states in an array. If your problem is very large (tens of millions of states), it might be worthwhile to create a custom type to define the problem's state space. See this post on stackoverflow on making simple iterators.

In [6]:
function POMDPs.states(mdp::GridWorld)
    s = GridWorldState[] # initialize an array of GridWorldStates
    # loop over all our states, remeber there are two binary variables:
    # done (d)
    for d = 0:1, y = 1:mdp.size_y, x = 1:mdp.size_x
        push!(s, GridWorldState(x,y,d))
    return s

Here, the code: function POMDPs.states(mdp::GridWorld) means that we want to take the function called states(...) from the POMDPs.jl module and add another method to it. The states(...) function in POMDPs.jl doesn't know about our GridWorld type. However, now when states(...) is called with GridWorld it will dispatch the function we defined above! This is the awesome thing about multiple-dispatch, and one of the features that should make working with MDP/POMDPs easier in Julia.

The solvers that support the POMDPs.jl interface know that a function called states(...) exists in the interface. However, they do not know the behavior of that function for GridWorld. That means in order for the solvers to use this behavior all we have to do is pass an instance of our GridWorld type into the solver. When states(...) is called in the solver with the GridWorld type, the function above will be called.

In [7]:
mdp = GridWorld()
state_space = states(mdp);
GridWorldState(1, 1, false)

Action Space

The action space is the set of all actions availiable to the agent. In the grid world problem the action space consists of up, down, left, and right. We can define the action space by implementing a new method of the actions function.

In [8]:
POMDPs.actions(mdp::GridWorld) = [:up, :down, :left, :right];

Now that we've defined our state and action spaces, we are half-way thorugh our MDP tuple: $$ (\mathcal{S}, \mathcal{A}, T, R) $$


Since MDPs are probabilistic models, we have to deal with probability distributions. In this section, we outline how to define probability distriubtions, and what tools are availiable to help you with the task.

Transition Distribution

If you are familiar with MDPs, you know that the transition function $T(s' \mid s, a)$ captures the dynamics of the system. Specifically, $T(s' \mid s, a)$ is a real value that defines the probabiltiy of transitioning to state $s'$ given that you took action $a$ in state $s$. The transition distirubtion $T(\cdot \mid s, a)$ is a slightly different construct. This is the actual distribution over the states that our agent can reach given that its in state $s$ and took action $a$. In other words this is the distribution over $s'$.

For this grid world example there are only a few states that the agent can transition to, so there are only a few states that have nonzero probability in $T(\cdot \mid s, a)$. Thus, we will use the sparse categorical distribution (SparseCat) from POMDPToolbox. Distributions.jl also contains some distributions, but in many cases, a custom distribution type will need to be defined - see the source code for SparseCat for an example.

A SparseCat object contains a vector of states and an associated vector of their probabilities. The probabilities of all other states are implied to be zero.

Transition Model

In this section we will define the system dynamics of the gird world MDP. In POMDPs.jl, we work with transition distirbution functions $T(s' \mid s, a)$, so we want to write a function that can generate the transition distributions over $s'$ for us given an $(s, a)$ pair.

In grid world, the dynamics of the system are fairly simple. We move in the specified direction with some pre-defined probability. This is the tprob parameter in our GridWorld MDP (it is set to 0.7 in the DMU book example). If we get to state with a positive reward, we've reached a terminal state and can no longer accumulate reward.

In the transition function we want to fill the neighbors in our distribution d with the reachable states from the state, action pair. We want to fill the probs in our distirbution d with the probabilities of reaching that neighbor.

In [9]:
# transition helpers
function inbounds(mdp::GridWorld,x::Int64,y::Int64)
    if 1 <= x <= mdp.size_x && 1 <= y <= mdp.size_y
        return true
        return false

inbounds(mdp::GridWorld, state::GridWorldState) = inbounds(mdp, state.x, state.y);
In [10]:
function POMDPs.transition(mdp::GridWorld, state::GridWorldState, action::Symbol)
    a = action
    x = state.x
    y = state.y
    if state.done
        return SparseCat([GridWorldState(x, y, true)], [1.0])
    elseif state in mdp.reward_states
        return SparseCat([GridWorldState(x, y, true)], [1.0])

    neighbors = [
        GridWorldState(x+1, y, false), # right
        GridWorldState(x-1, y, false), # left
        GridWorldState(x, y-1, false), # down
        GridWorldState(x, y+1, false), # up
        ] # See Performance Note below
    targets = Dict(:right=>1, :left=>2, :down=>3, :up=>4) # See Performance Note below
    target = targets[a]
    probability = fill(0.0, 4)

    if !inbounds(mdp, neighbors[target])
        # If would transition out of bounds, stay in
        # same cell with probability 1
        return SparseCat([GridWorldState(x, y)], [1.0])
        probability[target] = mdp.tprob

        oob_count = sum(!inbounds(mdp, n) for n in neighbors) # number of out of bounds neighbors

        new_probability = (1.0 - mdp.tprob)/(3-oob_count)

        for i = 1:4 # do not include neighbor 5
            if inbounds(mdp, neighbors[i]) && i != target
                probability[i] = new_probability

    return SparseCat(neighbors, probability)

Performance Note: It is inefficient to create mutable objects like dictionaries and vectors in low-level code like the transition function because it requires dynamic memory allocation. This code is written for clarity rather than speed. Better speed could be realized by putting the Dict in the mdp object or using if statements instead, and replacing the vector with a StaticArrays.Svector. However, a much more important consideration for performance is type stability, which this function maintains because it always returns a SparseCat{Vector{GridWorldState},Vector{Float64}} object.

Reward Model

The reward model $R(s,a,s')$ is a function that returns the reward of being in state $s$, taking an action $a$ from that state, and ending up in state $s'$. In our problem, we are rewarded for reaching a terimanl reward state (this could be positive or negative).

In [11]:
function POMDPs.reward(mdp::GridWorld, state::GridWorldState, action::Symbol, statep::GridWorldState) #deleted action
    if state.done
        return 0.0
    r = 0.0
    n = length(mdp.reward_states)
    for i = 1:n
        if posequal(state, mdp.reward_states[i])
            r += mdp.reward_values[i]
    return r

Miscallenous Functions

We are almost done! Just a few simple functions left. First let's implement two functions that return the sizes of our state and action spaces.

In [12]:
POMDPs.n_states(mdp::GridWorld) = 2*mdp.size_x*mdp.size_y
POMDPs.n_actions(mdp::GridWorld) = 4

Now, we implement the discount function.

In [13]: = mdp.discount_factor;

The last thing we need is indexing functions. This allows us to index between the discrete utility array and the states and actions in our problem. We will use the sub2ind() function from Julia base to help us here.

In [14]:
function POMDPs.state_index(mdp::GridWorld, state::GridWorldState)
    sd = Int(state.done + 1)
    return sub2ind((mdp.size_x, mdp.size_y, 2), state.x, state.y, sd)
function POMDPs.action_index(mdp::GridWorld, act::Symbol)
    if act==:up
        return 1
    elseif act==:down
        return 2
    elseif act==:left
        return 3
    elseif act==:right
        return 4
    error("Invalid GridWorld action: $act")

Finally let's define a function that checks if a state is terminal.

In [15]:
POMDPs.isterminal(mdp::GridWorld, s::GridWorldState) = s.done


Now that we have defined the problem, we should simulate it to see it working. The funcion sim(::MDP) from POMDPToolbox provides a convenient do block syntax for exploring the behavior of the mdp. The do block receives the state as the argument and should return an action. In this way it acts as a "hook" into the simulation and allows quick ad-hoc policies to be defined.

In [16]:
mdp = GridWorld()
sim(mdp, GridWorldState(4,1), max_steps=10) do s
    println("state is: $s")
    a = :right
    println("moving $a")
    return a
state is: GridWorldState(4, 1, false)
moving right
state is: GridWorldState(5, 1, false)
moving right
state is: GridWorldState(6, 1, false)
moving right
state is: GridWorldState(7, 1, false)
moving right
state is: GridWorldState(8, 1, false)
moving right
state is: GridWorldState(9, 1, false)
moving right
state is: GridWorldState(10, 1, false)
moving right
state is: GridWorldState(10, 1, false)
moving right
state is: GridWorldState(10, 1, false)
moving right
state is: GridWorldState(10, 1, false)
moving right

Value Iteration Solver

Value iteration is a dynamic porgramming apporach for solving MDPs. See the wikipedia article for a brief explanation. The solver can be found here. If you haven't isntalled the solver yet, you can run the following from the Julia REPL

using POMDPs

to download the module.

Each POMDPs.jl solver provides two data types for you to interface with. The first is the Solver type which contains solver parameters. The second is the Policy type. Let's see hwo we can use them to get an optimal action at a given state.

In [17]:
# first let's load the value iteration module
using DiscreteValueIteration

# initialize the problem
mdp = GridWorld()

# initialize the solver
# max_iterations: maximum number of iterations value iteration runs for (default is 100)
# belres: the value of Bellman residual used in the solver (defualt is 1e-3)
solver = ValueIterationSolver(max_iterations=100, belres=1e-3)

# initialize the policy by passing in your problem
policy = ValueIterationPolicy(mdp) 

# solve for an optimal policy
# if verbose=false, the text output will be supressed (false by default)
solve(solver, mdp, policy, verbose=true);
[Iteration 1   ] residual:         10 | iteration runtime:      0.323 ms, (  0.000323 s total)
[Iteration 2   ] residual:        6.3 | iteration runtime:      0.285 ms, (  0.000608 s total)
[Iteration 3   ] residual:       4.54 | iteration runtime:      0.345 ms, (  0.000953 s total)
[Iteration 4   ] residual:       3.39 | iteration runtime:      0.454 ms, (   0.00141 s total)
[Iteration 5   ] residual:       2.57 | iteration runtime:      0.269 ms, (   0.00168 s total)
[Iteration 6   ] residual:       1.92 | iteration runtime:      0.256 ms, (   0.00193 s total)
[Iteration 7   ] residual:       1.39 | iteration runtime:      0.256 ms, (   0.00219 s total)
[Iteration 8   ] residual:       1.07 | iteration runtime:      0.255 ms, (   0.00244 s total)
[Iteration 9   ] residual:      0.861 | iteration runtime:      0.258 ms, (    0.0027 s total)
[Iteration 10  ] residual:      0.662 | iteration runtime:      0.567 ms, (   0.00327 s total)
[Iteration 11  ] residual:      0.489 | iteration runtime:      0.257 ms, (   0.00352 s total)
[Iteration 12  ] residual:      0.405 | iteration runtime:      6.774 ms, (    0.0103 s total)
[Iteration 13  ] residual:      0.341 | iteration runtime:      0.287 ms, (    0.0106 s total)
[Iteration 14  ] residual:      0.244 | iteration runtime:      0.281 ms, (    0.0109 s total)
[Iteration 15  ] residual:      0.166 | iteration runtime:      0.248 ms, (    0.0111 s total)
[Iteration 16  ] residual:      0.106 | iteration runtime:      0.247 ms, (    0.0114 s total)
[Iteration 17  ] residual:     0.0638 | iteration runtime:      0.303 ms, (    0.0117 s total)
[Iteration 18  ] residual:     0.0369 | iteration runtime:      0.257 ms, (    0.0119 s total)
[Iteration 19  ] residual:     0.0208 | iteration runtime:      0.254 ms, (    0.0122 s total)
[Iteration 20  ] residual:     0.0115 | iteration runtime:      0.263 ms, (    0.0124 s total)
[Iteration 21  ] residual:    0.00621 | iteration runtime:      0.257 ms, (    0.0127 s total)
[Iteration 22  ] residual:    0.00333 | iteration runtime:      0.265 ms, (     0.013 s total)
[Iteration 23  ] residual:    0.00177 | iteration runtime:      0.254 ms, (    0.0132 s total)
[Iteration 24  ] residual:   0.000934 | iteration runtime:      0.261 ms, (    0.0135 s total)

Now, we can use the policy along with the action(...) function to get the optimal action in a given state.

In [18]:
# say we are in state (9,2)
s = GridWorldState(9,2)
a = action(policy, s)

Remeber that the state (9,3) has an immediate reward of +10.0, so the policy we found is moving up as expected!

In [19]:
s = GridWorldState(8,3)
a = action(policy, s)


Monte-Carlo Tree Search Solver

Monte-Carlo Tree Search (MCTS) is another MDP solver. It is an online method that looks for the best action from only the current state by building a search tree. A nice overview of MCTS can be found here. Run the following command to donwload the module:

using POMDPs

Let's quickly run through an example of using the solver:

In [20]:
using MCTS

# initialize the problem
mdp = GridWorld()

# initialize the solver with hyper parameters
# n_iterations: the number of iterations that each search runs for
# depth: the depth of the tree (how far away from the current state the algorithm explores)
# exploration constant: this is how much weight to put into exploratory actions. 
# A good rule of thumb is to set the exploration constant to what you expect the upper bound on your average expected reward to be.
solver = MCTSSolver(n_iterations=1000,

# initialize the planner by calling the `solve` function. For online solvers, the 
planner = solve(solver, mdp)

# to get the action:
s = GridWorldState(9,2)
a = action(planner, s)

Let's simulate using the planner to determine a good action at each timestep.

In [21]:
# we'll use POMDPToolbox for simulation
using POMDPToolbox # if you don't have this module install it by running POMDPs.add("POMDPToolbox")

s = GridWorldState(4,1) # this is our starting state
hist = HistoryRecorder(max_steps=1000)

hist = simulate(hist, mdp, policy, s)

println("Total discounted reward: $(discounted_reward(hist))")
Total discounted reward: 2.5418658283290014

Now we can view the state-action history using the eachstep funciton.

In [22]:
for (s, a, sp) in eachstep(hist, "s,a,sp")
    @printf("s: %-26s  a: %-6s  s': %-26s\n", s, a, sp)
s: GridWorldState(4, 1, false)  a: right   s': GridWorldState(5, 1, false)
s: GridWorldState(5, 1, false)  a: right   s': GridWorldState(6, 1, false)
s: GridWorldState(6, 1, false)  a: right   s': GridWorldState(7, 1, false)
s: GridWorldState(7, 1, false)  a: right   s': GridWorldState(6, 1, false)
s: GridWorldState(6, 1, false)  a: right   s': GridWorldState(7, 1, false)
s: GridWorldState(7, 1, false)  a: right   s': GridWorldState(8, 1, false)
s: GridWorldState(8, 1, false)  a: up      s': GridWorldState(7, 1, false)
s: GridWorldState(7, 1, false)  a: right   s': GridWorldState(8, 1, false)
s: GridWorldState(8, 1, false)  a: up      s': GridWorldState(8, 2, false)
s: GridWorldState(8, 2, false)  a: right   s': GridWorldState(9, 2, false)
s: GridWorldState(9, 2, false)  a: up      s': GridWorldState(10, 2, false)
s: GridWorldState(10, 2, false)  a: up      s': GridWorldState(10, 3, false)
s: GridWorldState(10, 3, false)  a: left    s': GridWorldState(9, 3, false)
s: GridWorldState(9, 3, false)  a: up      s': GridWorldState(9, 3, true)

To see what the planner is doing, we can look at the tree created when it plans at a particular state, for example, the first state in the history.

In [23]:
using D3Trees

# first, run the planner on the state
s = state_hist(hist)[1]
a = action(planner, s)

# show the tree (click the node to expand)
D3Tree(planner, s)
MCTS tree