Make notebook code match book examples more closely in chapter 1

main
Aurélien Geron 2017-06-01 09:57:58 +02:00
parent 88acd2b4b9
commit 1bc60fe315
1 changed files with 213 additions and 86 deletions

View File

@ -2,7 +2,10 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"**Chapter 1 The Machine Learning landscape**\n",
"\n",
@ -11,14 +14,20 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
]
@ -28,6 +37,8 @@
"execution_count": 1,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true,
"slideshow": {
"slide_type": "-"
}
@ -67,7 +78,10 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Load and prepare Life satisfaction data"
]
@ -76,7 +90,9 @@
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -95,7 +111,9 @@
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -104,7 +122,10 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"deletable": true,
"editable": true
},
"source": [
"# Load and prepare GDP per capita data"
]
@ -113,7 +134,9 @@
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -129,7 +152,9 @@
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -142,7 +167,9 @@
"cell_type": "code",
"execution_count": 6,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -153,7 +180,9 @@
"cell_type": "code",
"execution_count": 7,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -166,9 +195,11 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 8,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -195,7 +226,22 @@
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"sample_data.to_csv(\"life_satisfaction_vs_gdp_per_capita.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -204,9 +250,11 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -230,9 +278,11 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -247,9 +297,11 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -265,9 +317,11 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -279,9 +333,11 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -300,9 +356,11 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 16,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -311,9 +369,11 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -322,26 +382,40 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 18,
"metadata": {
"collapsed": false
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"backup = oecd_bli, gdp_per_capita\n",
"\n",
"def prepare_country_stats(oecd_bli, gdp_per_capita):\n",
" return sample_data\n",
"\n",
" return sample_data"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"# Code example\n",
"########################################################################\n",
"import sklearn\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import sklearn\n",
"\n",
"# Load the data\n",
"oecd_bli = pd.read_csv(datapath+\"oecd_bli_2015.csv\", thousands=',')\n",
"gdp_per_capita = pd.read_csv(datapath+\"gdp_per_capita.csv\", thousands=',',delimiter='\\t',\n",
"oecd_bli = pd.read_csv(datapath + \"oecd_bli_2015.csv\", thousands=',')\n",
"gdp_per_capita = pd.read_csv(datapath + \"gdp_per_capita.csv\",thousands=',',delimiter='\\t',\n",
" encoding='latin1', na_values=\"n/a\")\n",
"\n",
"# Prepare the data\n",
@ -354,24 +428,36 @@
"plt.show()\n",
"\n",
"# Select a linear model\n",
"lin_reg_model = sklearn.linear_model.LinearRegression()\n",
"model = sklearn.linear_model.LinearRegression()\n",
"\n",
"# Train the model\n",
"lin_reg_model.fit(X, y)\n",
"model.fit(X, y)\n",
"\n",
"# Make a prediction for Cyprus\n",
"X_new = [[22587]] # Cyprus' GDP per capita\n",
"print(lin_reg_model.predict(X_new)) # outputs [[ 5.96242338]]\n",
"########################################################################\n",
"\n",
"print(model.predict(X_new)) # outputs [[ 5.96242338]]"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"oecd_bli, gdp_per_capita = backup"
]
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 21,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -380,9 +466,11 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 22,
"metadata": {
"collapsed": true
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -399,9 +487,11 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 23,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -432,9 +522,11 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 24,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -458,9 +550,11 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 25,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -469,9 +563,11 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 26,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -480,9 +576,11 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 27,
"metadata": {
"collapsed": false
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
@ -513,50 +611,79 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 28,
"metadata": {
"collapsed": false
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"backup = oecd_bli, gdp_per_capita\n",
"\n",
"def prepare_country_stats(oecd_bli, gdp_per_capita):\n",
" return sample_data\n",
"\n",
"# Code example\n",
"########################################################################\n",
"from sklearn import neighbors\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"# Load the data\n",
"oecd_bli = pd.read_csv(datapath+\"oecd_bli_2015.csv\", thousands=',')\n",
"gdp_per_capita = pd.read_csv(datapath+\"gdp_per_capita.csv\", thousands=',',delimiter='\\t',\n",
" encoding='latin1', na_values=\"n/a\")\n",
"\n",
"# Prepare the data\n",
"country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)\n",
" return sample_data"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"# Replace this linear model:\n",
"model = sklearn.linear_model.LinearRegression()"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"# with this k-neighbors regression model:\n",
"model = sklearn.neighbors.KNeighborsRegressor(n_neighbors=3)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"collapsed": false,
"deletable": true,
"editable": true
},
"outputs": [],
"source": [
"X = np.c_[country_stats[\"GDP per capita\"]]\n",
"y = np.c_[country_stats[\"Life satisfaction\"]]\n",
"\n",
"# Visualize the data\n",
"country_stats.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction')\n",
"plt.show()\n",
"\n",
"# Select a k-neighboors regression model\n",
"k_neigh_reg_model = neighbors.KNeighborsRegressor(n_neighbors=3)\n",
"\n",
"# Train the model\n",
"k_neigh_reg_model.fit(X, y)\n",
"model.fit(X, y)\n",
"\n",
"# Make a prediction for Cyprus\n",
"X_new = [[22587]] # Cyprus' GDP per capita\n",
"print(lin_reg_model.predict(X_new)) # outputs [[ 5.96242338]]\n",
"########################################################################\n",
"\n",
"oecd_bli, gdp_per_capita = backup"
"X_new = np.array([[22587.0]]) # Cyprus' GDP per capita\n",
"print(model.predict(X_new)) # outputs [[ 5.76666667]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"deletable": true,
"editable": true
},
"outputs": [],
"source": []
}
],
"metadata": {
@ -575,7 +702,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
"version": "3.5.3"
},
"nav_menu": {},
"toc": {