This notebook expands on the Walk1D example, see the Walk1D.jl file for the non-notebook version and walk1d.pdf for the write-up version.
See the documentation for more details.
In this self-contained tutorial, we define a simple problem for adaptive stress testing (AST) to find failures. This problem, called Walk1D, samples random walking distances from a standard normal distribution $\mathcal{N}(0,1)$ and defines failures as walking past a certain threshold (which is set to ±10 in this example). AST will either select the seed which deterministically controls the sampled value from the distribution (i.e. from the transition model) or will directly sample the provided environmental distributions. These action modes are determined by the seed-action or sample-action options. AST will guide the simulation to failure events using a notion of distance to failure, while simultaneously trying to find the set of actions that maximizes the log-likelihood of the samples.
Some definitions to note for this example problem:
using POMDPStressTesting # this package
using Distributions # for the Normal distribution
using Parameters # for @with_kw default struct parameters
The simulator and environment are treated as gray-box because we need access to the state-transition distributions and their associated likelihoods. Refer to the gray-box definition section in the documentation for further details.
First, we define the parameters of our simulation.
@with_kw mutable struct Walk1DParams
startx::Float64 = 0 # Starting x-position
threshx::Float64 = 10 # +- boundary threshold
endtime::Int64 = 30 # Simulate end time
end
Walk1DParams
Next, we define a GrayBox.Simulation
structure which stores simulation-related values.
@with_kw mutable struct Walk1DSim <: GrayBox.Simulation
params::Walk1DParams = Walk1DParams() # Parameters
x::Float64 = 0 # Current x-position
t::Int64 = 0 # Current time ±
distribution::Distribution = Normal(0, 1) # Transition distribution
end
Walk1DSim
Then, we define our GrayBox.Environment
distributions.
When using the ASTSampleAction
, as opposed to ASTSeedAction
,
we need to provide access to the sampleable environment.
GrayBox.environment(sim::Walk1DSim) = GrayBox.Environment(:x => sim.distribution)
We override the transition function from the GrayBox
interface,
which takes an environment sample as input. We apply the sample in our simulator,
and return the log-likelihood.
function GrayBox.transition!(sim::Walk1DSim, sample::GrayBox.EnvironmentSample)
sim.t += 1 # Keep track of time
sim.x += sample[:x].value # Move agent using sampled value from input
return logpdf(sample)::Real # Summation handled by `logpdf()`
end
The system under test, in this case a simple single-dimensional moving agent, is always treated as black-box. The following interface functions are overridden to minimally interact with the system, and use outputs from the system to determine failure event indications and distance metrics. Refer to the black-box definition section of the documntation for further details.
Now we override the BlackBox
interface, starting with the
function that initializes the simulation object. Interface functions
ending in !
may modify the sim
object in place.
function BlackBox.initialize!(sim::Walk1DSim)
sim.t = 0
sim.x = sim.params.startx
end
We define how close we are to a failure event using a non-negative distance metric.
BlackBox.distance(sim::Walk1DSim) = max(sim.params.threshx - abs(sim.x), 0)
We define an indication that a failure event occurred.
BlackBox.isevent(sim::Walk1DSim) = abs(sim.x) >= sim.params.threshx
Similarly, we define an indication that the simulation (or system) is in a terminal state.
BlackBox.isterminal(sim::Walk1DSim) = BlackBox.isevent(sim) || sim.t >= sim.params.endtime
Lastly, we use our defined interface to evaluate the system under test. Using the input sample, we return the log-likelihood, distance to an event, and event indication.
function BlackBox.evaluate!(sim::Walk1DSim, sample::GrayBox.EnvironmentSample)
logprob::Real = GrayBox.transition!(sim, sample) # Step simulation
d::Real = BlackBox.distance(sim) # Calculate miss distance
event::Bool = BlackBox.isevent(sim) # Check event indication
return (logprob::Real, d::Real, event::Bool)
end
Setting up our simulation, we instantiate our simulation object and
pass that to the Markov decision proccess (MDP) object of the adaptive stress testing
formulation. We use Monte Carlo tree search (MCTS) with progressive widening on the action
space as our solver. Hyperparameters are passed to MCTSPWSolver
, which is
a simple wrapper around the POMDPs.jl
implementation of MCTS.jl. Lastly, we solve the MDP
to produce a planner
. Note we are using the ASTSampleAction
.
planner
is used to play out the searchplanner.mdp::ASTMDP
is the main MDP problem formulation object for AST (this holds reward metrics)planner.mdp.sim::Walk1DSim
is the main simulation object, holding all simulation information (e.g., current x position, settings for the simulation, etc)solver::MCTSPWSolver
holds solver-specific parameters and is used to generate the planner
function setup_ast(seed=0)
# Create gray-box simulation object
sim::GrayBox.Simulation = Walk1DSim()
# AST MDP formulation object
mdp::ASTMDP = ASTMDP{ASTSampleAction}(sim)
mdp.params.debug = true # record metrics
mdp.params.top_k = 10 # record top k best trajectories
mdp.params.seed = seed # set RNG seed for determinism
# Hyperparameters for MCTS-PW as the solver
solver = MCTSPWSolver(n_iterations=1000, # number of algorithm iterations
exploration_constant=1.0, # UCT exploration
k_action=1.0, # action widening
alpha_action=0.5, # action widening
depth=sim.params.endtime) # tree depth
# Get online planner (no work done, yet)
planner = solve(solver, mdp)
return planner
end
setup_ast (generic function with 2 methods)
After setup, we search for failures using the planner and output the best action trace.
planner = setup_ast();
action_trace = search!(planner)
Progress: 100%|█████████████████████████████████████████| Time: 0:00:00
10-element Array{ASTAction,1}: ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample} ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample} ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample} ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample} ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample} ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample} ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample} ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample} ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample} ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample}
We can also playback specific trajectories and print intermediate $x$-values.
final_state = playback(planner, action_trace, sim->sim.x)
0.0 0.015327970235574946 0.30370765717228826 1.4997663258450742 2.109577482502034 2.952826155527804 4.426053954911257 6.803697217082685 8.03548741774146 9.178959112657289 10.490317369782634
ASTState t_index: Int64 11 parent: ASTState action: ASTSampleAction hash: UInt64 0x12ff54a6c31b4cce q_value: Float64 -1.7787687724700838 terminal: Bool true
Finally, we can print metrics associated with the AST run for further analysis.
failure_rate = print_metrics(planner)
First failure: 23 of 62506 Number of failures: 510 Failure rate: 0.81592%
0.8159216715195341
When using the MCTSPWSolver
, we can output the tree from the search!
function and visulize it using D3Trees.jl
.
d3tree = visualize(planner) # re-runs the search to output the tree, then visualizes it
Progress: 100%|█████████████████████████████████████████| Time: 0:00:00
Attempting to display the tree. If the tree is large, this may take some time.
Note: D3Trees.jl requires an internet connection. If no tree appears, please check your connection. To help fix this, please see this issue. You may also diagnose errors with the javascript console (Ctrl-Shift-J in chrome).
POMDPStressTesting.jl comes with a variety of solvers:
Reinforcement learning
MCTSPWSolver
: Monte Carlo tree search with action progressive wideningDeep reinforcement learning
TRPOSolver
: Trust region policy optimizationPPOSolver
: Proximal policy optimizationStochastic optimization
CEMSolver
: Cross-entropy methodBaselines
RandomSearchSolver
: Standard/naive Monte Carlo randomized searchCEMSolver
)¶We can easily take our ASTMDP
object (planner.mdp
) and re-solve the MDP using a different solver.
mdp = planner.mdp # reused from above `setup_ast()`
solver = CEMSolver(n_iterations=1000, episode_length=mdp.sim.params.endtime)
planner = solve(solver, mdp)
action_trace = search!(planner)
Progress: 100%|█████████████████████████████████████████| Time: 0:00:05
7-element Array{ASTAction,1}: ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample} ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample} ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample} ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample} ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample} ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample} ASTSampleAction sample: Dict{Symbol,POMDPStressTesting.AST.GrayBox.Sample}
Again, playing back the best action trace and printing out metrics (but this time for the CEMSolver
).
final_state = playback(planner, action_trace, sim->sim.x)
failure_rate = print_metrics(planner)
0.0 -1.4681565073056968 -2.8790682106764236 -4.5318598317096015 -6.002385457333659 -7.38293658747257 -8.641661367101024 -10.000020551648118 First failure: 23 of 840947 Number of failures: 100793 Failure rate: 11.98565%
11.985654268342714
Plots the episodic metrics, including running miss distance mean, minimum miss distance, and cumulative failures all over episode (i.e. iteration).
Note we use Requires.jl to handle the PyPlot and Seaborn dependencies. So to plot, first install those two pacakges, then load them.
using Pkg
Pkg.add("PyPlot")
Pkg.add("Seaborn")
using PyPlot
using Seaborn
episodic_figures(planner.mdp.metrics; gui=false)
Plots miss distance distribution and log-likelihood distribution.
distribution_figures(planner.mdp.metrics; gui=false)