diff --git a/01_the_machine_learning_landscape.ipynb b/01_the_machine_learning_landscape.ipynb index 6a080af..8b99fce 100644 --- a/01_the_machine_learning_landscape.ipynb +++ b/01_the_machine_learning_landscape.ipynb @@ -2,10 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "deletable": true, - "editable": true - }, + "metadata": {}, "source": [ "**Chapter 1 – The Machine Learning landscape**\n", "\n", @@ -14,20 +11,14 @@ }, { "cell_type": "markdown", - "metadata": { - "deletable": true, - "editable": true - }, + "metadata": {}, "source": [ "# Setup" ] }, { "cell_type": "markdown", - "metadata": { - "deletable": true, - "editable": true - }, + "metadata": {}, "source": [ "First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:" ] @@ -36,9 +27,6 @@ "cell_type": "code", "execution_count": 1, "metadata": { - "collapsed": false, - "deletable": true, - "editable": true, "slideshow": { "slide_type": "-" } @@ -50,11 +38,10 @@ "\n", "# Common imports\n", "import numpy as np\n", - "import numpy.random as rnd\n", "import os\n", "\n", "# to make this notebook's output stable across runs\n", - "rnd.seed(42)\n", + "np.random.seed(42)\n", "\n", "# To plot pretty figures\n", "%matplotlib inline\n", @@ -73,35 +60,173 @@ " print(\"Saving figure\", fig_id)\n", " if tight_layout:\n", " plt.tight_layout()\n", - " plt.savefig(path, format='png', dpi=300)" + " plt.savefig(path, format='png', dpi=300)\n", + "\n", + "# Ignore useless warnings (see SciPy issue #5998)\n", + "import warnings\n", + "warnings.filterwarnings(action=\"ignore\", module=\"scipy\", message=\"^internal gelsd\")" ] }, { "cell_type": "markdown", - "metadata": { - "deletable": true, - "editable": true - }, + "metadata": {}, "source": [ - "# Load and prepare Life satisfaction data" + "# Code example 1-1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This function just merges the OECD's life satisfaction data and the IMF's GDP per capita data. It's a bit too long and boring and it's not specific to Machine Learning, which is why I left it out of the book." ] }, { "cell_type": "code", "execution_count": 2, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "metadata": {}, "outputs": [], "source": [ + "def prepare_country_stats(oecd_bli, gdp_per_capita):\n", + " oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n", + " oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n", + " gdp_per_capita.rename(columns={\"2015\": \"GDP per capita\"}, inplace=True)\n", + " gdp_per_capita.set_index(\"Country\", inplace=True)\n", + " full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n", + " left_index=True, right_index=True)\n", + " full_country_stats.sort_values(by=\"GDP per capita\", inplace=True)\n", + " remove_indices = [0, 1, 6, 8, 33, 34, 35]\n", + " keep_indices = list(set(range(36)) - set(remove_indices))\n", + " return full_country_stats[[\"GDP per capita\", 'Life satisfaction']].iloc[keep_indices]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The code in the book expects the data files to be located in the current directory. I just tweaked it here to fetch the files in datasets/lifesat." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "datapath = os.path.join(\"datasets\", \"lifesat\", \"\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Code example\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", "import pandas as pd\n", + "import sklearn.linear_model\n", "\n", - "# Download CSV from http://stats.oecd.org/index.aspx?DataSetCode=BLI\n", - "datapath = \"datasets/lifesat/\"\n", + "# Load the data\n", + "oecd_bli = pd.read_csv(datapath + \"oecd_bli_2015.csv\", thousands=',')\n", + "gdp_per_capita = pd.read_csv(datapath + \"gdp_per_capita.csv\",thousands=',',delimiter='\\t',\n", + " encoding='latin1', na_values=\"n/a\")\n", "\n", - "oecd_bli = pd.read_csv(datapath+\"oecd_bli_2015.csv\", thousands=',')\n", + "# Prepare the data\n", + "country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)\n", + "X = np.c_[country_stats[\"GDP per capita\"]]\n", + "y = np.c_[country_stats[\"Life satisfaction\"]]\n", + "\n", + "# Visualize the data\n", + "country_stats.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction')\n", + "plt.show()\n", + "\n", + "# Select a linear model\n", + "model = sklearn.linear_model.LinearRegression()\n", + "\n", + "# Train the model\n", + "model.fit(X, y)\n", + "\n", + "# Make a prediction for Cyprus\n", + "X_new = [[22587]] # Cyprus' GDP per capita\n", + "print(model.predict(X_new)) # outputs [[ 5.96242338]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Note: you can ignore the rest of this notebook, it just generates many of the figures in chapter 1." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load and prepare Life satisfaction data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you want, you can get fresh data from the OECD's website.\n", + "Download the CSV from http://stats.oecd.org/index.aspx?DataSetCode=BLI\n", + "and save it to `datasets/lifesat/`." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "oecd_bli = pd.read_csv(datapath + \"oecd_bli_2015.csv\", thousands=',')\n", "oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n", "oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n", "oecd_bli.head(2)" @@ -109,12 +234,8 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 6, + "metadata": {}, "outputs": [], "source": [ "oecd_bli[\"Life satisfaction\"].head()" @@ -122,25 +243,24 @@ }, { "cell_type": "markdown", - "metadata": { - "deletable": true, - "editable": true - }, + "metadata": {}, "source": [ "# Load and prepare GDP per capita data" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Just like above, you can update the GDP per capita data if you want. Just download data from http://goo.gl/j1MSKe (=> imf.org) and save it to `datasets/lifesat/`." + ] + }, { "cell_type": "code", - "execution_count": 4, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 7, + "metadata": {}, "outputs": [], "source": [ - "# Download data from http://goo.gl/j1MSKe (=> imf.org)\n", "gdp_per_capita = pd.read_csv(datapath+\"gdp_per_capita.csv\", thousands=',', delimiter='\\t',\n", " encoding='latin1', na_values=\"n/a\")\n", "gdp_per_capita.rename(columns={\"2015\": \"GDP per capita\"}, inplace=True)\n", @@ -150,12 +270,8 @@ }, { "cell_type": "code", - "execution_count": 5, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 8, + "metadata": {}, "outputs": [], "source": [ "full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita, left_index=True, right_index=True)\n", @@ -165,12 +281,8 @@ }, { "cell_type": "code", - "execution_count": 6, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 9, + "metadata": {}, "outputs": [], "source": [ "full_country_stats[[\"GDP per capita\", 'Life satisfaction']].loc[\"United States\"]" @@ -178,12 +290,8 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 10, + "metadata": {}, "outputs": [], "source": [ "remove_indices = [0, 1, 6, 8, 33, 34, 35]\n", @@ -195,12 +303,8 @@ }, { "cell_type": "code", - "execution_count": 8, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 11, + "metadata": {}, "outputs": [], "source": [ "sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n", @@ -224,25 +328,17 @@ }, { "cell_type": "code", - "execution_count": 9, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 12, + "metadata": {}, "outputs": [], "source": [ - "sample_data.to_csv(\"life_satisfaction_vs_gdp_per_capita.csv\")" + "sample_data.to_csv(os.path.join(\"datasets\", \"lifesat\", \"lifesat.csv\"))" ] }, { "cell_type": "code", - "execution_count": 10, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 13, + "metadata": {}, "outputs": [], "source": [ "sample_data.loc[list(position_text.keys())]" @@ -250,12 +346,8 @@ }, { "cell_type": "code", - "execution_count": 11, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 14, + "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", @@ -278,12 +370,8 @@ }, { "cell_type": "code", - "execution_count": 12, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 15, + "metadata": {}, "outputs": [], "source": [ "from sklearn import linear_model\n", @@ -297,12 +385,8 @@ }, { "cell_type": "code", - "execution_count": 13, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 16, + "metadata": {}, "outputs": [], "source": [ "sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n", @@ -317,12 +401,8 @@ }, { "cell_type": "code", - "execution_count": 14, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 17, + "metadata": {}, "outputs": [], "source": [ "cyprus_gdp_per_capita = gdp_per_capita.loc[\"Cyprus\"][\"GDP per capita\"]\n", @@ -333,12 +413,8 @@ }, { "cell_type": "code", - "execution_count": 15, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 18, + "metadata": {}, "outputs": [], "source": [ "sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3), s=1)\n", @@ -356,12 +432,8 @@ }, { "cell_type": "code", - "execution_count": 16, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 19, + "metadata": {}, "outputs": [], "source": [ "sample_data[7:10]" @@ -369,12 +441,8 @@ }, { "cell_type": "code", - "execution_count": 17, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 20, + "metadata": {}, "outputs": [], "source": [ "(5.1+5.7+6.5)/3" @@ -382,28 +450,29 @@ }, { "cell_type": "code", - "execution_count": 18, - "metadata": { - "collapsed": true, - "deletable": true, - "editable": true - }, + "execution_count": 21, + "metadata": {}, "outputs": [], "source": [ "backup = oecd_bli, gdp_per_capita\n", "\n", "def prepare_country_stats(oecd_bli, gdp_per_capita):\n", - " return sample_data" + " oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n", + " oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n", + " gdp_per_capita.rename(columns={\"2015\": \"GDP per capita\"}, inplace=True)\n", + " gdp_per_capita.set_index(\"Country\", inplace=True)\n", + " full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n", + " left_index=True, right_index=True)\n", + " full_country_stats.sort_values(by=\"GDP per capita\", inplace=True)\n", + " remove_indices = [0, 1, 6, 8, 33, 34, 35]\n", + " keep_indices = list(set(range(36)) - set(remove_indices))\n", + " return full_country_stats[[\"GDP per capita\", 'Life satisfaction']].iloc[keep_indices]" ] }, { "cell_type": "code", - "execution_count": 19, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 22, + "metadata": {}, "outputs": [], "source": [ "# Code example\n", @@ -440,12 +509,8 @@ }, { "cell_type": "code", - "execution_count": 20, - "metadata": { - "collapsed": true, - "deletable": true, - "editable": true - }, + "execution_count": 23, + "metadata": {}, "outputs": [], "source": [ "oecd_bli, gdp_per_capita = backup" @@ -453,12 +518,8 @@ }, { "cell_type": "code", - "execution_count": 21, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 24, + "metadata": {}, "outputs": [], "source": [ "missing_data" @@ -466,12 +527,8 @@ }, { "cell_type": "code", - "execution_count": 22, - "metadata": { - "collapsed": true, - "deletable": true, - "editable": true - }, + "execution_count": 25, + "metadata": {}, "outputs": [], "source": [ "position_text2 = {\n", @@ -487,12 +544,8 @@ }, { "cell_type": "code", - "execution_count": 23, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 26, + "metadata": {}, "outputs": [], "source": [ "sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(8,3))\n", @@ -522,12 +575,8 @@ }, { "cell_type": "code", - "execution_count": 24, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 27, + "metadata": {}, "outputs": [], "source": [ "full_country_stats.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(8,3))\n", @@ -550,12 +599,8 @@ }, { "cell_type": "code", - "execution_count": 25, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 28, + "metadata": {}, "outputs": [], "source": [ "full_country_stats.loc[[c for c in full_country_stats.index if \"W\" in c.upper()]][\"Life satisfaction\"]" @@ -563,12 +608,8 @@ }, { "cell_type": "code", - "execution_count": 26, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 29, + "metadata": {}, "outputs": [], "source": [ "gdp_per_capita.loc[[c for c in gdp_per_capita.index if \"W\" in c.upper()]].head()" @@ -576,12 +617,8 @@ }, { "cell_type": "code", - "execution_count": 27, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 30, + "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(8,3))\n", @@ -611,12 +648,8 @@ }, { "cell_type": "code", - "execution_count": 28, - "metadata": { - "collapsed": true, - "deletable": true, - "editable": true - }, + "execution_count": 31, + "metadata": {}, "outputs": [], "source": [ "backup = oecd_bli, gdp_per_capita\n", @@ -627,12 +660,8 @@ }, { "cell_type": "code", - "execution_count": 29, - "metadata": { - "collapsed": true, - "deletable": true, - "editable": true - }, + "execution_count": 32, + "metadata": {}, "outputs": [], "source": [ "# Replace this linear model:\n", @@ -641,12 +670,8 @@ }, { "cell_type": "code", - "execution_count": 30, - "metadata": { - "collapsed": true, - "deletable": true, - "editable": true - }, + "execution_count": 33, + "metadata": {}, "outputs": [], "source": [ "# with this k-neighbors regression model:\n", @@ -655,12 +680,8 @@ }, { "cell_type": "code", - "execution_count": 31, - "metadata": { - "collapsed": false, - "deletable": true, - "editable": true - }, + "execution_count": 34, + "metadata": {}, "outputs": [], "source": [ "X = np.c_[country_stats[\"GDP per capita\"]]\n", @@ -677,11 +698,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true, - "deletable": true, - "editable": true - }, + "metadata": {}, "outputs": [], "source": [] } @@ -702,7 +719,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.5.3" + "version": "3.6.3" }, "nav_menu": {}, "toc": { @@ -723,5 +740,5 @@ } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 1 } diff --git a/03_classification.ipynb b/03_classification.ipynb index 1e7960a..0f8b455 100644 --- a/03_classification.ipynb +++ b/03_classification.ipynb @@ -2281,7 +2281,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "It seems that the ham emails are more often plain text, while spam has quite a lot of HTML. Moreover, quite a few ham emails are signed using PGP, while no spam is. In short, it seems that the email structure is a usual information to have." + "It seems that the ham emails are more often plain text, while spam has quite a lot of HTML. Moreover, quite a few ham emails are signed using PGP, while no spam is. In short, it seems that the email structure is useful information to have." ] }, { @@ -2714,8 +2714,8 @@ "\n", "y_pred = log_clf.predict(X_test_transformed)\n", "\n", - "print(\"Precision: {:.2f}%\".format(precision_score(y_test, y_pred)))\n", - "print(\"Recall: {:.2f}%\".format(recall_score(y_test, y_pred)))" + "print(\"Precision: {:.2f}%\".format(100 * precision_score(y_test, y_pred)))\n", + "print(\"Recall: {:.2f}%\".format(100 * recall_score(y_test, y_pred)))" ] } ], diff --git a/extra_capsnets.ipynb b/extra_capsnets.ipynb index cdfeb2d..20e5dd1 100644 --- a/extra_capsnets.ipynb +++ b/extra_capsnets.ipynb @@ -77,10 +77,8 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": { - "collapsed": true - }, + "execution_count": 3, + "metadata": {}, "outputs": [], "source": [ "from __future__ import division, print_function, unicode_literals" @@ -95,10 +93,8 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": { - "collapsed": true - }, + "execution_count": 4, + "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", @@ -115,10 +111,8 @@ }, { "cell_type": "code", - "execution_count": 4, - "metadata": { - "collapsed": true - }, + "execution_count": 5, + "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", @@ -141,10 +135,8 @@ }, { "cell_type": "code", - "execution_count": 5, - "metadata": { - "collapsed": true - }, + "execution_count": 6, + "metadata": {}, "outputs": [], "source": [ "tf.reset_default_graph()" @@ -159,10 +151,8 @@ }, { "cell_type": "code", - "execution_count": 6, - "metadata": { - "collapsed": true - }, + "execution_count": 7, + "metadata": {}, "outputs": [], "source": [ "np.random.seed(42)\n", @@ -185,7 +175,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -203,7 +193,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -228,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -286,10 +276,8 @@ }, { "cell_type": "code", - "execution_count": 10, - "metadata": { - "collapsed": true - }, + "execution_count": 11, + "metadata": {}, "outputs": [], "source": [ "X = tf.placeholder(shape=[None, 28, 28, 1], dtype=tf.float32, name=\"X\")" @@ -311,10 +299,8 @@ }, { "cell_type": "code", - "execution_count": 11, - "metadata": { - "collapsed": true - }, + "execution_count": 12, + "metadata": {}, "outputs": [], "source": [ "caps1_n_maps = 32\n", @@ -331,10 +317,8 @@ }, { "cell_type": "code", - "execution_count": 12, - "metadata": { - "collapsed": true - }, + "execution_count": 13, + "metadata": {}, "outputs": [], "source": [ "conv1_params = {\n", @@ -356,10 +340,8 @@ }, { "cell_type": "code", - "execution_count": 13, - "metadata": { - "collapsed": true - }, + "execution_count": 14, + "metadata": {}, "outputs": [], "source": [ "conv1 = tf.layers.conv2d(X, name=\"conv1\", **conv1_params)\n", @@ -382,10 +364,8 @@ }, { "cell_type": "code", - "execution_count": 14, - "metadata": { - "collapsed": true - }, + "execution_count": 15, + "metadata": {}, "outputs": [], "source": [ "caps1_raw = tf.reshape(conv2, [-1, caps1_n_caps, caps1_n_dims],\n", @@ -407,10 +387,8 @@ }, { "cell_type": "code", - "execution_count": 15, - "metadata": { - "collapsed": true - }, + "execution_count": 16, + "metadata": {}, "outputs": [], "source": [ "def squash(s, axis=-1, epsilon=1e-7, name=None):\n", @@ -432,10 +410,8 @@ }, { "cell_type": "code", - "execution_count": 16, - "metadata": { - "collapsed": true - }, + "execution_count": 17, + "metadata": {}, "outputs": [], "source": [ "caps1_output = squash(caps1_raw, name=\"caps1_output\")" @@ -478,10 +454,8 @@ }, { "cell_type": "code", - "execution_count": 17, - "metadata": { - "collapsed": true - }, + "execution_count": 18, + "metadata": {}, "outputs": [], "source": [ "caps2_n_caps = 10\n", @@ -568,10 +542,8 @@ }, { "cell_type": "code", - "execution_count": 18, - "metadata": { - "collapsed": true - }, + "execution_count": 19, + "metadata": {}, "outputs": [], "source": [ "init_sigma = 0.01\n", @@ -591,10 +563,8 @@ }, { "cell_type": "code", - "execution_count": 19, - "metadata": { - "collapsed": true - }, + "execution_count": 20, + "metadata": {}, "outputs": [], "source": [ "batch_size = tf.shape(X)[0]\n", @@ -610,10 +580,8 @@ }, { "cell_type": "code", - "execution_count": 20, - "metadata": { - "collapsed": true - }, + "execution_count": 21, + "metadata": {}, "outputs": [], "source": [ "caps1_output_expanded = tf.expand_dims(caps1_output, -1,\n", @@ -633,7 +601,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -649,7 +617,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -665,10 +633,8 @@ }, { "cell_type": "code", - "execution_count": 23, - "metadata": { - "collapsed": true - }, + "execution_count": 24, + "metadata": {}, "outputs": [], "source": [ "caps2_predicted = tf.matmul(W_tiled, caps1_output_tiled,\n", @@ -684,7 +650,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -714,10 +680,8 @@ }, { "cell_type": "code", - "execution_count": 25, - "metadata": { - "collapsed": true - }, + "execution_count": 26, + "metadata": {}, "outputs": [], "source": [ "raw_weights = tf.zeros([batch_size, caps1_n_caps, caps2_n_caps, 1, 1],\n", @@ -747,10 +711,8 @@ }, { "cell_type": "code", - "execution_count": 26, - "metadata": { - "collapsed": true - }, + "execution_count": 27, + "metadata": {}, "outputs": [], "source": [ "routing_weights = tf.nn.softmax(raw_weights, dim=2, name=\"routing_weights\")" @@ -765,10 +727,8 @@ }, { "cell_type": "code", - "execution_count": 27, - "metadata": { - "collapsed": true - }, + "execution_count": 28, + "metadata": {}, "outputs": [], "source": [ "weighted_predictions = tf.multiply(routing_weights, caps2_predicted,\n", @@ -797,10 +757,8 @@ }, { "cell_type": "code", - "execution_count": 28, - "metadata": { - "collapsed": true - }, + "execution_count": 29, + "metadata": {}, "outputs": [], "source": [ "caps2_output_round_1 = squash(weighted_sum, axis=-2,\n", @@ -809,7 +767,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ @@ -853,7 +811,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 31, "metadata": {}, "outputs": [], "source": [ @@ -869,7 +827,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 32, "metadata": {}, "outputs": [], "source": [ @@ -885,10 +843,8 @@ }, { "cell_type": "code", - "execution_count": 32, - "metadata": { - "collapsed": true - }, + "execution_count": 33, + "metadata": {}, "outputs": [], "source": [ "caps2_output_round_1_tiled = tf.tile(\n", @@ -905,10 +861,8 @@ }, { "cell_type": "code", - "execution_count": 33, - "metadata": { - "collapsed": true - }, + "execution_count": 34, + "metadata": {}, "outputs": [], "source": [ "agreement = tf.matmul(caps2_predicted, caps2_output_round_1_tiled,\n", @@ -924,10 +878,8 @@ }, { "cell_type": "code", - "execution_count": 34, - "metadata": { - "collapsed": true - }, + "execution_count": 35, + "metadata": {}, "outputs": [], "source": [ "raw_weights_round_2 = tf.add(raw_weights, agreement,\n", @@ -943,10 +895,8 @@ }, { "cell_type": "code", - "execution_count": 35, - "metadata": { - "collapsed": true - }, + "execution_count": 36, + "metadata": {}, "outputs": [], "source": [ "routing_weights_round_2 = tf.nn.softmax(raw_weights_round_2,\n", @@ -972,10 +922,8 @@ }, { "cell_type": "code", - "execution_count": 36, - "metadata": { - "collapsed": true - }, + "execution_count": 37, + "metadata": {}, "outputs": [], "source": [ "caps2_output = caps2_output_round_2" @@ -1003,7 +951,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -1043,7 +991,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ @@ -1073,10 +1021,8 @@ }, { "cell_type": "code", - "execution_count": 39, - "metadata": { - "collapsed": true - }, + "execution_count": 40, + "metadata": {}, "outputs": [], "source": [ "def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False, name=None):\n", @@ -1088,10 +1034,8 @@ }, { "cell_type": "code", - "execution_count": 40, - "metadata": { - "collapsed": true - }, + "execution_count": 41, + "metadata": {}, "outputs": [], "source": [ "y_proba = safe_norm(caps2_output, axis=-2, name=\"y_proba\")" @@ -1106,10 +1050,8 @@ }, { "cell_type": "code", - "execution_count": 41, - "metadata": { - "collapsed": true - }, + "execution_count": 42, + "metadata": {}, "outputs": [], "source": [ "y_proba_argmax = tf.argmax(y_proba, axis=2, name=\"y_proba\")" @@ -1124,7 +1066,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 43, "metadata": {}, "outputs": [], "source": [ @@ -1140,10 +1082,8 @@ }, { "cell_type": "code", - "execution_count": 43, - "metadata": { - "collapsed": true - }, + "execution_count": 44, + "metadata": {}, "outputs": [], "source": [ "y_pred = tf.squeeze(y_proba_argmax, axis=[1,2], name=\"y_pred\")" @@ -1151,7 +1091,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ @@ -1181,10 +1121,8 @@ }, { "cell_type": "code", - "execution_count": 45, - "metadata": { - "collapsed": true - }, + "execution_count": 46, + "metadata": {}, "outputs": [], "source": [ "y = tf.placeholder(shape=[None], dtype=tf.int64, name=\"y\")" @@ -1212,10 +1150,8 @@ }, { "cell_type": "code", - "execution_count": 46, - "metadata": { - "collapsed": true - }, + "execution_count": 47, + "metadata": {}, "outputs": [], "source": [ "m_plus = 0.9\n", @@ -1232,10 +1168,8 @@ }, { "cell_type": "code", - "execution_count": 47, - "metadata": { - "collapsed": true - }, + "execution_count": 48, + "metadata": {}, "outputs": [], "source": [ "T = tf.one_hot(y, depth=caps2_n_caps, name=\"T\")" @@ -1250,7 +1184,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 49, "metadata": {}, "outputs": [], "source": [ @@ -1267,7 +1201,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 50, "metadata": {}, "outputs": [], "source": [ @@ -1283,10 +1217,8 @@ }, { "cell_type": "code", - "execution_count": 50, - "metadata": { - "collapsed": true - }, + "execution_count": 51, + "metadata": {}, "outputs": [], "source": [ "caps2_output_norm = safe_norm(caps2_output, axis=-2, keep_dims=True,\n", @@ -1302,10 +1234,8 @@ }, { "cell_type": "code", - "execution_count": 51, - "metadata": { - "collapsed": true - }, + "execution_count": 52, + "metadata": {}, "outputs": [], "source": [ "present_error_raw = tf.square(tf.maximum(0., m_plus - caps2_output_norm),\n", @@ -1323,10 +1253,8 @@ }, { "cell_type": "code", - "execution_count": 52, - "metadata": { - "collapsed": true - }, + "execution_count": 53, + "metadata": {}, "outputs": [], "source": [ "absent_error_raw = tf.square(tf.maximum(0., caps2_output_norm - m_minus),\n", @@ -1344,10 +1272,8 @@ }, { "cell_type": "code", - "execution_count": 53, - "metadata": { - "collapsed": true - }, + "execution_count": 54, + "metadata": {}, "outputs": [], "source": [ "L = tf.add(T * present_error, lambda_ * (1.0 - T) * absent_error,\n", @@ -1363,10 +1289,8 @@ }, { "cell_type": "code", - "execution_count": 54, - "metadata": { - "collapsed": true - }, + "execution_count": 55, + "metadata": {}, "outputs": [], "source": [ "margin_loss = tf.reduce_mean(tf.reduce_sum(L, axis=1), name=\"margin_loss\")" @@ -1409,10 +1333,8 @@ }, { "cell_type": "code", - "execution_count": 55, - "metadata": { - "collapsed": true - }, + "execution_count": 56, + "metadata": {}, "outputs": [], "source": [ "mask_with_labels = tf.placeholder_with_default(False, shape=(),\n", @@ -1428,10 +1350,8 @@ }, { "cell_type": "code", - "execution_count": 56, - "metadata": { - "collapsed": true - }, + "execution_count": 57, + "metadata": {}, "outputs": [], "source": [ "reconstruction_targets = tf.cond(mask_with_labels, # condition\n", @@ -1458,10 +1378,8 @@ }, { "cell_type": "code", - "execution_count": 57, - "metadata": { - "collapsed": true - }, + "execution_count": 58, + "metadata": {}, "outputs": [], "source": [ "reconstruction_mask = tf.one_hot(reconstruction_targets,\n", @@ -1478,7 +1396,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 59, "metadata": {}, "outputs": [], "source": [ @@ -1494,7 +1412,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 60, "metadata": {}, "outputs": [], "source": [ @@ -1510,10 +1428,8 @@ }, { "cell_type": "code", - "execution_count": 60, - "metadata": { - "collapsed": true - }, + "execution_count": 61, + "metadata": {}, "outputs": [], "source": [ "reconstruction_mask_reshaped = tf.reshape(\n", @@ -1530,10 +1446,8 @@ }, { "cell_type": "code", - "execution_count": 61, - "metadata": { - "collapsed": true - }, + "execution_count": 62, + "metadata": {}, "outputs": [], "source": [ "caps2_output_masked = tf.multiply(\n", @@ -1543,7 +1457,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 63, "metadata": {}, "outputs": [], "source": [ @@ -1559,10 +1473,8 @@ }, { "cell_type": "code", - "execution_count": 63, - "metadata": { - "collapsed": true - }, + "execution_count": 64, + "metadata": {}, "outputs": [], "source": [ "decoder_input = tf.reshape(caps2_output_masked,\n", @@ -1579,7 +1491,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 65, "metadata": {}, "outputs": [], "source": [ @@ -1602,10 +1514,8 @@ }, { "cell_type": "code", - "execution_count": 65, - "metadata": { - "collapsed": true - }, + "execution_count": 66, + "metadata": {}, "outputs": [], "source": [ "n_hidden1 = 512\n", @@ -1615,10 +1525,8 @@ }, { "cell_type": "code", - "execution_count": 66, - "metadata": { - "collapsed": true - }, + "execution_count": 67, + "metadata": {}, "outputs": [], "source": [ "with tf.name_scope(\"decoder\"):\n", @@ -1649,16 +1557,14 @@ }, { "cell_type": "code", - "execution_count": 67, - "metadata": { - "collapsed": true - }, + "execution_count": 68, + "metadata": {}, "outputs": [], "source": [ "X_flat = tf.reshape(X, [-1, n_output], name=\"X_flat\")\n", "squared_difference = tf.square(X_flat - decoder_output,\n", " name=\"squared_difference\")\n", - "reconstruction_loss = tf.reduce_sum(squared_difference,\n", + "reconstruction_loss = tf.reduce_mean(squared_difference,\n", " name=\"reconstruction_loss\")" ] }, @@ -1678,10 +1584,8 @@ }, { "cell_type": "code", - "execution_count": 68, - "metadata": { - "collapsed": true - }, + "execution_count": 69, + "metadata": {}, "outputs": [], "source": [ "alpha = 0.0005\n", @@ -1712,10 +1616,8 @@ }, { "cell_type": "code", - "execution_count": 69, - "metadata": { - "collapsed": true - }, + "execution_count": 70, + "metadata": {}, "outputs": [], "source": [ "correct = tf.equal(y, y_pred, name=\"correct\")\n", @@ -1738,10 +1640,8 @@ }, { "cell_type": "code", - "execution_count": 70, - "metadata": { - "collapsed": true - }, + "execution_count": 71, + "metadata": {}, "outputs": [], "source": [ "optimizer = tf.train.AdamOptimizer()\n", @@ -1764,10 +1664,8 @@ }, { "cell_type": "code", - "execution_count": 71, - "metadata": { - "collapsed": true - }, + "execution_count": 72, + "metadata": {}, "outputs": [], "source": [ "init = tf.global_variables_initializer()\n", @@ -1804,7 +1702,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 73, "metadata": {}, "outputs": [], "source": [ @@ -1870,7 +1768,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Training is finished, we reached over 99.3% accuracy on the validation set after just 5 epochs, things are looking good. Now let's evaluate the model on the test set." + "Training is finished, we reached over 99.4% accuracy on the validation set after just 5 epochs, things are looking good. Now let's evaluate the model on the test set." ] }, { @@ -1882,7 +1780,7 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 74, "metadata": {}, "outputs": [], "source": [ @@ -1915,7 +1813,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We reach 99.43% accuracy on the test set. Pretty nice. :)" + "We reach 99.53% accuracy on the test set. Pretty nice. :)" ] }, { @@ -1934,7 +1832,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 75, "metadata": {}, "outputs": [], "source": [ @@ -1966,7 +1864,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 76, "metadata": {}, "outputs": [], "source": [ @@ -2022,7 +1920,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 77, "metadata": {}, "outputs": [], "source": [ @@ -2038,10 +1936,8 @@ }, { "cell_type": "code", - "execution_count": 77, - "metadata": { - "collapsed": true - }, + "execution_count": 78, + "metadata": {}, "outputs": [], "source": [ "def tweak_pose_parameters(output_vectors, min=-0.5, max=0.5, n_steps=11):\n", @@ -2062,10 +1958,8 @@ }, { "cell_type": "code", - "execution_count": 78, - "metadata": { - "collapsed": true - }, + "execution_count": 79, + "metadata": {}, "outputs": [], "source": [ "n_steps = 11\n", @@ -2084,7 +1978,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 80, "metadata": {}, "outputs": [], "source": [ @@ -2108,10 +2002,8 @@ }, { "cell_type": "code", - "execution_count": 80, - "metadata": { - "collapsed": true - }, + "execution_count": 81, + "metadata": {}, "outputs": [], "source": [ "tweak_reconstructions = decoder_output_value.reshape(\n", @@ -2127,7 +2019,7 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 82, "metadata": {}, "outputs": [], "source": [ @@ -2161,9 +2053,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [] }