In [1]:
%matplotlib inline
In [2]:
import numpy
import sympy

from mpl_toolkits.mplot3d import Axes3D
import matplotlib
from matplotlib import cm
from matplotlib import pyplot as plt

A Naive Approach

...consisting of creating a complex grid and simply taking the square root of the grid.

In [3]:
a,b = -1,1
n = 16
x,y = numpy.mgrid[a:b:(1j*n),a:b:(1j*n)]
z = x + 1j*y
w = z**2
In [4]:
fig = plt.figure(figsize=(16,8))

# plot the real part
ax_real = fig.add_subplot(1,2,1,projection='3d')
ax_real.plot_surface(z.real, z.imag, w.real,
                    rstride=1, cstride=1, cmap=cm.jet)

# plot the imaginary part
ax_imag = fig.add_subplot(1,2,2,projection='3d')
ax_imag.plot_surface(z.real, z.imag, w.imag,
                    rstride=1, cstride=1, cmap=cm.jet)
Out[4]:
<mpl_toolkits.mplot3d.art3d.Poly3DCollection at 0x3ae0f10>
In [5]:
a,b = -1,1
n = 16
x,y = numpy.mgrid[a:b:(1j*n),a:b:(1j*n)]
z = x + 1j*y
w = numpy.sqrt(z)
In [6]:
fig = plt.figure(figsize=(16,8))

# plot the real part
ax_real = fig.add_subplot(1,2,1,projection='3d')
ax_real.plot_surface(z.real, z.imag, w.real,
                    rstride=1, cstride=1, cmap=cm.jet)

# plot the imaginary part
ax_imag = fig.add_subplot(1,2,2,projection='3d')
ax_imag.plot_surface(z.real, z.imag, w.imag,
                    rstride=1, cstride=1, cmap=cm.jet)
Out[6]:
<mpl_toolkits.mplot3d.art3d.Poly3DCollection at 0x3aeec90>

The Issue

In many (all) complex square-root functions a branch cut is chosen (usually the negative real axis) and only one branch is computed in order to make square root single-valued.

From complex analysis we learn that we can use polar coordinates to easily verify that $w = \sqrt{z}$, which I will read as $w^2 = z$ since the latter makes it clear that I won't be chosing a branch of $w$ / making a branch cut in the $z$ plane. Because the branch point $z=0$ is 2-ramified / has branching number two I will need to make two rotations to capture all of the behavior near the branch point.

In [7]:
branching_number = 2

Nr = 16
Ntheta = 32

# compute the theta,R domain
theta = numpy.linspace(0,2*numpy.pi*branching_number, Ntheta)
r = numpy.linspace(0,1,Nr)
Theta, R = numpy.meshgrid(theta,r)

z = R*numpy.exp(1j*Theta)

# compute w^2 = z. THE KEY IDEA is to pass the exponentiation by 1/2 into exp().
w = numpy.sqrt(R)*numpy.exp(1j*Theta/2)
In [8]:
fig = plt.figure(figsize=(16,8))

# plot the real part
ax_real = fig.add_subplot(1,2,1,projection='3d')
ax_real.plot_surface(z.real, z.imag, w.real,
                    rstride=1, cstride=1, cmap=cm.jet, alpha=0.5)

# plot the imaginary part
ax_imag = fig.add_subplot(1,2,2,projection='3d')
ax_imag.plot_surface(z.real, z.imag, w.imag,
                    rstride=1, cstride=1, cmap=cm.jet, alpha=0.5)
Out[8]:
<mpl_toolkits.mplot3d.art3d.Poly3DCollection at 0x4152110>
In [9]:
# plot abs(w)
fig = plt.figure(figsize=(8,8))
ax_real = fig.add_subplot(1,1,1,projection='3d')

ax_real.plot_surface(z.real, z.imag, abs(w),
                     rstride=1, cstride=1, cmap=cm.jet, alpha=0.5)
Out[9]:
<mpl_toolkits.mplot3d.art3d.Poly3DCollection at 0x43cf3d0>
In [10]:
%load simple.py
In [ ]:
import numpy
import sympy

from mpl_toolkits.mplot3d import Axes3D
import matplotlib
from matplotlib import cm, colors
from matplotlib import pyplot as plt

branching_number = 2

Nr = 16
Ntheta = 32

# compute the theta,R domain
theta = numpy.linspace(0,2*numpy.pi*branching_number, Ntheta)
r = numpy.linspace(0,1,Nr)
Theta, R = numpy.meshgrid(theta,r)

z = R*numpy.exp(1j*Theta)

# compute w^2 = z. THE KEY IDEA is to pass the exponentiation by 1/2 into exp().
w = numpy.sqrt(R)*numpy.exp(1j*Theta/2)

# color by argument
arguments = numpy.angle(w)
norm = colors.Normalize(arguments.min(), arguments.max())
color = cm.jet(norm(arguments))

fig = plt.figure(figsize=(16,8))

# plot the real part
ax_real = fig.add_subplot(1,2,1,projection='3d')
ax_real.plot_surface(z.real, z.imag, w.real,
                    rstride=1, cstride=1, alpha=0.5, facecolors=color)

# plot the imaginary part
ax_imag = fig.add_subplot(1,2,2,projection='3d')
ax_imag.plot_surface(z.real, z.imag, w.imag,
                    rstride=1, cstride=1, alpha=0.5, facecolors=color)
In [11]:
import numpy
import sympy

from mpl_toolkits.mplot3d import Axes3D
import matplotlib
from matplotlib import cm, colors
from matplotlib import pyplot as plt

branching_number = 3

Nr = 8
Ntheta = 128

# compute the theta,R domain
theta = numpy.linspace(0,2*numpy.pi*branching_number, Ntheta)
r = numpy.linspace(0,1,Nr)
Theta, R = numpy.meshgrid(theta,r)

z = R*numpy.exp(1j*Theta)

# compute w^2 = z. THE KEY IDEA is to pass the exponentiation by 1/2 into exp().
#w = numpy.sqrt(R)*numpy.exp(1j*Theta/2)
w = numpy.log(R) + 1j*Theta

# color by argument
arguments = numpy.angle(w)
norm = colors.Normalize(arguments.min(), arguments.max())
color = cm.hsv(norm(arguments))

fig = plt.figure(figsize=(16,8))

# plot the real part
ax_real = fig.add_subplot(1,2,1,projection='3d')
ax_real.plot_surface(z.real, z.imag, abs(w),
                    rstride=1, cstride=1, alpha=0.5, facecolors=color)

# plot the imaginary part
ax_imag = fig.add_subplot(1,2,2,projection='3d')
ax_imag.plot_surface(z.real, z.imag, w.imag,
                    rstride=1, cstride=1, alpha=0.5, facecolors=color)
-c:23: RuntimeWarning: divide by zero encountered in log
/local/certik/bld/profile/d2nhm3s6qzwn/lib/python2.7/site-packages/numpy/core/numeric.py:1486: RuntimeWarning: invalid value encountered in multiply
  y = a[2]*b[0] - a[0]*b[2]
/local/certik/bld/profile/d2nhm3s6qzwn/lib/python2.7/site-packages/mpl_toolkits/mplot3d/axes3d.py:1673: RuntimeWarning: invalid value encountered in divide
  for n in normals])
/local/certik/bld/profile/d2nhm3s6qzwn/lib/python2.7/site-packages/matplotlib/colors.py:394: RuntimeWarning: invalid value encountered in greater
  if (c.ravel() > 1).any() or (c.ravel() < 0).any():
/local/certik/bld/profile/d2nhm3s6qzwn/lib/python2.7/site-packages/matplotlib/colors.py:394: RuntimeWarning: invalid value encountered in less
  if (c.ravel() > 1).any() or (c.ravel() < 0).any():
Out[11]:
<mpl_toolkits.mplot3d.art3d.Poly3DCollection at 0x48ef8d0>
/local/certik/bld/profile/d2nhm3s6qzwn/lib/python2.7/site-packages/mpl_toolkits/mplot3d/proj3d.py:156: RuntimeWarning: invalid value encountered in divide
  txs, tys, tzs = vecw[0]/w, vecw[1]/w, vecw[2]/w