handson-ml/01_the_machine_learning_lan...

811 lines
24 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Chapter 1 The Machine Learning landscape**\n",
"\n",
"_This is the code used to generate some of the figures in chapter 1._"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/ageron/handson-ml2/blob/master/01_the_machine_learning_landscape.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Code example 1-1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"slideshow": {
"slide_type": "-"
}
},
"outputs": [],
"source": [
"# Python ≥3.5 is required\n",
"import sys\n",
"assert sys.version_info >= (3, 5)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Scikit-Learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This function just merges the OECD's life satisfaction data and the IMF's GDP per capita data. It's a bit too long and boring and it's not specific to Machine Learning, which is why I left it out of the book."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def prepare_country_stats(oecd_bli, gdp_per_capita):\n",
" oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
" oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
" gdp_per_capita.rename(columns={\"2015\": \"GDP per capita\"}, inplace=True)\n",
" gdp_per_capita.set_index(\"Country\", inplace=True)\n",
" full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n",
" left_index=True, right_index=True)\n",
" full_country_stats.sort_values(by=\"GDP per capita\", inplace=True)\n",
" remove_indices = [0, 1, 6, 8, 33, 34, 35]\n",
" keep_indices = list(set(range(36)) - set(remove_indices))\n",
" return full_country_stats[[\"GDP per capita\", 'Life satisfaction']].iloc[keep_indices]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The code in the book expects the data files to be located in the current directory. I just tweaked it here to fetch the files in datasets/lifesat."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"datapath = os.path.join(\"datasets\", \"lifesat\", \"\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# To plot pretty figures directly within Jupyter\n",
"%matplotlib inline\n",
"import matplotlib as mpl\n",
"mpl.rc('axes', labelsize=14)\n",
"mpl.rc('xtick', labelsize=12)\n",
"mpl.rc('ytick', labelsize=12)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Download the data\n",
"import urllib\n",
"DOWNLOAD_ROOT = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/\"\n",
"os.makedirs(datapath, exist_ok=True)\n",
"for filename in (\"oecd_bli_2015.csv\", \"gdp_per_capita.csv\"):\n",
" print(\"Downloading\", filename)\n",
" url = DOWNLOAD_ROOT + \"datasets/lifesat/\" + filename\n",
" urllib.request.urlretrieve(url, datapath + filename)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# Code example\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import sklearn.linear_model\n",
"\n",
"# Load the data\n",
"oecd_bli = pd.read_csv(datapath + \"oecd_bli_2015.csv\", thousands=',')\n",
"gdp_per_capita = pd.read_csv(datapath + \"gdp_per_capita.csv\",thousands=',',delimiter='\\t',\n",
" encoding='latin1', na_values=\"n/a\")\n",
"\n",
"# Prepare the data\n",
"country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)\n",
"X = np.c_[country_stats[\"GDP per capita\"]]\n",
"y = np.c_[country_stats[\"Life satisfaction\"]]\n",
"\n",
"# Visualize the data\n",
"country_stats.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction')\n",
"plt.show()\n",
"\n",
"# Select a linear model\n",
"model = sklearn.linear_model.LinearRegression()\n",
"\n",
"# Train the model\n",
"model.fit(X, y)\n",
"\n",
"# Make a prediction for Cyprus\n",
"X_new = [[22587]] # Cyprus' GDP per capita\n",
"print(model.predict(X_new)) # outputs [[ 5.96242338]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Note: you can ignore the rest of this notebook, it just generates many of the figures in chapter 1."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create a function to save the figures."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n",
"CHAPTER_ID = \"fundamentals\"\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Make this notebook's output stable across runs:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(42)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load and prepare Life satisfaction data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you want, you can get fresh data from the OECD's website.\n",
"Download the CSV from http://stats.oecd.org/index.aspx?DataSetCode=BLI\n",
"and save it to `datasets/lifesat/`."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"oecd_bli = pd.read_csv(datapath + \"oecd_bli_2015.csv\", thousands=',')\n",
"oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
"oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
"oecd_bli.head(2)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"oecd_bli[\"Life satisfaction\"].head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Load and prepare GDP per capita data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Just like above, you can update the GDP per capita data if you want. Just download data from http://goo.gl/j1MSKe (=> imf.org) and save it to `datasets/lifesat/`."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"gdp_per_capita = pd.read_csv(datapath+\"gdp_per_capita.csv\", thousands=',', delimiter='\\t',\n",
" encoding='latin1', na_values=\"n/a\")\n",
"gdp_per_capita.rename(columns={\"2015\": \"GDP per capita\"}, inplace=True)\n",
"gdp_per_capita.set_index(\"Country\", inplace=True)\n",
"gdp_per_capita.head(2)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita, left_index=True, right_index=True)\n",
"full_country_stats.sort_values(by=\"GDP per capita\", inplace=True)\n",
"full_country_stats"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"full_country_stats[[\"GDP per capita\", 'Life satisfaction']].loc[\"United States\"]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"remove_indices = [0, 1, 6, 8, 33, 34, 35]\n",
"keep_indices = list(set(range(36)) - set(remove_indices))\n",
"\n",
"sample_data = full_country_stats[[\"GDP per capita\", 'Life satisfaction']].iloc[keep_indices]\n",
"missing_data = full_country_stats[[\"GDP per capita\", 'Life satisfaction']].iloc[remove_indices]"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n",
"plt.axis([0, 60000, 0, 10])\n",
"position_text = {\n",
" \"Hungary\": (5000, 1),\n",
" \"Korea\": (18000, 1.7),\n",
" \"France\": (29000, 2.4),\n",
" \"Australia\": (40000, 3.0),\n",
" \"United States\": (52000, 3.8),\n",
"}\n",
"for country, pos_text in position_text.items():\n",
" pos_data_x, pos_data_y = sample_data.loc[country]\n",
" country = \"U.S.\" if country == \"United States\" else country\n",
" plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,\n",
" arrowprops=dict(facecolor='black', width=0.5, shrink=0.1, headwidth=5))\n",
" plt.plot(pos_data_x, pos_data_y, \"ro\")\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"save_fig('money_happy_scatterplot')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"sample_data.to_csv(os.path.join(\"datasets\", \"lifesat\", \"lifesat.csv\"))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"sample_data.loc[list(position_text.keys())]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"plt.axis([0, 60000, 0, 10])\n",
"X=np.linspace(0, 60000, 1000)\n",
"plt.plot(X, 2*X/100000, \"r\")\n",
"plt.text(40000, 2.7, r\"$\\theta_0 = 0$\", fontsize=14, color=\"r\")\n",
"plt.text(40000, 1.8, r\"$\\theta_1 = 2 \\times 10^{-5}$\", fontsize=14, color=\"r\")\n",
"plt.plot(X, 8 - 5*X/100000, \"g\")\n",
"plt.text(5000, 9.1, r\"$\\theta_0 = 8$\", fontsize=14, color=\"g\")\n",
"plt.text(5000, 8.2, r\"$\\theta_1 = -5 \\times 10^{-5}$\", fontsize=14, color=\"g\")\n",
"plt.plot(X, 4 + 5*X/100000, \"b\")\n",
"plt.text(5000, 3.5, r\"$\\theta_0 = 4$\", fontsize=14, color=\"b\")\n",
"plt.text(5000, 2.6, r\"$\\theta_1 = 5 \\times 10^{-5}$\", fontsize=14, color=\"b\")\n",
"save_fig('tweaking_model_params_plot')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"from sklearn import linear_model\n",
"lin1 = linear_model.LinearRegression()\n",
"Xsample = np.c_[sample_data[\"GDP per capita\"]]\n",
"ysample = np.c_[sample_data[\"Life satisfaction\"]]\n",
"lin1.fit(Xsample, ysample)\n",
"t0, t1 = lin1.intercept_[0], lin1.coef_[0][0]\n",
"t0, t1"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"plt.axis([0, 60000, 0, 10])\n",
"X=np.linspace(0, 60000, 1000)\n",
"plt.plot(X, t0 + t1*X, \"b\")\n",
"plt.text(5000, 3.1, r\"$\\theta_0 = 4.85$\", fontsize=14, color=\"b\")\n",
"plt.text(5000, 2.2, r\"$\\theta_1 = 4.91 \\times 10^{-5}$\", fontsize=14, color=\"b\")\n",
"save_fig('best_fit_model_plot')\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"cyprus_gdp_per_capita = gdp_per_capita.loc[\"Cyprus\"][\"GDP per capita\"]\n",
"print(cyprus_gdp_per_capita)\n",
"cyprus_predicted_life_satisfaction = lin1.predict([[cyprus_gdp_per_capita]])[0][0]\n",
"cyprus_predicted_life_satisfaction"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3), s=1)\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"X=np.linspace(0, 60000, 1000)\n",
"plt.plot(X, t0 + t1*X, \"b\")\n",
"plt.axis([0, 60000, 0, 10])\n",
"plt.text(5000, 7.5, r\"$\\theta_0 = 4.85$\", fontsize=14, color=\"b\")\n",
"plt.text(5000, 6.6, r\"$\\theta_1 = 4.91 \\times 10^{-5}$\", fontsize=14, color=\"b\")\n",
"plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita], [0, cyprus_predicted_life_satisfaction], \"r--\")\n",
"plt.text(25000, 5.0, r\"Prediction = 5.96\", fontsize=14, color=\"b\")\n",
"plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, \"ro\")\n",
"save_fig('cyprus_prediction_plot')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"sample_data[7:10]"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"(5.1+5.7+6.5)/3"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"backup = oecd_bli, gdp_per_capita\n",
"\n",
"def prepare_country_stats(oecd_bli, gdp_per_capita):\n",
" oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
" oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
" gdp_per_capita.rename(columns={\"2015\": \"GDP per capita\"}, inplace=True)\n",
" gdp_per_capita.set_index(\"Country\", inplace=True)\n",
" full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n",
" left_index=True, right_index=True)\n",
" full_country_stats.sort_values(by=\"GDP per capita\", inplace=True)\n",
" remove_indices = [0, 1, 6, 8, 33, 34, 35]\n",
" keep_indices = list(set(range(36)) - set(remove_indices))\n",
" return full_country_stats[[\"GDP per capita\", 'Life satisfaction']].iloc[keep_indices]"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"# Code example\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import sklearn.linear_model\n",
"\n",
"# Load the data\n",
"oecd_bli = pd.read_csv(datapath + \"oecd_bli_2015.csv\", thousands=',')\n",
"gdp_per_capita = pd.read_csv(datapath + \"gdp_per_capita.csv\",thousands=',',delimiter='\\t',\n",
" encoding='latin1', na_values=\"n/a\")\n",
"\n",
"# Prepare the data\n",
"country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)\n",
"X = np.c_[country_stats[\"GDP per capita\"]]\n",
"y = np.c_[country_stats[\"Life satisfaction\"]]\n",
"\n",
"# Visualize the data\n",
"country_stats.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction')\n",
"plt.show()\n",
"\n",
"# Select a linear model\n",
"model = sklearn.linear_model.LinearRegression()\n",
"\n",
"# Train the model\n",
"model.fit(X, y)\n",
"\n",
"# Make a prediction for Cyprus\n",
"X_new = [[22587]] # Cyprus' GDP per capita\n",
"print(model.predict(X_new)) # outputs [[ 5.96242338]]"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"oecd_bli, gdp_per_capita = backup"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"missing_data"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"position_text2 = {\n",
" \"Brazil\": (1000, 9.0),\n",
" \"Mexico\": (11000, 9.0),\n",
" \"Chile\": (25000, 9.0),\n",
" \"Czech Republic\": (35000, 9.0),\n",
" \"Norway\": (60000, 3),\n",
" \"Switzerland\": (72000, 3.0),\n",
" \"Luxembourg\": (90000, 3.0),\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(8,3))\n",
"plt.axis([0, 110000, 0, 10])\n",
"\n",
"for country, pos_text in position_text2.items():\n",
" pos_data_x, pos_data_y = missing_data.loc[country]\n",
" plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,\n",
" arrowprops=dict(facecolor='black', width=0.5, shrink=0.1, headwidth=5))\n",
" plt.plot(pos_data_x, pos_data_y, \"rs\")\n",
"\n",
"X=np.linspace(0, 110000, 1000)\n",
"plt.plot(X, t0 + t1*X, \"b:\")\n",
"\n",
"lin_reg_full = linear_model.LinearRegression()\n",
"Xfull = np.c_[full_country_stats[\"GDP per capita\"]]\n",
"yfull = np.c_[full_country_stats[\"Life satisfaction\"]]\n",
"lin_reg_full.fit(Xfull, yfull)\n",
"\n",
"t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]\n",
"X = np.linspace(0, 110000, 1000)\n",
"plt.plot(X, t0full + t1full * X, \"k\")\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"\n",
"save_fig('representative_training_data_scatterplot')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"full_country_stats.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(8,3))\n",
"plt.axis([0, 110000, 0, 10])\n",
"\n",
"from sklearn import preprocessing\n",
"from sklearn import pipeline\n",
"\n",
"poly = preprocessing.PolynomialFeatures(degree=60, include_bias=False)\n",
"scaler = preprocessing.StandardScaler()\n",
"lin_reg2 = linear_model.LinearRegression()\n",
"\n",
"pipeline_reg = pipeline.Pipeline([('poly', poly), ('scal', scaler), ('lin', lin_reg2)])\n",
"pipeline_reg.fit(Xfull, yfull)\n",
"curve = pipeline_reg.predict(X[:, np.newaxis])\n",
"plt.plot(X, curve)\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"save_fig('overfitting_model_plot')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"full_country_stats.loc[[c for c in full_country_stats.index if \"W\" in c.upper()]][\"Life satisfaction\"]"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"gdp_per_capita.loc[[c for c in gdp_per_capita.index if \"W\" in c.upper()]].head()"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(8,3))\n",
"\n",
"plt.xlabel(\"GDP per capita\")\n",
"plt.ylabel('Life satisfaction')\n",
"\n",
"plt.plot(list(sample_data[\"GDP per capita\"]), list(sample_data[\"Life satisfaction\"]), \"bo\")\n",
"plt.plot(list(missing_data[\"GDP per capita\"]), list(missing_data[\"Life satisfaction\"]), \"rs\")\n",
"\n",
"X = np.linspace(0, 110000, 1000)\n",
"plt.plot(X, t0full + t1full * X, \"r--\", label=\"Linear model on all data\")\n",
"plt.plot(X, t0 + t1*X, \"b:\", label=\"Linear model on partial data\")\n",
"\n",
"ridge = linear_model.Ridge(alpha=10**9.5)\n",
"Xsample = np.c_[sample_data[\"GDP per capita\"]]\n",
"ysample = np.c_[sample_data[\"Life satisfaction\"]]\n",
"ridge.fit(Xsample, ysample)\n",
"t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n",
"plt.plot(X, t0ridge + t1ridge * X, \"b\", label=\"Regularized linear model on partial data\")\n",
"\n",
"plt.legend(loc=\"lower right\")\n",
"plt.axis([0, 110000, 0, 10])\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"save_fig('ridge_model_plot')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"backup = oecd_bli, gdp_per_capita\n",
"\n",
"def prepare_country_stats(oecd_bli, gdp_per_capita):\n",
" return sample_data"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"# Replace this linear model:\n",
"import sklearn.linear_model\n",
"model = sklearn.linear_model.LinearRegression()"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"# with this k-neighbors regression model:\n",
"import sklearn.neighbors\n",
"model = sklearn.neighbors.KNeighborsRegressor(n_neighbors=3)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"X = np.c_[country_stats[\"GDP per capita\"]]\n",
"y = np.c_[country_stats[\"Life satisfaction\"]]\n",
"\n",
"# Train the model\n",
"model.fit(X, y)\n",
"\n",
"# Make a prediction for Cyprus\n",
"X_new = np.array([[22587.0]]) # Cyprus' GDP per capita\n",
"print(model.predict(X_new)) # outputs [[ 5.76666667]]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
},
"nav_menu": {},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": true
},
"toc_position": {
"height": "616px",
"left": "0px",
"right": "20px",
"top": "106px",
"width": "213px"
}
},
"nbformat": 4,
"nbformat_minor": 4
}