# LightDark Demo¶

importall POMDPs
using POMDPToolbox
using Distributions
using Parameters
using Plots
using StaticArrays


# Problem Definition¶

@with_kw struct SimpleLightDark <: POMDPs.POMDP{Int,Int,Float64}
discount::Float64       = 0.95
correct_r::Float64      = 100.0
incorrect_r::Float64    = -100.0
light_loc::Int          = 10
end
discount(p::SimpleLightDark) = p.discount

const ACTIONS = [-10, -1, 0, 1, 10]
actions(p::SimpleLightDark) = ACTIONS
n_actions(p::SimpleLightDark) = length(actions(p))
const ACTION_INDS = Dict(a=>i for (i,a) in enumerate(actions(SimpleLightDark())))
action_index(p::SimpleLightDark, a::Int) = ACTION_INDS[a]

n_states(p::SimpleLightDark) = length(states(p))
state_index(p::SimpleLightDark, s::Int) = s+p.radius+1

function transition(p::SimpleLightDark, s::Int, a::Int)
if a == 0
else
end
end

observation(p::SimpleLightDark, sp) = Normal(sp, abs(sp - p.light_loc) + 0.0001)

function reward(p::SimpleLightDark, s, a)
if a == 0
return s == 0 ? p.correct_r : p.incorrect_r
else
return -1.0
end
end

function initial_state_distribution(p::SimpleLightDark)
ps /= length(ps)
end;


# Visualization¶

using Plots

function plothist(pomdp, hist, heading="LightDark")
tmax = 80
smin = -10
smax = 20
vsh = collect(filter(s->!isterminal(pomdp,s), state_hist(hist)[1:end-1]))
bh = belief_hist(hist)

pts = Int[]
pss = Int[]
pws = Float64[]

for t in 0:length(bh)-1
b = bh[t+1]
for s in smin:smax
w = 10.0*sqrt(pdf(b, s))
if 0.0<w<1.0
w = 1.0
end
push!(pts, t)
push!(pss, s)
push!(pws, w)
end
end

T = linspace(0.0, tmax)
S = linspace(-1.0, 21.0)
inv_grays = cgrad([RGB(1.0, 1.0, 1.0),RGB(0.0,0.0,0.0)])
p = contour(T, S, (t,s)->abs(s-pomdp.light_loc),
bg_inside=:black,
fill=true,
xlim=(0, tmax),
ylim=(smin, smax),
color=inv_grays,
xlabel="Time",
ylabel="State",
cbar=false,
legend=:topright,
title=@sprintf("%s (Reward: %8.2f)", heading, discounted_reward(hist))
)
plot!(p, [0, tmax], [0, 0], linewidth=1, color="green", label="Goal", line=:dash)
scatter!(p, pts, pss, color="lightblue", label="Belief Particles", markersize=pws, marker=stroke(0.1, 0.3))
plot!(p, 0:length(vsh)-1, vsh, linewidth=3, color="orangered", label="Trajectory")
return p
end;

using ParticleFilters
rng = MersenneTwister(7)

p = SimpleLightDark()
pf = SIRParticleFilter(p, 10000, rng=rng)

h = sim(p, updater=pf, initial_state=1, initial_obs=initial_state_distribution(p), max_steps=80, rng=rng) do b
return rand(rng, [-1,1])
end

plothist(p, h)

using QMDP

solver = QMDPSolver()

policy = solve(solver, p, verbose=true);

srand(rng, 14)
hr = HistoryRecorder(max_steps=80, initial_state=1)
h = simulate(hr, p, policy, pf)
plothist(p, h, "QMDP")

