811 lines
24 KiB
Plaintext
811 lines
24 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Chapter 1 – The Machine Learning landscape**\n",
|
||
"\n",
|
||
"_This is the code used to generate some of the figures in chapter 1._"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"<table align=\"left\">\n",
|
||
" <td>\n",
|
||
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/ageron/handson-ml2/blob/master/01_the_machine_learning_landscape.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
||
" </td>\n",
|
||
"</table>"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Code example 1-1"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"slideshow": {
|
||
"slide_type": "-"
|
||
}
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Python ≥3.5 is required\n",
|
||
"import sys\n",
|
||
"assert sys.version_info >= (3, 5)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Scikit-Learn ≥0.20 is required\n",
|
||
"import sklearn\n",
|
||
"assert sklearn.__version__ >= \"0.20\""
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"This function just merges the OECD's life satisfaction data and the IMF's GDP per capita data. It's a bit too long and boring and it's not specific to Machine Learning, which is why I left it out of the book."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def prepare_country_stats(oecd_bli, gdp_per_capita):\n",
|
||
" oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
|
||
" oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
|
||
" gdp_per_capita.rename(columns={\"2015\": \"GDP per capita\"}, inplace=True)\n",
|
||
" gdp_per_capita.set_index(\"Country\", inplace=True)\n",
|
||
" full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n",
|
||
" left_index=True, right_index=True)\n",
|
||
" full_country_stats.sort_values(by=\"GDP per capita\", inplace=True)\n",
|
||
" remove_indices = [0, 1, 6, 8, 33, 34, 35]\n",
|
||
" keep_indices = list(set(range(36)) - set(remove_indices))\n",
|
||
" return full_country_stats[[\"GDP per capita\", 'Life satisfaction']].iloc[keep_indices]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"The code in the book expects the data files to be located in the current directory. I just tweaked it here to fetch the files in datasets/lifesat."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import os\n",
|
||
"datapath = os.path.join(\"datasets\", \"lifesat\", \"\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# To plot pretty figures directly within Jupyter\n",
|
||
"%matplotlib inline\n",
|
||
"import matplotlib as mpl\n",
|
||
"mpl.rc('axes', labelsize=14)\n",
|
||
"mpl.rc('xtick', labelsize=12)\n",
|
||
"mpl.rc('ytick', labelsize=12)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Download the data\n",
|
||
"import urllib\n",
|
||
"DOWNLOAD_ROOT = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/\"\n",
|
||
"os.makedirs(datapath, exist_ok=True)\n",
|
||
"for filename in (\"oecd_bli_2015.csv\", \"gdp_per_capita.csv\"):\n",
|
||
" print(\"Downloading\", filename)\n",
|
||
" url = DOWNLOAD_ROOT + \"datasets/lifesat/\" + filename\n",
|
||
" urllib.request.urlretrieve(url, datapath + filename)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Code example\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"import sklearn.linear_model\n",
|
||
"\n",
|
||
"# Load the data\n",
|
||
"oecd_bli = pd.read_csv(datapath + \"oecd_bli_2015.csv\", thousands=',')\n",
|
||
"gdp_per_capita = pd.read_csv(datapath + \"gdp_per_capita.csv\",thousands=',',delimiter='\\t',\n",
|
||
" encoding='latin1', na_values=\"n/a\")\n",
|
||
"\n",
|
||
"# Prepare the data\n",
|
||
"country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)\n",
|
||
"X = np.c_[country_stats[\"GDP per capita\"]]\n",
|
||
"y = np.c_[country_stats[\"Life satisfaction\"]]\n",
|
||
"\n",
|
||
"# Visualize the data\n",
|
||
"country_stats.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction')\n",
|
||
"plt.show()\n",
|
||
"\n",
|
||
"# Select a linear model\n",
|
||
"model = sklearn.linear_model.LinearRegression()\n",
|
||
"\n",
|
||
"# Train the model\n",
|
||
"model.fit(X, y)\n",
|
||
"\n",
|
||
"# Make a prediction for Cyprus\n",
|
||
"X_new = [[22587]] # Cyprus' GDP per capita\n",
|
||
"print(model.predict(X_new)) # outputs [[ 5.96242338]]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Note: you can ignore the rest of this notebook, it just generates many of the figures in chapter 1."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Create a function to save the figures."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Where to save the figures\n",
|
||
"PROJECT_ROOT_DIR = \".\"\n",
|
||
"CHAPTER_ID = \"fundamentals\"\n",
|
||
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
|
||
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
|
||
"\n",
|
||
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
|
||
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n",
|
||
" print(\"Saving figure\", fig_id)\n",
|
||
" if tight_layout:\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.savefig(path, format=fig_extension, dpi=resolution)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Make this notebook's output stable across runs:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"np.random.seed(42)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Load and prepare Life satisfaction data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"If you want, you can get fresh data from the OECD's website.\n",
|
||
"Download the CSV from http://stats.oecd.org/index.aspx?DataSetCode=BLI\n",
|
||
"and save it to `datasets/lifesat/`."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"oecd_bli = pd.read_csv(datapath + \"oecd_bli_2015.csv\", thousands=',')\n",
|
||
"oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
|
||
"oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
|
||
"oecd_bli.head(2)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"oecd_bli[\"Life satisfaction\"].head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Load and prepare GDP per capita data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Just like above, you can update the GDP per capita data if you want. Just download data from http://goo.gl/j1MSKe (=> imf.org) and save it to `datasets/lifesat/`."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"gdp_per_capita = pd.read_csv(datapath+\"gdp_per_capita.csv\", thousands=',', delimiter='\\t',\n",
|
||
" encoding='latin1', na_values=\"n/a\")\n",
|
||
"gdp_per_capita.rename(columns={\"2015\": \"GDP per capita\"}, inplace=True)\n",
|
||
"gdp_per_capita.set_index(\"Country\", inplace=True)\n",
|
||
"gdp_per_capita.head(2)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita, left_index=True, right_index=True)\n",
|
||
"full_country_stats.sort_values(by=\"GDP per capita\", inplace=True)\n",
|
||
"full_country_stats"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"full_country_stats[[\"GDP per capita\", 'Life satisfaction']].loc[\"United States\"]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"remove_indices = [0, 1, 6, 8, 33, 34, 35]\n",
|
||
"keep_indices = list(set(range(36)) - set(remove_indices))\n",
|
||
"\n",
|
||
"sample_data = full_country_stats[[\"GDP per capita\", 'Life satisfaction']].iloc[keep_indices]\n",
|
||
"missing_data = full_country_stats[[\"GDP per capita\", 'Life satisfaction']].iloc[remove_indices]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n",
|
||
"plt.axis([0, 60000, 0, 10])\n",
|
||
"position_text = {\n",
|
||
" \"Hungary\": (5000, 1),\n",
|
||
" \"Korea\": (18000, 1.7),\n",
|
||
" \"France\": (29000, 2.4),\n",
|
||
" \"Australia\": (40000, 3.0),\n",
|
||
" \"United States\": (52000, 3.8),\n",
|
||
"}\n",
|
||
"for country, pos_text in position_text.items():\n",
|
||
" pos_data_x, pos_data_y = sample_data.loc[country]\n",
|
||
" country = \"U.S.\" if country == \"United States\" else country\n",
|
||
" plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,\n",
|
||
" arrowprops=dict(facecolor='black', width=0.5, shrink=0.1, headwidth=5))\n",
|
||
" plt.plot(pos_data_x, pos_data_y, \"ro\")\n",
|
||
"plt.xlabel(\"GDP per capita (USD)\")\n",
|
||
"save_fig('money_happy_scatterplot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"sample_data.to_csv(os.path.join(\"datasets\", \"lifesat\", \"lifesat.csv\"))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"sample_data.loc[list(position_text.keys())]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n",
|
||
"plt.xlabel(\"GDP per capita (USD)\")\n",
|
||
"plt.axis([0, 60000, 0, 10])\n",
|
||
"X=np.linspace(0, 60000, 1000)\n",
|
||
"plt.plot(X, 2*X/100000, \"r\")\n",
|
||
"plt.text(40000, 2.7, r\"$\\theta_0 = 0$\", fontsize=14, color=\"r\")\n",
|
||
"plt.text(40000, 1.8, r\"$\\theta_1 = 2 \\times 10^{-5}$\", fontsize=14, color=\"r\")\n",
|
||
"plt.plot(X, 8 - 5*X/100000, \"g\")\n",
|
||
"plt.text(5000, 9.1, r\"$\\theta_0 = 8$\", fontsize=14, color=\"g\")\n",
|
||
"plt.text(5000, 8.2, r\"$\\theta_1 = -5 \\times 10^{-5}$\", fontsize=14, color=\"g\")\n",
|
||
"plt.plot(X, 4 + 5*X/100000, \"b\")\n",
|
||
"plt.text(5000, 3.5, r\"$\\theta_0 = 4$\", fontsize=14, color=\"b\")\n",
|
||
"plt.text(5000, 2.6, r\"$\\theta_1 = 5 \\times 10^{-5}$\", fontsize=14, color=\"b\")\n",
|
||
"save_fig('tweaking_model_params_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from sklearn import linear_model\n",
|
||
"lin1 = linear_model.LinearRegression()\n",
|
||
"Xsample = np.c_[sample_data[\"GDP per capita\"]]\n",
|
||
"ysample = np.c_[sample_data[\"Life satisfaction\"]]\n",
|
||
"lin1.fit(Xsample, ysample)\n",
|
||
"t0, t1 = lin1.intercept_[0], lin1.coef_[0][0]\n",
|
||
"t0, t1"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n",
|
||
"plt.xlabel(\"GDP per capita (USD)\")\n",
|
||
"plt.axis([0, 60000, 0, 10])\n",
|
||
"X=np.linspace(0, 60000, 1000)\n",
|
||
"plt.plot(X, t0 + t1*X, \"b\")\n",
|
||
"plt.text(5000, 3.1, r\"$\\theta_0 = 4.85$\", fontsize=14, color=\"b\")\n",
|
||
"plt.text(5000, 2.2, r\"$\\theta_1 = 4.91 \\times 10^{-5}$\", fontsize=14, color=\"b\")\n",
|
||
"save_fig('best_fit_model_plot')\n",
|
||
"plt.show()\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"cyprus_gdp_per_capita = gdp_per_capita.loc[\"Cyprus\"][\"GDP per capita\"]\n",
|
||
"print(cyprus_gdp_per_capita)\n",
|
||
"cyprus_predicted_life_satisfaction = lin1.predict([[cyprus_gdp_per_capita]])[0][0]\n",
|
||
"cyprus_predicted_life_satisfaction"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3), s=1)\n",
|
||
"plt.xlabel(\"GDP per capita (USD)\")\n",
|
||
"X=np.linspace(0, 60000, 1000)\n",
|
||
"plt.plot(X, t0 + t1*X, \"b\")\n",
|
||
"plt.axis([0, 60000, 0, 10])\n",
|
||
"plt.text(5000, 7.5, r\"$\\theta_0 = 4.85$\", fontsize=14, color=\"b\")\n",
|
||
"plt.text(5000, 6.6, r\"$\\theta_1 = 4.91 \\times 10^{-5}$\", fontsize=14, color=\"b\")\n",
|
||
"plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita], [0, cyprus_predicted_life_satisfaction], \"r--\")\n",
|
||
"plt.text(25000, 5.0, r\"Prediction = 5.96\", fontsize=14, color=\"b\")\n",
|
||
"plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, \"ro\")\n",
|
||
"save_fig('cyprus_prediction_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"sample_data[7:10]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"(5.1+5.7+6.5)/3"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"backup = oecd_bli, gdp_per_capita\n",
|
||
"\n",
|
||
"def prepare_country_stats(oecd_bli, gdp_per_capita):\n",
|
||
" oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
|
||
" oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
|
||
" gdp_per_capita.rename(columns={\"2015\": \"GDP per capita\"}, inplace=True)\n",
|
||
" gdp_per_capita.set_index(\"Country\", inplace=True)\n",
|
||
" full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n",
|
||
" left_index=True, right_index=True)\n",
|
||
" full_country_stats.sort_values(by=\"GDP per capita\", inplace=True)\n",
|
||
" remove_indices = [0, 1, 6, 8, 33, 34, 35]\n",
|
||
" keep_indices = list(set(range(36)) - set(remove_indices))\n",
|
||
" return full_country_stats[[\"GDP per capita\", 'Life satisfaction']].iloc[keep_indices]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Code example\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"import sklearn.linear_model\n",
|
||
"\n",
|
||
"# Load the data\n",
|
||
"oecd_bli = pd.read_csv(datapath + \"oecd_bli_2015.csv\", thousands=',')\n",
|
||
"gdp_per_capita = pd.read_csv(datapath + \"gdp_per_capita.csv\",thousands=',',delimiter='\\t',\n",
|
||
" encoding='latin1', na_values=\"n/a\")\n",
|
||
"\n",
|
||
"# Prepare the data\n",
|
||
"country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)\n",
|
||
"X = np.c_[country_stats[\"GDP per capita\"]]\n",
|
||
"y = np.c_[country_stats[\"Life satisfaction\"]]\n",
|
||
"\n",
|
||
"# Visualize the data\n",
|
||
"country_stats.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction')\n",
|
||
"plt.show()\n",
|
||
"\n",
|
||
"# Select a linear model\n",
|
||
"model = sklearn.linear_model.LinearRegression()\n",
|
||
"\n",
|
||
"# Train the model\n",
|
||
"model.fit(X, y)\n",
|
||
"\n",
|
||
"# Make a prediction for Cyprus\n",
|
||
"X_new = [[22587]] # Cyprus' GDP per capita\n",
|
||
"print(model.predict(X_new)) # outputs [[ 5.96242338]]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"oecd_bli, gdp_per_capita = backup"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 29,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"missing_data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 30,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"position_text2 = {\n",
|
||
" \"Brazil\": (1000, 9.0),\n",
|
||
" \"Mexico\": (11000, 9.0),\n",
|
||
" \"Chile\": (25000, 9.0),\n",
|
||
" \"Czech Republic\": (35000, 9.0),\n",
|
||
" \"Norway\": (60000, 3),\n",
|
||
" \"Switzerland\": (72000, 3.0),\n",
|
||
" \"Luxembourg\": (90000, 3.0),\n",
|
||
"}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 31,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(8,3))\n",
|
||
"plt.axis([0, 110000, 0, 10])\n",
|
||
"\n",
|
||
"for country, pos_text in position_text2.items():\n",
|
||
" pos_data_x, pos_data_y = missing_data.loc[country]\n",
|
||
" plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,\n",
|
||
" arrowprops=dict(facecolor='black', width=0.5, shrink=0.1, headwidth=5))\n",
|
||
" plt.plot(pos_data_x, pos_data_y, \"rs\")\n",
|
||
"\n",
|
||
"X=np.linspace(0, 110000, 1000)\n",
|
||
"plt.plot(X, t0 + t1*X, \"b:\")\n",
|
||
"\n",
|
||
"lin_reg_full = linear_model.LinearRegression()\n",
|
||
"Xfull = np.c_[full_country_stats[\"GDP per capita\"]]\n",
|
||
"yfull = np.c_[full_country_stats[\"Life satisfaction\"]]\n",
|
||
"lin_reg_full.fit(Xfull, yfull)\n",
|
||
"\n",
|
||
"t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]\n",
|
||
"X = np.linspace(0, 110000, 1000)\n",
|
||
"plt.plot(X, t0full + t1full * X, \"k\")\n",
|
||
"plt.xlabel(\"GDP per capita (USD)\")\n",
|
||
"\n",
|
||
"save_fig('representative_training_data_scatterplot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 32,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"full_country_stats.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(8,3))\n",
|
||
"plt.axis([0, 110000, 0, 10])\n",
|
||
"\n",
|
||
"from sklearn import preprocessing\n",
|
||
"from sklearn import pipeline\n",
|
||
"\n",
|
||
"poly = preprocessing.PolynomialFeatures(degree=60, include_bias=False)\n",
|
||
"scaler = preprocessing.StandardScaler()\n",
|
||
"lin_reg2 = linear_model.LinearRegression()\n",
|
||
"\n",
|
||
"pipeline_reg = pipeline.Pipeline([('poly', poly), ('scal', scaler), ('lin', lin_reg2)])\n",
|
||
"pipeline_reg.fit(Xfull, yfull)\n",
|
||
"curve = pipeline_reg.predict(X[:, np.newaxis])\n",
|
||
"plt.plot(X, curve)\n",
|
||
"plt.xlabel(\"GDP per capita (USD)\")\n",
|
||
"save_fig('overfitting_model_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 33,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"full_country_stats.loc[[c for c in full_country_stats.index if \"W\" in c.upper()]][\"Life satisfaction\"]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 34,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"gdp_per_capita.loc[[c for c in gdp_per_capita.index if \"W\" in c.upper()]].head()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 35,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"plt.figure(figsize=(8,3))\n",
|
||
"\n",
|
||
"plt.xlabel(\"GDP per capita\")\n",
|
||
"plt.ylabel('Life satisfaction')\n",
|
||
"\n",
|
||
"plt.plot(list(sample_data[\"GDP per capita\"]), list(sample_data[\"Life satisfaction\"]), \"bo\")\n",
|
||
"plt.plot(list(missing_data[\"GDP per capita\"]), list(missing_data[\"Life satisfaction\"]), \"rs\")\n",
|
||
"\n",
|
||
"X = np.linspace(0, 110000, 1000)\n",
|
||
"plt.plot(X, t0full + t1full * X, \"r--\", label=\"Linear model on all data\")\n",
|
||
"plt.plot(X, t0 + t1*X, \"b:\", label=\"Linear model on partial data\")\n",
|
||
"\n",
|
||
"ridge = linear_model.Ridge(alpha=10**9.5)\n",
|
||
"Xsample = np.c_[sample_data[\"GDP per capita\"]]\n",
|
||
"ysample = np.c_[sample_data[\"Life satisfaction\"]]\n",
|
||
"ridge.fit(Xsample, ysample)\n",
|
||
"t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n",
|
||
"plt.plot(X, t0ridge + t1ridge * X, \"b\", label=\"Regularized linear model on partial data\")\n",
|
||
"\n",
|
||
"plt.legend(loc=\"lower right\")\n",
|
||
"plt.axis([0, 110000, 0, 10])\n",
|
||
"plt.xlabel(\"GDP per capita (USD)\")\n",
|
||
"save_fig('ridge_model_plot')\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 36,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"backup = oecd_bli, gdp_per_capita\n",
|
||
"\n",
|
||
"def prepare_country_stats(oecd_bli, gdp_per_capita):\n",
|
||
" return sample_data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 37,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Replace this linear model:\n",
|
||
"import sklearn.linear_model\n",
|
||
"model = sklearn.linear_model.LinearRegression()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 38,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# with this k-neighbors regression model:\n",
|
||
"import sklearn.neighbors\n",
|
||
"model = sklearn.neighbors.KNeighborsRegressor(n_neighbors=3)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 39,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"X = np.c_[country_stats[\"GDP per capita\"]]\n",
|
||
"y = np.c_[country_stats[\"Life satisfaction\"]]\n",
|
||
"\n",
|
||
"# Train the model\n",
|
||
"model.fit(X, y)\n",
|
||
"\n",
|
||
"# Make a prediction for Cyprus\n",
|
||
"X_new = np.array([[22587.0]]) # Cyprus' GDP per capita\n",
|
||
"print(model.predict(X_new)) # outputs [[ 5.76666667]]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.7.3"
|
||
},
|
||
"nav_menu": {},
|
||
"toc": {
|
||
"navigate_menu": true,
|
||
"number_sections": true,
|
||
"sideBar": true,
|
||
"threshold": 6,
|
||
"toc_cell": false,
|
||
"toc_section_display": "block",
|
||
"toc_window_display": true
|
||
},
|
||
"toc_position": {
|
||
"height": "616px",
|
||
"left": "0px",
|
||
"right": "20px",
|
||
"top": "106px",
|
||
"width": "213px"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 1
|
||
}
|