Ref: Autoencoding Variational Inference for Topic Models. In ICLR. 2017.
from keras import backend as K
from keras.layers import Input, Dense, Lambda, Activation, Dropout, BatchNormalization, Layer
from keras.models import Model
from keras.optimizers import Adam
from keras.datasets import reuters
from keras.callbacks import EarlyStopping
import numpy as np
Using TensorFlow backend.
V = 10922
(x_train, _), (_, _) = reuters.load_data(start_char=None, oov_char=None, index_from=-1, num_words=V) # remove words having freq(q) <= 5
word_index = reuters.get_word_index()
index2word = {v-1: k for k, v in word_index.items()} # zero-origin word index
x_train = np.array([np.bincount(doc, minlength=V) for doc in x_train])
x_train = x_train[:8000, :]
num_hidden = 100
num_topic = 20
batch_size = 100
alpha = 1./20
mu1 = np.log(alpha) - 1/num_topic*num_topic*np.log(alpha)
sigma1 = 1./alpha*(1-2./num_topic) + 1/(num_topic**2)*num_topic/alpha
inv_sigma1 = 1./sigma1
log_det_sigma = num_topic*np.log(sigma1)
x = Input(batch_shape=(batch_size, V))
h = Dense(num_hidden, activation='softplus')(x)
h = Dense(num_hidden, activation='softplus')(h)
z_mean = BatchNormalization()(Dense(num_topic)(h))
z_log_var = BatchNormalization()(Dense(num_topic)(h))
def sampling(args):
z_mean, z_log_var = args
epsilon = K.random_normal(shape=(batch_size, num_topic),
mean=0., stddev=1.)
return z_mean + K.exp(z_log_var / 2) * epsilon
unnormalized_z = Lambda(sampling, output_shape=(num_topic,))([z_mean, z_log_var])
theta = Activation('softmax')(unnormalized_z)
theta = Dropout(0.5)(theta)
doc = Dense(units=V)(theta)
doc = BatchNormalization()(doc)
doc = Activation('softmax')(doc)
# Custom loss layer
class CustomVariationalLayer(Layer):
def __init__(self, **kwargs):
self.is_placeholder = True
super(CustomVariationalLayer, self).__init__(**kwargs)
def vae_loss(self, x, inference_x):
decoder_loss = K.sum(x * K.log(inference_x), axis=-1)
encoder_loss = -0.5*(K.sum(inv_sigma1*K.exp(z_log_var) + K.square(z_mean)*inv_sigma1 - 1 - z_log_var, axis=-1) + log_det_sigma)
return -K.mean(encoder_loss + decoder_loss)
def call(self, inputs):
x = inputs[0]
inference_x = inputs[1]
loss = self.vae_loss(x, inference_x)
self.add_loss(loss, inputs=inputs)
# We won't actually use the output.
return x
y = CustomVariationalLayer()([x, doc])
prodLDA = Model(x, y)
prodLDA.compile(optimizer=Adam(lr=0.001, beta_1=0.99), loss=None)
/Users/nzw/.pyenv/versions/miniconda3-latest/lib/python3.6/site-packages/ipykernel_launcher.py:3: UserWarning: Output "custom_variational_layer_1" missing from loss dictionary. We assume this was done on purpose, and we will not be expecting any data to be passed to "custom_variational_layer_1" during training. This is separate from the ipykernel package so we can avoid doing imports until
prodLDA.fit(x_train, verbose=1, batch_size=batch_size, validation_split=0.1, callbacks=[EarlyStopping(patience=3)], epochs=20)
Train on 7200 samples, validate on 800 samples Epoch 1/20 7200/7200 [==============================] - 16s - loss: 1294.8741 - val_loss: 1296.3189 Epoch 2/20 7200/7200 [==============================] - 13s - loss: 1236.1560 - val_loss: 1274.4536 Epoch 3/20 7200/7200 [==============================] - 13s - loss: 1212.5309 - val_loss: 1254.4617 Epoch 4/20 7200/7200 [==============================] - 14s - loss: 1192.9693 - val_loss: 1223.0534 Epoch 5/20 7200/7200 [==============================] - 16s - loss: 1177.8462 - val_loss: 1189.5439 Epoch 6/20 7200/7200 [==============================] - 17s - loss: 1163.8908 - val_loss: 1163.1716 Epoch 7/20 7200/7200 [==============================] - 14s - loss: 1149.6908 - val_loss: 1141.6182 Epoch 8/20 7200/7200 [==============================] - 15s - loss: 1137.0046 - val_loss: 1115.0331 Epoch 9/20 7200/7200 [==============================] - 15s - loss: 1123.7918 - val_loss: 1087.5322 Epoch 10/20 7200/7200 [==============================] - 15s - loss: 1113.6162 - val_loss: 1070.2166 Epoch 11/20 7200/7200 [==============================] - 16s - loss: 1101.6660 - val_loss: 1054.5499 Epoch 12/20 7200/7200 [==============================] - 16s - loss: 1091.5757 - val_loss: 1046.8468 Epoch 13/20 7200/7200 [==============================] - 15s - loss: 1084.2708 - val_loss: 1036.8501 Epoch 14/20 7200/7200 [==============================] - 14s - loss: 1073.5024 - val_loss: 1025.1848 Epoch 15/20 7200/7200 [==============================] - 13s - loss: 1066.2755 - val_loss: 1020.8630 Epoch 16/20 7200/7200 [==============================] - 13s - loss: 1058.9499 - val_loss: 1014.4392 Epoch 17/20 7200/7200 [==============================] - 13s - loss: 1051.1697 - val_loss: 1011.2585 Epoch 18/20 7200/7200 [==============================] - 13s - loss: 1043.8778 - val_loss: 1002.2133 Epoch 19/20 7200/7200 [==============================] - 13s - loss: 1038.1715 - val_loss: 998.5989 Epoch 20/20 7200/7200 [==============================] - 14s - loss: 1031.7404 - val_loss: 990.7487
<keras.callbacks.History at 0x11d005e48>
exp_beta = np.exp(prodLDA.get_weights()[-6]).T
phi = (exp_beta/np.sum(exp_beta, axis=0)).T
for k, phi_k in enumerate(phi):
print('topic: {}'.format(k))
for w in np.argsort(phi_k)[::-1][:10]:
print(index2word[w], phi_k[w])
print()
topic: 0 mln 0.000111143 billion 0.000108751 vs 0.00010815 4 0.000106495 2 0.000106403 dlrs 0.000106268 0 0.000105673 1 0.00010521 87 0.000104661 tonnes 0.000103869 topic: 1 offices 9.75551e-05 nogales 9.74515e-05 guard 9.73854e-05 automotive 9.7355e-05 unpaid 9.72935e-05 alarm 9.72468e-05 kilometers 9.72122e-05 dixon 9.71581e-05 library 9.71567e-05 independently 9.70571e-05 topic: 2 the 0.000103879 of 0.000102854 offer 0.000102746 a 0.000102459 dlrs 0.000101379 pesos 0.000101333 williams 0.000101298 to 0.00010123 share 0.000101195 norcros 0.000101193 topic: 3 the 0.000106743 trade 0.000104344 to 0.000103626 japan 0.000103166 yeutter 0.000102792 clayton 0.000102494 states 0.000102354 semiconductors 0.000102334 united 0.000102321 venice 0.000102175 topic: 4 vs 0.000121324 shr 0.000114225 cts 0.00011384 net 0.000113774 000 0.000113085 mln 0.000112964 loss 0.000110497 revs 0.00010924 shrs 0.000108045 avg 0.00010794 topic: 5 the 0.000120084 to 0.000113982 of 0.000113214 a 0.000110542 in 0.000110401 and 0.000109963 said 0.000109562 that 0.000104584 for 0.000103823 banks 0.000103126 topic: 6 twa 0.000100107 usair 9.96932e-05 idc 9.91386e-05 twa's 9.91085e-05 offer 9.88553e-05 alvite 9.88493e-05 usair's 9.87797e-05 ecuador's 9.87569e-05 said 9.85903e-05 lawsuit 9.83429e-05 topic: 7 in 0.000106211 pct 0.000105936 0 0.000105215 1 0.000104341 rose 0.000104238 2 0.000104187 unadjusted 0.000103774 87 0.000103751 09 0.000103438 billion 0.000103206 topic: 8 div 0.000107786 qtly 0.000107582 prior 0.000104525 record 0.000104042 pay 0.000103519 juergen 0.000103384 eckenfelder 0.000103248 decades 0.000103185 playing 0.000103119 overhanging 0.000103105 topic: 9 the 0.000118233 to 0.000110448 of 0.000109871 in 0.000108754 said 0.000108514 and 0.000107417 a 0.000106782 economists 0.000103912 fed 0.000103817 pct 0.000103388 topic: 10 shares 9.83596e-05 gold 9.79488e-05 it 9.76937e-05 offer 9.72621e-05 hillards 9.7235e-05 ton 9.715e-05 assistance 9.70506e-05 filing 9.70396e-05 rated 9.70243e-05 debentures 9.69472e-05 topic: 11 vs 0.000123313 mln 0.000119359 dlrs 0.000117206 cts 0.000116674 net 0.000115827 shr 0.000114957 000 0.000113607 loss 0.00011315 oper 0.000111243 1 0.000110801 topic: 12 undisclosed 0.000100513 nogales 9.92817e-05 bolivia's 9.90571e-05 inc 9.90198e-05 refine 9.87756e-05 haq 9.87713e-05 vulnerability 9.83347e-05 eckenfelder 9.82606e-05 unitary 9.8222e-05 remittances 9.81787e-05 topic: 13 undisclosed 9.84391e-05 covenants 9.83275e-05 inc 9.82571e-05 corp 9.81662e-05 shrinking 9.81463e-05 bolivia's 9.80812e-05 completed 9.80501e-05 nogales 9.8049e-05 lieberman 9.7992e-05 rawl 9.79763e-05 topic: 14 shr 0.00011173 vs 0.000110093 cts 0.00010956 net 0.000109266 revs 0.000109084 vulnerability 0.000101041 avg 0.000101024 lieberman 0.000100964 note 0.00010096 calculating 0.000100699 topic: 15 vs 0.000123786 shr 0.000116089 net 0.000115306 cts 0.000114294 loss 0.000112816 000 0.000112405 revs 0.000111282 mln 0.000110697 profit 0.000109111 avg 0.000108816 topic: 16 the 0.000117573 to 0.000111937 of 0.000110009 and 0.000108801 said 0.000108758 in 0.000108275 a 0.000107725 opec 0.000106638 oil 0.000105008 prices 0.000104015 topic: 17 the 0.000111432 in 0.000106903 to 0.000106766 of 0.00010578 said 0.000104663 and 0.000104482 a 0.000103509 mln 0.000103303 year 0.000102863 pct 0.000102565 topic: 18 the 0.000115889 to 0.00011093 of 0.000108026 rep 0.000106465 a 0.000106093 and 0.000105682 subcommittee 0.000105528 said 0.000105319 trade 0.000105141 bill 0.000105072 topic: 19 div 0.000106618 qtly 0.000105465 prior 0.000103936 record 0.000103885 decades 0.000103455 eckenfelder 0.000103418 matane 0.000103244 harkin 0.00010324 donohue's 0.000103181 preferring 0.000103011