I've been doing a lot of work in Keras recently, and as with all things deep learning it's helpful to track the loss and accuracy metrics of your algorithm during training. If you're using Tensorflow you can simply use Tensorboard to graph everything. Unfortunately, I was using the Theano backend so I had to get creative.
This then serves as a small guide to how I achieved what I needed for my purposes. In a nutshell, I used the callbacks functions in Keras to initialize a streaming connection plot.ly, and then at the end of each epoch I updated the graph. I also wrote a little heartbeat thread to keep the connection open during training.
Let's get into the code...
First, we grab the stream tokens from the credentials file. In order to use stream tokens, you need to set them up in your plot.ly account. Go to https://plot.ly/settings/api and hit "Add new token". I added four, which was sufficient for my needs. To use them, they need to be stored in your ~/.plotly/.credentials
file alongside your username and API key. The plot.ly getting started guide has a full explanation if you've never done this before.
Next, I generate stream IDs. These take two forms. Here, I'm creating simply dictionaries which have two keys, token
and maxpoints
. Alternately, you could use the plotly.graph_objs method as below:
stream_id1 = go.Stream(token=stream_tokens[-1],maxpoints=600)
import plotly.tools as tls
stream_tokens = tls.get_credentials_file()['stream_ids']
stream_id1 = dict(token=stream_tokens[-1], maxpoints=600)
stream_id2 = dict(token=stream_tokens[-2], maxpoints=600)
stream_id3 = dict(token=stream_tokens[-3], maxpoints=600)
stream_id4 = dict(token=stream_tokens[-4], maxpoints=600)
Now I define the traces we plan on plotting. In my case, I've got four traces for the loss and the jaccard index, for both the training data and the validation data. Each trace requires a stream ID. Note that the x and y values for the traces are just empty lists at this point, since I'm not ready to plot anything yet.
import plotly.graph_objs as go
loss = go.Scatter(x=[], y=[], name='Loss', stream=stream_id1, mode='lines+markers',yaxis='y', line=dict(shape='spline',smoothing=1))
jaccard = go.Scatter(x=[], y=[], name='Jaccard', stream=stream_id2, mode='lines+markers',yaxis='y2', line=dict(shape='spline',smoothing=1))
val_loss = go.Scatter(x=[], y=[], name='Val Loss', stream=stream_id3, mode='lines+markers',yaxis='y', line=dict(shape='spline',smoothing=1))
val_jaccard = go.Scatter(x=[], y=[], name='Val Jaccard', stream=stream_id4, mode='lines+markers',yaxis='y2', line=dict(shape='spline',smoothing=1))
Once the traces are set up, I can create the data and layout objects and finally, the figure object. I also configured a few things about the plot that I needed to here. One thing that is handy is a log axis for the loss, since loss typically decreases following a log function, so this will plot a straight line and you'll be be able to tell when the decrease in the loss has truly leveled off.
data = go.Data([loss, jaccard, val_loss, val_jaccard])
layout = go.Layout(title='Loss/Jaccard for run {}'.format(run),
yaxis=dict(title='Loss',type='log',autorange=True,rangemode='tozero'),
yaxis2=dict(title='Jaccard index',overlaying='y',side='right',range=[0,1]))
plotly_fig = go.Figure(data=data, layout=layout)
Finally, we make the stream objects. These will be used to open and close connections to the stream API, as well as to write data to the graph.
import plotly.plotly as py
loss_stream = py.Stream(stream_id=stream_tokens[-1])
jaccard_stream = py.Stream(stream_id=stream_tokens[-2])
val_loss_stream = py.Stream(stream_id=stream_tokens[-3])
val_jaccard_stream = py.Stream(stream_id=stream_tokens[-4])
I was initially confused by the difference between stream IDs and stream objects, so it's helpful to think of it this way: once a streaming graph is all set up and pushed to plot.ly, it will be empty - just a set of axes and traces, ready to receive data. Defining the stream IDs and assigning them to traces helps the plot.ly server to know where to put the data when it starts arriving.
The stream objects, on the other hand, are python objects in your local environment that handle the connection to the server. These objects have open()
, close()
and write()
methods which are used to handle connections to the plot.ly server and send data.
The streaming API automatically closes connections that don't receive data for more than about 60 seconds, which can be a major hurdle in neural network training since a single epoch often takes longer than a minute (sometimes much longer if you're training on a CPU). To counter this, the stream objects also have a heartbeat()
method which keeps the connection alive.
I used Python's threading package to start a new thread with a while loop that sends a heartbeat to each stream every five seconds, ensuring that the connections don't close between epochs. The thread runs as a daemon, allowing the rest of the processes to continue while this happens in the background. The threading event object begins as not "set", and can be changed to "set" using it's set()
method. I simply call this at the end of training and the while loop terminates, ending the thread.
If anyone knows of a better way to do this, please let me know!
import time
import threading
training = threading.Event()
def heartbeater(training):
while not training.isSet():
loss_stream.heartbeat()
jaccard_stream.heartbeat()
val_loss_stream.heartbeat()
val_jaccard_stream.heartbeat()
time.sleep(5)
return
t = threading.Thread(target=heartbeater,args=(training,))
t.setDaemon(True)
Now that everything is set up, I'm ready to make the callback class that Keras will use during training. Keras has a list of callbacks that you can take advantage of at the beginning and end of training, each epoch, and each batch. I'll start with on_train_begin()
and open the streams, start the heartbeater thread and initialize empty lists for the logs.
Then, using on_epoch_end()
, I update the logs with the latest metric and loss values, and use the stream objects' write()
method to send the data to plot.ly.
Finally, we use on_train_end()
to wrap things up. Here, I redefined the traces but without streams this time. Additionally, I supply the traces with full x and y values from the logs. I'm basically plotting the graph staticallly. I made new layout, data and figure objects with these new traces and then sent the plot to plot.ly using py.plot()
. Since the filename
argument is the same as it was for the streaming plot, this figure will replace the stream figure that has been developing as we train. Thus, the plot at this URL will be a streaming plot while training and then a static plot once training is over.
The reason for this last step is simple: when I start training a new network, I'll start streaming again (albeit to a new graph, with a unique filename) and the data on this graph will be lost. I'm not sure why this happens, it seems crazy to me, but it does. This last step creates a regular plot.ly graph with all the data on it for posterity purposes. Now you can tweak the learning rate or add more dropout, train again and see if the network behaves differently.
As an aside, it's possible to include the argument fileopts='extend'
in the py.plot()
call, which causes plot.ly to append any new data to an existing plot (if the filename
argument matches an existing file, otherwise a new file is created). This could be used if you decide to load weights from a previously trained model and train for more epochs. The model.fit()
method has an argument initial_epoch
which allows you to specify where to start training from. Be careful here with the x values of your graph - if they get out of sync between runs the resulting plot have odd overlaps.
The last thing we do is terminate the heartbeater thread by calling the threading event object's set()
method and close all of the streams.
import keras
class streamer(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
# Open streams
loss_stream.open()
jaccard_stream.open()
val_loss_stream.open()
val_jaccard_stream.open()
# Start heartbeater process
t.start()
# Initialize logs
self.x = []
self.losses = []
self.val_losses = []
self.jaccards = []
self.val_jaccards = []
def on_epoch_end(self, epoch, logs={}):
self.x.append(epoch) # potentially switch to self.epoch?
self.losses.append(logs.get('loss'))
self.val_losses.append(logs.get('val_loss'))
self.jaccards.append(logs.get('jaccard'))
self.val_jaccards.append(logs.get('val_jaccard'))
loss_stream.write(dict(x=self.x, y=self.losses))
jaccard_stream.write(dict(x=self.x, y=self.jaccards))
val_loss_stream.write(dict(x=self.x, y=self.val_losses))
val_jaccard_stream.write(dict(x=self.x, y=self.val_jaccards))
def on_train_end(self,logs={}):
# Define traces
loss = go.Scatter(x=self.x, y=self.losses, name='Loss', mode='lines+markers',yaxis='y', line=dict(shape='spline',smoothing=1))
jaccard = go.Scatter(x=self.x, y=self.jaccards, name='Jaccard', mode='lines+markers',yaxis='y2', line=dict(shape='spline',smoothing=1))
val_loss = go.Scatter(x=self.x, y=self.val_losses, name='Val Loss', mode='lines+markers',yaxis='y', line=dict(shape='spline',smoothing=1))
val_jaccard = go.Scatter(x=self.x, y=self.val_jaccards, name='Val Jaccard', mode='lines+markers',yaxis='y2', line=dict(shape='spline',smoothing=1))
# Define data and layout
data = go.Data([loss, jaccard, val_loss, val_jaccard])
layout = go.Layout(title='Loss/Jaccard for run {}'.format(run),
yaxis=dict(title='Loss',type='log',autorange=True,rangemode='tozero'),
yaxis2=dict(title='Jaccard index',overlaying='y',side='right',range=[0,1]))
# Generate plot
plotly_fig = go.Figure(data=data, layout=layout)
url = py.plot(plotly_fig, filename='training_run_{}'.format(run))
# Kill the heartbeater
training.set()
# Close the streams
loss_stream.close()
jaccard_stream.close()
val_loss_stream.close()
val_jaccard_stream.close()
streamer = streamer()
Now I'm ready to train the network. Before calling the model's fit()
method, send the plot to plot.ly using the py.plot()
method. I use the line below:
url = py.plot(plotly_fig, filename='training_run_{}'.format(run))
This saves the plot URL to a variable (I use pushbullet to send this to my phone so I have the link on hand, then I go for a beer and let the network train itself). I also have a variable called run
, which increments everytime I train the model so that each iteration is stored seperately. This is useful for saving logs, weights, models, etc. with unique file names that can be tracked.
Next, add the streamer to your list of callbacks in the model's fit()
method and start training! You can use tls.embed(url)
to watch the graph. Just note that you'll have to do this in a different notebook to the one you're training in. You can find an example of mine below:
tls.embed('https://plot.ly/~grant2d2/24')