import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dense,Activation
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train.shape = (60000, 28 * 28)
x_train = x_train / 255
y_train = keras.utils.to_categorical(y_train)
model = Sequential()
model.add(Dense(786, input_dim = 28*28))
model.add(Activation('relu'))
model.add(Dense(256))
model.add(Activation('relu'))
model.add(Dense(160))
model.add(Activation('relu'))
model.add(Dense(10))
model.add(Activation('softmax'))
model.compile(optimizer=keras.optimizers.SGD(lr=0.0001, momentum=0.9),
loss='categorical_crossentropy',
metric='accuracy')
estimator = keras.estimator.model_to_estimator(keras_model=model)
INFO:tensorflow:Using the Keras model provided.
INFO:tensorflow:Using the Keras model provided.
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /var/folders/46/293cpt8x5gq61f_5rjb2msrc0000gn/T/tmp7dn4eogp
WARNING:tensorflow:Using temporary folder as model directory: /var/folders/46/293cpt8x5gq61f_5rjb2msrc0000gn/T/tmp7dn4eogp
INFO:tensorflow:Using config: {'_train_distribute': None, '_save_checkpoints_secs': 600, '_num_worker_replicas': 1, '_global_id_in_cluster': 0, '_save_summary_steps': 100, '_keep_checkpoint_every_n_hours': 10000, '_master': '', '_evaluation_master': '', '_keep_checkpoint_max': 5, '_session_config': None, '_save_checkpoints_steps': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x133446780>, '_task_id': 0, '_service': None, '_task_type': 'worker', '_tf_random_seed': None, '_log_step_count_steps': 100, '_model_dir': '/var/folders/46/293cpt8x5gq61f_5rjb2msrc0000gn/T/tmp7dn4eogp', '_is_chief': True, '_num_ps_replicas': 0}
INFO:tensorflow:Using config: {'_train_distribute': None, '_save_checkpoints_secs': 600, '_num_worker_replicas': 1, '_global_id_in_cluster': 0, '_save_summary_steps': 100, '_keep_checkpoint_every_n_hours': 10000, '_master': '', '_evaluation_master': '', '_keep_checkpoint_max': 5, '_session_config': None, '_save_checkpoints_steps': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x133446780>, '_task_id': 0, '_service': None, '_task_type': 'worker', '_tf_random_seed': None, '_log_step_count_steps': 100, '_model_dir': '/var/folders/46/293cpt8x5gq61f_5rjb2msrc0000gn/T/tmp7dn4eogp', '_is_chief': True, '_num_ps_replicas': 0}
model.input_names
['dense_9_input']
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={'dense_9_input': x_train},
y=y_train,
num_epochs=1,
shuffle=False)
estimator.train(input_fn=train_input_fn, steps=2000)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /var/folders/46/293cpt8x5gq61f_5rjb2msrc0000gn/T/tmp7dn4eogp/keras_model.ckpt
INFO:tensorflow:Restoring parameters from /var/folders/46/293cpt8x5gq61f_5rjb2msrc0000gn/T/tmp7dn4eogp/keras_model.ckpt
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 1 into /var/folders/46/293cpt8x5gq61f_5rjb2msrc0000gn/T/tmp7dn4eogp/model.ckpt.
INFO:tensorflow:Saving checkpoints for 1 into /var/folders/46/293cpt8x5gq61f_5rjb2msrc0000gn/T/tmp7dn4eogp/model.ckpt.
INFO:tensorflow:step = 1, loss = 14.745756
INFO:tensorflow:step = 1, loss = 14.745756
INFO:tensorflow:global_step/sec: 72.6573
INFO:tensorflow:global_step/sec: 72.6573
INFO:tensorflow:step = 101, loss = 10.383348 (1.379 sec)
INFO:tensorflow:step = 101, loss = 10.383348 (1.379 sec)
INFO:tensorflow:global_step/sec: 94.0097
INFO:tensorflow:global_step/sec: 94.0097
INFO:tensorflow:step = 201, loss = 8.26259 (1.063 sec)
INFO:tensorflow:step = 201, loss = 8.26259 (1.063 sec)
INFO:tensorflow:global_step/sec: 92.2613
INFO:tensorflow:global_step/sec: 92.2613
INFO:tensorflow:step = 301, loss = 8.038471 (1.085 sec)
INFO:tensorflow:step = 301, loss = 8.038471 (1.085 sec)
INFO:tensorflow:global_step/sec: 85.0706
INFO:tensorflow:global_step/sec: 85.0706
INFO:tensorflow:step = 401, loss = 6.6674304 (1.174 sec)
INFO:tensorflow:step = 401, loss = 6.6674304 (1.174 sec)
INFO:tensorflow:Saving checkpoints for 470 into /var/folders/46/293cpt8x5gq61f_5rjb2msrc0000gn/T/tmp7dn4eogp/model.ckpt.
INFO:tensorflow:Saving checkpoints for 470 into /var/folders/46/293cpt8x5gq61f_5rjb2msrc0000gn/T/tmp7dn4eogp/model.ckpt.
INFO:tensorflow:Loss for final step: 6.4350104.
INFO:tensorflow:Loss for final step: 6.4350104.
<tensorflow.python.estimator.estimator.Estimator at 0x133446c50>
model.summary()
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_1 (Dense) (None, 786) 1717410 _________________________________________________________________ activation_95 (Activation) (None, 786) 0 _________________________________________________________________ dense_2 (Dense) (None, 256) 201472 _________________________________________________________________ activation_96 (Activation) (None, 256) 0 _________________________________________________________________ dense_3 (Dense) (None, 160) 41120 _________________________________________________________________ activation_97 (Activation) (None, 160) 0 _________________________________________________________________ dense_4 (Dense) (None, 10) 1610 _________________________________________________________________ activation_98 (Activation) (None, 10) 0 ================================================================= Total params: 1,961,612 Trainable params: 1,961,612 Non-trainable params: 0 _________________________________________________________________
model.input_shape
(None, 2184)