Replace reduce_sum with reduce_mean: adds an extra .1% accuracy :)

main
Aurélien Geron 2018-01-18 17:41:32 +01:00
parent 94914db82e
commit 87040e084e
1 changed files with 138 additions and 248 deletions

View File

@ -77,10 +77,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 3,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"from __future__ import division, print_function, unicode_literals" "from __future__ import division, print_function, unicode_literals"
@ -95,10 +93,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 4,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"%matplotlib inline\n", "%matplotlib inline\n",
@ -115,10 +111,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 5,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"import numpy as np\n", "import numpy as np\n",
@ -141,10 +135,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 6,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"tf.reset_default_graph()" "tf.reset_default_graph()"
@ -159,10 +151,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 7,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"np.random.seed(42)\n", "np.random.seed(42)\n",
@ -185,7 +175,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -203,7 +193,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 9,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -228,7 +218,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -286,10 +276,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 11,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"X = tf.placeholder(shape=[None, 28, 28, 1], dtype=tf.float32, name=\"X\")" "X = tf.placeholder(shape=[None, 28, 28, 1], dtype=tf.float32, name=\"X\")"
@ -311,10 +299,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 12,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"caps1_n_maps = 32\n", "caps1_n_maps = 32\n",
@ -331,10 +317,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 13,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"conv1_params = {\n", "conv1_params = {\n",
@ -356,10 +340,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 14,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"conv1 = tf.layers.conv2d(X, name=\"conv1\", **conv1_params)\n", "conv1 = tf.layers.conv2d(X, name=\"conv1\", **conv1_params)\n",
@ -382,10 +364,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 15,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"caps1_raw = tf.reshape(conv2, [-1, caps1_n_caps, caps1_n_dims],\n", "caps1_raw = tf.reshape(conv2, [-1, caps1_n_caps, caps1_n_dims],\n",
@ -407,10 +387,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 16,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"def squash(s, axis=-1, epsilon=1e-7, name=None):\n", "def squash(s, axis=-1, epsilon=1e-7, name=None):\n",
@ -432,10 +410,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 17,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"caps1_output = squash(caps1_raw, name=\"caps1_output\")" "caps1_output = squash(caps1_raw, name=\"caps1_output\")"
@ -478,10 +454,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 18,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"caps2_n_caps = 10\n", "caps2_n_caps = 10\n",
@ -568,10 +542,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 19,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"init_sigma = 0.01\n", "init_sigma = 0.01\n",
@ -591,10 +563,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 20,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"batch_size = tf.shape(X)[0]\n", "batch_size = tf.shape(X)[0]\n",
@ -610,10 +580,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 21,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"caps1_output_expanded = tf.expand_dims(caps1_output, -1,\n", "caps1_output_expanded = tf.expand_dims(caps1_output, -1,\n",
@ -633,7 +601,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 21, "execution_count": 22,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -649,7 +617,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 23,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -665,10 +633,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 23, "execution_count": 24,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"caps2_predicted = tf.matmul(W_tiled, caps1_output_tiled,\n", "caps2_predicted = tf.matmul(W_tiled, caps1_output_tiled,\n",
@ -684,7 +650,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 24, "execution_count": 25,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -714,10 +680,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 25, "execution_count": 26,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"raw_weights = tf.zeros([batch_size, caps1_n_caps, caps2_n_caps, 1, 1],\n", "raw_weights = tf.zeros([batch_size, caps1_n_caps, caps2_n_caps, 1, 1],\n",
@ -747,10 +711,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 26, "execution_count": 27,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"routing_weights = tf.nn.softmax(raw_weights, dim=2, name=\"routing_weights\")" "routing_weights = tf.nn.softmax(raw_weights, dim=2, name=\"routing_weights\")"
@ -765,10 +727,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 27, "execution_count": 28,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"weighted_predictions = tf.multiply(routing_weights, caps2_predicted,\n", "weighted_predictions = tf.multiply(routing_weights, caps2_predicted,\n",
@ -797,10 +757,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 28, "execution_count": 29,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"caps2_output_round_1 = squash(weighted_sum, axis=-2,\n", "caps2_output_round_1 = squash(weighted_sum, axis=-2,\n",
@ -809,7 +767,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 29, "execution_count": 30,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -853,7 +811,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30, "execution_count": 31,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -869,7 +827,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 31, "execution_count": 32,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -885,10 +843,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 32, "execution_count": 33,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"caps2_output_round_1_tiled = tf.tile(\n", "caps2_output_round_1_tiled = tf.tile(\n",
@ -905,10 +861,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 33, "execution_count": 34,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"agreement = tf.matmul(caps2_predicted, caps2_output_round_1_tiled,\n", "agreement = tf.matmul(caps2_predicted, caps2_output_round_1_tiled,\n",
@ -924,10 +878,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 34, "execution_count": 35,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"raw_weights_round_2 = tf.add(raw_weights, agreement,\n", "raw_weights_round_2 = tf.add(raw_weights, agreement,\n",
@ -943,10 +895,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 35, "execution_count": 36,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"routing_weights_round_2 = tf.nn.softmax(raw_weights_round_2,\n", "routing_weights_round_2 = tf.nn.softmax(raw_weights_round_2,\n",
@ -972,10 +922,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 36, "execution_count": 37,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"caps2_output = caps2_output_round_2" "caps2_output = caps2_output_round_2"
@ -1003,7 +951,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 37, "execution_count": 38,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1043,7 +991,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 38, "execution_count": 39,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1073,10 +1021,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 39, "execution_count": 40,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False, name=None):\n", "def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False, name=None):\n",
@ -1088,10 +1034,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 40, "execution_count": 41,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"y_proba = safe_norm(caps2_output, axis=-2, name=\"y_proba\")" "y_proba = safe_norm(caps2_output, axis=-2, name=\"y_proba\")"
@ -1106,10 +1050,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 41, "execution_count": 42,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"y_proba_argmax = tf.argmax(y_proba, axis=2, name=\"y_proba\")" "y_proba_argmax = tf.argmax(y_proba, axis=2, name=\"y_proba\")"
@ -1124,7 +1066,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 42, "execution_count": 43,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1140,10 +1082,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 43, "execution_count": 44,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"y_pred = tf.squeeze(y_proba_argmax, axis=[1,2], name=\"y_pred\")" "y_pred = tf.squeeze(y_proba_argmax, axis=[1,2], name=\"y_pred\")"
@ -1151,7 +1091,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 44, "execution_count": 45,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1181,10 +1121,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 45, "execution_count": 46,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"y = tf.placeholder(shape=[None], dtype=tf.int64, name=\"y\")" "y = tf.placeholder(shape=[None], dtype=tf.int64, name=\"y\")"
@ -1212,10 +1150,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 46, "execution_count": 47,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"m_plus = 0.9\n", "m_plus = 0.9\n",
@ -1232,10 +1168,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 47, "execution_count": 48,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"T = tf.one_hot(y, depth=caps2_n_caps, name=\"T\")" "T = tf.one_hot(y, depth=caps2_n_caps, name=\"T\")"
@ -1250,7 +1184,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 48, "execution_count": 49,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1267,7 +1201,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 49, "execution_count": 50,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1283,10 +1217,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 50, "execution_count": 51,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"caps2_output_norm = safe_norm(caps2_output, axis=-2, keep_dims=True,\n", "caps2_output_norm = safe_norm(caps2_output, axis=-2, keep_dims=True,\n",
@ -1302,10 +1234,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 51, "execution_count": 52,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"present_error_raw = tf.square(tf.maximum(0., m_plus - caps2_output_norm),\n", "present_error_raw = tf.square(tf.maximum(0., m_plus - caps2_output_norm),\n",
@ -1323,10 +1253,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 52, "execution_count": 53,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"absent_error_raw = tf.square(tf.maximum(0., caps2_output_norm - m_minus),\n", "absent_error_raw = tf.square(tf.maximum(0., caps2_output_norm - m_minus),\n",
@ -1344,10 +1272,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 53, "execution_count": 54,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"L = tf.add(T * present_error, lambda_ * (1.0 - T) * absent_error,\n", "L = tf.add(T * present_error, lambda_ * (1.0 - T) * absent_error,\n",
@ -1363,10 +1289,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 54, "execution_count": 55,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"margin_loss = tf.reduce_mean(tf.reduce_sum(L, axis=1), name=\"margin_loss\")" "margin_loss = tf.reduce_mean(tf.reduce_sum(L, axis=1), name=\"margin_loss\")"
@ -1409,10 +1333,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 55, "execution_count": 56,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"mask_with_labels = tf.placeholder_with_default(False, shape=(),\n", "mask_with_labels = tf.placeholder_with_default(False, shape=(),\n",
@ -1428,10 +1350,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 56, "execution_count": 57,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"reconstruction_targets = tf.cond(mask_with_labels, # condition\n", "reconstruction_targets = tf.cond(mask_with_labels, # condition\n",
@ -1458,10 +1378,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 57, "execution_count": 58,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"reconstruction_mask = tf.one_hot(reconstruction_targets,\n", "reconstruction_mask = tf.one_hot(reconstruction_targets,\n",
@ -1478,7 +1396,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 58, "execution_count": 59,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1494,7 +1412,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 59, "execution_count": 60,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1510,10 +1428,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 60, "execution_count": 61,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"reconstruction_mask_reshaped = tf.reshape(\n", "reconstruction_mask_reshaped = tf.reshape(\n",
@ -1530,10 +1446,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 61, "execution_count": 62,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"caps2_output_masked = tf.multiply(\n", "caps2_output_masked = tf.multiply(\n",
@ -1543,7 +1457,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 62, "execution_count": 63,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1559,10 +1473,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 63, "execution_count": 64,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"decoder_input = tf.reshape(caps2_output_masked,\n", "decoder_input = tf.reshape(caps2_output_masked,\n",
@ -1579,7 +1491,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 64, "execution_count": 65,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1602,10 +1514,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 65, "execution_count": 66,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"n_hidden1 = 512\n", "n_hidden1 = 512\n",
@ -1615,10 +1525,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 66, "execution_count": 67,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"with tf.name_scope(\"decoder\"):\n", "with tf.name_scope(\"decoder\"):\n",
@ -1649,16 +1557,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 67, "execution_count": 68,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"X_flat = tf.reshape(X, [-1, n_output], name=\"X_flat\")\n", "X_flat = tf.reshape(X, [-1, n_output], name=\"X_flat\")\n",
"squared_difference = tf.square(X_flat - decoder_output,\n", "squared_difference = tf.square(X_flat - decoder_output,\n",
" name=\"squared_difference\")\n", " name=\"squared_difference\")\n",
"reconstruction_loss = tf.reduce_sum(squared_difference,\n", "reconstruction_loss = tf.reduce_mean(squared_difference,\n",
" name=\"reconstruction_loss\")" " name=\"reconstruction_loss\")"
] ]
}, },
@ -1678,10 +1584,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 68, "execution_count": 69,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"alpha = 0.0005\n", "alpha = 0.0005\n",
@ -1712,10 +1616,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 69, "execution_count": 70,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"correct = tf.equal(y, y_pred, name=\"correct\")\n", "correct = tf.equal(y, y_pred, name=\"correct\")\n",
@ -1738,10 +1640,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 70, "execution_count": 71,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"optimizer = tf.train.AdamOptimizer()\n", "optimizer = tf.train.AdamOptimizer()\n",
@ -1764,10 +1664,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 71, "execution_count": 72,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"init = tf.global_variables_initializer()\n", "init = tf.global_variables_initializer()\n",
@ -1804,7 +1702,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 72, "execution_count": 73,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1870,7 +1768,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Training is finished, we reached over 99.3% accuracy on the validation set after just 5 epochs, things are looking good. Now let's evaluate the model on the test set." "Training is finished, we reached over 99.4% accuracy on the validation set after just 5 epochs, things are looking good. Now let's evaluate the model on the test set."
] ]
}, },
{ {
@ -1882,7 +1780,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 73, "execution_count": 74,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1915,7 +1813,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"We reach 99.43% accuracy on the test set. Pretty nice. :)" "We reach 99.53% accuracy on the test set. Pretty nice. :)"
] ]
}, },
{ {
@ -1934,7 +1832,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 74, "execution_count": 75,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1966,7 +1864,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 75, "execution_count": 76,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2022,7 +1920,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 76, "execution_count": 77,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2038,10 +1936,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 77, "execution_count": 78,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"def tweak_pose_parameters(output_vectors, min=-0.5, max=0.5, n_steps=11):\n", "def tweak_pose_parameters(output_vectors, min=-0.5, max=0.5, n_steps=11):\n",
@ -2062,10 +1958,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 78, "execution_count": 79,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"n_steps = 11\n", "n_steps = 11\n",
@ -2084,7 +1978,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 79, "execution_count": 80,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2108,10 +2002,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 80, "execution_count": 81,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"tweak_reconstructions = decoder_output_value.reshape(\n", "tweak_reconstructions = decoder_output_value.reshape(\n",
@ -2127,7 +2019,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 81, "execution_count": 82,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2161,9 +2053,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [] "source": []
} }