Merge branch 'master' of https://github.com/ageron/handson-ml into upstream
commit
0f46b85009
|
@ -2,10 +2,7 @@
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"source": [
|
"source": [
|
||||||
"**Chapter 1 – The Machine Learning landscape**\n",
|
"**Chapter 1 – The Machine Learning landscape**\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -14,20 +11,14 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"source": [
|
"source": [
|
||||||
"# Setup"
|
"# Setup"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"source": [
|
"source": [
|
||||||
"First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
|
"First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
|
||||||
]
|
]
|
||||||
|
@ -36,9 +27,6 @@
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 1,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true,
|
|
||||||
"slideshow": {
|
"slideshow": {
|
||||||
"slide_type": "-"
|
"slide_type": "-"
|
||||||
}
|
}
|
||||||
|
@ -50,11 +38,10 @@
|
||||||
"\n",
|
"\n",
|
||||||
"# Common imports\n",
|
"# Common imports\n",
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"import numpy.random as rnd\n",
|
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# to make this notebook's output stable across runs\n",
|
"# to make this notebook's output stable across runs\n",
|
||||||
"rnd.seed(42)\n",
|
"np.random.seed(42)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# To plot pretty figures\n",
|
"# To plot pretty figures\n",
|
||||||
"%matplotlib inline\n",
|
"%matplotlib inline\n",
|
||||||
|
@ -73,35 +60,173 @@
|
||||||
" print(\"Saving figure\", fig_id)\n",
|
" print(\"Saving figure\", fig_id)\n",
|
||||||
" if tight_layout:\n",
|
" if tight_layout:\n",
|
||||||
" plt.tight_layout()\n",
|
" plt.tight_layout()\n",
|
||||||
" plt.savefig(path, format='png', dpi=300)"
|
" plt.savefig(path, format='png', dpi=300)\n",
|
||||||
|
"\n",
|
||||||
|
"# Ignore useless warnings (see SciPy issue #5998)\n",
|
||||||
|
"import warnings\n",
|
||||||
|
"warnings.filterwarnings(action=\"ignore\", module=\"scipy\", message=\"^internal gelsd\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"source": [
|
"source": [
|
||||||
"# Load and prepare Life satisfaction data"
|
"# Code example 1-1"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"This function just merges the OECD's life satisfaction data and the IMF's GDP per capita data. It's a bit too long and boring and it's not specific to Machine Learning, which is why I left it out of the book."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 2,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"def prepare_country_stats(oecd_bli, gdp_per_capita):\n",
|
||||||
|
" oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
|
||||||
|
" oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
|
||||||
|
" gdp_per_capita.rename(columns={\"2015\": \"GDP per capita\"}, inplace=True)\n",
|
||||||
|
" gdp_per_capita.set_index(\"Country\", inplace=True)\n",
|
||||||
|
" full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n",
|
||||||
|
" left_index=True, right_index=True)\n",
|
||||||
|
" full_country_stats.sort_values(by=\"GDP per capita\", inplace=True)\n",
|
||||||
|
" remove_indices = [0, 1, 6, 8, 33, 34, 35]\n",
|
||||||
|
" keep_indices = list(set(range(36)) - set(remove_indices))\n",
|
||||||
|
" return full_country_stats[[\"GDP per capita\", 'Life satisfaction']].iloc[keep_indices]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"The code in the book expects the data files to be located in the current directory. I just tweaked it here to fetch the files in datasets/lifesat."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import os\n",
|
||||||
|
"datapath = os.path.join(\"datasets\", \"lifesat\", \"\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Code example\n",
|
||||||
|
"import matplotlib\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import numpy as np\n",
|
||||||
"import pandas as pd\n",
|
"import pandas as pd\n",
|
||||||
|
"import sklearn.linear_model\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Download CSV from http://stats.oecd.org/index.aspx?DataSetCode=BLI\n",
|
"# Load the data\n",
|
||||||
"datapath = \"datasets/lifesat/\"\n",
|
"oecd_bli = pd.read_csv(datapath + \"oecd_bli_2015.csv\", thousands=',')\n",
|
||||||
|
"gdp_per_capita = pd.read_csv(datapath + \"gdp_per_capita.csv\",thousands=',',delimiter='\\t',\n",
|
||||||
|
" encoding='latin1', na_values=\"n/a\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"oecd_bli = pd.read_csv(datapath+\"oecd_bli_2015.csv\", thousands=',')\n",
|
"# Prepare the data\n",
|
||||||
|
"country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)\n",
|
||||||
|
"X = np.c_[country_stats[\"GDP per capita\"]]\n",
|
||||||
|
"y = np.c_[country_stats[\"Life satisfaction\"]]\n",
|
||||||
|
"\n",
|
||||||
|
"# Visualize the data\n",
|
||||||
|
"country_stats.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction')\n",
|
||||||
|
"plt.show()\n",
|
||||||
|
"\n",
|
||||||
|
"# Select a linear model\n",
|
||||||
|
"model = sklearn.linear_model.LinearRegression()\n",
|
||||||
|
"\n",
|
||||||
|
"# Train the model\n",
|
||||||
|
"model.fit(X, y)\n",
|
||||||
|
"\n",
|
||||||
|
"# Make a prediction for Cyprus\n",
|
||||||
|
"X_new = [[22587]] # Cyprus' GDP per capita\n",
|
||||||
|
"print(model.predict(X_new)) # outputs [[ 5.96242338]]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Note: you can ignore the rest of this notebook, it just generates many of the figures in chapter 1."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Load and prepare Life satisfaction data"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"If you want, you can get fresh data from the OECD's website.\n",
|
||||||
|
"Download the CSV from http://stats.oecd.org/index.aspx?DataSetCode=BLI\n",
|
||||||
|
"and save it to `datasets/lifesat/`."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"oecd_bli = pd.read_csv(datapath + \"oecd_bli_2015.csv\", thousands=',')\n",
|
||||||
"oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
|
"oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
|
||||||
"oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
|
"oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
|
||||||
"oecd_bli.head(2)"
|
"oecd_bli.head(2)"
|
||||||
|
@ -109,12 +234,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 6,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"oecd_bli[\"Life satisfaction\"].head()"
|
"oecd_bli[\"Life satisfaction\"].head()"
|
||||||
|
@ -122,25 +243,24 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"source": [
|
"source": [
|
||||||
"# Load and prepare GDP per capita data"
|
"# Load and prepare GDP per capita data"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Just like above, you can update the GDP per capita data if you want. Just download data from http://goo.gl/j1MSKe (=> imf.org) and save it to `datasets/lifesat/`."
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 7,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Download data from http://goo.gl/j1MSKe (=> imf.org)\n",
|
|
||||||
"gdp_per_capita = pd.read_csv(datapath+\"gdp_per_capita.csv\", thousands=',', delimiter='\\t',\n",
|
"gdp_per_capita = pd.read_csv(datapath+\"gdp_per_capita.csv\", thousands=',', delimiter='\\t',\n",
|
||||||
" encoding='latin1', na_values=\"n/a\")\n",
|
" encoding='latin1', na_values=\"n/a\")\n",
|
||||||
"gdp_per_capita.rename(columns={\"2015\": \"GDP per capita\"}, inplace=True)\n",
|
"gdp_per_capita.rename(columns={\"2015\": \"GDP per capita\"}, inplace=True)\n",
|
||||||
|
@ -150,12 +270,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 8,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita, left_index=True, right_index=True)\n",
|
"full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita, left_index=True, right_index=True)\n",
|
||||||
|
@ -165,12 +281,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 9,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"full_country_stats[[\"GDP per capita\", 'Life satisfaction']].loc[\"United States\"]"
|
"full_country_stats[[\"GDP per capita\", 'Life satisfaction']].loc[\"United States\"]"
|
||||||
|
@ -178,12 +290,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 10,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"remove_indices = [0, 1, 6, 8, 33, 34, 35]\n",
|
"remove_indices = [0, 1, 6, 8, 33, 34, 35]\n",
|
||||||
|
@ -195,12 +303,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 11,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n",
|
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n",
|
||||||
|
@ -224,25 +328,17 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 12,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sample_data.to_csv(\"life_satisfaction_vs_gdp_per_capita.csv\")"
|
"sample_data.to_csv(os.path.join(\"datasets\", \"lifesat\", \"lifesat.csv\"))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 13,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sample_data.loc[list(position_text.keys())]"
|
"sample_data.loc[list(position_text.keys())]"
|
||||||
|
@ -250,12 +346,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 14,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
|
@ -278,12 +370,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": 15,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from sklearn import linear_model\n",
|
"from sklearn import linear_model\n",
|
||||||
|
@ -297,12 +385,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 16,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n",
|
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n",
|
||||||
|
@ -317,12 +401,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 17,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"cyprus_gdp_per_capita = gdp_per_capita.loc[\"Cyprus\"][\"GDP per capita\"]\n",
|
"cyprus_gdp_per_capita = gdp_per_capita.loc[\"Cyprus\"][\"GDP per capita\"]\n",
|
||||||
|
@ -333,12 +413,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 18,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3), s=1)\n",
|
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3), s=1)\n",
|
||||||
|
@ -356,12 +432,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 19,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sample_data[7:10]"
|
"sample_data[7:10]"
|
||||||
|
@ -369,12 +441,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 17,
|
"execution_count": 20,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"(5.1+5.7+6.5)/3"
|
"(5.1+5.7+6.5)/3"
|
||||||
|
@ -382,28 +450,29 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 18,
|
"execution_count": 21,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"backup = oecd_bli, gdp_per_capita\n",
|
"backup = oecd_bli, gdp_per_capita\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def prepare_country_stats(oecd_bli, gdp_per_capita):\n",
|
"def prepare_country_stats(oecd_bli, gdp_per_capita):\n",
|
||||||
" return sample_data"
|
" oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
|
||||||
|
" oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
|
||||||
|
" gdp_per_capita.rename(columns={\"2015\": \"GDP per capita\"}, inplace=True)\n",
|
||||||
|
" gdp_per_capita.set_index(\"Country\", inplace=True)\n",
|
||||||
|
" full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n",
|
||||||
|
" left_index=True, right_index=True)\n",
|
||||||
|
" full_country_stats.sort_values(by=\"GDP per capita\", inplace=True)\n",
|
||||||
|
" remove_indices = [0, 1, 6, 8, 33, 34, 35]\n",
|
||||||
|
" keep_indices = list(set(range(36)) - set(remove_indices))\n",
|
||||||
|
" return full_country_stats[[\"GDP per capita\", 'Life satisfaction']].iloc[keep_indices]"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": 22,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Code example\n",
|
"# Code example\n",
|
||||||
|
@ -440,12 +509,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 23,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"oecd_bli, gdp_per_capita = backup"
|
"oecd_bli, gdp_per_capita = backup"
|
||||||
|
@ -453,12 +518,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 21,
|
"execution_count": 24,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"missing_data"
|
"missing_data"
|
||||||
|
@ -466,12 +527,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 22,
|
"execution_count": 25,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"position_text2 = {\n",
|
"position_text2 = {\n",
|
||||||
|
@ -487,12 +544,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 23,
|
"execution_count": 26,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(8,3))\n",
|
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(8,3))\n",
|
||||||
|
@ -522,12 +575,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 24,
|
"execution_count": 27,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"full_country_stats.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(8,3))\n",
|
"full_country_stats.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(8,3))\n",
|
||||||
|
@ -550,12 +599,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 25,
|
"execution_count": 28,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"full_country_stats.loc[[c for c in full_country_stats.index if \"W\" in c.upper()]][\"Life satisfaction\"]"
|
"full_country_stats.loc[[c for c in full_country_stats.index if \"W\" in c.upper()]][\"Life satisfaction\"]"
|
||||||
|
@ -563,12 +608,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 26,
|
"execution_count": 29,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"gdp_per_capita.loc[[c for c in gdp_per_capita.index if \"W\" in c.upper()]].head()"
|
"gdp_per_capita.loc[[c for c in gdp_per_capita.index if \"W\" in c.upper()]].head()"
|
||||||
|
@ -576,12 +617,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 27,
|
"execution_count": 30,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"plt.figure(figsize=(8,3))\n",
|
"plt.figure(figsize=(8,3))\n",
|
||||||
|
@ -611,12 +648,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 28,
|
"execution_count": 31,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"backup = oecd_bli, gdp_per_capita\n",
|
"backup = oecd_bli, gdp_per_capita\n",
|
||||||
|
@ -627,12 +660,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 29,
|
"execution_count": 32,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# Replace this linear model:\n",
|
"# Replace this linear model:\n",
|
||||||
|
@ -641,12 +670,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 30,
|
"execution_count": 33,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# with this k-neighbors regression model:\n",
|
"# with this k-neighbors regression model:\n",
|
||||||
|
@ -655,12 +680,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 31,
|
"execution_count": 34,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": false,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"X = np.c_[country_stats[\"GDP per capita\"]]\n",
|
"X = np.c_[country_stats[\"GDP per capita\"]]\n",
|
||||||
|
@ -677,11 +698,7 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true,
|
|
||||||
"deletable": true,
|
|
||||||
"editable": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
}
|
}
|
||||||
|
@ -702,7 +719,7 @@
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.5.3"
|
"version": "3.6.3"
|
||||||
},
|
},
|
||||||
"nav_menu": {},
|
"nav_menu": {},
|
||||||
"toc": {
|
"toc": {
|
||||||
|
@ -723,5 +740,5 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 0
|
"nbformat_minor": 1
|
||||||
}
|
}
|
||||||
|
|
|
@ -2281,7 +2281,7 @@
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"It seems that the ham emails are more often plain text, while spam has quite a lot of HTML. Moreover, quite a few ham emails are signed using PGP, while no spam is. In short, it seems that the email structure is a usual information to have."
|
"It seems that the ham emails are more often plain text, while spam has quite a lot of HTML. Moreover, quite a few ham emails are signed using PGP, while no spam is. In short, it seems that the email structure is useful information to have."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -2714,8 +2714,8 @@
|
||||||
"\n",
|
"\n",
|
||||||
"y_pred = log_clf.predict(X_test_transformed)\n",
|
"y_pred = log_clf.predict(X_test_transformed)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print(\"Precision: {:.2f}%\".format(precision_score(y_test, y_pred)))\n",
|
"print(\"Precision: {:.2f}%\".format(100 * precision_score(y_test, y_pred)))\n",
|
||||||
"print(\"Recall: {:.2f}%\".format(recall_score(y_test, y_pred)))"
|
"print(\"Recall: {:.2f}%\".format(100 * recall_score(y_test, y_pred)))"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|
|
@ -77,10 +77,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 3,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from __future__ import division, print_function, unicode_literals"
|
"from __future__ import division, print_function, unicode_literals"
|
||||||
|
@ -95,10 +93,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 4,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"%matplotlib inline\n",
|
"%matplotlib inline\n",
|
||||||
|
@ -115,10 +111,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 5,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
|
@ -141,10 +135,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 6,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"tf.reset_default_graph()"
|
"tf.reset_default_graph()"
|
||||||
|
@ -159,10 +151,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 7,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"np.random.seed(42)\n",
|
"np.random.seed(42)\n",
|
||||||
|
@ -185,7 +175,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 8,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -203,7 +193,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 9,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -228,7 +218,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -286,10 +276,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 11,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"X = tf.placeholder(shape=[None, 28, 28, 1], dtype=tf.float32, name=\"X\")"
|
"X = tf.placeholder(shape=[None, 28, 28, 1], dtype=tf.float32, name=\"X\")"
|
||||||
|
@ -311,10 +299,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 12,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"caps1_n_maps = 32\n",
|
"caps1_n_maps = 32\n",
|
||||||
|
@ -331,10 +317,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": 13,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"conv1_params = {\n",
|
"conv1_params = {\n",
|
||||||
|
@ -356,10 +340,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 14,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"conv1 = tf.layers.conv2d(X, name=\"conv1\", **conv1_params)\n",
|
"conv1 = tf.layers.conv2d(X, name=\"conv1\", **conv1_params)\n",
|
||||||
|
@ -382,10 +364,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 15,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"caps1_raw = tf.reshape(conv2, [-1, caps1_n_caps, caps1_n_dims],\n",
|
"caps1_raw = tf.reshape(conv2, [-1, caps1_n_caps, caps1_n_dims],\n",
|
||||||
|
@ -407,10 +387,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 16,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def squash(s, axis=-1, epsilon=1e-7, name=None):\n",
|
"def squash(s, axis=-1, epsilon=1e-7, name=None):\n",
|
||||||
|
@ -432,10 +410,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 17,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"caps1_output = squash(caps1_raw, name=\"caps1_output\")"
|
"caps1_output = squash(caps1_raw, name=\"caps1_output\")"
|
||||||
|
@ -478,10 +454,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 17,
|
"execution_count": 18,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"caps2_n_caps = 10\n",
|
"caps2_n_caps = 10\n",
|
||||||
|
@ -568,10 +542,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 18,
|
"execution_count": 19,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"init_sigma = 0.01\n",
|
"init_sigma = 0.01\n",
|
||||||
|
@ -591,10 +563,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": 20,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"batch_size = tf.shape(X)[0]\n",
|
"batch_size = tf.shape(X)[0]\n",
|
||||||
|
@ -610,10 +580,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 21,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"caps1_output_expanded = tf.expand_dims(caps1_output, -1,\n",
|
"caps1_output_expanded = tf.expand_dims(caps1_output, -1,\n",
|
||||||
|
@ -633,7 +601,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 21,
|
"execution_count": 22,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -649,7 +617,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 22,
|
"execution_count": 23,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -665,10 +633,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 23,
|
"execution_count": 24,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"caps2_predicted = tf.matmul(W_tiled, caps1_output_tiled,\n",
|
"caps2_predicted = tf.matmul(W_tiled, caps1_output_tiled,\n",
|
||||||
|
@ -684,7 +650,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 24,
|
"execution_count": 25,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -714,10 +680,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 25,
|
"execution_count": 26,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"raw_weights = tf.zeros([batch_size, caps1_n_caps, caps2_n_caps, 1, 1],\n",
|
"raw_weights = tf.zeros([batch_size, caps1_n_caps, caps2_n_caps, 1, 1],\n",
|
||||||
|
@ -747,10 +711,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 26,
|
"execution_count": 27,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"routing_weights = tf.nn.softmax(raw_weights, dim=2, name=\"routing_weights\")"
|
"routing_weights = tf.nn.softmax(raw_weights, dim=2, name=\"routing_weights\")"
|
||||||
|
@ -765,10 +727,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 27,
|
"execution_count": 28,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"weighted_predictions = tf.multiply(routing_weights, caps2_predicted,\n",
|
"weighted_predictions = tf.multiply(routing_weights, caps2_predicted,\n",
|
||||||
|
@ -797,10 +757,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 28,
|
"execution_count": 29,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"caps2_output_round_1 = squash(weighted_sum, axis=-2,\n",
|
"caps2_output_round_1 = squash(weighted_sum, axis=-2,\n",
|
||||||
|
@ -809,7 +767,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 29,
|
"execution_count": 30,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -853,7 +811,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 30,
|
"execution_count": 31,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -869,7 +827,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 31,
|
"execution_count": 32,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -885,10 +843,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 32,
|
"execution_count": 33,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"caps2_output_round_1_tiled = tf.tile(\n",
|
"caps2_output_round_1_tiled = tf.tile(\n",
|
||||||
|
@ -905,10 +861,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 33,
|
"execution_count": 34,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"agreement = tf.matmul(caps2_predicted, caps2_output_round_1_tiled,\n",
|
"agreement = tf.matmul(caps2_predicted, caps2_output_round_1_tiled,\n",
|
||||||
|
@ -924,10 +878,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 34,
|
"execution_count": 35,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"raw_weights_round_2 = tf.add(raw_weights, agreement,\n",
|
"raw_weights_round_2 = tf.add(raw_weights, agreement,\n",
|
||||||
|
@ -943,10 +895,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 35,
|
"execution_count": 36,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"routing_weights_round_2 = tf.nn.softmax(raw_weights_round_2,\n",
|
"routing_weights_round_2 = tf.nn.softmax(raw_weights_round_2,\n",
|
||||||
|
@ -972,10 +922,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 36,
|
"execution_count": 37,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"caps2_output = caps2_output_round_2"
|
"caps2_output = caps2_output_round_2"
|
||||||
|
@ -1003,7 +951,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 37,
|
"execution_count": 38,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1043,7 +991,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 38,
|
"execution_count": 39,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1073,10 +1021,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 39,
|
"execution_count": 40,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False, name=None):\n",
|
"def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False, name=None):\n",
|
||||||
|
@ -1088,10 +1034,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 40,
|
"execution_count": 41,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"y_proba = safe_norm(caps2_output, axis=-2, name=\"y_proba\")"
|
"y_proba = safe_norm(caps2_output, axis=-2, name=\"y_proba\")"
|
||||||
|
@ -1106,10 +1050,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 41,
|
"execution_count": 42,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"y_proba_argmax = tf.argmax(y_proba, axis=2, name=\"y_proba\")"
|
"y_proba_argmax = tf.argmax(y_proba, axis=2, name=\"y_proba\")"
|
||||||
|
@ -1124,7 +1066,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 42,
|
"execution_count": 43,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1140,10 +1082,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 43,
|
"execution_count": 44,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"y_pred = tf.squeeze(y_proba_argmax, axis=[1,2], name=\"y_pred\")"
|
"y_pred = tf.squeeze(y_proba_argmax, axis=[1,2], name=\"y_pred\")"
|
||||||
|
@ -1151,7 +1091,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 44,
|
"execution_count": 45,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1181,10 +1121,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 45,
|
"execution_count": 46,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"y = tf.placeholder(shape=[None], dtype=tf.int64, name=\"y\")"
|
"y = tf.placeholder(shape=[None], dtype=tf.int64, name=\"y\")"
|
||||||
|
@ -1212,10 +1150,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 46,
|
"execution_count": 47,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"m_plus = 0.9\n",
|
"m_plus = 0.9\n",
|
||||||
|
@ -1232,10 +1168,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 47,
|
"execution_count": 48,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"T = tf.one_hot(y, depth=caps2_n_caps, name=\"T\")"
|
"T = tf.one_hot(y, depth=caps2_n_caps, name=\"T\")"
|
||||||
|
@ -1250,7 +1184,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 48,
|
"execution_count": 49,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1267,7 +1201,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 49,
|
"execution_count": 50,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1283,10 +1217,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 50,
|
"execution_count": 51,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"caps2_output_norm = safe_norm(caps2_output, axis=-2, keep_dims=True,\n",
|
"caps2_output_norm = safe_norm(caps2_output, axis=-2, keep_dims=True,\n",
|
||||||
|
@ -1302,10 +1234,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 51,
|
"execution_count": 52,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"present_error_raw = tf.square(tf.maximum(0., m_plus - caps2_output_norm),\n",
|
"present_error_raw = tf.square(tf.maximum(0., m_plus - caps2_output_norm),\n",
|
||||||
|
@ -1323,10 +1253,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 52,
|
"execution_count": 53,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"absent_error_raw = tf.square(tf.maximum(0., caps2_output_norm - m_minus),\n",
|
"absent_error_raw = tf.square(tf.maximum(0., caps2_output_norm - m_minus),\n",
|
||||||
|
@ -1344,10 +1272,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 53,
|
"execution_count": 54,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"L = tf.add(T * present_error, lambda_ * (1.0 - T) * absent_error,\n",
|
"L = tf.add(T * present_error, lambda_ * (1.0 - T) * absent_error,\n",
|
||||||
|
@ -1363,10 +1289,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 54,
|
"execution_count": 55,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"margin_loss = tf.reduce_mean(tf.reduce_sum(L, axis=1), name=\"margin_loss\")"
|
"margin_loss = tf.reduce_mean(tf.reduce_sum(L, axis=1), name=\"margin_loss\")"
|
||||||
|
@ -1409,10 +1333,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 55,
|
"execution_count": 56,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"mask_with_labels = tf.placeholder_with_default(False, shape=(),\n",
|
"mask_with_labels = tf.placeholder_with_default(False, shape=(),\n",
|
||||||
|
@ -1428,10 +1350,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 56,
|
"execution_count": 57,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"reconstruction_targets = tf.cond(mask_with_labels, # condition\n",
|
"reconstruction_targets = tf.cond(mask_with_labels, # condition\n",
|
||||||
|
@ -1458,10 +1378,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 57,
|
"execution_count": 58,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"reconstruction_mask = tf.one_hot(reconstruction_targets,\n",
|
"reconstruction_mask = tf.one_hot(reconstruction_targets,\n",
|
||||||
|
@ -1478,7 +1396,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 58,
|
"execution_count": 59,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1494,7 +1412,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 59,
|
"execution_count": 60,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1510,10 +1428,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 60,
|
"execution_count": 61,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"reconstruction_mask_reshaped = tf.reshape(\n",
|
"reconstruction_mask_reshaped = tf.reshape(\n",
|
||||||
|
@ -1530,10 +1446,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 61,
|
"execution_count": 62,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"caps2_output_masked = tf.multiply(\n",
|
"caps2_output_masked = tf.multiply(\n",
|
||||||
|
@ -1543,7 +1457,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 62,
|
"execution_count": 63,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1559,10 +1473,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 63,
|
"execution_count": 64,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"decoder_input = tf.reshape(caps2_output_masked,\n",
|
"decoder_input = tf.reshape(caps2_output_masked,\n",
|
||||||
|
@ -1579,7 +1491,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 64,
|
"execution_count": 65,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1602,10 +1514,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 65,
|
"execution_count": 66,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"n_hidden1 = 512\n",
|
"n_hidden1 = 512\n",
|
||||||
|
@ -1615,10 +1525,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 66,
|
"execution_count": 67,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"with tf.name_scope(\"decoder\"):\n",
|
"with tf.name_scope(\"decoder\"):\n",
|
||||||
|
@ -1649,16 +1557,14 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 67,
|
"execution_count": 68,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"X_flat = tf.reshape(X, [-1, n_output], name=\"X_flat\")\n",
|
"X_flat = tf.reshape(X, [-1, n_output], name=\"X_flat\")\n",
|
||||||
"squared_difference = tf.square(X_flat - decoder_output,\n",
|
"squared_difference = tf.square(X_flat - decoder_output,\n",
|
||||||
" name=\"squared_difference\")\n",
|
" name=\"squared_difference\")\n",
|
||||||
"reconstruction_loss = tf.reduce_sum(squared_difference,\n",
|
"reconstruction_loss = tf.reduce_mean(squared_difference,\n",
|
||||||
" name=\"reconstruction_loss\")"
|
" name=\"reconstruction_loss\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -1678,10 +1584,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 68,
|
"execution_count": 69,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"alpha = 0.0005\n",
|
"alpha = 0.0005\n",
|
||||||
|
@ -1712,10 +1616,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 69,
|
"execution_count": 70,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"correct = tf.equal(y, y_pred, name=\"correct\")\n",
|
"correct = tf.equal(y, y_pred, name=\"correct\")\n",
|
||||||
|
@ -1738,10 +1640,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 70,
|
"execution_count": 71,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"optimizer = tf.train.AdamOptimizer()\n",
|
"optimizer = tf.train.AdamOptimizer()\n",
|
||||||
|
@ -1764,10 +1664,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 71,
|
"execution_count": 72,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"init = tf.global_variables_initializer()\n",
|
"init = tf.global_variables_initializer()\n",
|
||||||
|
@ -1804,7 +1702,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 72,
|
"execution_count": 73,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1870,7 +1768,7 @@
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"Training is finished, we reached over 99.3% accuracy on the validation set after just 5 epochs, things are looking good. Now let's evaluate the model on the test set."
|
"Training is finished, we reached over 99.4% accuracy on the validation set after just 5 epochs, things are looking good. Now let's evaluate the model on the test set."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1882,7 +1780,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 73,
|
"execution_count": 74,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1915,7 +1813,7 @@
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"We reach 99.43% accuracy on the test set. Pretty nice. :)"
|
"We reach 99.53% accuracy on the test set. Pretty nice. :)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1934,7 +1832,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 74,
|
"execution_count": 75,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1966,7 +1864,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 75,
|
"execution_count": 76,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2022,7 +1920,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 76,
|
"execution_count": 77,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2038,10 +1936,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 77,
|
"execution_count": 78,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def tweak_pose_parameters(output_vectors, min=-0.5, max=0.5, n_steps=11):\n",
|
"def tweak_pose_parameters(output_vectors, min=-0.5, max=0.5, n_steps=11):\n",
|
||||||
|
@ -2062,10 +1958,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 78,
|
"execution_count": 79,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"n_steps = 11\n",
|
"n_steps = 11\n",
|
||||||
|
@ -2084,7 +1978,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 79,
|
"execution_count": 80,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2108,10 +2002,8 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 80,
|
"execution_count": 81,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"tweak_reconstructions = decoder_output_value.reshape(\n",
|
"tweak_reconstructions = decoder_output_value.reshape(\n",
|
||||||
|
@ -2127,7 +2019,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 81,
|
"execution_count": 82,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2161,9 +2053,7 @@
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
"metadata": {
|
"metadata": {},
|
||||||
"collapsed": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue