#!/usr/bin/env python # coding: utf-8 # In[1]: get_ipython().run_line_magic('pylab', 'inline') # # Test Cases for LSTM Training # This worksheet contains code that generates a variety of LSTM test cases. The output files are suitable for use with `clstmseq`. # In[2]: from pylab import * from scipy.ndimage import filters default_ninput = 2 default_n = 29 # Here is a simple utility class to write out sequence data to an HDF5 file quickly. # In[3]: import h5py import numpy as np class H5SeqData: def __init__(self,fname,N=None): self.fname = fname h5 = h5py.File("rnntest-"+fname+".h5","w") self.h5 = h5 dt = h5py.special_dtype(vlen=np.dtype('float32')) it = np.dtype('int32') self.inputs = h5.create_dataset("inputs",(1,), maxshape=(None,),compression="gzip",dtype=dt) self.inputs_dims = h5.create_dataset("inputs_dims",(1,2), maxshape=(None,2), dtype=it) self.outputs = h5.create_dataset("outputs",(1,),maxshape=(None,),compression="gzip",dtype=dt) self.outputs_dims = h5.create_dataset("outputs_dims",(1,2), maxshape=(None,2), dtype=it) self.fill = 0 if N is not None: self.resize(N) def close(self): self.h5.close() self.h5 = None def __enter__(self): print "writing",self.fname return self def __exit__(self, type, value, traceback): self.close() print "done writing",self.fname def resize(self,n): self.inputs.resize((n,)) self.inputs_dims.resize((n,2)) self.outputs.resize((n,)) self.outputs_dims.resize((n,2)) def add(self,inputs,outputs): self.inputs[self.fill] = inputs.ravel() self.inputs_dims[self.fill] = array(inputs.shape,'i') self.outputs[self.fill] = outputs.ravel() self.outputs_dims[self.fill] = array(outputs.shape,'i') self.fill += 1 N = 50000 # In[4]: def genfile(fname,f): with H5SeqData(fname,N) as db: for i in range(N): xs,ys = f() db.add(xs,ys) # In[5]: def plotseq(fname,index=17): h5 = h5py.File(fname,"r") try: inputs = h5["inputs"][index].reshape(*h5["inputs_dims"][index]) outputs = h5["outputs"][index].reshape(*h5["outputs_dims"][index]) plot(inputs[:,0],'r-',linewidth=5,alpha=0.5) if inputs.shape[1]>1: plot(inputs[:,1:],'r-',linewidth=1,alpha=0.3) plot(outputs,'b--') finally: h5.close() # In[6]: def generate_threshold(n=default_n,ninput=default_ninput,threshold=0.5,example=0): "No temporal dependencies, just threshold of the sum of the inputs." x = rand(n,ninput) y = 1.0*(sum(x,axis=1)>threshold*ninput).reshape(n,1) return x,y genfile("threshold", generate_threshold) # In[7]: plotseq("rnntest-threshold.h5") # In[8]: def generate_mod(n=default_n,ninput=default_ninput,m=3,example=0): "Generate a regular beat every m steps. The input is random." x = rand(n,ninput) y = 1.0*(arange(n,dtype='i')%m==0).reshape(n,1) return x,y genfile("mod3", generate_mod) # In[9]: plotseq("rnntest-mod3.h5") # In[10]: def generate_dmod(n=default_n,ninput=default_ninput,m=3,example=0): """Generate a regular beat every m steps, the input is random except for the first dimension, which contains a downbeat at the very beginning.""" x = rand(n,ninput) y = 1.0*(arange(n,dtype='i')%m==0).reshape(n,1) x[:,0] = 0 x[0,0] = 1 return x,y genfile("dmod3", generate_dmod) genfile("dmod4", lambda:generate_dmod(m=4)) genfile("dmod5", lambda:generate_dmod(m=5)) genfile("dmod6", lambda:generate_dmod(m=6)) # In[11]: plotseq("rnntest-dmod3.h5") # In[12]: def generate_imod(n=default_n,ninput=default_ninput,m=3,p=0.2,example=0): """Generate an output for every m input pulses.""" if example: x = array(arange(n)%4==1,'i') else: x = array(rand(n)
roll(x,-1))*(x>roll(x,1))
y = (add.accumulate(x)%m==1)*x*1.0
x = array(vstack([x]*ninput).T,'f')
y = y.reshape(len(y),1)
return x,y
genfile("smod3", generate_smod)
genfile("smod4", lambda:generate_smod(m=4))
genfile("smod5", lambda:generate_smod(m=5))
# In[15]:
plotseq("rnntest-smod3.h5")
# In[16]:
def generate_anbn(ninput=default_ninput,n=default_n,k=default_n//3,example=0):
"""A simple detector for a^nb^n. Note that this does not
train the network to distinguish this langugage from other languages."""
inputs = zeros(n)
outputs = zeros(n)
if example:
l = n//3
else:
l = 1+int((k-1)*rand())
inputs[:l] = 1
outputs[2*l] = 1
outputs = outputs.reshape(len(outputs),1)
return vstack([inputs]*ninput).T,outputs
genfile("anbn", generate_anbn)
# In[17]:
plotseq("rnntest-anbn.h5")
# In[18]:
def generate_timing(ninput=default_ninput,n=default_n,t=5,example=0):
"""A simple timing related task: output a spike if no spike occurred within
t time steps before."""
x = 0
inputs = []
while x