# Energies for the Schrödinger equation in a well with a potential¶

In [1]:
from sympy import *
from sympy import symbols
import numpy as np
import matplotlib.pyplot as plt#, mpld3
from matplotlib import rcParams
from scipy.linalg import eigh

In [3]:
%matplotlib notebook
rcParams['font.family'] = 'serif'
rcParams['font.size'] = 16
rcParams['axes.labelsize'] = 18
rcParams['axes.titlesize'] = 20

init_printing()

In [4]:
x = symbols('x')

In [5]:
def approx_sol(num, approx_type='poly'):
"""Compute the trial function using polynomials or sines

Parameters
----------
num : int
Number of terms in the expansion.

Returns
-------
u_n : Sympy expression
Trial function for the given basis.
c : (num) Sympy symbols list
List of coefficients.

"""
c = symbols('c0:%d'%num)
if approx_type=='poly':
u_n = x*(1 - x)*sum([c[k]*x**k for k in range(num)])
if approx_type=='sine':
u_n = sum([c[k]*sin((k + 1)*x) for k in range(num)])

return u_n, c

In [6]:
def functs(u, V, x, x0, x1):
"""Functional for a given trial function

Parameters
----------
u : Sympy expression
Approximant solution.
x : Sympy symbol
Independent variable.

Returns
-------
J : Sympy expression
Functional for the given approximant function.

"""
Hint = expand(S(1)/2*(diff(u,x))**2 + V*u**2)
Hf = integrate(Hint, (x, x0, x1))
Sf = integrate(expand(u**2), (x, x0, x1))
return Hf, Sf

In [7]:
n = 5
u_n, c = approx_sol(n, 'sine')

In [8]:
Hf, Sf = functs(u_n, x*(2*pi - x), x, 0, 2*pi)

In [9]:
Hmat = Matrix(n, n, lambda i,j: diff(Hf, c[i], c[j]) )

In [10]:
Mmat = Matrix(n, n, lambda i,j: diff(Sf, c[i], c[j]) )

In [11]:
vals, vecs = eigh(np.array(Hmat).astype(float), b=np.array(Mmat).astype(float))

In [12]:
vals

Out[12]:
array([ 5.72324464,  9.04413044, 11.55005606, 14.95945642, 19.85359933])
In [13]:
x_vec = np.linspace(0, 2*np.pi, 201)
efuns = [sum(vecs[j, k]*np.sin((j+1)*x_vec) for j in range(n)) for k in range(n)]

In [14]:
plt.figure(figsize=(6, 4))
for k in range(n):
plt.plot(x_vec, efuns[k])

plt.xlabel("x")
plt.ylabel("Wave function");

In [15]:
from IPython.core.display import HTML
def css_styling():