#!/usr/bin/env python # coding: utf-8 # # Streaming results from keras to plot.ly # # I've been doing a lot of work in Keras recently, and as with all things deep learning it's helpful to track the loss and accuracy metrics of your algorithm *during* training. If you're using Tensorflow you can simply use [Tensorboard](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/README.md) to graph everything. Unfortunately, I was using the Theano backend so I had to get creative. # # This then serves as a small guide to how I achieved what I needed for my purposes. In a nutshell, I used the callbacks functions in Keras to initialize a streaming connection [plot.ly](http://plot.ly), and then at the end of each epoch I updated the graph. I also wrote a little heartbeat thread to keep the connection open during training. # # Let's get into the code... # First, we grab the stream tokens from the credentials file. In order to use stream tokens, you need to set them up in your plot.ly account. Go to https://plot.ly/settings/api and hit "Add new token". I added four, which was sufficient for my needs. To use them, they need to be stored in your `~/.plotly/.credentials` file alongside your username and API key. The [plot.ly getting started guide](https://plot.ly/python/getting-started/) has a full explanation if you've never done this before. # # Next, I generate stream IDs. These take two forms. Here, I'm creating simply dictionaries which have two keys, `token` and `maxpoints`. Alternately, you could use the plotly.graph_objs method as below: # # `stream_id1 = go.Stream(token=stream_tokens[-1],maxpoints=600)` # In[1]: import plotly.tools as tls stream_tokens = tls.get_credentials_file()['stream_ids'] stream_id1 = dict(token=stream_tokens[-1], maxpoints=600) stream_id2 = dict(token=stream_tokens[-2], maxpoints=600) stream_id3 = dict(token=stream_tokens[-3], maxpoints=600) stream_id4 = dict(token=stream_tokens[-4], maxpoints=600) # Now I define the traces we plan on plotting. In my case, I've got four traces for the loss and the [jaccard index](https://en.wikipedia.org/wiki/Jaccard_index), for both the training data and the validation data. Each trace requires a stream ID. Note that the x and y values for the traces are just empty lists at this point, since I'm not ready to plot anything yet. # In[ ]: import plotly.graph_objs as go loss = go.Scatter(x=[], y=[], name='Loss', stream=stream_id1, mode='lines+markers',yaxis='y', line=dict(shape='spline',smoothing=1)) jaccard = go.Scatter(x=[], y=[], name='Jaccard', stream=stream_id2, mode='lines+markers',yaxis='y2', line=dict(shape='spline',smoothing=1)) val_loss = go.Scatter(x=[], y=[], name='Val Loss', stream=stream_id3, mode='lines+markers',yaxis='y', line=dict(shape='spline',smoothing=1)) val_jaccard = go.Scatter(x=[], y=[], name='Val Jaccard', stream=stream_id4, mode='lines+markers',yaxis='y2', line=dict(shape='spline',smoothing=1)) # Once the traces are set up, I can create the data and layout objects and finally, the figure object. I also configured a few things about the plot that I needed to here. One thing that is handy is a log axis for the loss, since loss typically decreases following a log function, so this will plot a straight line and you'll be be able to tell when the decrease in the loss has truly leveled off. # In[ ]: data = go.Data([loss, jaccard, val_loss, val_jaccard]) layout = go.Layout(title='Loss/Jaccard for run {}'.format(run), yaxis=dict(title='Loss',type='log',autorange=True,rangemode='tozero'), yaxis2=dict(title='Jaccard index',overlaying='y',side='right',range=[0,1])) plotly_fig = go.Figure(data=data, layout=layout) # Finally, we make the stream objects. These will be used to open and close connections to the stream API, as well as to write data to the graph. # In[ ]: import plotly.plotly as py loss_stream = py.Stream(stream_id=stream_tokens[-1]) jaccard_stream = py.Stream(stream_id=stream_tokens[-2]) val_loss_stream = py.Stream(stream_id=stream_tokens[-3]) val_jaccard_stream = py.Stream(stream_id=stream_tokens[-4]) # I was initially confused by the difference between stream IDs and stream objects, so it's helpful to think of it this way: once a streaming graph is all set up and pushed to plot.ly, it will be empty - just a set of axes and traces, ready to receive data. Defining the stream IDs and assigning them to traces helps the plot.ly server to know where to put the data when it starts arriving. # # The stream objects, on the other hand, are python objects in your local environment that handle the connection to the server. These objects have `open()`, `close()` and `write()` methods which are used to handle connections to the plot.ly server and send data. # #### A note about keeping the connections open # # The streaming API automatically closes connections that don't receive data for more than about 60 seconds, which can be a major hurdle in neural network training since a single epoch often takes longer than a minute (sometimes *much* longer if you're training on a CPU). To counter this, the stream objects also have a `heartbeat()` method which keeps the connection alive. # # I used Python's threading package to start a new thread with a while loop that sends a heartbeat to each stream every five seconds, ensuring that the connections don't close between epochs. The thread runs as a daemon, allowing the rest of the processes to continue while this happens in the background. The threading event object begins as *not* "set", and can be changed to "set" using it's `set()` method. I simply call this at the end of training and the while loop terminates, ending the thread. # # If anyone knows of a better way to do this, please let me know! # In[ ]: import time import threading training = threading.Event() def heartbeater(training): while not training.isSet(): loss_stream.heartbeat() jaccard_stream.heartbeat() val_loss_stream.heartbeat() val_jaccard_stream.heartbeat() time.sleep(5) return t = threading.Thread(target=heartbeater,args=(training,)) t.setDaemon(True) # Now that everything is set up, I'm ready to make the callback class that Keras will use during training. Keras has a list of callbacks that you can take advantage of at the beginning and end of training, each epoch, and each batch. I'll start with `on_train_begin()` and open the streams, start the heartbeater thread and initialize empty lists for the logs. # # Then, using `on_epoch_end()`, I update the logs with the latest metric and loss values, and use the stream objects' `write()` method to send the data to plot.ly. # # Finally, we use `on_train_end()` to wrap things up. Here, I redefined the traces but without streams this time. Additionally, I supply the traces with full x and y values from the logs. I'm basically plotting the graph staticallly. I made new layout, data and figure objects with these new traces and then sent the plot to plot.ly using `py.plot()`. Since the `filename` argument is the same as it was for the streaming plot, this figure will replace the stream figure that has been developing as we train. Thus, the plot at this URL will be a streaming plot *while* training and then a static plot once training is over. # # The reason for this last step is simple: when I start training a new network, I'll start streaming again (albeit to a new graph, with a unique filename) and the data on this graph will be lost. I'm not sure why this happens, it seems crazy to me, but it does. This last step creates a regular plot.ly graph with all the data on it for posterity purposes. Now you can tweak the learning rate or add more dropout, train again and see if the network behaves differently. # # As an aside, it's possible to include the argument `fileopts='extend'` in the `py.plot()` call, which causes plot.ly to append any new data to an existing plot (if the `filename` argument matches an existing file, otherwise a new file is created). This could be used if you decide to load weights from a previously trained model and train for more epochs. The `model.fit()` method has an argument `initial_epoch` which allows you to specify where to start training from. Be careful here with the x values of your graph - if they get out of sync between runs the resulting plot have odd overlaps. # # The last thing we do is terminate the heartbeater thread by calling the threading event object's `set()` method and close all of the streams. # In[ ]: import keras class streamer(keras.callbacks.Callback): def on_train_begin(self, logs={}): # Open streams loss_stream.open() jaccard_stream.open() val_loss_stream.open() val_jaccard_stream.open() # Start heartbeater process t.start() # Initialize logs self.x = [] self.losses = [] self.val_losses = [] self.jaccards = [] self.val_jaccards = [] def on_epoch_end(self, epoch, logs={}): self.x.append(epoch) # potentially switch to self.epoch? self.losses.append(logs.get('loss')) self.val_losses.append(logs.get('val_loss')) self.jaccards.append(logs.get('jaccard')) self.val_jaccards.append(logs.get('val_jaccard')) loss_stream.write(dict(x=self.x, y=self.losses)) jaccard_stream.write(dict(x=self.x, y=self.jaccards)) val_loss_stream.write(dict(x=self.x, y=self.val_losses)) val_jaccard_stream.write(dict(x=self.x, y=self.val_jaccards)) def on_train_end(self,logs={}): # Define traces loss = go.Scatter(x=self.x, y=self.losses, name='Loss', mode='lines+markers',yaxis='y', line=dict(shape='spline',smoothing=1)) jaccard = go.Scatter(x=self.x, y=self.jaccards, name='Jaccard', mode='lines+markers',yaxis='y2', line=dict(shape='spline',smoothing=1)) val_loss = go.Scatter(x=self.x, y=self.val_losses, name='Val Loss', mode='lines+markers',yaxis='y', line=dict(shape='spline',smoothing=1)) val_jaccard = go.Scatter(x=self.x, y=self.val_jaccards, name='Val Jaccard', mode='lines+markers',yaxis='y2', line=dict(shape='spline',smoothing=1)) # Define data and layout data = go.Data([loss, jaccard, val_loss, val_jaccard]) layout = go.Layout(title='Loss/Jaccard for run {}'.format(run), yaxis=dict(title='Loss',type='log',autorange=True,rangemode='tozero'), yaxis2=dict(title='Jaccard index',overlaying='y',side='right',range=[0,1])) # Generate plot plotly_fig = go.Figure(data=data, layout=layout) url = py.plot(plotly_fig, filename='training_run_{}'.format(run)) # Kill the heartbeater training.set() # Close the streams loss_stream.close() jaccard_stream.close() val_loss_stream.close() val_jaccard_stream.close() streamer = streamer() # Now I'm ready to train the network. Before calling the model's `fit()` method, send the plot to plot.ly using the `py.plot()` method. I use the line below: # # `url = py.plot(plotly_fig, filename='training_run_{}'.format(run))` # # This saves the plot URL to a variable (I use pushbullet to send this to my phone so I have the link on hand, then I go for a beer and let the network train itself). I also have a variable called `run`, which increments everytime I train the model so that each iteration is stored seperately. This is useful for saving logs, weights, models, etc. with unique file names that can be tracked. # # Next, add the streamer to your list of callbacks in the model's `fit()` method and start training! You can use `tls.embed(url)` to watch the graph. Just note that you'll have to do this in a different notebook to the one you're training in. You can find an example of mine below: # In[9]: tls.embed('https://plot.ly/~grant2d2/24') # And that's a wrap. Happy training. You can see the code that I wrote this for [here](https://github.com/grantbey/deep_microscopy). Get at me on twitter ([@grantbey](https://twitter.com/grantbey)) if you have any questions!