using BenchmarkTools using Plots type Model # primitive parameter beta::Float64 #subjective discount factor sigma::Float64 # relative riskb aversion delta::Float64 #depriciation rate alpha::Float64 # capital share # discretize asset space agrid::Vector{Float64} end function argmax(mat) values, indices = findmax(mat,2) return ind2sub(size(mat),vec(indices))[2] end function VFI(m::Model; max_iter::Int=1000, tol::Float64=1e-5) const penalty = -999999999.9 const na = size(m.agrid, 1) #initialize value function and so on v = zeros(na, na) # temp value function c = zeros(na, na) # consuption matrix util = zeros(na, na) # utility matrix v0 = zeros(na, 1) # initial guess of value function Tv = zeros(na, 1) # update value function pol_a = zeros(na, 1) # policy function #create consuption and utility matrix for i in 1:na for j in 1:na @inbounds c[j,i] = m.agrid[j]^m.alpha + (1-m.delta) * m.agrid[j] - m.agrid[i] @inbounds util[j,i] =(c[j,i]^(1-m.sigma)) / (1-m.sigma) if c[j,i] <= 0 @inbounds util[j,i] = penalty # penalty end end end # value function iteration err = 0.0 for t in 1:max_iter # calculate temp value function for i in 1:na for j in 1:na @inbounds v[j, i] = util[j, i] + m.beta * v0[i] end end Tv = maximum(v, 2) # obtain new value funtion err = maximum(abs.(Tv - v0)) # update error v0 = Tv # update value function if err < tol # obtain policy function a_index = argmax(v) for i in 1:na @inbounds pol_a[i] = m.agrid[a_index[i]] end break end end if err >= tol println("VFI does not converge in $(max_iter) times") end return(m.agrid, v0, pol_a) end beta = 0.95 #subjective discount factor sigma = 2.0 # relative riskb aversion delta = 0.1 #depriciation rate alpha = 0.33 # capital share # Steady state aterm = 1.0/beta -(1.0 -delta) kstar = alpha/aterm kstar = kstar^(1.0/(1.0-alpha)) amin = 0.1 * kstar amax = 2 * kstar na = 250 model = Model(beta, sigma, delta, alpha, linspace(amin, amax, na)) agrid, v0, pol_a = VFI(model) @benchmark VFI(model) plot(agrid, pol_a, color="blue", linewidth=1.5, label="Policy Function") plot!(agrid, agrid, color="red", linewidth=1.5, label="45 degree Line") plot(agrid, v0, color="blue", linewidth=1.5, label="Value Function") @code_warntype VFI(model) @time VFI(model)