Use Keras load_data() to load MNIST, rather than the deprecated read_datasets()
parent
ad125bbdba
commit
e5fab6d604
|
@ -31,13 +31,12 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# To support both python 2 and python 3\n",
|
||||
"from __future__ import division, print_function, unicode_literals\n",
|
||||
"from io import open\n",
|
||||
"\n",
|
||||
"# Common imports\n",
|
||||
"import numpy as np\n",
|
||||
|
@ -79,9 +78,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def plot_image(image):\n",
|
||||
|
@ -103,9 +100,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tensorflow as tf"
|
||||
|
@ -121,9 +116,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.datasets import load_sample_image\n",
|
||||
|
@ -153,9 +146,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"reset_graph()\n",
|
||||
|
@ -168,9 +159,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with tf.Session() as sess:\n",
|
||||
|
@ -270,9 +259,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"reset_graph()\n",
|
||||
|
@ -285,9 +272,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"init = tf.global_variables_initializer()\n",
|
||||
|
@ -359,9 +344,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"batch_size, height, width, channels = dataset.shape\n",
|
||||
|
@ -435,9 +418,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"height = 28\n",
|
||||
|
@ -502,13 +483,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note: if you are using Python 3.6 on OSX, you need to run the following command on terminal to install the certifi package of certificates because Python 3.6 on OSX has no certificates to validate SSL connections (see this [StackOverflow question](https://stackoverflow.com/questions/27835619/urllib-and-ssl-certificate-verify-failed-error)):\n",
|
||||
"\n",
|
||||
" $ /Applications/Python\\ 3.6/Install\\ Certificates.command"
|
||||
"**Warning**: `tf.examples.tutorials.mnist` is deprecated. We will use `tf.keras.datasets.mnist` instead."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -517,8 +494,13 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tensorflow.examples.tutorials.mnist import input_data\n",
|
||||
"mnist = input_data.read_data_sets(\"/tmp/data/\")"
|
||||
"(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
|
||||
"X_train = X_train.astype(np.float32).reshape(-1, 28*28) / 255.0\n",
|
||||
"X_test = X_test.astype(np.float32).reshape(-1, 28*28) / 255.0\n",
|
||||
"y_train = y_train.astype(np.int32)\n",
|
||||
"y_test = y_test.astype(np.int32)\n",
|
||||
"X_valid, X_train = X_train[:5000], X_train[5000:]\n",
|
||||
"y_valid, y_train = y_train[:5000], y_train[5000:]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -526,6 +508,20 @@
|
|||
"execution_count": 23,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def shuffle_batch(X, y, batch_size):\n",
|
||||
" rnd_idx = np.random.permutation(len(X))\n",
|
||||
" n_batches = len(X) // batch_size\n",
|
||||
" for batch_idx in np.array_split(rnd_idx, n_batches):\n",
|
||||
" X_batch, y_batch = X[batch_idx], y[batch_idx]\n",
|
||||
" yield X_batch, y_batch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"n_epochs = 10\n",
|
||||
"batch_size = 100\n",
|
||||
|
@ -533,12 +529,11 @@
|
|||
"with tf.Session() as sess:\n",
|
||||
" init.run()\n",
|
||||
" for epoch in range(n_epochs):\n",
|
||||
" for iteration in range(mnist.train.num_examples // batch_size):\n",
|
||||
" X_batch, y_batch = mnist.train.next_batch(batch_size)\n",
|
||||
" for X_batch, y_batch in shuffle_batch(X_train, y_train, batch_size):\n",
|
||||
" sess.run(training_op, feed_dict={X: X_batch, y: y_batch})\n",
|
||||
" acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})\n",
|
||||
" acc_test = accuracy.eval(feed_dict={X: mnist.test.images, y: mnist.test.labels})\n",
|
||||
" print(epoch, \"Train accuracy:\", acc_train, \"Test accuracy:\", acc_test)\n",
|
||||
" acc_batch = accuracy.eval(feed_dict={X: X_batch, y: y_batch})\n",
|
||||
" acc_test = accuracy.eval(feed_dict={X: X_test, y: y_test})\n",
|
||||
" print(epoch, \"Last batch accuracy:\", acc_batch, \"Test accuracy:\", acc_test)\n",
|
||||
"\n",
|
||||
" save_path = saver.save(sess, \"./my_mnist_model\")"
|
||||
]
|
||||
|
@ -583,10 +578,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tensorflow as tf\n",
|
||||
|
@ -657,23 +650,6 @@
|
|||
" saver = tf.train.Saver()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's load the data:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tensorflow.examples.tutorials.mnist import input_data\n",
|
||||
"mnist = input_data.read_data_sets(\"/tmp/data/\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -684,9 +660,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_model_params():\n",
|
||||
|
@ -715,12 +689,13 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"n_epochs = 1000\n",
|
||||
"batch_size = 50\n",
|
||||
"iteration = 0\n",
|
||||
"\n",
|
||||
"best_loss_val = np.infty\n",
|
||||
"check_interval = 500\n",
|
||||
|
@ -731,31 +706,28 @@
|
|||
"with tf.Session() as sess:\n",
|
||||
" init.run()\n",
|
||||
" for epoch in range(n_epochs):\n",
|
||||
" for iteration in range(mnist.train.num_examples // batch_size):\n",
|
||||
" X_batch, y_batch = mnist.train.next_batch(batch_size)\n",
|
||||
" for X_batch, y_batch in shuffle_batch(X_train, y_train, batch_size):\n",
|
||||
" iteration += 1\n",
|
||||
" sess.run(training_op, feed_dict={X: X_batch, y: y_batch, training: True})\n",
|
||||
" if iteration % check_interval == 0:\n",
|
||||
" loss_val = loss.eval(feed_dict={X: mnist.validation.images,\n",
|
||||
" y: mnist.validation.labels})\n",
|
||||
" loss_val = loss.eval(feed_dict={X: X_valid, y: y_valid})\n",
|
||||
" if loss_val < best_loss_val:\n",
|
||||
" best_loss_val = loss_val\n",
|
||||
" checks_since_last_progress = 0\n",
|
||||
" best_model_params = get_model_params()\n",
|
||||
" else:\n",
|
||||
" checks_since_last_progress += 1\n",
|
||||
" acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})\n",
|
||||
" acc_val = accuracy.eval(feed_dict={X: mnist.validation.images,\n",
|
||||
" y: mnist.validation.labels})\n",
|
||||
" print(\"Epoch {}, train accuracy: {:.4f}%, valid. accuracy: {:.4f}%, valid. best loss: {:.6f}\".format(\n",
|
||||
" epoch, acc_train * 100, acc_val * 100, best_loss_val))\n",
|
||||
" acc_batch = accuracy.eval(feed_dict={X: X_batch, y: y_batch})\n",
|
||||
" acc_val = accuracy.eval(feed_dict={X: X_valid, y: y_valid})\n",
|
||||
" print(\"Epoch {}, last batch accuracy: {:.4f}%, valid. accuracy: {:.4f}%, valid. best loss: {:.6f}\".format(\n",
|
||||
" epoch, acc_batch * 100, acc_val * 100, best_loss_val))\n",
|
||||
" if checks_since_last_progress > max_checks_without_progress:\n",
|
||||
" print(\"Early stopping!\")\n",
|
||||
" break\n",
|
||||
"\n",
|
||||
" if best_model_params:\n",
|
||||
" restore_model_params(best_model_params)\n",
|
||||
" acc_test = accuracy.eval(feed_dict={X: mnist.test.images,\n",
|
||||
" y: mnist.test.labels})\n",
|
||||
" acc_test = accuracy.eval(feed_dict={X: X_test, y: y_test})\n",
|
||||
" print(\"Final accuracy on test set:\", acc_test)\n",
|
||||
" save_path = saver.save(sess, \"./my_mnist_model\")"
|
||||
]
|
||||
|
@ -774,10 +746,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"width = 299\n",
|
||||
|
@ -787,7 +757,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -807,10 +777,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_image = 2 * test_image - 1"
|
||||
|
@ -826,10 +794,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
|
@ -860,10 +826,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fetch_pretrained_inception_v3()"
|
||||
|
@ -871,10 +835,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 35,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import re\n",
|
||||
|
@ -882,17 +844,16 @@
|
|||
"CLASS_NAME_REGEX = re.compile(r\"^n\\d+\\s+(.*)\\s*$\", re.M | re.U)\n",
|
||||
"\n",
|
||||
"def load_class_names():\n",
|
||||
" with open(os.path.join(\"datasets\", \"inception\", \"imagenet_class_names.txt\"), \"rb\") as f:\n",
|
||||
" content = f.read().decode(\"utf-8\")\n",
|
||||
" path = os.path.join(\"datasets\", \"inception\", \"imagenet_class_names.txt\")\n",
|
||||
" with open(path, encoding=\"utf-8\") as f:\n",
|
||||
" content = f.read()\n",
|
||||
" return CLASS_NAME_REGEX.findall(content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 36,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class_names = [\"background\"] + load_class_names()"
|
||||
|
@ -900,7 +861,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 37,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -917,10 +878,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 38,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tensorflow.contrib.slim.nets import inception\n",
|
||||
|
@ -946,7 +905,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 39,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -965,7 +924,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"execution_count": 40,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -978,7 +937,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"execution_count": 41,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -988,7 +947,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"execution_count": 42,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -997,7 +956,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"execution_count": 43,
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
|
@ -1035,10 +994,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 44,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
|
@ -1062,10 +1019,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fetch_flowers()"
|
||||
|
@ -1080,7 +1035,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"execution_count": 46,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1099,10 +1054,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 47,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from collections import defaultdict\n",
|
||||
|
@ -1125,10 +1078,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 48,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for paths in image_paths.values():\n",
|
||||
|
@ -1144,7 +1095,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
"execution_count": 49,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1195,10 +1146,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 50,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from scipy.misc import imresize\n",
|
||||
|
@ -1258,7 +1207,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"execution_count": 51,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1278,7 +1227,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"execution_count": 52,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1300,7 +1249,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 51,
|
||||
"execution_count": 53,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1333,10 +1282,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 52,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 54,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def prepare_image_with_tensorflow(image, target_width = 299, target_height = 299, max_zoom = 0.2):\n",
|
||||
|
@ -1389,7 +1336,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 53,
|
||||
"execution_count": 55,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1432,10 +1379,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 54,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 56,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tensorflow.contrib.slim.nets import inception\n",
|
||||
|
@ -1460,7 +1405,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 55,
|
||||
"execution_count": 57,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1476,7 +1421,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 56,
|
||||
"execution_count": 58,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1492,7 +1437,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 57,
|
||||
"execution_count": 59,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1508,7 +1453,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 58,
|
||||
"execution_count": 60,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1524,7 +1469,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 59,
|
||||
"execution_count": 61,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1540,10 +1485,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 60,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 62,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"prelogits = tf.squeeze(end_points[\"PreLogits\"], axis=[1, 2])"
|
||||
|
@ -1558,10 +1501,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 61,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 63,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"n_outputs = len(flower_classes)\n",
|
||||
|
@ -1588,10 +1529,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 62,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 64,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y = tf.placeholder(tf.int32, shape=[None])\n",
|
||||
|
@ -1614,7 +1553,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 63,
|
||||
"execution_count": 65,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1645,7 +1584,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 64,
|
||||
"execution_count": 66,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1662,10 +1601,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 65,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 67,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"flower_paths_and_classes = []\n",
|
||||
|
@ -1683,10 +1620,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 66,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 68,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_ratio = 0.2\n",
|
||||
|
@ -1707,7 +1642,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 67,
|
||||
"execution_count": 69,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1723,10 +1658,8 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 68,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"execution_count": 70,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from random import sample\n",
|
||||
|
@ -1740,33 +1673,13 @@
|
|||
" return X_batch, y_batch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 69,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_batch, y_batch = prepare_batch(flower_paths_and_classes_train, batch_size=4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 70,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_batch.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 71,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_batch.dtype"
|
||||
"X_batch, y_batch = prepare_batch(flower_paths_and_classes_train, batch_size=4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1775,7 +1688,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y_batch.shape"
|
||||
"X_batch.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1783,6 +1696,24 @@
|
|||
"execution_count": 73,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_batch.dtype"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 74,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y_batch.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 75,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y_batch.dtype"
|
||||
]
|
||||
|
@ -1794,46 +1725,10 @@
|
|||
"Looking good. Now let's use this function to prepare the test set:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 74,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_test, y_test = prepare_batch(flower_paths_and_classes_test, batch_size=len(flower_paths_and_classes_test))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 75,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_test.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We could prepare the training set in much the same way, but it would only generate one variant for each image. Instead, it's preferable to generate the training batches on the fly during training, so that we can really benefit from data augmentation, with many variants of each image."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"And now, we are ready to train the network (or more precisely, the output layer we just added, since all the other layers are frozen). Be aware that this may take a (very) long time."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 76,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_test, y_test = prepare_batch(flower_paths_and_classes_test, batch_size=len(flower_paths_and_classes_test))"
|
||||
|
@ -1867,6 +1762,38 @@
|
|||
"execution_count": 78,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_test, y_test = prepare_batch(flower_paths_and_classes_test, batch_size=len(flower_paths_and_classes_test))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 79,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_test.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We could prepare the training set in much the same way, but it would only generate one variant for each image. Instead, it's preferable to generate the training batches on the fly during training, so that we can really benefit from data augmentation, with many variants of each image."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"And now, we are ready to train the network (or more precisely, the output layer we just added, since all the other layers are frozen). Be aware that this may take a (very) long time."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 80,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"n_epochs = 10\n",
|
||||
"batch_size = 40\n",
|
||||
|
@ -1883,15 +1810,15 @@
|
|||
" X_batch, y_batch = prepare_batch(flower_paths_and_classes_train, batch_size)\n",
|
||||
" sess.run(training_op, feed_dict={X: X_batch, y: y_batch, training: True})\n",
|
||||
"\n",
|
||||
" acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})\n",
|
||||
" print(\" Train accuracy:\", acc_train)\n",
|
||||
" acc_batch = accuracy.eval(feed_dict={X: X_batch, y: y_batch})\n",
|
||||
" print(\" Last batch accuracy:\", acc_batch)\n",
|
||||
"\n",
|
||||
" save_path = saver.save(sess, \"./my_flowers_model\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 79,
|
||||
"execution_count": 81,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -1913,7 +1840,7 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Okay, 70.58% accuracy is not great (in fact, it's really bad), but this is only after 10 epochs, and freezing all layers except for the output layer. If you have a GPU, you can try again and let training run for much longer (e.g., using early stopping to decide when to stop). You can also improve the image preprocessing function to make more tweaks to the image (e.g., changing the brightness and hue, rotate the image slightly). You can reach above 95% accuracy on this task. If you want to dig deeper, this [great blog post](https://kwotsin.github.io/tech/2017/02/11/transfer-learning.html) goes into more details and reaches 96% accuracy."
|
||||
"Okay, 67.84% accuracy is not great (in fact, it's really bad), but this is only after 10 epochs, and freezing all layers except for the output layer. If you have a GPU, you can try again and let training run for much longer (e.g., using early stopping to decide when to stop). You can also improve the image preprocessing function to make more tweaks to the image (e.g., changing the brightness and hue, rotate the image slightly). You can reach above 95% accuracy on this task. If you want to dig deeper, this [great blog post](https://kwotsin.github.io/tech/2017/02/11/transfer-learning.html) goes into more details and reaches 96% accuracy."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1936,9 +1863,7 @@
|
|||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
|
@ -1959,7 +1884,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.2"
|
||||
"version": "3.6.6"
|
||||
},
|
||||
"nav_menu": {},
|
||||
"toc": {
|
||||
|
|
Loading…
Reference in New Issue