import k3d
import numpy as np
import scipy.special
import scipy.misc
r = lambda x,y,z: np.sqrt(x**2+y**2+z**2)
theta = lambda x,y,z: np.arccos(z/r(x,y,z))
phi = lambda x,y,z: np.arctan(y/x)
a0 = 1.
R = lambda r,n,l: (2*r/n/a0)**l * np.exp(-r/n/a0) * scipy.special.genlaguerre(n-l-1,2*l+1)(2*r/n/a0)
WF = lambda r,theta,phi,n,l,m: R(r,n,l) * scipy.special.sph_harm(m,l,phi,theta)
absWF = lambda r,theta,phi,n,l,m: abs(WF(r,theta,phi,n,l,m))**2
N=100j
a = 200.0
x,y,z = np.ogrid[-a:a:N,-a:a:N,-a:a:N]
x = x.astype(np.float32)
y = y.astype(np.float32)
z = z.astype(np.float32)
orbital = absWF(r(x,y,z),theta(x,y,z),phi(x,y,z),1,0,0) # 1s
plot = k3d.plot()
plot.display()
plot.grid_auto_fit = False
E = 10
volume_animation = {}
label_animation = {}
i = 0
for l in range(E):
print(l, '/', E-1, end='\r')
for m in range(-l,l+1):
psi2 = absWF(r(x, y, z), theta(x, y, z), phi(x, y, z), E, l, m)
volume_animation[str(i)] = (psi2/np.max(psi2))
label_animation[str(i)] = 'n=%d \quad l=%d \quad m=%d' % (E,l,m)
i += 0.1
plot += k3d.text2d(label_animation, (0.,0.))
plot += k3d.volume(volume_animation, color_map=k3d.colormaps.basic_color_maps.CoolWarm,
color_range=(0.0,0.1))
plot.colorbar_object_id = 0
plot.start_auto_play()
plot.stop_auto_play()