Conjugate gradient method

Consider solving $$ -u_{xx} = f(x), \qquad x \in [0,1] $$ with boundary condition $$ u(0) = u(1) = 0 $$ Choose $$ f(x) = 1 $$ The exact solution is $$ u(x) = \frac{1}{2}x(1-x) $$ Make a partition of $n$ intervals with spacing and grid points $$ h = \frac{1}{n}, \qquad x_i = i h, \qquad i=0,1,\ldots,n $$ The finite difference scheme is \begin{eqnarray*} u_0 &=& 0 \

  • \frac{u_{i-1} - 2 ui + u{i+1}}{h^2} &=& f_i, \qquad i=1,2,\ldots,n-1 \ u_n &=& 0 \end{eqnarray*} We have a matrix equation $$ Au = f $$

Algorithm

  • Set initial guess $u_0 = 0$, $r_0 = f - A u_0 = f$, $p_0 = 0$
  • For $k=0,1,\ldots$
    • If $\| r_k \| < TOL \cdot \|f\|$, then stop
    • If $k=0$, $\beta_1 = 0$
    • If $k > 0$, $\beta_{k+1} = \frac{r_k^\top r_k}{r_{k-1}^\top r_{k-1}}$
    • $p_{k+1} = r_k + \beta_{k+1} p_k$
    • $\alpha_{k+1} = \frac{r_k^\top r_k}{p_{k+1}^\top A p_{k+1}}$
    • $u_{k+1} = u_k + \alpha_{k+1} p_{k+1}$
    • $r_{k+1} = r_k - \alpha_{k+1} p_{k+1}$

Code

In [1]:
import numpy as np
from matplotlib import pyplot as plt

This function computes the matrix-vector product.

In [2]:
def ax(h,u):
    n = len(u) - 1
    r = np.zeros(n+1)
    for i in range(1,n):
        r[i] = -(u[i-1]-2*u[i]+u[i+1])/h**2
    return r
In [3]:
xmin, xmax = 0.0, 1.0
n = 100
h = (xmax - xmin)/n

x = np.linspace(0.0, 1.0, n+1)
f = np.ones(n+1)
ue= 0.5*x*(1.0-x)

TOL   = 1.0e-6
itmax = 100

u   = np.zeros(n+1)
p   = np.zeros(n+1)
res = np.array(f)

# First and last grid point, solution is fixed to zero.
# Hence we make residual zero, in which case solution
# will not change at these points.
res[0] = 0.0
res[n] = 0.0

f_norm = np.linalg.norm(f)
res_old, res_new = 0.0, 0.0
for it in range(itmax):
    res_new = np.linalg.norm(res)
    print it, res_new
    if res_new < TOL * f_norm:
        break
    if it==0:
        beta = 0.0
    else:
        beta = res_new**2 / res_old**2
    p = res + beta * p
    ap= ax(h,p)
    alpha = res_new**2 / p.dot(ap)
    u += alpha * p
    res -= alpha * ap
    res_old = res_new

print "Number of iterations = %d" % it
plt.plot(x,ue,x,u)
plt.legend(("Exact","Numerical"));
0 9.94987437107
1 69.2928567747
2 67.8785680462
3 66.4642761188
4 65.049980784
5 63.6356818145
6 62.2213789625
7 60.8070719571
8 59.3927605016
9 57.9784442703
10 56.5641229049
11 55.1497960105
12 53.7354631505
13 52.3211238411
14 50.9067775448
15 49.4924236626
16 48.078061525
17 46.6636903813
18 45.249309387
19 43.8349175886
20 42.4205139054
21 41.0060971076
22 39.5916657897
23 38.1772183376
24 36.7627528893
25 35.3482672843
26 33.9337590019
27 32.519225083
28 31.10466203
29 29.6900656786
30 28.2754310312
31 26.8607520371
32 25.4460213
33 24.0312296814
34 22.6163657558
35 21.2014150471
36 19.7863589374
37 18.3711730709
38 16.9558249578
39 15.5402702679
40 14.1244468918
41 12.7082650271
42 11.2915897906
43 9.87420882907
44 8.45576726264
45 7.03562363974
46 5.61248608016
47 4.18330013267
48 2.73861278753
49 1.22474487139
50 1.22558586276e-14
Number of iterations = 50