Grid World Tutorial: POMDPs.jl for Complete Beginners

In this tutorial, we try to provide a simple example of how to define a Markov decision process (MDP) problem using the POMDPs.jl interface. After defining the problem in this way, you will be able to use the solvers that the interface supports. In this tutorial, we will show you how to use the value iteration and the Monte Carlo Tree Search solvers that the POMDPs.jl interface supports. We assume that you have some knowledge of basic programming, but are not necessarily familiar with all the features that exist in Julia. We try to cover the many language specific features used in POMDPs.jl in this tutorial. We do assume that you know the grid world problem, and are familiar with the formal defintion of the MDP. Let's get started!

Dependencies

You need to install a few modules in order to use this notebook. If you have all the modules below installed, great! If not run the following commands:

# install the POMDPs.jl interface
Pkg.clone("https://github.com/JuliaPOMDP/POMDPs.jl.git")

using POMDPs # we'll use the POMDPs.add function to install packages that are part of JuliaPOMDP

# install the Value Iteration solver
POMDPs.add("DiscreteValueIteration")

# install the MCTS solver
POMDPs.add("MCTS")

# install support tools we'll use for simulation
POMDPs.add("POMDPToolbox")
In [1]:
# first import the POMDPs.jl interface
using POMDPs

# import our helper Distributions.jl module
using Distributions

# POMDPToolbox has some glue code to help us use Distributions.jl
using POMDPToolbox

Problem Overview

In Grid World, we are trying to control an agent who has trouble moving in the desired direction. In our problem, we have a four reward states on a $10\times 10$ grid. Each position on the grid represents a state, and the positive reward states are terminal (the agent stops recieveing reward after reaching them). The agent has four actions to choose from: up, down, left, right. The agent moves in the desired direction with a probability of 0.7, and with a probability of 0.1 in each of the remaining three directions. The problem has the following form: example

MDP Type

An MDP must contain a state space and an action space. In POMDPs.jl, we define an MDP type by parametarizing it (to read more about parametric types look here) with states and actions both of which are their own types. For example, if our states and actions are both represented by integers we can define our MDP type in the following way:

type MyMDP <: MDP{Int64, Int64} # MDP{StateType, ActionType}

end

The MyMDP type is inheriting from an abstract MDP type define in POMDPs.jl. If you are interested in in learning more about the type system and inheritance in Jullia check out this blog post. Let's first define our states and actions, and then we'll go through defining our Grid World MDP type.

States

The data container below represents the state of the agent in the grid world.

In [2]:
type GridWorldState 
    x::Int64 # x position
    y::Int64 # y position
    bumped::Bool # did we bump the wall?
    done::Bool # are we in a terminal state?
end

Below are some convenience functions for working with the GridWorldState.

In [3]:
# initial state constructor
GridWorldState(x::Int64, y::Int64) = GridWorldState(x,y,false,false)
GridWorldState(x::Int64, y::Int64, bumped::Bool) = GridWorldState(x,y,bumped,false)
# checks if the position of two states are the same
posequal(s1::GridWorldState, s2::GridWorldState) = s1.x == s2.x && s1.y == s2.y
# copies state s2 to s2
function Base.copy!(s1::GridWorldState, s2::GridWorldState) 
    s1.x = s2.x
    s1.y = s2.y
    s1.bumped = s2.bumped
    s1.done = s2.done
    s1
end
# if you want to use Monte Carlo Tree Search, you will need to define the functions below
Base.hash(s::GridWorldState, h::UInt64 = zero(UInt64)) = hash(s.x, hash(s.y, hash(s.bumped, hash(s.done, h))))
Base.:(==)(s1::GridWorldState,s2::GridWorldState) = s1.x == s2.x && s1.y == s2.y && s1.bumped == s2.bumped && s1.done == s2.done;

Actions

Since our action is simply the direction the agent chooses to go (i.e. up, down, left, right), we can use a Symbol to represent it. Symbols are special types in Julia that allow for nice represntation of complex data. However, in our case a string could serve the same purpose as the symbol or even and integer, so feel free to use what you're most comfortable with. Note that in this case, we will not define a type for our action, instead we represent it directly with a symbol. So that our action looks like:

action = :up # can also be :down, :left, :right

MDP

The GridWorld data container is defined below. It holds all the information we need to define the MDP tuple $$(\mathcal{S}, \mathcal{A}, T, R).$$

In [4]:
# the grid world mdp type
type GridWorld <: MDP{GridWorldState, Symbol} # Note that our MDP is parametarized by the state and the action
    size_x::Int64 # x size of the grid
    size_y::Int64 # y size of the grid
    reward_states::Vector{GridWorldState} # the states in which agent recieves reward
    reward_values::Vector{Float64} # reward values for those states
    bounds_penalty::Float64 # penalty for bumping the wall
    tprob::Float64 # probability of transitioning to the desired state
    discount_factor::Float64 # disocunt factor
end

Before moving on, I want to create a constructor for GridWorld for convenience. Currently, if I want to create an instance of GridWorld, I have to pass in all of fields inside the GridWorld container (size_x, size_y, etc). The function below will return a GridWorld type with all the fields filled with some default values.

In [5]:
# we use key worded arguments so we can change any of the values we pass in 
function GridWorld(;sx::Int64=10, # size_x
                    sy::Int64=10, # size_y
                    rs::Vector{GridWorldState}=[GridWorldState(4,3), GridWorldState(4,6), GridWorldState(9,3), GridWorldState(8,8)], # reward states
                    rv::Vector{Float64}=rv = [-10.,-5,10,3], # reward values
                    penalty::Float64=-1.0, # bounds penalty
                    tp::Float64=0.7, # tprob
                    discount_factor::Float64=0.9)
    return GridWorld(sx, sy, rs, rv, penalty, tp, discount_factor)
end

# we can now create a GridWorld mdp instance like this:
mdp = GridWorld()
mdp.reward_states # mdp contains all the defualt values from the constructor
Out[5]:
4-element Array{GridWorldState,1}:
 GridWorldState(4,3,false,false)
 GridWorldState(4,6,false,false)
 GridWorldState(9,3,false,false)
 GridWorldState(8,8,false,false)

Spaces

Let's look at how we can define the state and action spaces for our problem.

State Space

The state space in an MDP represents all the states in the problem. There are two primary functionalities that we want our spaces to support. We want to be able to iterate over the state space (for Value Iteration for example), and sometimes we want to be able to sample form the state space (used in some POMDP solvers). In this notebook, we will only look at iterable state spaces.

Since we can iterate over elements of an array, and our problem is small, we can store all of our states in an array. If your problem is very large (tens of millions of states), it might be worthwhile to create a custom type to define the problem's state space. See this post on stackoverflow on making simple iterators.

In [6]:
function POMDPs.states(mdp::GridWorld)
    s = GridWorldState[] # initialize an array of GridWorldStates
    # loop over all our states, remeber there are two binary variables:
    # done (d) and bumped(b)
    for d = 0:1, b = 0:1, y = 1:mdp.size_y, x = 1:mdp.size_x
        push!(s, GridWorldState(x,y,b,d))
    end
    return s
end;

Here, the code: function POMDPs.states(mdp::GridWorld) means that we want to take the function called states(...) from the POMDPs.jl module and add another method to it. The states(...) function in POMDPs.jl doesn't know about our GridWorld type. However, now when states(...) is called with GridWorld it will dispatch the function we defined above! This is the awesome thing about multiple-dispatch, and one of the features that should make working with MDP/POMDPs easier in Julia.

The solvers that suppor the POMDPs.jl interface know that a function called states(...) exists in the interface. However, they do not know the behavior of that function for GridWorld. That means in order for the solvers to use this behavior all we have to do is pass an instance of our GridWorld type into the solver. When states(...) is called in the solver with the GridWorld type, the function above will be called.

In [7]:
# let's use the constructor for GridWorld we defined earlier
mdp = GridWorld()
state_space = states(mdp);
state_space[1]
Out[7]:
GridWorldState(1,1,false,false)

In Julia, a method may only work on a specific type. It is up to the program writer to specify what type a given function will work on. This might seem tedious at first, but the Julia type system adds a lot of flexibility to large software frameworks, and allows us to easily extend functionality of any type.

So what happens if we pass something else to states(...)?

In [8]:
type TestMDP <: MDP{Int64, Int64}
    x::Int64
end
tmdp = TestMDP(1)
states(tmdp)
MethodError: no method matching states(::TestMDP)
Closest candidates are:
  states{S,A}(::Union{POMDPs.MDP{S,A},POMDPs.POMDP{S,A,O}}, ::S) at /home/zach/.julia/v0.5/POMDPs/src/space.jl:26
  states(::GridWorld) at In[6]:2
  states(::POMDPs.MDP{Bool,A}) at /home/zach/.julia/v0.5/POMDPToolbox/src/convenience/implementations.jl:20
  ...

The interface in POMDPs.jl defines the function names (API) that the supported MDP/POMDP solvers are allowed to call. In Julia, the functions do not belong to a specific type, but are dispatched based on the type passed in (i.e. functions are generic). Once we define the states(mdp::TestMDP) function, the solvers will be able to call it. If you forget to define a function that a solver requires, an error like the one above will be thrown.

In value iteration, for example, the solver will iterrate over your state space by doing the following:

In [9]:
mdp = GridWorld()
state_space = states(mdp);
for s in iterator(state_space)
    # value iteration applies the bellman operator to your state s
end

Note that, since we used a vector for the state space, the rand(::AbstractRNG, ::AbstractVector) method is already defined and we can use it to sample from the space.

In [10]:
rand(Base.GLOBAL_RNG, state_space)
Out[10]:
GridWorldState(10,10,false,false)

Action Space

The action space is the set of all actions availiable to the agent. In the grid world problem the action space consists of up, down, left, and right. We can define the action space by implementing a new method of the actions function.

In [11]:
POMDPs.actions(mdp::GridWorld) = [:up, :down, :left, :right];

Now that we've defined our state and action spaces, we are half-way thorugh our MDP tuple: $$ (\mathcal{S}, \mathcal{A}, T, R) $$

Distributions

Since MDPs are probabilistic models, we have to deal with probability distributions. In this section, we outline how to define probability distriubtions, and what tools are availiable to help you with the task.

Transition Distribution

If you are familiar with MDPs, you know that the transition function $T(s' \mid s, a)$ captures the dynamics of the system. Specifically, $T(s' \mid s, a)$ is a real value that defines the probabiltiy of transitioning to state $s'$ given that you took action $a$ in state $s$. The transition distirubtion $T(\cdot \mid s, a)$ is a slightly different construct. This is the actual distribution over the states that our agent can reach given that its in state $s$ and took action $a$. In other words this is the distribution over $s'$.

There are many ways to implement transition distributions for your problem. Your choice of distribution as well as how you implement it will heavily depend on your problem. Distributions.jl provides support for many common univariate and multivarite distributions. Below is how we implement the one for grid world.

In [12]:
type GridWorldDistribution
    neighbors::Array{GridWorldState} # the states s' in the distribution
    probs::Array{Float64} # the probability corresponding to each state s'
end
In [13]:
function create_transition_distribution(mdp::GridWorld)
    # can have at most five neighbors in grid world
    neighbors =  [GridWorldState(i,i) for i = 1:5]
    probabilities = zeros(5) + 1.0/5.0
    return GridWorldDistribution(neighbors, probabilities)
end;

Notice that there are a maximum of five neighboring states $s'$ that can be reached from a given state action pair. There is a neighboring grid space in each direction around us (4 in total), and the grid space we are currently in for a total of five.

The next function we want is iterator(...). For discrete state distributions, domain returns an iterator over the states in that distributions (this is just the neighbors array in our distriubtion type). The function takes on the following form:

In [14]:
function POMDPs.iterator(d::GridWorldDistribution)
    return d.neighbors
end;

Let's implement the probability density function (really this is a probability mass function since the distriubtion is discrete, but we overload the pdf function name to serve as both). Below is a fairly inneficient impelemntation of pdf. For the discrete distribution in our problem, the pdf function returns the probability of drawing the state s from the distribution d.

In [15]:
function POMDPs.pdf(d::GridWorldDistribution, s::GridWorldState)
    for (i, sp) in enumerate(d.neighbors)
        if s == sp
            return d.probs[i]
        end
    end   
    return 0.0
end;

Finally, we want to implement a sampling function that can draw samples from our distribution. Once again, there are many ways to do this, but we recommend using Distributions.jl. We use POMDPDistributions which mimicks a lot of the behavior of Distributions.jl.

In [16]:
function POMDPs.rand(rng::AbstractRNG, d::GridWorldDistribution)
    ns = d.neighbors[sample(rng, WeightVec(d.probs))] # sample a neighbor state according to the distribution c
    return ns
end;

One thing that might be unfamilliar in this cell is AbstractRNG. What in the world is that? This is an abstract type that represents a random number generator. This could be Julia's default random number generator (Base.GLOBAL_RNG) or one that functions independently of the default RNG. This allows more control of random number seeding for reproducible simulations. Let's take a look at an example.

Below we initialize a grid world MDP, a transition distribution and a state. We also use the MeresenneTwister type as our AbstractRNG. MeresenneTwister is a Julia type used for pseudo random number generation.

In [17]:
mdp = GridWorld()
# the function below initializes our distriubtion d to have the states at:
# (1,1) (2,2) (3,3) (4,4) (5,5)
# we should expect to sample only these states from d
d = create_transition_distribution(mdp) 
rng = MersenneTwister(1) # this is an rng type in Julia

for i = 1:5
    s = rand(rng, d)
    println(s)
end
GridWorldState(2,2,false,false)
GridWorldState(2,2,false,false)
GridWorldState(2,2,false,false)
GridWorldState(1,1,false,false)
GridWorldState(3,3,false,false)

Now if we seed the MersenneTwister with the same number again, we will sample the same sequence of staes from the distribution d.

In [18]:
rng = MersenneTwister(1) 

for i = 1:5
    s = rand(rng, d)
    println(s)
end
GridWorldState(2,2,false,false)
GridWorldState(2,2,false,false)
GridWorldState(2,2,false,false)
GridWorldState(1,1,false,false)
GridWorldState(3,3,false,false)

If we set the MersenneTwister seed to something else, we will get a different sequence of states.

In [19]:
rng = MersenneTwister(2) 

for i = 1:5
    s = rand(rng, d)
    println(s)
end
GridWorldState(2,2,false,false)
GridWorldState(3,3,false,false)
GridWorldState(2,2,false,false)
GridWorldState(5,5,false,false)
GridWorldState(3,3,false,false)

Using an AbstractRNG is very similar to setting a random seed. This type of functionality is very useful when you are trying to generate the same sequence of states in a Monte Carlo simulation for example.

To recap, there are three functionalities that we require your distirbutions to support. We want to be able to sample from them using the rand(...) function, we want to obtain the probability distribution/density using the pdf(...) function, and we want to be able to iterate through the elements with nonzero probability with iterator(...).

Transition Model

In this section we will define the system dynamics of the gird world MDP. In POMDPs.jl, we work with transition distirbution functions $T(s' \mid s, a)$, so we want to write a function that can generate the transition distributions over $s'$ for us given an $(s, a)$ pair.

In grid world, the dynamics of the system are fairly simple. We move in the specified direction with some pre-defined probability. This is the tprob parameter in our GridWorld MDP (it is set to 0.7 in the DMU book example). If we bump against a wall, we recieve a penalty. If we get to state with a positive reward, we've reached a terminal state and can no longer accumulate reward.

In the transition function we want to fill the neighbors in our distribution d with the reachable states from the state, action pair. We want to fill the probs in our distirbution d with the probabilities of reaching that neighbor.

In [20]:
# transition helpers
function inbounds(mdp::GridWorld,x::Int64,y::Int64)
    if 1 <= x <= mdp.size_x && 1 <= y <= mdp.size_y
        return true
    else
        return false
    end
end

function inbounds(mdp::GridWorld,state::GridWorldState)
    x = state.x #point x of state
    y = state.y
    return inbounds(mdp, x, y)
end

function fill_probability!(p::Vector{Float64}, val::Float64, index::Int64)
    for i = 1:length(p)
        if i == index
            p[i] = val
        else
            p[i] = 0.0
        end
    end
end;
In [21]:
function POMDPs.transition(mdp::GridWorld, state::GridWorldState, action::Symbol)
    a = action
    x = state.x
    y = state.y 

    neighbors = [
        GridWorldState(x+1, y, false), # right
        GridWorldState(x-1, y, false), # left
        GridWorldState(x, y-1, false), # down
        GridWorldState(x, y+1, false), # up
        GridWorldState(x, y, false)    # stay
       ]

    d = GridWorldDistribution(neighbors, Array(Float64, 5)) 
    
    probability = d.probs
    fill!(probability, 0.0)

    if state.done
        fill_probability!(probability, 1.0, 5)
        neighbors[5].done = true
        return d
    end

    for i = 1:5 neighbors[i].done = false end 
    reward_states = mdp.reward_states
    reward_values = mdp.reward_values
    n = length(reward_states)
    if state in mdp.reward_states
        fill_probability!(probability, 1.0, 5)
        neighbors[5].done = true
        return d
    end

    # The following match the definition of neighbors
    # given above
    target_neighbor = 0
    if a == :right
        target_neighbor = 1
    elseif a == :left
        target_neighbor = 2
    elseif a == :down
        target_neighbor = 3
    elseif a == :up
        target_neighbor = 4
    end
    # @assert target_neighbor > 0

    if !inbounds(mdp, neighbors[target_neighbor])
        # If would transition out of bounds, stay in
        # same cell with probability 1
        fill_probability!(probability, 1.0, 5)
    else
        probability[target_neighbor] = mdp.tprob

        oob_count = 0 # number of out of bounds neighbors
        
        for i = 1:length(neighbors)
             if !inbounds(mdp, neighbors[i])
                oob_count += 1
                @assert probability[i] == 0.0
             end
        end

        new_probability = (1.0 - mdp.tprob)/(3-oob_count)

        for i = 1:4 # do not include neighbor 5
            if inbounds(mdp, neighbors[i]) && i != target_neighbor
                probability[i] = new_probability
            end
        end
    end

    return d
end;

Reward Model

The reward model $R(s,a,s')$ is a function that returns the reward of being in state $s$, taking an action $a$ from that state, and ending up in state $s'$. In our problem, we are rewarded for reaching a terimanl reward state (this could be positive or negative), and we are penalized for bumping into a wall.

In [22]:
function POMDPs.reward(mdp::GridWorld, state::GridWorldState, action::Symbol, statep::GridWorldState) #deleted action
    if state.done
        return 0.0
    end
    r = 0.0
    reward_states = mdp.reward_states
    reward_values = mdp.reward_values
    n = length(reward_states)
    for i = 1:n
        if posequal(state, reward_states[i])
            r += reward_values[i]
        end
    end
    if state.bumped
        r += mdp.bounds_penalty
    end
    return r
end;

Miscallenous Functions

We are almost done! Just a few simple functions left. First let's implement two functions that return the sizes of our state and action spaces.

In [23]:
POMDPs.n_states(mdp::GridWorld) = 4*mdp.size_x*mdp.size_y
POMDPs.n_actions(mdp::GridWorld) = 4;

Now, we implement the discount function.

In [24]:
POMDPs.discount(mdp::GridWorld) = mdp.discount_factor;

The last thing we need is indexing functions. This allows us to index between the discrete utility array and the states and actions in our problem. We will use the sub2ind() function from Julia base to help us here.

In [25]:
function POMDPs.state_index(mdp::GridWorld, state::GridWorldState)
    sb = Int(state.bumped + 1)
    sd = Int(state.done + 1)
    return sub2ind((mdp.size_x, mdp.size_y, 2, 2), state.x, state.y, sb, sd)
end
function POMDPs.action_index(mdp::GridWorld, act::Symbol)
    if act==:up
        return 1
    elseif act==:down
        return 2
    elseif act==:left
        return 3
    elseif act==:right
        return 4
    end
    error("Invalid GridWorld action: $act")
end;

Finally let's define a function that checks if a state is terminal.

In [26]:
POMDPs.isterminal(mdp::GridWorld, s::GridWorldState) = s.done

Simulations

Now that we have defined the problem, we should simulate it to see it working. The funcion sim(::MDP) from POMDPToolbox provides a convenient do block syntax for exploring the behavior of the mdp. The do block receives the state as the argument and should return an action. In this way it acts as a "hook" into the simulation and allows quick ad-hoc policies to be defined.

In [27]:
mdp = GridWorld()
mdp.tprob=1.0
sim(mdp, GridWorldState(4,1), max_steps=10) do s
    println("state is: $s")
    a = :right
    println("moving $a")
    return a
end;
state is: GridWorldState(4,1,false,false)
moving right
state is: GridWorldState(5,1,false,false)
moving right
state is: GridWorldState(6,1,false,false)
moving right
state is: GridWorldState(7,1,false,false)
moving right
state is: GridWorldState(8,1,false,false)
moving right
state is: GridWorldState(9,1,false,false)
moving right
state is: GridWorldState(10,1,false,false)
moving right
state is: GridWorldState(10,1,false,false)
moving right
state is: GridWorldState(10,1,false,false)
moving right
state is: GridWorldState(10,1,false,false)
moving right

Value Iteration Solver

Value iteration is a dynamic porgramming apporach for solving MDPs. See the wikipedia article for a brief explanation. The solver can be found here. If you haven't isntalled the solver yet, you can run the following from the Julia REPL

using POMDPs
POMDPs.add("DiscreteValueIteration")

to download the module.

Each POMDPs.jl solver provides two data types for you to interface with. The first is the Solver type which contains solver parameters. The second is the Policy type. Let's see hwo we can use them to get an optimal action at a given state.

In [28]:
# first let's load the value iteration module
using DiscreteValueIteration

# initialize the problem
mdp = GridWorld()

# initialize the solver
# max_iterations: maximum number of iterations value iteration runs for (default is 100)
# belres: the value of Bellman residual used in the solver (defualt is 1e-3)
solver = ValueIterationSolver(max_iterations=100, belres=1e-3)

# initialize the policy by passing in your problem
policy = ValueIterationPolicy(mdp) 

# solve for an optimal policy
# if verbose=false, the text output will be supressed (false by default)
solve(solver, mdp, policy, verbose=true);
Iteration : 1, residual: 13.5359385595155, iteration run-time: 0.068839337, total run-time: 0.068839337
Iteration : 2, residual: 6.2999734279499995, iteration run-time: 0.06464727, total run-time: 0.133486607
Iteration : 3, residual: 4.535916776339399, iteration run-time: 0.064842247, total run-time: 0.198328854
Iteration : 4, residual: 3.3932543047909687, iteration run-time: 0.065329747, total run-time: 0.263658601
Iteration : 5, residual: 2.571299327001349, iteration run-time: 0.078932557, total run-time: 0.342591158
Iteration : 6, residual: 1.9200652060312606, iteration run-time: 0.065602144, total run-time: 0.40819330200000004
Iteration : 7, residual: 1.3944077772964636, iteration run-time: 0.072909027, total run-time: 0.48110232900000005
Iteration : 8, residual: 1.0733013837754315, iteration run-time: 0.072141148, total run-time: 0.5532434770000001
Iteration : 9, residual: 0.8612394855324683, iteration run-time: 0.07183643, total run-time: 0.6250799070000002
Iteration : 10, residual: 0.6617298799960059, iteration run-time: 0.06501717, total run-time: 0.6900970770000001
Iteration : 11, residual: 0.4890198391566638, iteration run-time: 0.071625396, total run-time: 0.7617224730000001
Iteration : 12, residual: 0.40524432611098704, iteration run-time: 0.069108493, total run-time: 0.8308309660000001
Iteration : 13, residual: 0.3407925086598921, iteration run-time: 0.071954861, total run-time: 0.9027858270000001
Iteration : 14, residual: 0.24439341221410416, iteration run-time: 0.067886658, total run-time: 0.9706724850000001
Iteration : 15, residual: 0.16597725930419127, iteration run-time: 0.076314954, total run-time: 1.0469874390000002
Iteration : 16, residual: 0.10594068263682299, iteration run-time: 0.066390445, total run-time: 1.1133778840000002
Iteration : 17, residual: 0.06381799508797092, iteration run-time: 0.069514989, total run-time: 1.1828928730000001
Iteration : 18, residual: 0.03694120249471089, iteration run-time: 0.067998157, total run-time: 1.25089103
Iteration : 19, residual: 0.020784548267542835, iteration run-time: 0.068923668, total run-time: 1.319814698
Iteration : 20, residual: 0.011453236948841594, iteration run-time: 0.065371024, total run-time: 1.3851857220000001
Iteration : 21, residual: 0.006213722610454386, iteration run-time: 0.07037953, total run-time: 1.4555652520000002
Iteration : 22, residual: 0.0033314062328759775, iteration run-time: 0.064843806, total run-time: 1.5204090580000003
Iteration : 23, residual: 0.0017698244797910156, iteration run-time: 0.069627455, total run-time: 1.5900365130000003
Iteration : 24, residual: 0.0009335331707231997, iteration run-time: 0.066500927, total run-time: 1.6565374400000004

Now, we can use the policy along with the action(...) function to get the optimal action in a given state.

In [29]:
# say we are in state (9,2)
s = GridWorldState(9,2)
a = action(policy, s)
Out[29]:
:up

Remeber that the state (9,3) has an immediate reward of +10.0, so the policy we found is moving up as expected!

In [30]:
s = GridWorldState(8,3)
a = action(policy, s)
Out[30]:
:right

description

Monte-Carlo Tree Search Solver

Monte-Carlo Tree Search (MCTS) is another MDP solver. It is an online method that looks for the best action from only the current state by building a search tree. A nice overview of MCTS can be found here. Run the following command to donwload the module:

using POMDPs
POMDPs.add("MCTS")

Let's quickly run through an example of using the solver:

In [41]:
using MCTS

# initialize the problem
mdp = GridWorld()

# initialize the solver
# the hyper parameters in MCTS can be tricky to set properly
# n_iterations: the number of iterations that each search runs for
# depth: the depth of the tree (how far away from the current state the algorithm explores)
# exploration constant: this is how much weight to put into exploratory actions. 
# A good rule of thumb is to set the exploration constant to what you expect the upper bound on your average expected reward to be.
solver = MCTSSolver(n_iterations=100,
                    depth=20,
                    exploration_constant=10.0,
                    enable_tree_vis=true)

# initialize the policy by passing in your problem and the solver
policy = MCTSPolicy(solver, mdp)

# we don't need to call solver for MCTS

# to get the action:
s = GridWorldState(9,2)
a = action(policy, s)
Out[41]:
:up

Let's simulate the policy.

In [42]:
# we'll use POMDPToolbox for simulation
using POMDPToolbox # if you don't have this module install it by running POMDPs.add("POMDPToolbox")

s = GridWorldState(4,1) # this is our starting state
hist = HistoryRecorder(max_steps=1000)

hist = simulate(hist, mdp, policy, s)

println("Total discounted reward: $(discounted_reward(hist))")
Total discounted reward: 2.8242953648100015
In [43]:
hist.state_hist # look at the state history
Out[43]:
14-element Array{GridWorldState,1}:
 GridWorldState(4,1,false,false)
 GridWorldState(4,1,false,false)
 GridWorldState(5,1,false,false)
 GridWorldState(6,1,false,false)
 GridWorldState(7,1,false,false)
 GridWorldState(7,2,false,false)
 GridWorldState(8,2,false,false)
 GridWorldState(9,2,false,false)
 GridWorldState(8,2,false,false)
 GridWorldState(7,2,false,false)
 GridWorldState(8,2,false,false)
 GridWorldState(9,2,false,false)
 GridWorldState(9,3,false,false)
 GridWorldState(9,3,false,true) 
In [39]:
action(policy, hist.state_hist[1])
Out[39]:
:right
In [40]:
TreeVisualizer(policy, hist.state_hist[1])
Out[40]:
In [ ]: