from functools import reduce
def to_bits(*l):
return reduce(lambda r,i: r |(1<<i), l, 0)
def winning_patterns():
v1 = to_bits(0,1,2)
h1 = to_bits(0,3,6)
return [v1, v1<<3, v1<<6, h1, h1<<1, h1<<2, to_bits(0,4,8), to_bits(2,4,6)]
WINNING_PATTERNS = winning_patterns()
c
Using gpu device 0: GeForce GTX 965M (CNMeM is disabled)
'float32'
all_input_data = np.zeros((2**9, 1, 3, 3), dtype=floatX)
all_target_data = np.zeros(2**9, dtype=floatX)
for board in range(2**9):
b = 1
for i in range(3):
for j in range(3):
all_input_data[board, 0, i, j] = bool(board & b)
b <<= 1
all_target_data[board] = any(p&board == p for p in WINNING_PATTERNS)
#theano.config.exception_verbosity = "high"
#theano.config.optimizer= 'None'
input_var = T.tensor4('inputs')
target_var = T.vector('targets')
l_in = lasagne.layers.InputLayer(shape=(None, 1, 3, 3), input_var=input_var)
#l_in_drop = lasagne.layers.DropoutLayer(l_in, p=0.2)
l_hidden = lasagne.layers.DenseLayer(l_in, num_units=64, nonlinearity=lasagne.nonlinearities.tanh, W=lasagne.init.GlorotUniform())
#l_hidden_drop = lasagne.layers.DropoutLayer(l_hidden, p=0.5)
l_out = lasagne.layers.DenseLayer(l_hidden, num_units=1, nonlinearity=lasagne.nonlinearities.sigmoid, W=lasagne.init.GlorotUniform())
prediction = lasagne.layers.get_output(l_out).flatten()
#loss = lasagne.objectives.binary_crossentropy(prediction, target_var)
loss = lasagne.objectives.squared_error(prediction, target_var*0.99)
loss = loss.mean()
params = lasagne.layers.get_all_params(l_out, trainable=True)
print(params)
#updates = lasagne.updates.nesterov_momentum(loss, params, learning_rate=0.01, momentum=0.9)
updates = lasagne.updates.adam(loss, params)
[W, b, W, b]
test_prediction = lasagne.layers.get_output(l_out, deterministic=True).flatten()
test_loss = lasagne.objectives.binary_crossentropy(test_prediction, target_var)
test_loss = test_loss.mean()
test_acc = lasagne.objectives.binary_accuracy(T.gt(test_prediction, 0.5), target_var).mean()
train_fn = theano.function([input_var, target_var], loss, updates=updates)
prediction_fn = theano.function([input_var], T.gt(test_prediction, 0.))
val_fn = theano.function([input_var, target_var], [test_loss, test_acc])
from random import randint
N = 100
for epoch in range(100000):
idx = np.random.randint(2**9,size=N)
input_data = all_input_data[idx, :]
target_data = all_target_data[idx]
if epoch%1 == 0:
aloss, accuracy = val_fn(all_input_data, all_target_data)
if epoch%100==0:
print(epoch, "accuracy", accuracy, aloss)
if accuracy == 1.0:
break
loss = train_fn(input_data, target_data)
print(epoch, "accuracy", val_fn(all_input_data, all_target_data)[1])
0 accuracy 0.5859375 0.6583781838417053 100 accuracy 0.708984375 0.5453886985778809 200 accuracy 0.828125 0.46027469635009766 300 accuracy 0.8359375 0.40126198530197144 400 accuracy 0.849609375 0.3687146306037903 500 accuracy 0.84765625 0.3505007028579712 600 accuracy 0.84375 0.33722493052482605 700 accuracy 0.859375 0.3303086757659912 800 accuracy 0.853515625 0.32213646173477173 900 accuracy 0.84765625 0.3188232183456421 1000 accuracy 0.861328125 0.3121160864830017 1100 accuracy 0.853515625 0.3075151741504669 1200 accuracy 0.857421875 0.30120742321014404 1300 accuracy 0.87109375 0.2947361171245575 1400 accuracy 0.8828125 0.2909179925918579 1500 accuracy 0.880859375 0.2801564931869507 1600 accuracy 0.884765625 0.2742465138435364 1700 accuracy 0.88671875 0.26584911346435547 1800 accuracy 0.892578125 0.25704246759414673 1900 accuracy 0.90234375 0.25057801604270935 2000 accuracy 0.908203125 0.24204590916633606 2100 accuracy 0.916015625 0.2357223927974701 2200 accuracy 0.91796875 0.2258404940366745 2300 accuracy 0.9375 0.22131146490573883 2400 accuracy 0.947265625 0.21085698902606964 2500 accuracy 0.951171875 0.20293736457824707 2600 accuracy 0.9609375 0.19260381162166595 2700 accuracy 0.96875 0.18439024686813354 2800 accuracy 0.97265625 0.17622195184230804 2900 accuracy 0.9765625 0.16739627718925476 3000 accuracy 0.9765625 0.15894639492034912 3100 accuracy 0.9765625 0.14979851245880127 3200 accuracy 0.986328125 0.141909658908844 3300 accuracy 0.986328125 0.13496193289756775 3400 accuracy 0.990234375 0.12670326232910156 3500 accuracy 0.98828125 0.11964216083288193 3600 accuracy 0.994140625 0.11327987909317017 3700 accuracy 0.9921875 0.10799919068813324 3800 accuracy 0.99609375 0.10304692387580872 3900 accuracy 0.994140625 0.0961313247680664 4000 accuracy 0.99609375 0.09087066352367401 4100 accuracy 0.994140625 0.08672349154949188 4200 accuracy 0.998046875 0.08285956084728241 4235 accuracy 1.0
for x in params:
print(x.get_value())
[[ 4.18483526e-01 2.69759119e-01 -7.59260170e-03 5.12706757e-01 4.55930084e-01 2.79328287e-01 8.04584146e-01 3.88474204e-02 4.17414069e-01 -5.56433916e-01 4.44707304e-01 6.61927164e-01 -3.49529803e-01 6.29592240e-01 -2.40869656e-01 1.90079331e-01 1.18805397e+00 -2.00313553e-01 -3.39206010e-01 4.57586423e-02 8.72410119e-01 -6.92570031e-01 5.20162821e-01 -1.19220361e-01 -2.57124305e-01 3.60738337e-01 -2.52967954e-01 2.16879901e-02 2.45237380e-01 -3.28956217e-01 -2.83724010e-01 -2.21403107e-01 -6.73699141e-01 -1.87377378e-01 -9.52876881e-02 2.64916360e-01 -2.08849087e-01 -1.00696731e+00 -6.75794065e-01 -4.93810743e-01 7.17632532e-01 -6.64196849e-01 -1.80892199e-01 9.95973349e-02 -1.09519444e-01 6.33552790e-01 -7.73389399e-01 -3.06025475e-01 6.89434171e-01 -7.62161016e-01 -4.28372651e-01 1.30543355e-02 -8.49999189e-02 -3.06482241e-02 -3.22102845e-01 6.07555211e-01 -1.57011449e-01 -5.73117018e-01 3.62516999e-01 -9.39474702e-02 5.55969834e-01 7.56791532e-02 -3.94294292e-01 7.57431269e-01] [ 1.67151675e-01 2.64278531e-01 4.43806425e-02 7.30237067e-01 -2.42764801e-01 4.88336027e-01 3.98325831e-01 -6.80895231e-04 2.52512712e-02 6.49237514e-01 -1.75191656e-01 -3.58412981e-01 3.88933301e-01 -7.26354897e-01 1.06493461e+00 -5.18472314e-01 6.28828406e-01 -1.62441358e-01 -8.17095712e-02 -1.36176705e+00 5.04715502e-01 -5.35405934e-01 1.07246172e+00 4.93461639e-02 1.30169719e-01 -3.96274596e-01 -1.72919810e-01 -3.06978017e-01 1.11386023e-01 4.60290313e-01 1.52396455e-01 -4.60059226e-01 -2.94177324e-01 2.26197109e-01 -6.36561096e-01 -3.13768178e-01 -4.45106387e-01 -1.83074459e-01 1.51198149e-01 7.04321980e-01 -2.01739281e-01 7.35423446e-01 2.35101432e-01 -7.77905136e-02 -8.36337268e-01 -3.02311331e-01 -2.21972689e-01 2.36513555e-01 -4.72001940e-01 -8.14709008e-01 4.10909027e-01 7.62284920e-02 2.69056231e-01 2.92792201e-01 2.56092489e-01 1.56151831e-01 -1.16285086e+00 -4.51067746e-01 1.16512388e-01 -4.44965690e-01 -2.85560459e-01 4.12789248e-02 5.58527671e-02 5.42474449e-01] [ -2.11870536e-01 -7.57893845e-02 -1.42873362e-01 6.98603690e-01 -6.04689240e-01 -6.35428727e-02 -2.45596200e-01 4.05158281e-01 3.29469234e-01 -6.35195136e-01 -2.67238766e-01 -6.20838761e-01 4.90247488e-01 7.39130259e-01 1.05368048e-01 -1.98600501e-01 3.35916489e-01 -2.32930124e-01 1.16074249e-01 2.47073650e-01 -8.34672868e-01 7.80489266e-01 5.62718511e-01 3.88561219e-01 -2.53121048e-01 -3.80851835e-01 -5.69320843e-02 -5.75732470e-01 4.17046845e-02 -2.23490030e-01 -1.87550351e-01 5.87576404e-02 -1.96262300e-01 5.77715218e-01 3.25957984e-01 -3.32458913e-01 -3.49753618e-01 3.43952775e-01 -6.56557918e-01 7.52728462e-01 1.12010576e-01 -6.25354826e-01 -9.91278216e-02 3.45303714e-01 2.38130942e-01 2.60378152e-01 -4.80126113e-01 3.13060164e-01 -5.62594950e-01 -9.11764383e-01 6.14916980e-01 6.63202286e-01 -1.12665221e-01 -2.81394832e-02 3.54103506e-01 -8.65743756e-01 -1.31995189e+00 6.09415770e-01 9.38001722e-02 8.71729136e-01 5.67386150e-01 3.52047056e-01 -6.77942216e-01 -2.54786730e-01] [ 4.82647307e-03 -1.98780194e-01 5.82843870e-02 -1.22866899e-01 -2.27141783e-01 9.66657922e-02 -6.74539924e-01 6.55264929e-02 4.14957851e-01 9.20446292e-02 2.65304595e-01 -4.80816811e-01 -5.40773049e-02 2.98815966e-01 8.88912737e-01 -1.50266781e-01 5.15716493e-01 1.73503160e-01 -2.21236795e-01 -3.95721607e-02 8.56035233e-01 -4.06634808e-01 -7.30633616e-01 -8.41819495e-02 1.01426411e-02 -3.98423910e-01 1.70232713e-01 -1.08258331e+00 1.01034291e-01 -1.26416326e-01 1.48776993e-02 -1.63162395e-01 -3.89225148e-02 7.11737216e-01 2.00615197e-01 -2.82574058e-01 2.61138886e-01 7.60514796e-01 -3.94940257e-01 -1.40347469e+00 -4.65573937e-01 -2.02380851e-01 -2.52319008e-01 -1.05546512e-01 -1.14958145e-01 1.19513303e-01 8.00218999e-01 -3.62284482e-01 -2.88515925e-01 2.15127110e-01 4.25384045e-02 -1.27148047e-01 -2.44119480e-01 -9.73497480e-02 -4.26746160e-01 3.93325895e-01 1.01967776e+00 -3.79121751e-01 -3.08109283e-01 -6.51261091e-01 4.66107219e-01 2.86760423e-02 -1.20721668e-01 4.50918049e-01] [ -3.13755870e-01 1.17489137e-01 3.66920114e-01 -9.89474177e-01 5.08355916e-01 2.87953526e-01 1.46398216e-01 -1.08244054e-01 -4.53598350e-02 7.54209161e-01 -3.69676352e-01 6.21540248e-01 -2.04147324e-02 5.83658397e-01 8.09098855e-02 -2.20327321e-02 -1.09467053e+00 4.64596540e-01 -1.61413893e-01 -3.20746273e-01 1.80383027e+00 5.95668256e-01 -1.75191939e-01 4.16209906e-01 -1.80363968e-01 -4.93950576e-01 1.77235797e-01 -1.02430187e-01 2.25718264e-02 3.95187974e-01 -1.86797604e-01 -7.25470260e-02 1.23207164e+00 4.31666642e-01 -6.73638523e-01 3.12578648e-01 -2.42272094e-02 -4.00122076e-01 -1.32399291e-01 7.36798346e-02 5.18847525e-01 -5.29653311e-01 9.93238389e-02 -1.35287806e-01 -1.69990391e-01 5.85826278e-01 6.79551423e-01 4.25367832e-01 7.58942068e-01 1.02630544e+00 5.37900329e-01 3.60927165e-01 3.12062204e-02 -3.51585537e-01 6.63518190e-01 1.08626986e+00 -3.64488661e-02 6.09239161e-01 2.33363081e-02 -5.63642621e-01 9.56821255e-03 -9.65002030e-02 2.15926662e-01 7.78436780e-01] [ 1.52726471e-01 -3.81888121e-01 8.80553424e-02 3.50277349e-02 -4.80940849e-01 -2.11031109e-01 1.73385942e+00 3.42017651e-01 2.67591119e-01 1.08979732e-01 -6.31761968e-01 -3.45875055e-01 2.79467493e-01 -1.68567851e-01 -1.45836666e-01 -7.82942623e-02 -3.72703910e-01 8.75643566e-02 6.90157637e-02 -5.10598063e-01 -9.18696821e-01 -5.56833923e-01 -3.54434371e-01 -1.19453475e-01 -7.73116872e-02 4.25316334e-01 4.01346922e-01 4.89702702e-01 -3.40829715e-02 -1.08146667e-01 2.00812653e-01 3.54507744e-01 8.44726980e-01 -6.79010868e-01 -8.56161416e-02 -2.34694079e-01 3.16483468e-01 -1.18462145e+00 -4.65101004e-01 -2.46768996e-01 4.40040201e-01 1.10015057e-01 -2.46539321e-02 2.76523642e-02 -2.33010486e-01 -8.26398283e-02 8.89547825e-01 2.74895042e-01 -2.26332396e-01 3.22282791e-01 -5.13832629e-01 6.88645303e-01 9.97817218e-02 -7.87154064e-02 5.02774417e-01 -8.57327342e-01 4.72455949e-01 -2.52752721e-01 8.83315057e-02 1.56614989e-01 1.26951069e-01 -1.24301314e-01 -3.50111693e-01 -4.92299467e-01] [ -1.28980607e-01 1.05951473e-01 3.23907912e-01 -4.52438027e-01 -5.96742272e-01 -4.65104759e-01 8.92923057e-01 3.95191640e-01 -3.74380261e-01 -7.34798968e-01 9.21874344e-01 -5.63195646e-01 -1.09255455e-01 -3.32228959e-01 7.10433945e-02 -3.84402663e-01 3.03599179e-01 4.01687384e-01 6.74579218e-02 -1.14292037e+00 1.76932681e+00 6.05398357e-01 -7.44244337e-01 -1.57134727e-01 -4.64350469e-02 -2.33833253e-01 -1.87384412e-01 2.33285144e-01 -1.51849926e-01 -3.64999741e-01 -1.28453672e-01 -3.19381744e-01 2.72750054e-02 5.90682290e-02 -6.83418810e-01 -2.92619765e-01 -8.90894905e-02 -5.47839642e-01 4.56304401e-01 -5.52570939e-01 3.14277798e-01 2.61917353e-01 -1.35128334e-01 1.15978122e-01 -3.78009379e-01 -7.69701898e-02 -7.31372178e-01 -4.09047335e-01 -7.35478103e-01 7.69780636e-01 -2.53710337e-03 -5.98885179e-01 -3.77774060e-01 -2.75469720e-02 -5.73520184e-01 7.85714269e-01 -2.19918340e-01 7.51124680e-01 4.74281818e-01 7.28578389e-01 -2.55773574e-01 1.84770033e-01 -2.44939327e-01 -3.09004933e-01] [ 1.48891481e-02 1.05748512e-01 -2.30051339e-01 -6.46029413e-01 -4.96444315e-01 -3.34029168e-01 3.60235721e-01 5.45433819e-01 -9.29341733e-01 6.03839159e-01 3.19483757e-01 -3.91308606e-01 8.52180719e-02 8.11671734e-01 -2.11976290e-01 5.61531037e-02 -5.70635796e-01 -7.43554607e-02 7.02141821e-02 5.67699969e-01 1.18432796e+00 -4.73354727e-01 1.52944222e-01 8.08209255e-02 -4.09048557e-01 3.08905125e-01 -1.09348699e-01 -2.63088197e-01 2.37225816e-01 3.74674618e-01 -2.33153775e-01 -3.53675485e-01 9.16921079e-01 -3.73192243e-02 8.01338375e-01 -3.40588182e-01 -7.44539872e-02 -2.21374825e-01 1.10618222e+00 5.63558757e-01 1.00012816e-01 -9.13917184e-01 2.03752860e-01 -7.05113411e-02 3.56672466e-01 1.77385196e-01 -2.17976168e-01 9.52056646e-02 -6.25625730e-01 2.47311980e-01 -1.98170036e-01 -6.48819447e-01 2.94325799e-01 -3.11757624e-01 -5.08828647e-02 3.27321261e-01 -9.69880879e-01 -6.32486820e-01 -9.18154642e-02 2.97784060e-01 1.07144475e+00 1.32302389e-01 3.87274593e-01 -4.51003045e-01] [ 5.10926664e-01 -9.50601101e-02 3.10929120e-01 -1.81989089e-01 7.15278149e-01 -5.36913276e-01 -2.37568095e-01 -4.39516902e-01 -4.89345670e-01 -6.46295369e-01 3.03571671e-01 5.90506494e-01 -2.95945495e-01 1.72119945e-01 1.14438450e+00 -1.36222184e-01 -9.08134580e-01 4.33219105e-01 4.96932939e-02 -4.83662337e-01 -5.82826972e-01 -6.06962025e-01 -4.54020977e-01 4.07646775e-01 -1.31258527e-02 -7.25396872e-01 6.95380270e-02 -7.87484288e-01 1.97102621e-01 -4.46066797e-01 -1.16135061e-01 1.94256753e-01 9.98753190e-01 7.39023924e-01 -6.09172106e-01 3.20188940e-01 2.16472834e-01 -5.32284342e-02 4.51938242e-01 4.75898802e-01 -2.49191895e-01 -1.45937920e-01 -2.51330972e-01 6.96447715e-02 -6.25155687e-01 -2.99078405e-01 -4.73462403e-01 1.63482666e-01 8.16761017e-01 6.14592850e-01 7.80209661e-01 -2.63774723e-01 5.22259116e-01 2.40453910e-02 1.98942907e-02 -7.27659881e-01 -1.14654028e+00 -6.62159622e-01 9.54491347e-02 1.14061630e+00 2.96662506e-02 1.19721917e-02 -2.84518689e-01 -8.16897690e-01]] [-0.32571009 -0.48874131 -0.3270874 0.32692969 0.37803906 -0.30155432 -0.16395621 -0.27514377 -0.21979423 0.35770056 0.23739098 0.32894292 -0.42133906 -0.21086395 -0.04061511 0.27547225 0.27973431 -0.46671966 0.48462519 0.14245707 0.7242803 0.44150189 0.37441638 -0.34307024 0.32654846 0.17491664 0.45054659 0.1082921 -0.40649605 0.43048802 0.36877716 -0.51132172 0.56276733 -0.16996717 0.16795002 0.32153451 -0.49326968 0.18044193 0.28776923 -0.40389299 -0.17252436 0.19697575 0.51116955 -0.44292337 0.17883576 -0.28902099 0.30460393 -0.39438644 0.23629031 -0.1571061 -0.2268829 -0.51737887 -0.38377669 0.34875521 -0.41533169 -0.27694708 -0.5844999 0.34087992 -0.33050111 0.39185333 -0.114302 -0.3569966 -0.49816149 -0.35555646] [[ 0.48723856] [ 0.60480052] [ 0.62330908] [-0.96644455] [-0.88113916] [ 0.45159534] [ 0.96177834] [ 0.7204473 ] [ 0.58423793] [-1.3539201 ] [-0.43407202] [-1.1494323 ] [ 0.43500751] [ 0.90524864] [ 0.83264339] [-0.59006268] [-1.11550426] [ 0.59520704] [-0.51298326] [-0.86559886] [-0.78313315] [-1.26345336] [-0.82620966] [ 0.69827366] [-0.48619375] [-0.82099462] [-0.60774225] [-0.84828478] [ 0.45489886] [-0.85440809] [-0.59685689] [ 0.69321775] [-0.59790462] [ 0.83870232] [-0.93293458] [-0.71416676] [ 0.52368629] [-0.86841691] [-0.95480931] [ 1.01682889] [ 0.81646997] [-0.85834318] [-0.60507244] [ 0.54665875] [-0.76002777] [ 0.87522888] [-1.35240412] [ 0.73926353] [-1.13527179] [ 0.86649674] [ 0.50064993] [ 0.7484867 ] [ 0.62787777] [-0.46786523] [ 0.68521672] [ 1.29471874] [ 0.7990222 ] [-1.1824789 ] [ 0.62338585] [-0.65906829] [ 0.85014784] [ 0.63636106] [ 0.67048615] [ 1.14917779]] [-0.3289195]
params[0].get_value()[:, 6]
array([ 0.80458415, 0.39832583, -0.2455962 , -0.67453992, 0.14639822, 1.73385942, 0.89292306, 0.36023572, -0.2375681 ], dtype=float32)