#!/usr/bin/env python # coding: utf-8 # In[1]: get_ipython().run_line_magic('pylab', 'inline') # # Test Cases for LSTM Training # This worksheet contains code that generates a variety of LSTM test cases. The output files are suitable for use with `clstmseq`. # In[2]: from pylab import * from scipy.ndimage import filters default_ninput = 2 default_n = 29 # Here is a simple utility class to write out sequence data to an HDF5 file quickly. # In[3]: import h5py import numpy as np class H5SeqData: def __init__(self,fname,N=None): self.fname = fname h5 = h5py.File("rnntest-"+fname+".h5","w") self.h5 = h5 dt = h5py.special_dtype(vlen=np.dtype('float32')) it = np.dtype('int32') self.inputs = h5.create_dataset("inputs",(1,), maxshape=(None,),compression="gzip",dtype=dt) self.inputs_dims = h5.create_dataset("inputs_dims",(1,2), maxshape=(None,2), dtype=it) self.outputs = h5.create_dataset("outputs",(1,),maxshape=(None,),compression="gzip",dtype=dt) self.outputs_dims = h5.create_dataset("outputs_dims",(1,2), maxshape=(None,2), dtype=it) self.fill = 0 if N is not None: self.resize(N) def close(self): self.h5.close() self.h5 = None def __enter__(self): print "writing",self.fname return self def __exit__(self, type, value, traceback): self.close() print "done writing",self.fname def resize(self,n): self.inputs.resize((n,)) self.inputs_dims.resize((n,2)) self.outputs.resize((n,)) self.outputs_dims.resize((n,2)) def add(self,inputs,outputs): self.inputs[self.fill] = inputs.ravel() self.inputs_dims[self.fill] = array(inputs.shape,'i') self.outputs[self.fill] = outputs.ravel() self.outputs_dims[self.fill] = array(outputs.shape,'i') self.fill += 1 N = 50000 # In[4]: def genfile(fname,f): with H5SeqData(fname,N) as db: for i in range(N): xs,ys = f() db.add(xs,ys) # In[5]: def plotseq(fname,index=17): h5 = h5py.File(fname,"r") try: inputs = h5["inputs"][index].reshape(*h5["inputs_dims"][index]) outputs = h5["outputs"][index].reshape(*h5["outputs_dims"][index]) plot(inputs[:,0],'r-',linewidth=5,alpha=0.5) if inputs.shape[1]>1: plot(inputs[:,1:],'r-',linewidth=1,alpha=0.3) plot(outputs,'b--') finally: h5.close() # In[6]: def generate_threshold(n=default_n,ninput=default_ninput,threshold=0.5,example=0): "No temporal dependencies, just threshold of the sum of the inputs." x = rand(n,ninput) y = 1.0*(sum(x,axis=1)>threshold*ninput).reshape(n,1) return x,y genfile("threshold", generate_threshold) # In[7]: plotseq("rnntest-threshold.h5") # In[8]: def generate_mod(n=default_n,ninput=default_ninput,m=3,example=0): "Generate a regular beat every m steps. The input is random." x = rand(n,ninput) y = 1.0*(arange(n,dtype='i')%m==0).reshape(n,1) return x,y genfile("mod3", generate_mod) # In[9]: plotseq("rnntest-mod3.h5") # In[10]: def generate_dmod(n=default_n,ninput=default_ninput,m=3,example=0): """Generate a regular beat every m steps, the input is random except for the first dimension, which contains a downbeat at the very beginning.""" x = rand(n,ninput) y = 1.0*(arange(n,dtype='i')%m==0).reshape(n,1) x[:,0] = 0 x[0,0] = 1 return x,y genfile("dmod3", generate_dmod) genfile("dmod4", lambda:generate_dmod(m=4)) genfile("dmod5", lambda:generate_dmod(m=5)) genfile("dmod6", lambda:generate_dmod(m=6)) # In[11]: plotseq("rnntest-dmod3.h5") # In[12]: def generate_imod(n=default_n,ninput=default_ninput,m=3,p=0.2,example=0): """Generate an output for every m input pulses.""" if example: x = array(arange(n)%4==1,'i') else: x = array(rand(n)roll(x,-1))*(x>roll(x,1)) y = (add.accumulate(x)%m==1)*x*1.0 x = array(vstack([x]*ninput).T,'f') y = y.reshape(len(y),1) return x,y genfile("smod3", generate_smod) genfile("smod4", lambda:generate_smod(m=4)) genfile("smod5", lambda:generate_smod(m=5)) # In[15]: plotseq("rnntest-smod3.h5") # In[16]: def generate_anbn(ninput=default_ninput,n=default_n,k=default_n//3,example=0): """A simple detector for a^nb^n. Note that this does not train the network to distinguish this langugage from other languages.""" inputs = zeros(n) outputs = zeros(n) if example: l = n//3 else: l = 1+int((k-1)*rand()) inputs[:l] = 1 outputs[2*l] = 1 outputs = outputs.reshape(len(outputs),1) return vstack([inputs]*ninput).T,outputs genfile("anbn", generate_anbn) # In[17]: plotseq("rnntest-anbn.h5") # In[18]: def generate_timing(ninput=default_ninput,n=default_n,t=5,example=0): """A simple timing related task: output a spike if no spike occurred within t time steps before.""" x = 0 inputs = [] while xt: outputs.append(inputs[i]) inputs = inputs[1:] xs = zeros((n,ninput)) xs[inputs,:] = 1.0 ys = zeros((n,1)) ys[outputs,:] = 1.0 return xs,ys genfile("timing", generate_timing) # In[19]: def generate_revtiming(ninput=default_ninput,n=default_n,t=5,example=0): """A simple timing related task: output a spike if no spike occurs within t time steps after. This cannot be learned using a causal model (it requires a reverse model).""" x = 0 inputs = [] while xt: outputs.append(inputs[i]) inputs = inputs[:-1] xs = zeros((n,ninput)) xs[inputs,:] = 1.0 ys = zeros((n,1)) ys[outputs,:] = 1.0 return xs,ys genfile("revtiming", generate_revtiming) # In[20]: def generate_biditiming(ninput=default_ninput,n=default_n,t=5,example=0): x = 0 inputs = [] while x=t and inputs[i]-inputs[i-1]>=t: outputs.append(inputs[i]) inputs = inputs[1:-1] xs = zeros((n,ninput)) xs[inputs,:] = 1.0 ys = zeros((n,1)) ys[outputs,:] = 1.0 return xs,ys genfile("biditiming", generate_biditiming) # In[21]: def detect_12(x): n = len(x) y = zeros(n) state = 0 for i in range(n): s = tuple(1*(x[i]>0.5)) if s==(0,0): pass elif s==(1,0): state = 1 elif s==(0,1) and state==1: y[i] = 1 state = 0 else: state = 0 return y # In[22]: def generate_detect(n=default_n,ninput=default_ninput,m=3,r=0.5,example=0): """Generates a random sequence of bits and outputs a "1" whenever there is a sequence of inputs 01-00*-10""" x = rand(n,2) x = filters.gaussian_filter(x,(r,0)) x = 1.0*(x>roll(x,-1,0))*(x>roll(x,1,0)) y = detect_12(x) return x,y.reshape(len(y),1) genfile("detect", generate_detect) # In[23]: def generate_revdetect(n=default_n,ninput=default_ninput,m=3,r=0.5,example=0): """Reverse of generate_detect.""" xs,ys = generate_detect(n=n,ninput=ninput,m=m,r=r,example=example) return array(xs)[::-1],array(ys)[::-1] genfile("revdetect", generate_revdetect) # In[24]: def generate_bididetect(n=default_n,ninput=default_ninput,m=3,r=0.5,example=0): """Generate a particular pattern whenever there is some input trigger.""" xs,ys = generate_detect(n=n,ninput=ninput,m=m,r=r,example=example) rys = detect_12(xs[::-1])[::-1].reshape(len(ys),1) return array(xs),array(ys*rys) genfile("bididetect", generate_bididetect) # In[25]: def generate_predict_and_sync(): """Similar to smod, but the correct output is provided one step after the required prediction for resynchronization.""" pass # In[26]: def generate_distracted_recall(): """Distracted sequence recall example.""" pass # In[27]: def generate_morse(): """Morse code encoding/decoding.""" pass # In[28]: def genseq_timing1(n=30,threshold=0.2,m=4,example=0): """Returns an output for every input within m time steps. A 1 -> N -> 1 problem.""" x = (rand(n)threshold)).reshape(len(x),1) x[:,c] *= scale return x,y # In[30]: def genseq_delay(n=30,threshold=0.2,d=1): """Returns an output for every input within m time steps. A 1 -> N -> 1 problem.""" x = array(rand(n)0: y[:d] = 0 elif d<0: y[d:] = 0 return x.reshape(n,1),y.reshape(n,1) genfile("delay1", genseq_delay) genfile("delay2", lambda:genseq_delay(d=2)) genfile("delay3", lambda:genseq_delay(d=3)) genfile("rdelay1", lambda:genseq_delay(d=-1)) genfile("rdelay2", lambda:genseq_delay(d=-2)) genfile("rdelay3", lambda:genseq_delay(d=-3)) # In[31]: plotseq("rnntest-delay2.h5") # # Test Run with `clstmseq` # Here is a simple example of sequence training with `clstmseq`. It takes one of the HDF5 files we generated above as an example. By default, it uses every tenth training sample as part of a test set. The `TESTERR` it reports is MSE error and binary error rate (assuming a threshold of 0.5). # In[34]: get_ipython().system('lrate=1e-3 report_every=5000 ntrain=20000 test_every=10000 ../clstmseq rnntest-delay1.h5') # In[ ]: