Spring Layout Animation

This notebook was used to develop a customized spring layout. It uses Bokeh to animate the algorithm.

You will need to run this notebook to see the animation.

Setup

Here's the conda commands to setup an environment for this notebook.

conda install ipython-notebook
conda install numba bokeh
In [ ]:
import random
import math
import numpy as np

from IPython.html.widgets import interact

from numba import jit
from bokeh import plotting as bp
from bokeh.models import GlyphRenderer, DataRange1d


bp.output_notebook()

Populate Data

In [ ]:
node_count = 30
edge_count = 20
nodes = list(range(node_count))
masses = np.random.random(node_count) * 20 + 10

def build_edges():
    edges = []
    while len(edges) < edge_count:
        sel_a = random.choice(nodes)
        sel_b = random.choice(list(set(nodes) - set([sel_a])))
        edges.append((sel_a, sel_b))
        
    return edges
    
edges = build_edges()

# Initial nodes' position with normally distributed random numbers
xpos = np.random.normal(size=node_count, scale=1000)
ypos = np.random.normal(size=node_count, scale=1000)

Draw Graph

In [ ]:
fig = bp.figure(width=600, height=600, x_range=(-500, 500), y_range=(-500, 500))

fig.circle(xpos, ypos, radius=masses / 2, fill_alpha=0.6, name="nodes")

xlines = []
ylines = []

arr_edges = np.array(edges)
xlines = xpos[arr_edges]
ylines = ypos[arr_edges]
fig.multi_line(xlines.tolist(), ylines.tolist(), name="edges", line_alpha=0.5)

fig.grid.grid_line_color = None
fig.axis.axis_line_color = None
fig.axis.major_tick_line_color = None
In [ ]:
# During developement, this remembers the original position of the graph.
# So we can update the later cells only.
orig_xpos = xpos.copy()
orig_ypos = ypos.copy()

The Layout Algorithm

In [ ]:
# Reset to original (for interactively changing the notebook)
xpos[:] = orig_xpos
ypos[:] = orig_ypos

DAMPENING = 0.2

past_xpos = xpos.copy()
past_ypos = ypos.copy()

@jit
def calc_force(i, j, fxs, fys, xs, ys, masses, strength):
    dx = xs[i] - xs[j]
    dy = ys[i] - ys[j]

    dist = math.hypot(dx, dy)
    theta = math.atan2(dy, dx)

    optimal_dist = (masses[i] + masses[j])/2
    force = (optimal_dist - dist) * strength

    fx = math.cos(theta) * force
    fy = math.sin(theta) * force
    fxs[i] += fx
    fys[i] += fy
    fxs[j] -= fx
    fys[j] -= fy

    
def update_force(fxs, fys, xs, ys, edges, masses):
    for i, j in edges:
        calc_force(i, j, fxs, fys, xs, ys, masses, 1.0)
        
    for i in range(len(xs)):
        for j in range(i + 1, len(ys)):
            calc_force(i, j, fxs, fys, xs, ys, masses, 1/len(xs))

@jit
def collision_avoid(xs, ys, masses):
    for i in range(len(xs)):
        for j in range(i + 1, len(ys)):
            dx = xs[i] - xs[j]
            dy = ys[i] - ys[j]
            dist = math.hypot(dx, dy)
            opt_dist = (masses[i] + masses[j])/1.8
            if dist < opt_dist:
                offset = (opt_dist - dist)/2
                if abs(dx) < opt_dist:
                    sign = -1 if dx < 0 else 1
                    xs[i] += offset * sign
                    xs[j] -= offset * sign
                if abs(dy) < opt_dist:
                    sign = -1 if dy < 0 else 1
                    ys[i] += offset * sign
                    ys[j] -= offset * sign

                    
def spring_fit_once(xs, ys, edges, masses, dt):
    num = len(xs)
    fxs = np.zeros(num, dtype=np.float32)
    fys = np.zeros(num, dtype=np.float32)
    update_force(fxs, fys, xs, ys, edges, masses)
    dtdt = dt * dt
    
    # Mass = 1
    axs = fxs #/masses
    ays = fys #/masses
    
    # Verlet integration
    new_xs = (2-DAMPENING) * xs - (1 - DAMPENING) * past_xpos + axs * dtdt
    new_ys = (2-DAMPENING) * ys - (1 - DAMPENING) * past_ypos + ays * dtdt
    
    collision_avoid(new_xs, new_ys, masses)
    
    past_xpos[:] = xs
    past_ypos[:] = ys
    
    xs[:] = new_xs
    ys[:] = new_ys

    
renderer = fig.select(dict(name="nodes", type=GlyphRenderer))
ds_cir = renderer[0].data_source

renderer = fig.select(dict(name="edges", type=GlyphRenderer))
ds_lines = renderer[0].data_source

def update():
    global xpos, ypos
    max_width = np.max(masses)
    spring_fit_once(xpos, ypos, edges, masses, dt=1/50)
    ds_cir.data['x'] = xpos
    ds_cir.data['y'] = ypos
    ds_cir.push_notebook()
    
    xlines = xpos[arr_edges].tolist()
    ylines = ypos[arr_edges].tolist()
    ds_lines.data['xs'] = xlines
    ds_lines.data['ys'] = ylines
    ds_lines.push_notebook()
    

Insert Plot

In [ ]:
bp.show(fig)

Animate

Run update for a fixed iteration count

In [ ]:
for i in range(2000):
    update()
In [ ]: