Decomposition: $$ \min_{x\in\R^n} \;\; f(x) = \underbrace{F(x)}_{\text{convex, smooth}} + \underbrace{\Psi(x)}_{\text{convex, nonsmooth, simple}} $$
Ex. Convex regularization problems
We define the proximal operator
$$ \prox_{h}(z) := \argmin_{x\in\R^n} \; \left\{ \frac{1}{2} \|x - z\|^2 + h(x) \right\} $$The proximal gradient descent update can be written as
\begin{align*} x_{k+1} &= \argmin_{x\in\R^n} \; \left\{ \frac{1}{2\alpha_k}\| x - (x_k -\alpha_k \nabla F(x_k)) \|^2 + \Psi(x) \right\} \\ &= \prox_{\alpha_k \Psi} ( x_k - \alpha_k \nabla F(x_k)) \end{align*}That is,
$$ x_{k+1} = x_k - \alpha_k G_{\alpha_k}^\Psi(x_k) $$$G_{\alpha_k}^\Psi(x_k)$ can be seen as a generalized gradient
Each iteration needs to compute the prox operation:
$$ \text{prox}_{\Psi}(z) := \argmin_{x\in\R^n} \; \left\{ \frac{1}{2} \|x - z\|^2 + \Psi(x) \right\} $$We call $\Psi$ "simple", if its prox op can be solved efficiently
Many interesting regularizers $\Psi$ have this property
This is so-called the soft-thresholding operator
$F$ is required to be $L$-strongly smooth, i.e.,
$$ \| \nabla F(x) - \nabla F(y) \|_2 \le L \|x - y\|_2, \;\; \forall x, y \in\R^n $$Then with stepsize $\alpha_k = \alpha \le \frac{1}{L}$, we have monotone ($f(x_{k+1}) \le f(x_k)$) convergence with rate:
$$ f(x_{k}) - f(x^*) \le \frac{\|x_0 - x^*\|^2}{2\alpha k} $$#
# Proximal gradient algorithm to solve
#
# min_x .5 || y - Ax ||^2 + λ ||x||_1
#
# Author: Sangkyun Lee (sangkyun.lee@cs.tu-dortmund.de)
#
#
function prox_L1(z, λ)
x = z;
for i=1:length(x)
x[i] = sign(z[i])*max(abs(z[i]) - λ, 0)
end
return x;
end
function pgd(A, y, λ, L; maxiter=1000, ϵ=1e-3)
k, n = size(A)
x = zeros(n);
z = zeros(n);
Ax = zeros(k);
opt = Inf;
Ay = A'*y;
for k=1:maxiter
g = -Ay + A'*Ax;
x = prox_L1(x - (1/L)*g, λ/L)
Ax = A*x;
if(k%50==0)
f = .5*vecnorm(y-Ax,2)^2 + λ*vecnorm(x,1)
g = -Ay + A'*Ax
for i=1:length(x)
if(abs(x[i])>0)
z[i] = g[i] + λ*sign(x[i])
else
z[i] = sign(g[i])*max(abs(g[i]) - λ, 0)
end
end
opt = vecnorm(z, Inf)
@printf "%04d: f=%5.3e, opt=%5.3e, nnz=%04d, L=%5.3e\n" k f opt countnz(x) L
if(opt < ϵ)
println("Optimal solution found at iter $k")
break;
end
end
end
println("")
return(x)
end
pgd (generic function with 1 method)
Solve the inverse problem : can we reconstruct the original signal from only a few observations?
The sensing matrix $A \in \R^{k\times n}$ must satisfy RIP for recovery:
$$ \exists \epsilon \in (0,1): \;\; (1-\epsilon) \|x\|_2^2 \le \|Ax\|_2^2 \le (1+\epsilon) \|x\|_2^2 $$for all $s$-sparse vectors $x\in\R^n$.
Random matrices satisfy RIP with high probability, if
$$ k \approx s \log(n/s) \ll n $$[Baraniuk, Davenport, DeVore, & Wakin, 2008]
When $A$ satisfies the RIP, then we can recover the true $s$-sparse signal exactly with a very high probability by solving
$$ \min_{x\in\R^n} \;\; \frac12 \|y - Ax\|^2 + \lambda \|x\|_1 $$#
# Noisy Compressed sensing
#
using StatsBase, PyPlot
n=1000
s=10
x=zeros(n)
x[sample(1:n,s)] = rand([-1,1], s).*sqrt(log(n))
k=round(Int, 3*s*log(n/s))
A=randn(k,n)
println("A: k=$k, n=$n\n")
# noisy measurement
y=A*x + .1*randn(k)
# reconstruction
z = pgd(A, y, 10, norm(A,2)^2, maxiter=10000)
println("True signal: sparsity=$(countnz(x)), dim=$n")
println("Recovered signal: sparsity=$(countnz(z)), mse=$(vecnorm(x-z,2)^2/n)")
sx=find(abs(x).>0)
sz=find(abs(z).>0)
print(sx)
print(sz)
subplot(211); plot(1:n, x); ylabel("True signal")
subplot(212); plot(1:n, z); ylabel("Recoverd signal")
A: k=138, n=1000 0050: f=4.599e+02, opt=1.642e+01, nnz=0332, L=1.812e+03 0100: f=4.068e+02, opt=1.200e+01, nnz=0238, L=1.812e+03 0150: f=3.730e+02, opt=9.265e+00, nnz=0191, L=1.812e+03 0200: f=3.497e+02, opt=8.749e+00, nnz=0163, L=1.812e+03 0250: f=3.300e+02, opt=9.368e+00, nnz=0140, L=1.812e+03 0300: f=3.134e+02, opt=8.884e+00, nnz=0118, L=1.812e+03 0350: f=2.992e+02, opt=8.112e+00, nnz=0104, L=1.812e+03 0400: f=2.868e+02, opt=7.489e+00, nnz=0094, L=1.812e+03 0450: f=2.757e+02, opt=6.859e+00, nnz=0078, L=1.812e+03 0500: f=2.662e+02, opt=5.948e+00, nnz=0055, L=1.812e+03 0550: f=2.606e+02, opt=3.804e+00, nnz=0033, L=1.812e+03 0600: f=2.593e+02, opt=9.205e-01, nnz=0012, L=1.812e+03 0650: f=2.593e+02, opt=6.517e-02, nnz=0011, L=1.812e+03 0700: f=2.593e+02, opt=4.182e-03, nnz=0010, L=1.812e+03 0750: f=2.593e+02, opt=2.920e-04, nnz=0010, L=1.812e+03 Optimal solution found at iter 750 True signal: sparsity=10, dim=1000 Recovered signal: sparsity=10, mse=7.241539535283949e-5 [42,374,496,524,592,753,767,797,848,873][42,374,496,524,592,753,767,797,848,873]
PyObject <matplotlib.text.Text object at 0x318ef7450>
using Images, ImageView, Interact, Colors, PyPlot, Wavelets
pic = load("./tu100.jpg");
# # pic = load("./checker100.jpg"); # BW image!
# pic = load("./puzzle100.png");
# pic = load("./german100.gif");
arrays = float(separate(pic.data));
pic
# Wavelet transform
wv = wavelet(WT.haar)
#wv = wavelet(WT.db6)
#wv = wavelet(WT.sym4)
#wv = wavelet(WT.coif4)
#wv = wavelet(WT.batt2)
#wv = wavelet(WT.vaid)
Rw = dwt(arrays[:,:,1], wv);
Gw = dwt(arrays[:,:,2], wv);
Bw = dwt(arrays[:,:,3], wv);
@printf "nnz R=%f\n" countnz(Rw)/size(Gw,1)/size(Gw,2)
@printf "nnz G=%f\n" countnz(Gw)/size(Gw,1)/size(Gw,2)
@printf "nnz B=%f\n" countnz(Bw)/size(Gw,1)/size(Gw,2)
nnz R=0.270200 nnz G=0.265300 nnz B=0.281500
# Sparsify the signal
th = 1
im_m, im_n = size(Rw)
Rw = reshape(threshold(Rw[:], HardTH(), th), im_m,im_n)
Gw = reshape(threshold(Gw[:], HardTH(), th), im_m,im_n)
Bw = reshape(threshold(Bw[:], HardTH(), th), im_m,im_n)
@printf "nnz R=%f\n" countnz(Rw)/size(Gw,1)/size(Gw,2)
@printf "nnz G=%f\n" countnz(Gw)/size(Gw,1)/size(Gw,2)
@printf "nnz B=%f\n" countnz(Bw)/size(Gw,1)/size(Gw,2)
s = countnz(Gw)
original = Image(map(RGB, idwt(Rw,wv), idwt(Gw,wv), idwt(Bw,wv)))
nnz R=0.062700 nnz G=0.062500 nnz B=0.052500
# Prepare compressed sensing data
# true signals as vectors
r = Rw[:]
g = Gw[:]
b = Bw[:]
# dimension
n = length(r)
# sensing matrix
k = round(Int, 2*s*log(n/s))
println("no observations k=$k, dimension n=$n")
no observations k=3466, dimension n=10000
# create sensing matrix and observations
A = randn(k, n);
yR = A*r;
yG = A*g;
yB = A*b;
# L = norm(A,2)^2
L = 0;
for r=1:5
z = randn(k);
z = z/norm(z);
L = max(L, vecnorm(z'*A,2)^2)
end
L=2L
20169.8539672471
if(nprocs() < 4)
addprocs(3)
end
#
# The following functions are from:
#
# http://stackoverflow.com/questions/27677399/julia-how-to-copy-data-to-another-processor-in-julia
#
function sendto(p::Int; args...)
for (nm, val) in args
@spawnat(p, eval(Main, Expr(:(=), nm, val)))
end
end
function sendto(ps::Vector{Int}; args...)
for p in ps
sendto(p; args...)
end
end
getfrom(p::Int, nm::Symbol; mod=Main) = fetch(@spawnat(p, getfield(mod, nm)))
function passobj(src::Int, target::Vector{Int}, nm::Symbol;
from_mod=Main, to_mod=Main)
r = RemoteRef(src)
@spawnat(src, put!(r, getfield(from_mod, nm)))
for to in target
@spawnat(to, eval(to_mod, Expr(:(=), nm, fetch(r))))
end
nothing
end
function passobj(src::Int, target::Int, nm::Symbol; from_mod=Main, to_mod=Main)
passobj(src, [target], nm; from_mod=from_mod, to_mod=to_mod)
end
function passobj(src::Int, target, nms::Vector{Symbol};
from_mod=Main, to_mod=Main)
for nm in nms
passobj(src, target, nm; from_mod=from_mod, to_mod=to_mod)
end
end
passobj (generic function with 3 methods)
@everywhere λ = 1000;
sendto(workers(), A=A)
sendto(workers()[1], yC=yR)
sendto(workers()[2], yC=yG)
sendto(workers()[3], yC=yB)
sendto(workers(), λ=λ)
sendto(workers(), L=L)
@everywhere include("../julia/proximalGD.jl")
rrefs = [@spawnat pid pgd(A, yC, λ, L) for pid in workers()]
xR,xG,xB = [fetch(rref) for rref in rrefs];
From worker 2: 0050: f=2.160e+06, opt=5.475e+02, nnz=2914, L=2.017e+04 From worker 3: 0050: f=2.367e+06, opt=5.322e+02, nnz=3079, L=2.017e+04 From worker 4: 0050: f=1.820e+06, opt=7.042e+02, nnz=2433, L=2.017e+04 From worker 2: 0100: f=2.029e+06, opt=3.993e+02, nnz=1996, L=2.017e+04 From worker 3: 0100: f=2.222e+06, opt=4.210e+02, nnz=2154, L=2.017e+04 From worker 4: 0100: f=1.672e+06, opt=4.726e+02, nnz=1413, L=2.017e+04 From worker 2: 0150: f=1.981e+06, opt=2.229e+02, nnz=1463, L=2.017e+04 From worker 3: 0150: f=2.161e+06, opt=2.864e+02, nnz=1592, L=2.017e+04 From worker 4: 0150: f=1.647e+06, opt=1.112e+02, nnz=0798, L=2.017e+04 From worker 2: 0200: f=1.971e+06, opt=7.559e+01, nnz=1096, L=2.017e+04 From worker 3: 0200: f=2.143e+06, opt=1.074e+02, nnz=1165, L=2.017e+04 From worker 4: 0200: f=1.647e+06, opt=9.609e+00, nnz=0713, L=2.017e+04 From worker 2: 0250: f=1.970e+06, opt=1.588e+01, nnz=1006, L=2.017e+04 From worker 3: 0250: f=2.142e+06, opt=2.344e+01, nnz=1010, L=2.017e+04 From worker 4: 0250: f=1.647e+06, opt=8.260e-01, nnz=0706, L=2.017e+04 From worker 2: 0300: f=1.970e+06, opt=3.378e+00, nnz=0985, L=2.017e+04 From worker 3: 0300: f=2.142e+06, opt=5.136e+00, nnz=0978, L=2.017e+04 From worker 4: 0300: f=1.647e+06, opt=7.313e-02, nnz=0706, L=2.017e+04 From worker 2: 0350: f=1.970e+06, opt=6.755e-01, nnz=0981, L=2.017e+04 From worker 3: 0350: f=2.142e+06, opt=9.977e-01, nnz=0970, L=2.017e+04 From worker 4: 0350: f=1.647e+06, opt=6.461e-03, nnz=0706, L=2.017e+04 From worker 2: 0400: f=1.970e+06, opt=1.311e-01, nnz=0981, L=2.017e+04 From worker 3: 0400: f=2.142e+06, opt=1.802e-01, nnz=0969, L=2.017e+04 From worker 4: 0400: f=1.647e+06, opt=5.705e-04, nnz=0706, L=2.017e+04 From worker 4: Optimal solution found at iter 400 From worker 4: From worker 2: 0450: f=1.970e+06, opt=2.554e-02, nnz=0980, L=2.017e+04 From worker 3: 0450: f=2.142e+06, opt=3.266e-02, nnz=0968, L=2.017e+04 From worker 2: 0500: f=1.970e+06, opt=4.896e-03, nnz=0980, L=2.017e+04 From worker 3: 0500: f=2.142e+06, opt=5.915e-03, nnz=0968, L=2.017e+04 From worker 2: 0550: f=1.970e+06, opt=9.415e-04, nnz=0980, L=2.017e+04 From worker 2: Optimal solution found at iter 550 From worker 2: From worker 3: 0550: f=2.142e+06, opt=1.073e-03, nnz=0968, L=2.017e+04 From worker 3: 0600: f=2.142e+06, opt=1.947e-04, nnz=0968, L=2.017e+04 From worker 3: Optimal solution found at iter 600 From worker 3:
@printf "nnz R=%f\n" countnz(xR)/n
@printf "nnz G=%f\n" countnz(xG)/n
@printf "nnz B=%f\n" countnz(xB)/n
xR = reshape(xR, im_m,im_n);
xG = reshape(xG, im_m,im_n);
xB = reshape(xB, im_m,im_n);
recovered = Image(map(RGB, idwt(xR,wv), idwt(xG,wv), idwt(xB,wv))')
arr_ori = float(separate(original.data));
arr_rec = float(separate(recovered.data));
mse = [sum((arr_ori[:,:,i]-arr_rec[:,:,i]).^2)/n for i in 1:3]
@show(mse)
println("Recovered:")
recovered
nnz R=0.098000 nnz G=0.096800 nnz B=0.070600 mse = Any[0.0975256019108048,0.04138135801115502,0.34997008264656954] Recovered:
WARNING: InexactError() in trunc at float.jl:374 in call at /Users/sklee/.julia/v0.4/ColorTypes/src/types.jl:399 in copy! at abstractarray.jl:310 in convert at /Users/sklee/.julia/v0.4/Images/src/core.jl:372 in save_ at /Users/sklee/.julia/v0.4/QuartzImageIO/src/QuartzImageIO.jl:258 in getblob at /Users/sklee/.julia/v0.4/QuartzImageIO/src/QuartzImageIO.jl:282 in save at /Users/sklee/.julia/v0.4/QuartzImageIO/src/QuartzImageIO.jl:48 in save at /Users/sklee/.julia/v0.4/FileIO/src/loadsave.jl:95 in writemime at /Users/sklee/.julia/v0.4/Images/src/writemime.jl:30 in base64encode at base64.jl:160 in display_dict at /Users/sklee/.julia/v0.4/IJulia/src/execute_request.jl:32 in execute_request_0x535c5df2 at /Users/sklee/.julia/v0.4/IJulia/src/execute_request.jl:214 in eventloop at /Users/sklee/.julia/v0.4/IJulia/src/IJulia.jl:143 in anonymous at task.jl:447 Trying next loading library! Please report this issue on the Github page for QuartzImageIO