Custom plotting library fully in Python!

In [ ]:
from math import pi

import numpy as np

import branca

from ipywidgets import VBox, IntSlider

from ipycanvas import Canvas, MultiCanvas, hold_canvas
In [ ]:
class Plot(MultiCanvas):
    def __init__(self, x, y, color=None, scheme=branca.colormap.linear.RdBu_11):
        super(Plot, self).__init__(3, width=800, height=600, sync_image_data=True)

        self.color = color
        self.scheme = scheme
        
        self.background_color = '#f7f7f7'

        self.init_plot(x, y)

    def init_plot(self, x, y, color=None, scheme=None):
        self.x = x
        self.y = y
        self.color = color if color is not None else self.color
        self.scheme = scheme if scheme is not None else self.scheme

        padding = 0.1
        padding_x = padding * self.size[0]
        padding_y = padding * self.size[1]

        # TODO Fix drawarea max: It should be (canvas.size - padding)
        self.drawarea = (drawarea_min_x, drawarea_min_y, drawarea_max_x, drawarea_max_y) = (padding_x, padding_y, self.size[0] - 2 * padding_x, self.size[1] - 2 * padding_y)

        min_x, min_y, max_x, max_y = np.min(x), np.min(y), np.max(x), np.max(y)

        dx = max_x - min_x
        dy = max_y - min_y

        # Turns a data coordinate into pixel coordinate
        self.scale_x = lambda x: drawarea_max_x * (x - min_x) / dx + drawarea_min_x
        self.scale_y = lambda y: drawarea_max_y * (1 - (y - min_y) / dy) + drawarea_min_y

        # Turns a pixel coordinate into data coordinate
        self.unscale_x = lambda sx: (sx - drawarea_min_x) * dx / drawarea_max_x + min_x
        self.unscale_y = lambda sy: (1 - ((sy - drawarea_min_y) / drawarea_max_y)) * dy + min_y

        self.colormap = None
        if self.color is not None:
            self.colormap = self.scheme.scale(np.min(self.color), np.max(self.color))

    def draw_background(self):
        drawarea_min_x, drawarea_min_y, drawarea_max_x, drawarea_max_y = self.drawarea

        background = self[0]

        # Draw background
        background.fill_style = self.background_color
        background.global_alpha = 0.3
        background.fill_rect(drawarea_min_x, drawarea_min_y, drawarea_max_x, drawarea_max_y)
        background.global_alpha = 1

        # Draw grid and ticks
        n_lines = 10
        background.fill_style = 'black'
        background.stroke_style = '#8c8c8c'
        background.line_width = 1

        for i in range(n_lines):
            j = i / (n_lines - 1)
            line_x = drawarea_max_x * j + drawarea_min_x
            line_y = drawarea_max_y * j + drawarea_min_y

            # Line on the y axis
            background.stroke_line(line_x, drawarea_min_y, line_x, drawarea_max_y + drawarea_min_y)

            # Line on the x axis
            background.stroke_line(drawarea_min_x, line_y, drawarea_max_x + drawarea_min_x, line_y)

            # Draw y tick
            background.text_align = 'right'
            background.text_baseline = 'middle'
            background.fill_text('{0:.2e}'.format(self.unscale_y(line_y)), drawarea_min_x * 0.95, line_y)

            # Draw x tick
            background.text_align = 'center'
            background.text_baseline = 'top'
            background.fill_text('{0:.2e}'.format(self.unscale_x(line_x)), line_x, drawarea_max_y + drawarea_min_y + drawarea_min_y * 0.05)
In [ ]:
class ScatterPlot(Plot):
    def __init__(self, x, y, size, color, scheme=branca.colormap.linear.RdBu_11, stroke_color='black'):
        super(ScatterPlot, self).__init__(x, y, color, scheme)

        self.dragging = False
        self.sizes = size
        self.stroke_color = stroke_color

        self.n_marks = min(x.shape[0], y.shape[0], size.shape[0], color.shape[0])

        # Index of the dragged point
        self.i_mark = -1

        self[2].on_mouse_down(self.mouse_down_handler)
        self[2].on_mouse_move(self.mouse_move_handler)
        self[2].on_mouse_up(self.mouse_up_handler)

        self.draw()

    def draw(self):
        with hold_canvas(self):
            self.clear()
            plot_layer = self[1]

            plot_layer.save()

            self.draw_background()

            # Draw scatter
            plot_layer.stroke_style = self.stroke_color

            for idx in range(self.n_marks):
                plot_layer.fill_style = self.colormap(self.color[idx])

                mark_x = self.scale_x(self.x[idx])
                mark_y = self.scale_y(self.y[idx])
                mark_size = self.sizes[idx]

                plot_layer.fill_circle(mark_x, mark_y, mark_size)
                plot_layer.stroke_circle(mark_x, mark_y, mark_size)

            plot_layer.restore()

    def mouse_down_handler(self, pixel_x, pixel_y):
        plot_layer = self[1]

        for idx in range(self.n_marks):
            mark_x = self.x[idx]
            mark_y = self.y[idx]
            mark_size = self.sizes[idx]

            if (pixel_x > self.scale_x(mark_x) - mark_size and pixel_x < self.scale_x(mark_x) + mark_size and
                pixel_y > self.scale_y(mark_y) - mark_size and pixel_y < self.scale_y(mark_y) + mark_size):
                self.i_mark = idx
                self.dragging = True

                with hold_canvas(plot_layer):
                    plot_layer.fill_style = self.background_color
                    plot_layer.stroke_style = self.colormap(self.color[self.i_mark])

                    plot_layer.fill_circle(self.scale_x(mark_x), self.scale_y(mark_y), mark_size)
                    plot_layer.stroke_circle(self.scale_x(mark_x), self.scale_y(mark_y), mark_size)
                break

    def mouse_move_handler(self, pixel_x, pixel_y):
        if self.dragging and self.i_mark != -1:
            interaction_layer = self[2]

            unscaled_x = self.unscale_x(pixel_x)
            unscaled_y = self.unscale_y(pixel_y)

            with hold_canvas(interaction_layer):
                interaction_layer.clear()
                interaction_layer.fill_style = self.colormap(self.color[self.i_mark])
                interaction_layer.stroke_style = self.stroke_color

                self.x[self.i_mark] = unscaled_x
                self.y[self.i_mark] = unscaled_y

                interaction_layer.fill_circle(pixel_x, pixel_y, self.sizes[self.i_mark])
                interaction_layer.stroke_circle(pixel_x, pixel_y, self.sizes[self.i_mark])

    def mouse_up_handler(self, pixel_x, pixel_y):
        self.dragging = False

        self.draw()

        interaction_layer = self[2]
        interaction_layer.clear()
In [ ]:
class LinePlot(Plot):
    def __init__(self, x, y, line_color='#749cb8', line_width=2):
        super(LinePlot, self).__init__(x, y)

        self.line_color = line_color
        self.line_width = line_width

        self.draw()

    def update(self, x, y, line_color=None, line_width=None):
        self.init_plot(x, y)

        self.line_color = line_color if line_color is not None else self.line_color
        self.line_width = line_width if line_width is not None else self.line_width

        self.draw()

    def draw(self):
        with hold_canvas(self):
            self.clear()
            plot_layer = self[1]
            plot_layer.save()

            self.draw_background()

            # Draw lines
            n_points = min(self.x.shape[0], self.y.shape[0])

            plot_layer.stroke_style = self.line_color
            plot_layer.line_width = self.line_width
            plot_layer.line_join = 'bevel'
            plot_layer.line_cap = 'round'
            
            plot_layer.stroke_lines(np.stack((self.scale_x(self.x), self.scale_y(self.y)), axis=1))

            plot_layer.restore()
In [ ]:
class HeatmapPlot(Plot):
    def __init__(self, x, y, color, scheme=branca.colormap.linear.RdBu_11):
        super(HeatmapPlot, self).__init__(x, y, color, scheme)
        
        self.draw()

    def draw(self):
        outof_x_bound = lambda idx: True if idx >= x.shape[0] or idx < 0 else False
        outof_y_bound = lambda idx: True if idx >= y.shape[0] or idx < 0 else False

        with hold_canvas(self):
            self.clear()
            plot_layer = self[1]
            plot_layer.save()

            self.draw_background()

            # Draw heatmap
            n_marks = min(self.x.shape[0], self.y.shape[0])

            for x_idx in range(1, self.color.shape[0] - 1):
                for y_idx in range(1, self.color.shape[1] - 1):
                    plot_layer.fill_style = self.colormap(self.color[x_idx][y_idx])

                    rect_center = (self.scale_x(self.x[x_idx]), self.scale_y(self.y[y_idx]))
                    neighbours_x = (self.scale_x(self.x[x_idx - 1]), self.scale_x(self.x[x_idx + 1]))
                    neighbours_y = (self.scale_y(self.y[y_idx - 1]), self.scale_y(self.y[y_idx + 1]))

                    rect_top_left_corner = ((neighbours_x[0] + rect_center[0]) / 2, (neighbours_y[0] + rect_center[1]) / 2)
                    rect_low_right_corner = ((neighbours_x[1] + rect_center[0]) / 2, (neighbours_y[1] + rect_center[1]) / 2)

                    width = rect_low_right_corner[0] - rect_top_left_corner[0] + 0.5
                    height = rect_low_right_corner[1] - rect_top_left_corner[1] - 0.5

                    plot_layer.fill_rect(
                        rect_top_left_corner[0], rect_top_left_corner[1],
                        width, height
                    )

            plot_layer.restore()

Scatter plot

In [ ]:
n_points = 1_000

Scatter marks are draggable! Move the mouse while clicking on them...

In [ ]:
x = np.random.rand(n_points)
y = np.random.rand(n_points)
sizes = np.random.randint(2, 8, n_points)
colors = np.random.rand(n_points) * 10 - 2

plot = ScatterPlot(x, y, sizes, colors, branca.colormap.linear.viridis, stroke_color='white')
plot

You can retrieve the entire Canvas or a subpart of it using the get_image_data method

In [ ]:
arr = plot.get_image_data(200, 300, 50, 100)
arr.shape
In [ ]:
plot[1].stroke_style = 'red'
plot[1].line_width = 2
plot[1].stroke_rect(200, 300, 50, 100)

c = Canvas(width=50, height=100)
c.put_image_data(arr, 0, 0)
c

Or you can save it to a file using to_file

In [ ]:
plot.to_file('my_scatter.png')
In [ ]:
from ipywidgets import Image

Image.from_file('my_scatter.png')

Line plot

In [ ]:
x = np.linspace(0, 20, 500)
y = np.sin(x)

LinePlot(x, y, line_width=3)
In [ ]:
slider = IntSlider(description='Pow:', min=1, max=10, step=1)

x = np.linspace(-20, 20, 500)
y = np.power(x, slider.value)

power_plot = LinePlot(x, y, line_color='#32a852', line_width=3)

def on_slider_change(change):
    y = np.power(x, slider.value)

    power_plot.update(x, y)

slider.observe(on_slider_change, 'value')

VBox((power_plot, slider))
In [ ]:
n = 2_000
x = np.linspace(0, 100, n)
y = np.cumsum(np.random.randn(n))

LinePlot(x, y, line_width=3)

Heatmap

In [ ]:
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
x_grid, y_grid = np.meshgrid(x, y)
color = np.sin(x_grid + y_grid**2) + np.cos(x_grid**2 + y_grid**2)

HeatmapPlot(x, y, color, scheme=branca.colormap.linear.RdYlBu_05)