LightDark Demo

In [1]:
importall POMDPs
using POMDPToolbox
using Distributions
using Parameters
using Plots
using StaticArrays

Problem Definition

In [2]:
@with_kw struct SimpleLightDark <: POMDPs.POMDP{Int,Int,Float64}
    discount::Float64       = 0.95
    correct_r::Float64      = 100.0
    incorrect_r::Float64    = -100.0
    light_loc::Int          = 10
    radius::Int             = 60
end
discount(p::SimpleLightDark) = p.discount
isterminal(p::SimpleLightDark, s::Number) = !(s in -p.radius:p.radius)

const ACTIONS = [-10, -1, 0, 1, 10]
actions(p::SimpleLightDark) = ACTIONS
n_actions(p::SimpleLightDark) = length(actions(p))
const ACTION_INDS = Dict(a=>i for (i,a) in enumerate(actions(SimpleLightDark())))
action_index(p::SimpleLightDark, a::Int) = ACTION_INDS[a]

states(p::SimpleLightDark) = -p.radius:p.radius + 1
n_states(p::SimpleLightDark) = length(states(p))
state_index(p::SimpleLightDark, s::Int) = s+p.radius+1

function transition(p::SimpleLightDark, s::Int, a::Int) 
    if a == 0
        return SparseCat(SVector(p.radius+1), SVector(1.0))
    else
        return SparseCat(SVector(clamp(s+a, -p.radius, p.radius)), SVector(1.0))
    end
end

observation(p::SimpleLightDark, sp) = Normal(sp, abs(sp - p.light_loc) + 0.0001)

function reward(p::SimpleLightDark, s, a)
    if a == 0
        return s == 0 ? p.correct_r : p.incorrect_r
    else
        return -1.0
    end
end

function initial_state_distribution(p::SimpleLightDark)
    ps = ones(2*div(p.radius,2)+1)
    ps /= length(ps)
    return SparseCat(div(-p.radius,2):div(p.radius,2), ps)
end;

Visualization

In [3]:
using Plots

function plothist(pomdp, hist, heading="LightDark")
    tmax = 80
    smin = -10
    smax = 20
    vsh = collect(filter(s->!isterminal(pomdp,s), state_hist(hist)[1:end-1]))
    bh = belief_hist(hist)

    pts = Int[]
    pss = Int[]
    pws = Float64[]

    for t in 0:length(bh)-1
        b = bh[t+1]
        for s in smin:smax
            w = 10.0*sqrt(pdf(b, s))
            if 0.0<w<1.0
                w = 1.0
            end
            push!(pts, t)
            push!(pss, s)
            push!(pws, w)
        end
    end

    T = linspace(0.0, tmax)
    S = linspace(-1.0, 21.0)
    inv_grays = cgrad([RGB(1.0, 1.0, 1.0),RGB(0.0,0.0,0.0)])
    p = contour(T, S, (t,s)->abs(s-pomdp.light_loc),
            bg_inside=:black,
            fill=true,
            xlim=(0, tmax),
            ylim=(smin, smax),
            color=inv_grays,
            xlabel="Time",
            ylabel="State",
            cbar=false,
            legend=:topright,
            title=@sprintf("%s (Reward: %8.2f)", heading, discounted_reward(hist))
           )
    plot!(p, [0, tmax], [0, 0], linewidth=1, color="green", label="Goal", line=:dash)
    scatter!(p, pts, pss, color="lightblue", label="Belief Particles", markersize=pws, marker=stroke(0.1, 0.3))
    plot!(p, 0:length(vsh)-1, vsh, linewidth=3, color="orangered", label="Trajectory")
    return p
end;
In [4]:
using ParticleFilters
rng = MersenneTwister(7)

p = SimpleLightDark()
pf = SIRParticleFilter(p, 10000, rng=rng)

h = sim(p, updater=pf, initial_state=1, initial_obs=initial_state_distribution(p), max_steps=80, rng=rng) do b
    return rand(rng, [-1,1])
end

plothist(p, h)
Out[4]:
0 20 40 60 80 -10 0 10 20 LightDark (Reward: -19.67) Time State Goal Belief Particles Trajectory
In [5]:
using QMDP

solver = QMDPSolver()

policy = solve(solver, p, verbose=true);
[Iteration 1   ] residual:        100 | iteration runtime:      0.084 ms, (  8.41E-05 s total)
[Iteration 2   ] residual:         95 | iteration runtime:      0.040 ms, (  0.000124 s total)
[Iteration 3   ] residual:       90.3 | iteration runtime:      0.029 ms, (  0.000153 s total)
[Iteration 4   ] residual:       85.7 | iteration runtime:      0.026 ms, (  0.000179 s total)
[Iteration 5   ] residual:       81.5 | iteration runtime:      0.026 ms, (  0.000205 s total)
[Iteration 6   ] residual:       77.4 | iteration runtime:      0.026 ms, (  0.000231 s total)
[Iteration 7   ] residual:       73.5 | iteration runtime:      0.026 ms, (  0.000256 s total)
[Iteration 8   ] residual:       8.17 | iteration runtime:      0.026 ms, (  0.000282 s total)
[Iteration 9   ] residual:       3.98 | iteration runtime:      0.026 ms, (  0.000308 s total)
[Iteration 10  ] residual:          0 | iteration runtime:      0.027 ms, (  0.000335 s total)
In [6]:
srand(rng, 14)
hr = HistoryRecorder(max_steps=80, initial_state=1)
h = simulate(hr, p, policy, pf)
plothist(p, h, "QMDP")
WARNING: The initial_state argument for HistoryRecorder is deprecated. The initial state should be specified as the last argument to simulate(...).
Out[6]:
0 20 40 60 80 -10 0 10 20 QMDP (Reward: -19.67) Time State