scan
函数是theano中一个十分重要的概念,利用它我们可以处理时序数据,从而完成许多复杂的计算过程。其原理有点类似一个高度封装过的for循环,每个时刻都调用相同的回调函数处理该时刻的数据,最后再将处理的结果按照时间顺序堆叠、汇总。
scan
相对于for循环的好处:
它的缺点是过于复杂,难于调试,因此也成为笔者学习theano遇到的第一道坎。
scan函数的定义:
theano.scan(fn, sequences=None, outputs_info=None, non_sequences=None, n_steps=None, truncate_gradient=-1, go_backwards=False, mode=None, name=None, profile=False, allow_gc=None, strict=False)
fn读取参数的顺序为:sequence->outputs_info->non_sequences
举个例子,假设sequences=[a, b],那么执行scan时fn会依次读取该列表中每个变量第t时刻的数据a[t],b[t]。需要注意的是传入时要通过dimshuffle把时间维放在axis0
我们以累加器为例,来研究scan函数到底是怎么工作的: $$ sum(n) = \sum_{i=0}^n i$$ 这个关系可以表示为如下的递归关系: $$sum(n)=sum(n-1)+n$$
import theano
import theano.tensor as T
import numpy as np
n = T.iscalar()
acc_out, updates = theano.scan(lambda i, acc_sum: acc_sum + i, sequences=T.arange(n+1),
outputs_info=T.constant(np.float64(0)))
accumulate_sum = theano.function([n], acc_out)
print accumulate_sum(5)
[ 0. 1. 3. 6. 10. 15.]
DEBUG: nvcc STDOUT mod.cu ���ڴ����� C:/Users/hschen/AppData/Local/Theano/compiledir_Windows-10-10.0.14393-Intel64_Family_6_Model_60_Stepping_3_GenuineIntel-2.7.12-64/tmpiviokf/265abc51f7c376c224983485238ff1a5.lib �Ͷ��� C:/Users/hschen/AppData/Local/Theano/compiledir_Windows-10-10.0.14393-Intel64_Family_6_Model_60_Stepping_3_GenuineIntel-2.7.12-64/tmpiviokf/265abc51f7c376c224983485238ff1a5.exp Using gpu device 0: GeForce GTX 960M (CNMeM is disabled, cuDNN 5103) C:\Anaconda2\lib\site-packages\theano\sandbox\cuda\__init__.py:600: UserWarning: Your cuDNN version is more recent than the one Theano officially supports. If you see any problems, try updating Theano or downgrading cuDNN to version 5. warnings.warn(warn)
为了获得累加结果,我以0作为outputs_info的initial_state的初始值。接着传给sequences一个以0起始、增量为1的整数序列,T.arange是theano版本的np.arange。随后我定义了一个匿名函数,这个函数的第一个参数i
是序列的第i个元素;第二个参数acc_sum
是上一个时刻的输出值,即$sum(i-1)$;而这个函数的返回值是$sum(i)=sum(i-1)+i$。scan函数按照时间顺序计算每个时刻的输出,并将结果按照时间顺序堆叠成一个np.ndarray数组:$[sum(0), sum(1),...,sum(n)]$
下面结合一些scan函数的examples进行讲解,以加深理解
给定一个正整数$n$,我们要通过scan函数求解如下的式子: $$\sum_{i=1}^n i^2$$
n = T.iscalar()
acc_out, updates = theano.scan(lambda i, acc_sum: acc_sum + i**2, sequences=T.arange(n+1),
outputs_info=T.constant(np.int64(0)))
acc_out = acc_out[-1]
accumulate_square_sum = theano.function([n], acc_out)
print accumulate_square_sum(5)
55
斐波那契数列的递推式为
$$ x(n)=x(n-1)+x(n-2),(n\geq 2)$$
其中$x(0)=0, x(1)=1$
其scan版本的实现如下:
n = T.iscalar()
x0 = T.ivector()
fib_out, updates = theano.scan(lambda xtm1, xtm2: xtm1 + xtm2,
outputs_info=[dict(initial=x0, taps=[-1,-2])], n_steps=n)
fib_out = fib_out
fib = theano.function([x0, n], fib_out)
print fib(np.int32([0, 1]), 10)
[ 1 2 3 5 8 13 21 34 55 89]
圆周率可以通过下面的积分式计算 $$\pi=2\int_{-1}^1 \sqrt{1-x^2}dx$$
inp=T.dvector()
dx = T.dscalar()
pi, updates = theano.scan(lambda xt, pi_sum: pi_sum+2.*T.sqrt(1-xt**2)*dx, sequences=[inp],
outputs_info=T.constant(np.float64(0.)))
pi = pi[-1]
cal_pi = theano.function([inp, dx], pi)
n_interval = 100000
print cal_pi(np.linspace(-1, 1, n_interval)[1:-1], 2. / n_interval)
3.14156113248
我们用scan实现$e^x$的泰勒展开式: $$ e^x=1+x+\frac{1}{2!}x^2+\frac{1}{3!}x^3+\cdots = \sum_{n=0}^\infty \frac{1}{n!}x^n$$
n = T.iscalar()
x = T.dscalar()
#factorial = T.cumprod(T.arange(1, n + 1)))
factorial = T.gamma(n+1)
def fn(n, power, exp_sum, x):
power = power*x
return power, exp_sum + 1./T.gamma(n+1)*power
result, updates = theano.scan(fn, sequences=T.arange(1, n),
outputs_info=[T.constant(np.float64(1.)), T.constant(np.float64(1.))],
non_sequences=x)
exp_ = result[1][-1]
calc_exp = theano.function([n, x], exp_)
print "calc_exp(1)=%f"%calc_exp(15, 1)
print "calc_exp(0)=%f"%calc_exp(15, 0)
print "calc_exp(0)=%f"%calc_exp(15, -1)
print "whether calc_exp(1) equals np.exp(1):%s"%np.allclose(calc_exp(15, 1), np.exp(1))
calc_exp(1)=2.718282 calc_exp(0)=1.000000 calc_exp(0)=0.367879 whether calc_exp(1) equals np.exp(1):True