handson-ml/13_convolutional_neural_net...

856 lines
24 KiB
Plaintext
Raw Normal View History

2016-09-27 23:31:21 +02:00
{
"cells": [
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"**Chapter 13 Convolutional Neural Networks**"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"_This notebook contains all the sample code and solutions to the exercices in chapter 13._"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"# To support both python 2 and python 3\n",
"from __future__ import division, print_function, unicode_literals\n",
"\n",
"# Common imports\n",
"import numpy as np\n",
"import numpy.random as rnd\n",
"import os\n",
"\n",
"# to make this notebook's output stable across runs\n",
"rnd.seed(42)\n",
"\n",
"# To plot pretty figures\n",
"%matplotlib inline\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"plt.rcParams['axes.labelsize'] = 14\n",
"plt.rcParams['xtick.labelsize'] = 12\n",
"plt.rcParams['ytick.labelsize'] = 12\n",
"\n",
"# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n",
"CHAPTER_ID = \"cnn\"\n",
"\n",
"def save_fig(fig_id, tight_layout=True):\n",
" path = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id + \".png\")\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format='png', dpi=300)"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"A couple utility functions to plot grayscale and RGB images:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"def plot_image(image):\n",
" plt.imshow(image, cmap=\"gray\", interpolation=\"nearest\")\n",
" plt.axis(\"off\")\n",
"\n",
"def plot_color_image(image):\n",
" plt.imshow(image.astype(np.uint8),interpolation=\"nearest\")\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"And of course we will need TensorFlow:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"import tensorflow as tf"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Convolutional layer"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
2016-09-28 11:30:56 +02:00
"from sklearn.datasets import load_sample_image\n",
"china = load_sample_image(\"china.jpg\")\n",
"flower = load_sample_image(\"flower.jpg\")\n",
2016-09-27 23:31:21 +02:00
"image = china[150:220, 130:250]\n",
"height, width, channels = image.shape\n",
"image_grayscale = image.mean(axis=2).astype(np.float32)\n",
"images = image_grayscale.reshape(1, height, width, 1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"fmap = np.zeros(shape=(7, 7, 1, 2), dtype=np.float32)\n",
"fmap[:, 3, 0, 0] = 1\n",
"fmap[3, :, 0, 1] = 1\n",
"fmap[:, :, 0, 0]\n",
"plot_image(fmap[:, :, 0, 0])\n",
"plt.show()\n",
"plot_image(fmap[:, :, 0, 1])\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"tf.reset_default_graph()\n",
"\n",
"X = tf.placeholder(tf.float32, shape=(None, height, width, 1))\n",
"feature_maps = tf.constant(fmap)\n",
"convolution = tf.nn.conv2d(X, feature_maps, strides=[1,1,1,1], padding=\"SAME\", use_cudnn_on_gpu=False)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"with tf.Session() as sess:\n",
" output = convolution.eval(feed_dict={X: images})"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"plot_image(images[0, :, :, 0])\n",
"save_fig(\"china_original\", tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"plot_image(output[0, :, :, 0])\n",
"save_fig(\"china_vertical\", tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"plot_image(output[0, :, :, 1])\n",
"save_fig(\"china_horizontal\", tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"## Simple example"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
2016-09-28 11:30:56 +02:00
"dataset = np.array([china, flower], dtype=np.float32)\n",
2016-09-27 23:31:21 +02:00
"batch_size, height, width, channels = dataset.shape\n",
"\n",
"filters = np.zeros(shape=(7, 7, channels, 2), dtype=np.float32)\n",
"filters[:, 3, :, 0] = 1 # vertical line\n",
"filters[3, :, :, 1] = 1 # horizontal line\n",
"\n",
"X = tf.placeholder(tf.float32, shape=(None, height, width, channels))\n",
"convolution = tf.nn.conv2d(X, filters, strides=[1,2,2,1], padding=\"SAME\")\n",
"\n",
"with tf.Session() as sess:\n",
" output = sess.run(convolution, feed_dict={X: dataset})\n",
"\n",
"for image_index in (0, 1):\n",
" for feature_map_index in (0, 1):\n",
" plot_image(output[image_index, :, :, feature_map_index])\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"## VALID vs SAME padding"
]
},
{
"cell_type": "code",
2016-09-28 12:37:31 +02:00
"execution_count": 12,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"tf.reset_default_graph()\n",
"\n",
"filter_primes = np.array([2., 3., 5., 7., 11., 13.], dtype=np.float32)\n",
"x = tf.constant(np.arange(1, 13+1, dtype=np.float32).reshape([1, 1, 13, 1]))\n",
"filters = tf.constant(filter_primes.reshape(1, 6, 1, 1))\n",
"\n",
"valid_conv = tf.nn.conv2d(x, filters, strides=[1, 1, 5, 1], padding='VALID')\n",
"same_conv = tf.nn.conv2d(x, filters, strides=[1, 1, 5, 1], padding='SAME')\n",
"\n",
"with tf.Session() as sess:\n",
" print(\"VALID:\\n\", valid_conv.eval())\n",
" print(\"SAME:\\n\", same_conv.eval())"
]
},
{
"cell_type": "code",
2016-09-28 12:37:31 +02:00
"execution_count": 13,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"print(\"VALID:\")\n",
"print(np.array([1,2,3,4,5,6]).T.dot(filter_primes))\n",
"print(np.array([6,7,8,9,10,11]).T.dot(filter_primes))\n",
"print(\"SAME:\")\n",
"print(np.array([0,1,2,3,4,5]).T.dot(filter_primes))\n",
"print(np.array([5,6,7,8,9,10]).T.dot(filter_primes))\n",
"print(np.array([10,11,12,13,0,0]).T.dot(filter_primes))\n"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Pooling layer"
]
},
{
"cell_type": "code",
2016-09-28 12:37:31 +02:00
"execution_count": 14,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"batch_size, height, width, channels = dataset.shape\n",
"\n",
"filters = np.zeros(shape=(7, 7, channels, 2), dtype=np.float32)\n",
"filters[:, 3, :, 0] = 1 # vertical line\n",
"filters[3, :, :, 1] = 1 # horizontal line\n",
"\n",
"X = tf.placeholder(tf.float32, shape=(None, height, width, channels))\n",
"max_pool = tf.nn.max_pool(X, ksize=[1, 2, 2, 1], strides=[1,2,2,1], padding=\"VALID\")\n",
"\n",
"with tf.Session() as sess:\n",
" output = sess.run(max_pool, feed_dict={X: dataset})\n",
"\n",
"plot_color_image(dataset[0])\n",
"save_fig(\"china_original\")\n",
"plt.show()\n",
" \n",
"plot_color_image(output[0])\n",
"save_fig(\"china_max_pool\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# MNIST"
]
},
{
"cell_type": "markdown",
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"Note: instead of using the `fully_connected()`, `conv2d()` and `dropout()` functions from the `tensorflow.contrib.layers` module (as in the book), we now use the `dense()`, `conv2d()` and `dropout()` functions (respectively) from the `tf.layers` module, which did not exist when this chapter was written. This is preferable because anything in contrib may change or be deleted without notice, while `tf.layers` is part of the official API. As you will see, the code is mostly the same.\n",
"\n",
"For all these functions:\n",
"* the `scope` parameter was renamed to `name`, and the `_fn` suffix was removed in all the parameters that had it (for example the `activation_fn` parameter was renamed to `activation`).\n",
"\n",
"The other main differences in `tf.layers.dense()` are:\n",
"* the `weights` parameter was renamed to `kernel` (and the weights variable is now named `\"kernel\"` rather than `\"weights\"`),\n",
"* the default activation is `None` instead of `tf.nn.relu`\n",
"\n",
"The other main differences in `tf.layers.conv2d()` are:\n",
"* the `num_outputs` parameter was renamed to `filters`,\n",
"* the `stride` parameter was renamed to `strides`,\n",
"* the default `activation` is now `None` instead of `tf.nn.relu`.\n",
"\n",
"The other main differences in `tf.layers.dropout()` are:\n",
"* it takes the dropout rate (`rate`) rather than the keep probability (`keep_prob`). Of course, `rate == 1 - keep_prob`,\n",
"* the `is_training` parameters was renamed to `training`."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"height = 28\n",
"width = 28\n",
"channels = 1\n",
"n_inputs = height * width\n",
"\n",
"conv1_fmaps = 32\n",
"conv1_ksize = 3\n",
"conv1_stride = 1\n",
"conv1_pad = \"SAME\"\n",
"\n",
"conv2_fmaps = 64\n",
"conv2_ksize = 3\n",
"conv2_stride = 1\n",
"conv2_pad = \"SAME\"\n",
"conv2_dropout_rate = 0.25\n",
"\n",
"pool3_fmaps = conv2_fmaps\n",
"\n",
"n_fc1 = 128\n",
"fc1_dropout_rate = 0.5\n",
"\n",
"n_outputs = 10\n",
"\n",
"graph = tf.Graph()\n",
"with graph.as_default():\n",
" with tf.name_scope(\"inputs\"):\n",
" X = tf.placeholder(tf.float32, shape=[None, n_inputs], name=\"X\")\n",
" X_reshaped = tf.reshape(X, shape=[-1, height, width, channels])\n",
" y = tf.placeholder(tf.int32, shape=[None], name=\"y\")\n",
" is_training = tf.placeholder_with_default(False, shape=[], name='is_training')\n",
"\n",
" conv1 = tf.layers.conv2d(X_reshaped, filters=conv1_fmaps, kernel_size=conv1_ksize, strides=conv1_stride, padding=conv1_pad, activation=tf.nn.relu, name=\"conv1\")\n",
" conv2 = tf.layers.conv2d(conv1, filters=conv2_fmaps, kernel_size=conv2_ksize, strides=conv2_stride, padding=conv2_pad, activation=tf.nn.relu, name=\"conv2\")\n",
"\n",
" with tf.name_scope(\"pool3\"):\n",
" pool3 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding=\"VALID\")\n",
" pool3_flat = tf.reshape(pool3, shape=[-1, pool3_fmaps * 14 * 14])\n",
" pool3_flat_drop = tf.layers.dropout(pool3_flat, conv2_dropout_rate, training=is_training)\n",
"\n",
" with tf.name_scope(\"fc1\"):\n",
" fc1 = tf.layers.dense(pool3_flat_drop, n_fc1, activation=tf.nn.relu, name=\"fc1\")\n",
" fc1_drop = tf.layers.dropout(fc1, fc1_dropout_rate, training=is_training)\n",
"\n",
" with tf.name_scope(\"output\"):\n",
" logits = tf.layers.dense(fc1, n_outputs, name=\"output\")\n",
" Y_proba = tf.nn.softmax(logits, name=\"Y_proba\")\n",
"\n",
" with tf.name_scope(\"train\"):\n",
" xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y)\n",
" loss = tf.reduce_mean(xentropy)\n",
" optimizer = tf.train.AdamOptimizer()\n",
" training_op = optimizer.minimize(loss)\n",
"\n",
" with tf.name_scope(\"eval\"):\n",
" correct = tf.nn.in_top_k(logits, y, 1)\n",
" accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))\n",
"\n",
" with tf.name_scope(\"init_and_save\"):\n",
" init = tf.global_variables_initializer()\n",
" saver = tf.train.Saver()"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"from tensorflow.examples.tutorials.mnist import input_data\n",
"mnist = input_data.read_data_sets(\"/tmp/data/\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"def get_model_params():\n",
" gvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)\n",
" return {gvar.op.name: value for gvar, value in zip(gvars, tf.get_default_session().run(gvars))}\n",
"\n",
"def restore_model_params(model_params):\n",
" gvar_names = list(model_params.keys())\n",
" assign_ops = {gvar_name: tf.get_default_graph().get_operation_by_name(gvar_name + \"/Assign\")\n",
" for gvar_name in gvar_names}\n",
" init_values = {gvar_name: assign_op.inputs[1] for gvar_name, assign_op in assign_ops.items()}\n",
" feed_dict = {init_values[gvar_name]: model_params[gvar_name] for gvar_name in gvar_names}\n",
" tf.get_default_session().run(assign_ops, feed_dict=feed_dict)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 18,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"n_epochs = 1000\n",
"batch_size = 50\n",
"\n",
"best_acc_val = 0\n",
"check_interval = 100\n",
"checks_since_last_progress = 0\n",
"max_checks_without_progress = 100\n",
"best_model_params = None \n",
"\n",
"with tf.Session(graph=graph) as sess:\n",
" init.run()\n",
" for epoch in range(n_epochs):\n",
" for iteration in range(mnist.train.num_examples // batch_size):\n",
" X_batch, y_batch = mnist.train.next_batch(batch_size)\n",
" sess.run(training_op, feed_dict={X: X_batch, y: y_batch, is_training: True})\n",
" if iteration % check_interval == 0:\n",
" acc_val = accuracy.eval(feed_dict={X: mnist.test.images[:2000], y: mnist.test.labels[:2000]})\n",
" if acc_val > best_acc_val:\n",
" best_acc_val = acc_val\n",
" checks_since_last_progress = 0\n",
" best_model_params = get_model_params()\n",
" else:\n",
" checks_since_last_progress += 1\n",
" acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})\n",
" acc_test = accuracy.eval(feed_dict={X: mnist.test.images[2000:], y: mnist.test.labels[2000:]})\n",
" print(epoch, \"Train accuracy:\", acc_train, \"Test accuracy:\", acc_test, \"Best validation accuracy:\", best_acc_val)\n",
" if checks_since_last_progress > max_checks_without_progress:\n",
" print(\"Early stopping!\")\n",
" break\n",
"\n",
" if best_model_params:\n",
" restore_model_params(best_model_params)\n",
" acc_test = accuracy.eval(feed_dict={X: mnist.test.images[2000:], y: mnist.test.labels[2000:]})\n",
" print(\"Final accuracy on test set:\", acc_test)\n",
" save_path = saver.save(sess, \"./my_mnist_model\")"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"# Inception v3"
]
},
{
"cell_type": "code",
"execution_count": 21,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"import sys\n",
"import tarfile\n",
"from six.moves import urllib\n",
2016-09-27 23:31:21 +02:00
"\n",
"TF_MODELS_URL = \"http://download.tensorflow.org/models\"\n",
"INCEPTION_V3_URL = TF_MODELS_URL + \"/inception_v3_2016_08_28.tar.gz\"\n",
"INCEPTION_PATH = os.path.join(\"datasets\", \"inception\")\n",
"INCEPTION_V3_CHECKPOINT_PATH = os.path.join(INCEPTION_PATH, \"inception_v3.ckpt\")\n",
"\n",
"def download_progress(count, block_size, total_size):\n",
" percent = count * block_size * 100 // total_size\n",
" sys.stdout.write(\"\\rDownloading: {}%\".format(percent))\n",
" sys.stdout.flush()\n",
"\n",
"def fetch_pretrained_inception_v3(url=INCEPTION_V3_URL, path=INCEPTION_PATH):\n",
" if os.path.exists(INCEPTION_V3_CHECKPOINT_PATH):\n",
" return\n",
" os.makedirs(path, exist_ok=True)\n",
" tgz_path = os.path.join(path, \"inception_v3.tgz\")\n",
" urllib.request.urlretrieve(url, tgz_path, reporthook=download_progress)\n",
" inception_tgz = tarfile.open(tgz_path)\n",
" inception_tgz.extractall(path=path)\n",
" inception_tgz.close()\n",
" os.remove(tgz_path)"
]
},
{
"cell_type": "code",
"execution_count": 22,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"fetch_pretrained_inception_v3()"
]
},
{
"cell_type": "code",
"execution_count": 23,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"import re\n",
"\n",
"CLASS_NAME_REGEX = re.compile(r\"^n\\d+\\s+(.*)\\s*$\", re.M | re.U)\n",
"\n",
"def load_class_names():\n",
" with open(os.path.join(\"datasets\",\"inception\",\"imagenet_class_names.txt\"), \"rb\") as f:\n",
" content = f.read().decode(\"utf-8\")\n",
" return CLASS_NAME_REGEX.findall(content)"
]
},
{
"cell_type": "code",
"execution_count": 24,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"class_names = load_class_names()"
]
},
{
"cell_type": "code",
"execution_count": 25,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"width = 299\n",
"height = 299\n",
"channels = 3"
]
},
{
"cell_type": "code",
"execution_count": 26,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"import matplotlib.image as mpimg\n",
"test_image = mpimg.imread(os.path.join(\"images\",\"cnn\",\"test_image.png\"))[:, :, :channels]\n",
"plt.imshow(test_image)\n",
"plt.axis(\"off\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 27,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
2017-02-17 11:51:26 +01:00
"from tensorflow.contrib.slim.nets import inception\n",
2016-09-27 23:31:21 +02:00
"import tensorflow.contrib.slim as slim\n",
"\n",
"tf.reset_default_graph()\n",
"\n",
"X = tf.placeholder(tf.float32, shape=[None, height, width, channels], name=\"X\")\n",
2017-02-17 11:51:26 +01:00
"with slim.arg_scope(inception.inception_v3_arg_scope()):\n",
" logits, end_points = inception.inception_v3(X, num_classes=1001, is_training=False)\n",
2016-09-27 23:31:21 +02:00
"predictions = end_points[\"Predictions\"]\n",
"saver = tf.train.Saver()"
]
},
{
"cell_type": "code",
"execution_count": 28,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"X_test = test_image.reshape(-1, height, width, channels)\n",
"\n",
"with tf.Session() as sess:\n",
" saver.restore(sess, INCEPTION_V3_CHECKPOINT_PATH)\n",
" predictions_val = predictions.eval(feed_dict={X: X_test})"
]
},
{
"cell_type": "code",
"execution_count": 29,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"class_names[np.argmax(predictions_val[0])]"
]
},
{
"cell_type": "code",
"execution_count": 30,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"np.argmax(predictions_val, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 31,
2016-09-27 23:31:21 +02:00
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": false,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": [
"top_5 = np.argpartition(predictions_val[0], -5)[-5:]\n",
"top_5 = top_5[np.argsort(predictions_val[0][top_5])]\n",
"for i in top_5:\n",
" print(\"{0}: {1:.2f}%\".format(class_names[i], 100*predictions_val[0][i]))"
]
},
{
"cell_type": "markdown",
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
2017-02-17 11:51:26 +01:00
"metadata": {
"deletable": true,
"editable": true
},
2016-09-27 23:31:21 +02:00
"source": [
"**Coming soon**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
2017-02-17 11:51:26 +01:00
"collapsed": true,
"deletable": true,
"editable": true
2016-09-27 23:31:21 +02:00
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.3"
2016-09-27 23:31:21 +02:00
},
"nav_menu": {},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 0
}