Large change: replace os.path with pathlib, move to Python 3.7

main
Aurélien Geron 2021-10-15 21:46:27 +13:00
parent 1b16a81fe5
commit fa1ae51184
19 changed files with 969 additions and 1066 deletions

View File

@ -6,7 +6,7 @@
"source": [ "source": [
"**Chapter 1 The Machine Learning landscape**\n", "**Chapter 1 The Machine Learning landscape**\n",
"\n", "\n",
"_This is the code used to generate some of the figures in chapter 1._" "_This contains the code example in this chapter 1, as well as all the code used to generate `lifesat.csv` and some of this chapter's figures._"
] ]
}, },
{ {
@ -30,13 +30,6 @@
"# Code example 1-1" "# Code example 1-1"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 1,
@ -47,9 +40,9 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)" "assert sys.version_info >= (3, 7)"
] ]
}, },
{ {
@ -58,16 +51,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Scikit-Learn ≥0.20 is required\n", "import numpy as np\n",
"import sklearn\n", "\n",
"assert sklearn.__version__ >= \"0.20\"" "# Make this notebook's output stable across runs\n",
] "np.random.seed(42)"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This function just merges the OECD's life satisfaction data and the IMF's GDP per capita data. It's a bit too long and boring and it's not specific to Machine Learning, which is why I left it out of the book."
] ]
}, },
{ {
@ -76,24 +63,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def prepare_country_stats(oecd_bli, gdp_per_capita):\n", "# Scikit-Learn ≥1.0 is required\n",
" oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n", "import sklearn\n",
" oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n", "assert sklearn.__version__ >= \"1.0\""
" gdp_per_capita.rename(columns={\"2015\": \"GDP per capita\"}, inplace=True)\n",
" gdp_per_capita.set_index(\"Country\", inplace=True)\n",
" full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n",
" left_index=True, right_index=True)\n",
" full_country_stats.sort_values(by=\"GDP per capita\", inplace=True)\n",
" remove_indices = [0, 1, 6, 8, 33, 34, 35]\n",
" keep_indices = list(set(range(36)) - set(remove_indices))\n",
" return full_country_stats[[\"GDP per capita\", 'Life satisfaction']].iloc[keep_indices]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The code in the book expects the data files to be located in the current directory. I just tweaked it here to fetch the files in `datasets/lifesat`."
] ]
}, },
{ {
@ -102,8 +74,13 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "# To plot pretty figures directly within Jupyter\n",
"datapath = os.path.join(\"datasets\", \"lifesat\", \"\")" "%matplotlib inline\n",
"import matplotlib as mpl\n",
"\n",
"mpl.rc('axes', labelsize=14)\n",
"mpl.rc('xtick', labelsize=12)\n",
"mpl.rc('ytick', labelsize=12)"
] ]
}, },
{ {
@ -112,12 +89,19 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# To plot pretty figures directly within Jupyter\n", "# Download the data\n",
"%matplotlib inline\n", "from pathlib import Path\n",
"import matplotlib as mpl\n", "import urllib.request\n",
"mpl.rc('axes', labelsize=14)\n", "\n",
"mpl.rc('xtick', labelsize=12)\n", "datapath = Path() / \"datasets\" / \"lifesat\"\n",
"mpl.rc('ytick', labelsize=12)" "datapath.mkdir(parents=True, exist_ok=True)\n",
"\n",
"root = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/\"\n",
"filename = \"lifesat.csv\"\n",
"if not (datapath / filename).is_file():\n",
" print(\"Downloading\", filename)\n",
" url = root + \"datasets/lifesat/\" + filename\n",
" urllib.request.urlretrieve(url, datapath / filename)"
] ]
}, },
{ {
@ -125,52 +109,36 @@
"execution_count": 6, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [
"# Download the data\n",
"import urllib.request\n",
"DOWNLOAD_ROOT = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/\"\n",
"os.makedirs(datapath, exist_ok=True)\n",
"for filename in (\"oecd_bli_2015.csv\", \"gdp_per_capita.csv\"):\n",
" print(\"Downloading\", filename)\n",
" url = DOWNLOAD_ROOT + \"datasets/lifesat/\" + filename\n",
" urllib.request.urlretrieve(url, datapath + filename)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# Code example\n", "# Code example\n",
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"import numpy as np\n", "import numpy as np\n",
"import pandas as pd\n", "import pandas as pd\n",
"import sklearn.linear_model\n", "from sklearn.linear_model import LinearRegression\n",
"\n", "\n",
"# Load the data\n", "# Load the data\n",
"oecd_bli = pd.read_csv(datapath + \"oecd_bli_2015.csv\", thousands=',')\n", "lifesat = pd.read_csv(Path() / \"datasets\" / \"lifesat\" / \"lifesat.csv\")\n",
"gdp_per_capita = pd.read_csv(datapath + \"gdp_per_capita.csv\",thousands=',',delimiter='\\t',\n", "X = lifesat[[\"GDP per capita (USD)\"]].values\n",
" encoding='latin1', na_values=\"n/a\")\n", "y = lifesat[[\"Life satisfaction\"]].values\n",
"\n",
"# Prepare the data\n",
"country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)\n",
"X = np.c_[country_stats[\"GDP per capita\"]]\n",
"y = np.c_[country_stats[\"Life satisfaction\"]]\n",
"\n", "\n",
"# Visualize the data\n", "# Visualize the data\n",
"country_stats.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction')\n", "lifesat.plot(kind='scatter',\n",
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
"plt.axis([23_500, 62_500, 4, 9])\n",
"plt.grid(True)\n",
"plt.show()\n", "plt.show()\n",
"\n", "\n",
"# Select a linear model\n", "# Select a linear model\n",
"model = sklearn.linear_model.LinearRegression()\n", "model = LinearRegression()\n",
"\n", "\n",
"# Train the model\n", "# Train the model\n",
"model.fit(X, y)\n", "model.fit(X, y)\n",
"\n", "\n",
"# Make a prediction for Cyprus\n", "# Make a prediction for Cyprus\n",
"X_new = [[22587]] # Cyprus' GDP per capita\n", "X_new = [[37_655.2]] # Cyprus' GDP per capita in 2020\n",
"print(model.predict(X_new)) # outputs [[ 5.96242338]]" "print(model.predict(X_new)) # outputs [[6.30165767]]"
] ]
}, },
{ {
@ -195,19 +163,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Select a 3-Nearest Neighbors regression model\n", "# Select a 3-Nearest Neighbors regression model\n",
"import sklearn.neighbors\n", "import sklearn.neighbors\n",
"model1 = sklearn.neighbors.KNeighborsRegressor(n_neighbors=3)\n", "\n",
"model = sklearn.neighbors.KNeighborsRegressor(n_neighbors=3)\n",
"\n", "\n",
"# Train the model\n", "# Train the model\n",
"model1.fit(X,y)\n", "model.fit(X,y)\n",
"\n", "\n",
"# Make a prediction for Cyprus\n", "# Make a prediction for Cyprus\n",
"print(model1.predict(X_new)) # outputs [[5.76666667]]\n" "print(model.predict(X_new)) # outputs [[6.33333333]]\n"
] ]
}, },
{ {
@ -235,73 +204,33 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Note: you can ignore the rest of this notebook, it just generates many of the figures in chapter 1." "# Note: you can safely ignore the rest of this notebook, it just generates many of the figures in chapter 1."
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Create a function to save the figures." "Create a function to save the figures:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 8,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"fundamentals\"\n",
"CHAPTER_ID = \"fundamentals\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Make this notebook's output stable across runs:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(42)"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
@ -313,9 +242,39 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"If you want, you can get fresh data from the OECD's website.\n", "To create `lifesat.csv`, I downloaded the Better Life Index (BLI) data from [OECD's website](http://stats.oecd.org/index.aspx?DataSetCode=BLI) (to get the Life Satisfaction for each country), and World Bank GDP per capita data from [OurWorldInData.org](https://ourworldindata.org/grapher/gdp-per-capita-worldbank). The BLI data is in `datasets/lifesat/oecd_bli.csv` (data from 2020), and the GDP per capita data is in `datasets/lifesat/gdp_per_capita.csv` (data up to 2020).\n",
"Download the CSV from http://stats.oecd.org/index.aspx?DataSetCode=BLI\n", "\n",
"and save it to `datasets/lifesat/`." "If you want to grab the latest versions, please feel free to do so. However, there may be some changes (e.g., in the column names, or different countries missing data), so be prepared to have to tweak the code."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"for filename in (\"oecd_bli.csv\", \"gdp_per_capita.csv\"):\n",
" if not (datapath / filename).is_file():\n",
" print(\"Downloading\", filename)\n",
" url = root + \"datasets/lifesat/\" + filename\n",
" urllib.request.urlretrieve(url, datapath / filename)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"oecd_bli = pd.read_csv(datapath / \"oecd_bli.csv\")\n",
"gdp_per_capita = pd.read_csv(datapath / \"gdp_per_capita.csv\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This function just merges the OECD's life satisfaction data and the World Bank's GDP per capita data:"
] ]
}, },
{ {
@ -324,10 +283,19 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"oecd_bli = pd.read_csv(datapath + \"oecd_bli_2015.csv\", thousands=',')\n", "def prepare_country_stats(oecd_bli, gdp_per_capita):\n",
"oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n", " gdp_year = 2020\n",
"oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n", " gdppc = \"GDP per capita (USD)\"\n",
"oecd_bli.head(2)" " oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
" oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
" gdp_per_capita = gdp_per_capita[gdp_per_capita[\"Year\"] == gdp_year]\n",
" gdp_per_capita = gdp_per_capita.drop([\"Code\", \"Year\"], axis=1)\n",
" gdp_per_capita.columns = [\"Country\", gdppc]\n",
" gdp_per_capita.set_index(\"Country\", inplace=True)\n",
" full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n",
" left_index=True, right_index=True)\n",
" full_country_stats.sort_values(by=gdppc, inplace=True)\n",
" return full_country_stats[[gdppc, 'Life satisfaction']]"
] ]
}, },
{ {
@ -336,21 +304,15 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"oecd_bli[\"Life satisfaction\"].head()" "full_country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)\n",
"full_country_stats.to_csv(datapath / \"lifesat.csv\")"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Load and prepare GDP per capita data" "To illustrate the risk of overfitting, I use only part of the data in most figures (all countries with a GDP per capita between `min_gdp` and `max_gdp`). Later in the chapter I reveal the missing countries, and show that they don't follow the same linear trend at all."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Just like above, you can update the GDP per capita data if you want. Just download data from http://goo.gl/j1MSKe (=> imf.org) and save it to `datasets/lifesat/`."
] ]
}, },
{ {
@ -359,11 +321,12 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"gdp_per_capita = pd.read_csv(datapath+\"gdp_per_capita.csv\", thousands=',', delimiter='\\t',\n", "gdppc = \"GDP per capita (USD)\"\n",
" encoding='latin1', na_values=\"n/a\")\n", "min_gdp = 23_500\n",
"gdp_per_capita.rename(columns={\"2015\": \"GDP per capita\"}, inplace=True)\n", "max_gdp = 62_500\n",
"gdp_per_capita.set_index(\"Country\", inplace=True)\n", "country_stats = full_country_stats[(full_country_stats[gdppc] >= min_gdp) &\n",
"gdp_per_capita.head(2)" " (full_country_stats[gdppc] <= max_gdp)]\n",
"country_stats.head()"
] ]
}, },
{ {
@ -372,9 +335,35 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita, left_index=True, right_index=True)\n", "country_stats.plot(kind='scatter', figsize=(5,3),\n",
"full_country_stats.sort_values(by=\"GDP per capita\", inplace=True)\n", " x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
"full_country_stats" "\n",
"min_life_sat = 4\n",
"max_life_sat = 9\n",
"\n",
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
"position_text = {\n",
" \"Hungary\": (28_000, 4.2),\n",
" \"France\": (40_000, 5),\n",
" \"New Zealand\": (30_000, 8),\n",
" \"Australia\": (50_000, 5.5),\n",
" \"United States\": (59_000, 5.5),\n",
" \"Denmark\": (46_000, 8.5)\n",
"}\n",
"\n",
"for country, pos_text in position_text.items():\n",
" pos_data_x, pos_data_y = country_stats[[\"GDP per capita (USD)\",\n",
" \"Life satisfaction\"]].loc[country]\n",
" country = \"U.S.\" if country == \"United States\" else country\n",
" plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,\n",
" arrowprops=dict(facecolor='black', width=0.5, shrink=0.2,\n",
" headwidth=5))\n",
" plt.plot(pos_data_x, pos_data_y, \"ro\")\n",
"\n",
"plt.grid(True)\n",
"\n",
"save_fig('money_happy_scatterplot')\n",
"plt.show()"
] ]
}, },
{ {
@ -383,7 +372,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"full_country_stats[[\"GDP per capita\", 'Life satisfaction']].loc[\"United States\"]" "highlighted_countries = country_stats.loc[list(position_text.keys())]\n",
"highlighted_countries[[\"Life satisfaction\"]].sort_values(by=\"Life satisfaction\")"
] ]
}, },
{ {
@ -392,11 +382,38 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"remove_indices = [0, 1, 6, 8, 33, 34, 35]\n", "import numpy as np\n",
"keep_indices = list(set(range(36)) - set(remove_indices))\n",
"\n", "\n",
"sample_data = full_country_stats[[\"GDP per capita\", 'Life satisfaction']].iloc[keep_indices]\n", "country_stats.plot(kind='scatter', figsize=(5,3),\n",
"missing_data = full_country_stats[[\"GDP per capita\", 'Life satisfaction']].iloc[remove_indices]" " x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
"\n",
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
"\n",
"w1, w2 = 4.2, 0\n",
"plt.plot(X, w1 + w2 * 1e-5 * X, \"r\")\n",
"plt.text(40_000, 4.9, fr\"$\\theta_0 = {w1}$\",\n",
" fontsize=14, color=\"r\")\n",
"plt.text(40_000, 4.4, fr\"$\\theta_1 = {w2}$\",\n",
" fontsize=14, color=\"r\")\n",
"\n",
"w1, w2 = 10, -9\n",
"plt.plot(X, w1 + w2 * 1e-5 * X, \"g\")\n",
"plt.text(26_000, 8.5, fr\"$\\theta_0 = {w1}$\",\n",
" fontsize=14, color=\"g\")\n",
"plt.text(26_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\",\n",
" fontsize=14, color=\"g\")\n",
"\n",
"w1, w2 = 3, 8\n",
"plt.plot(X, w1 + w2 * 1e-5 * X, \"b\")\n",
"plt.text(48_000, 8.5, fr\"$\\theta_0 = {w1}$\",\n",
" fontsize=14, color=\"b\")\n",
"plt.text(48_000, 8.0, fr\"$\\theta_1 = {w2} \\times 10^{{-5}}$\",\n",
" fontsize=14, color=\"b\")\n",
"plt.grid(True)\n",
"\n",
"save_fig('tweaking_model_params_plot')\n",
"plt.show()"
] ]
}, },
{ {
@ -405,24 +422,16 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n", "from sklearn import linear_model\n",
"plt.axis([0, 60000, 0, 10])\n", "\n",
"position_text = {\n", "X_sample = country_stats[[\"GDP per capita (USD)\"]].values\n",
" \"Hungary\": (5000, 1),\n", "y_sample = country_stats[[\"Life satisfaction\"]].values\n",
" \"Korea\": (18000, 1.7),\n", "\n",
" \"France\": (29000, 2.4),\n", "lin1 = linear_model.LinearRegression()\n",
" \"Australia\": (40000, 3.0),\n", "lin1.fit(X_sample, y_sample)\n",
" \"United States\": (52000, 3.8),\n", "\n",
"}\n", "t0, t1 = lin1.intercept_[0], lin1.coef_[0][0]\n",
"for country, pos_text in position_text.items():\n", "t0, t1"
" pos_data_x, pos_data_y = sample_data.loc[country]\n",
" country = \"U.S.\" if country == \"United States\" else country\n",
" plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,\n",
" arrowprops=dict(facecolor='black', width=0.5, shrink=0.1, headwidth=5))\n",
" plt.plot(pos_data_x, pos_data_y, \"ro\")\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"save_fig('money_happy_scatterplot')\n",
"plt.show()"
] ]
}, },
{ {
@ -431,7 +440,24 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"sample_data.to_csv(os.path.join(\"datasets\", \"lifesat\", \"lifesat.csv\"))" "country_stats.plot(kind='scatter', figsize=(5,3),\n",
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
"\n",
"X = np.linspace(min_gdp, max_gdp, 1000)\n",
"plt.plot(X, t0 + t1 * X, \"b\")\n",
"\n",
"plt.text(max_gdp - 20_000, min_life_sat + 1.5,\n",
" fr\"$\\theta_0 = {t0:.2f}$\",\n",
" fontsize=14, color=\"b\")\n",
"plt.text(max_gdp - 20_000, min_life_sat + 1,\n",
" fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\",\n",
" fontsize=14, color=\"b\")\n",
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
"plt.grid(True)\n",
"\n",
"save_fig('best_fit_model_plot')\n",
"plt.show()"
] ]
}, },
{ {
@ -440,7 +466,11 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"sample_data.loc[list(position_text.keys())]" "gdp_year = 2020\n",
"gdp_per_capita_clean = gdp_per_capita[gdp_per_capita[\"Year\"] == gdp_year]\n",
"gdp_per_capita_clean = gdp_per_capita_clean.drop([\"Code\", \"Year\"], axis=1)\n",
"gdp_per_capita_clean.columns = [\"Country\", \"GDP per capita (USD)\"]\n",
"gdp_per_capita_clean.set_index(\"Country\", inplace=True)"
] ]
}, },
{ {
@ -449,23 +479,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import numpy as np\n", "cyprus_gdp_per_capita = gdp_per_capita_clean.loc[\"Cyprus\"][\"GDP per capita (USD)\"]\n",
"\n", "print(cyprus_gdp_per_capita)\n",
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n", "cyprus_predicted_life_satisfaction = lin1.predict([[cyprus_gdp_per_capita]])[0, 0]\n",
"plt.xlabel(\"GDP per capita (USD)\")\n", "cyprus_predicted_life_satisfaction"
"plt.axis([0, 60000, 0, 10])\n",
"X=np.linspace(0, 60000, 1000)\n",
"plt.plot(X, 2*X/100000, \"r\")\n",
"plt.text(40000, 2.7, r\"$\\theta_0 = 0$\", fontsize=14, color=\"r\")\n",
"plt.text(40000, 1.8, r\"$\\theta_1 = 2 \\times 10^{-5}$\", fontsize=14, color=\"r\")\n",
"plt.plot(X, 8 - 5*X/100000, \"g\")\n",
"plt.text(5000, 9.1, r\"$\\theta_0 = 8$\", fontsize=14, color=\"g\")\n",
"plt.text(5000, 8.2, r\"$\\theta_1 = -5 \\times 10^{-5}$\", fontsize=14, color=\"g\")\n",
"plt.plot(X, 4 + 5*X/100000, \"b\")\n",
"plt.text(5000, 3.5, r\"$\\theta_0 = 4$\", fontsize=14, color=\"b\")\n",
"plt.text(5000, 2.6, r\"$\\theta_1 = 5 \\times 10^{-5}$\", fontsize=14, color=\"b\")\n",
"save_fig('tweaking_model_params_plot')\n",
"plt.show()"
] ]
}, },
{ {
@ -474,13 +491,31 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from sklearn import linear_model\n", "country_stats.plot(kind='scatter', figsize=(5,3),\n",
"lin1 = linear_model.LinearRegression()\n", " x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
"Xsample = np.c_[sample_data[\"GDP per capita\"]]\n", "\n",
"ysample = np.c_[sample_data[\"Life satisfaction\"]]\n", "X = np.linspace(min_gdp, max_gdp, 1000)\n",
"lin1.fit(Xsample, ysample)\n", "plt.plot(X, t0 + t1 * X, \"b\")\n",
"t0, t1 = lin1.intercept_[0], lin1.coef_[0][0]\n", "\n",
"t0, t1" "plt.text(min_gdp + 15_000, max_life_sat - 1.5,\n",
" fr\"$\\theta_0 = {t0:.2f}$\",\n",
" fontsize=14, color=\"b\")\n",
"plt.text(min_gdp + 15_000, max_life_sat - 1,\n",
" fr\"$\\theta_1 = {t1 * 1e5:.2f} \\times 10^{{-5}}$\",\n",
" fontsize=14, color=\"b\")\n",
"\n",
"plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita],\n",
" [min_life_sat, cyprus_predicted_life_satisfaction], \"r--\")\n",
"plt.text(cyprus_gdp_per_capita + 1000, 5.0,\n",
" fr\"Prediction = {cyprus_predicted_life_satisfaction:.2f}\",\n",
" fontsize=14, color=\"r\")\n",
"plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, \"ro\")\n",
"\n",
"plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])\n",
"plt.grid(True)\n",
"\n",
"save_fig('cyprus_prediction_plot')\n",
"plt.show()"
] ]
}, },
{ {
@ -489,15 +524,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3))\n", "missing_data = full_country_stats[(full_country_stats[gdppc] < min_gdp) |\n",
"plt.xlabel(\"GDP per capita (USD)\")\n", " (full_country_stats[gdppc] > max_gdp)]\n",
"plt.axis([0, 60000, 0, 10])\n", "missing_data"
"X=np.linspace(0, 60000, 1000)\n",
"plt.plot(X, t0 + t1*X, \"b\")\n",
"plt.text(5000, 3.1, r\"$\\theta_0 = 4.85$\", fontsize=14, color=\"b\")\n",
"plt.text(5000, 2.2, r\"$\\theta_1 = 4.91 \\times 10^{-5}$\", fontsize=14, color=\"b\")\n",
"save_fig('best_fit_model_plot')\n",
"plt.show()\n"
] ]
}, },
{ {
@ -506,10 +535,17 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"cyprus_gdp_per_capita = gdp_per_capita.loc[\"Cyprus\"][\"GDP per capita\"]\n", "position_text2 = {\n",
"print(cyprus_gdp_per_capita)\n", " \"South Africa\": (20_000, 4.2),\n",
"cyprus_predicted_life_satisfaction = lin1.predict([[cyprus_gdp_per_capita]])[0][0]\n", " \"Colombia\": (6_000, 8.2),\n",
"cyprus_predicted_life_satisfaction" " \"Brazil\": (18_000, 7.8),\n",
" \"Mexico\": (24_000, 7.4),\n",
" \"Chile\": (30_000, 7.0),\n",
" \"Norway\": (60_000, 6.2),\n",
" \"Switzerland\": (65_000, 5.7),\n",
" \"Ireland\": (80_000, 5.5),\n",
" \"Luxembourg\": (100_000, 5.0),\n",
"}"
] ]
}, },
{ {
@ -518,17 +554,32 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(5,3), s=1)\n", "full_country_stats.plot(kind='scatter', figsize=(8,3),\n",
"plt.xlabel(\"GDP per capita (USD)\")\n", " x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
"X=np.linspace(0, 60000, 1000)\n", "\n",
"plt.plot(X, t0 + t1*X, \"b\")\n", "for country, pos_text in position_text2.items():\n",
"plt.axis([0, 60000, 0, 10])\n", " pos_data_x, pos_data_y = missing_data.loc[country]\n",
"plt.text(5000, 7.5, r\"$\\theta_0 = 4.85$\", fontsize=14, color=\"b\")\n", " plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,\n",
"plt.text(5000, 6.6, r\"$\\theta_1 = 4.91 \\times 10^{-5}$\", fontsize=14, color=\"b\")\n", " arrowprops=dict(facecolor='black', width=0.5, shrink=0.1,\n",
"plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita], [0, cyprus_predicted_life_satisfaction], \"r--\")\n", " headwidth=5))\n",
"plt.text(25000, 5.0, r\"Prediction = 5.96\", fontsize=14, color=\"b\")\n", " plt.plot(pos_data_x, pos_data_y, \"rs\")\n",
"plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, \"ro\")\n", "\n",
"save_fig('cyprus_prediction_plot')\n", "X = np.linspace(0, 115_000, 1000)\n",
"plt.plot(X, t0 + t1 * X, \"b:\")\n",
"\n",
"lin_reg_full = linear_model.LinearRegression()\n",
"Xfull = np.c_[full_country_stats[\"GDP per capita (USD)\"]]\n",
"yfull = np.c_[full_country_stats[\"Life satisfaction\"]]\n",
"lin_reg_full.fit(Xfull, yfull)\n",
"\n",
"t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]\n",
"X = np.linspace(0, 115_000, 1000)\n",
"plt.plot(X, t0full + t1full * X, \"k\")\n",
"\n",
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
"plt.grid(True)\n",
"\n",
"save_fig('representative_training_data_scatterplot')\n",
"plt.show()" "plt.show()"
] ]
}, },
@ -538,7 +589,28 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"sample_data[7:10]" "from sklearn import preprocessing\n",
"from sklearn import pipeline\n",
"\n",
"full_country_stats.plot(kind='scatter', figsize=(8,3),\n",
" x=\"GDP per capita (USD)\", y='Life satisfaction')\n",
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
"\n",
"poly = preprocessing.PolynomialFeatures(degree=10, include_bias=False)\n",
"scaler = preprocessing.StandardScaler()\n",
"lin_reg2 = linear_model.LinearRegression()\n",
"\n",
"pipeline_reg = pipeline.Pipeline([\n",
" ('poly', poly),\n",
" ('scal', scaler),\n",
" ('lin', lin_reg2)])\n",
"pipeline_reg.fit(Xfull, yfull)\n",
"curve = pipeline_reg.predict(X[:, np.newaxis])\n",
"plt.plot(X, curve)\n",
"plt.grid(True)\n",
"\n",
"save_fig('overfitting_model_plot')\n",
"plt.show()"
] ]
}, },
{ {
@ -547,7 +619,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"(5.1+5.7+6.5)/3" "w_countries = [c for c in full_country_stats.index if \"W\" in c.upper()]\n",
"full_country_stats.loc[w_countries][\"Life satisfaction\"]"
] ]
}, },
{ {
@ -556,19 +629,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"backup = oecd_bli, gdp_per_capita\n", "all_w_countries = [c for c in gdp_per_capita_clean.index if \"W\" in c.upper()]\n",
"\n", "gdp_per_capita_clean.loc[all_w_countries].sort_values(by=gdppc)"
"def prepare_country_stats(oecd_bli, gdp_per_capita):\n",
" oecd_bli = oecd_bli[oecd_bli[\"INEQUALITY\"]==\"TOT\"]\n",
" oecd_bli = oecd_bli.pivot(index=\"Country\", columns=\"Indicator\", values=\"Value\")\n",
" gdp_per_capita.rename(columns={\"2015\": \"GDP per capita\"}, inplace=True)\n",
" gdp_per_capita.set_index(\"Country\", inplace=True)\n",
" full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,\n",
" left_index=True, right_index=True)\n",
" full_country_stats.sort_values(by=\"GDP per capita\", inplace=True)\n",
" remove_indices = [0, 1, 6, 8, 33, 34, 35]\n",
" keep_indices = list(set(range(36)) - set(remove_indices))\n",
" return full_country_stats[[\"GDP per capita\", 'Life satisfaction']].iloc[keep_indices]"
] ]
}, },
{ {
@ -576,176 +638,35 @@
"execution_count": 28, "execution_count": 28,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [
"# Code example\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
"import sklearn.linear_model\n",
"\n",
"# Load the data\n",
"oecd_bli = pd.read_csv(datapath + \"oecd_bli_2015.csv\", thousands=',')\n",
"gdp_per_capita = pd.read_csv(datapath + \"gdp_per_capita.csv\",thousands=',',delimiter='\\t',\n",
" encoding='latin1', na_values=\"n/a\")\n",
"\n",
"# Prepare the data\n",
"country_stats = prepare_country_stats(oecd_bli, gdp_per_capita)\n",
"X = np.c_[country_stats[\"GDP per capita\"]]\n",
"y = np.c_[country_stats[\"Life satisfaction\"]]\n",
"\n",
"# Visualize the data\n",
"country_stats.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction')\n",
"plt.show()\n",
"\n",
"# Select a linear model\n",
"model = sklearn.linear_model.LinearRegression()\n",
"\n",
"# Train the model\n",
"model.fit(X, y)\n",
"\n",
"# Make a prediction for Cyprus\n",
"X_new = [[22587]] # Cyprus' GDP per capita\n",
"print(model.predict(X_new)) # outputs [[ 5.96242338]]"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"oecd_bli, gdp_per_capita = backup"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"missing_data"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"position_text2 = {\n",
" \"Brazil\": (1000, 9.0),\n",
" \"Mexico\": (11000, 9.0),\n",
" \"Chile\": (25000, 9.0),\n",
" \"Czech Republic\": (35000, 9.0),\n",
" \"Norway\": (60000, 3),\n",
" \"Switzerland\": (72000, 3.0),\n",
" \"Luxembourg\": (90000, 3.0),\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"sample_data.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(8,3))\n",
"plt.axis([0, 110000, 0, 10])\n",
"\n",
"for country, pos_text in position_text2.items():\n",
" pos_data_x, pos_data_y = missing_data.loc[country]\n",
" plt.annotate(country, xy=(pos_data_x, pos_data_y), xytext=pos_text,\n",
" arrowprops=dict(facecolor='black', width=0.5, shrink=0.1, headwidth=5))\n",
" plt.plot(pos_data_x, pos_data_y, \"rs\")\n",
"\n",
"X=np.linspace(0, 110000, 1000)\n",
"plt.plot(X, t0 + t1*X, \"b:\")\n",
"\n",
"lin_reg_full = linear_model.LinearRegression()\n",
"Xfull = np.c_[full_country_stats[\"GDP per capita\"]]\n",
"yfull = np.c_[full_country_stats[\"Life satisfaction\"]]\n",
"lin_reg_full.fit(Xfull, yfull)\n",
"\n",
"t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]\n",
"X = np.linspace(0, 110000, 1000)\n",
"plt.plot(X, t0full + t1full * X, \"k\")\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"\n",
"save_fig('representative_training_data_scatterplot')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"full_country_stats.plot(kind='scatter', x=\"GDP per capita\", y='Life satisfaction', figsize=(8,3))\n",
"plt.axis([0, 110000, 0, 10])\n",
"\n",
"from sklearn import preprocessing\n",
"from sklearn import pipeline\n",
"\n",
"poly = preprocessing.PolynomialFeatures(degree=30, include_bias=False)\n",
"scaler = preprocessing.StandardScaler()\n",
"lin_reg2 = linear_model.LinearRegression()\n",
"\n",
"pipeline_reg = pipeline.Pipeline([('poly', poly), ('scal', scaler), ('lin', lin_reg2)])\n",
"pipeline_reg.fit(Xfull, yfull)\n",
"curve = pipeline_reg.predict(X[:, np.newaxis])\n",
"plt.plot(X, curve)\n",
"plt.xlabel(\"GDP per capita (USD)\")\n",
"save_fig('overfitting_model_plot')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"full_country_stats.loc[[c for c in full_country_stats.index if \"W\" in c.upper()]][\"Life satisfaction\"]"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"gdp_per_capita.loc[[c for c in gdp_per_capita.index if \"W\" in c.upper()]].head()"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [ "source": [
"plt.figure(figsize=(8,3))\n", "plt.figure(figsize=(8,3))\n",
"\n", "\n",
"plt.xlabel(\"GDP per capita\")\n", "plt.xlabel(\"GDP per capita (USD)\")\n",
"plt.ylabel('Life satisfaction')\n", "plt.ylabel('Life satisfaction')\n",
"\n", "\n",
"plt.plot(list(sample_data[\"GDP per capita\"]), list(sample_data[\"Life satisfaction\"]), \"bo\")\n", "plt.plot(list(country_stats[\"GDP per capita (USD)\"]),\n",
"plt.plot(list(missing_data[\"GDP per capita\"]), list(missing_data[\"Life satisfaction\"]), \"rs\")\n", " list(country_stats[\"Life satisfaction\"]), \"bo\")\n",
"plt.plot(list(missing_data[\"GDP per capita (USD)\"]),\n",
" list(missing_data[\"Life satisfaction\"]), \"rs\")\n",
"\n", "\n",
"X = np.linspace(0, 110000, 1000)\n", "X = np.linspace(0, 115_000, 1000)\n",
"plt.plot(X, t0full + t1full * X, \"r--\", label=\"Linear model on all data\")\n", "plt.plot(X, t0full + t1full * X, \"r--\", label=\"Linear model on all data\")\n",
"plt.plot(X, t0 + t1*X, \"b:\", label=\"Linear model on partial data\")\n", "plt.plot(X, t0 + t1*X, \"b:\", label=\"Linear model on partial data\")\n",
"\n", "\n",
"ridge = linear_model.Ridge(alpha=10**9.5)\n", "ridge = linear_model.Ridge(alpha=10**9.5)\n",
"Xsample = np.c_[sample_data[\"GDP per capita\"]]\n", "Xsample = country_stats[[\"GDP per capita (USD)\"]]\n",
"ysample = np.c_[sample_data[\"Life satisfaction\"]]\n", "ysample = country_stats[[\"Life satisfaction\"]]\n",
"ridge.fit(Xsample, ysample)\n", "ridge.fit(Xsample, ysample)\n",
"t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n", "t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]\n",
"plt.plot(X, t0ridge + t1ridge * X, \"b\", label=\"Regularized linear model on partial data\")\n", "plt.plot(X, t0ridge + t1ridge * X, \"b\", label=\"Regularized linear model on partial data\")\n",
"\n", "\n",
"plt.legend(loc=\"lower right\")\n", "plt.legend(loc=\"lower right\")\n",
"plt.axis([0, 110000, 0, 10])\n", "plt.axis([0, 115_000, 0, 10])\n",
"plt.xlabel(\"GDP per capita (USD)\")\n", "plt.xlabel(\"GDP per capita (USD)\")\n",
"\n",
"plt.axis([0, 115_000, min_life_sat, max_life_sat])\n",
"plt.grid(True)\n",
"\n",
"save_fig('ridge_model_plot')\n", "save_fig('ridge_model_plot')\n",
"plt.show()" "plt.show()"
] ]
@ -760,9 +681,9 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "homl3",
"language": "python", "language": "python",
"name": "python3" "name": "homl3"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {

View File

@ -4,8 +4,13 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"**Chapter 2 End-to-end Machine Learning project**\n", "**Chapter 2 End-to-end Machine Learning project**"
"\n", ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Welcome to Machine Learning Housing Corp.! Your task is to predict median house values in Californian districts, given a number of features from these districts.*\n", "*Welcome to Machine Learning Housing Corp.! Your task is to predict median house values in Californian districts, given a number of features from these districts.*\n",
"\n", "\n",
"*This notebook contains all the sample code and solutions to the exercices in chapter 2.*" "*This notebook contains all the sample code and solutions to the exercices in chapter 2.*"
@ -36,7 +41,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20." "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -45,17 +50,17 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n", "\n",
"# Scikit-Learn ≥0.20 is required\n", "# Scikit-Learn ≥1.0 is required\n",
"import sklearn\n", "import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n", "assert sklearn.__version__ >= \"1.0\"\n",
"\n", "\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n", "from pathlib import Path\n",
"\n", "\n",
"# To plot pretty figures\n", "# To plot pretty figures\n",
"%matplotlib inline\n", "%matplotlib inline\n",
@ -66,14 +71,11 @@
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"end_to_end_project\"\n",
"CHAPTER_ID = \"end_to_end_project\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
@ -95,48 +97,36 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "from pathlib import Path\n",
"import tarfile\n", "import tarfile\n",
"import urllib.request\n", "import urllib.request\n",
"\n",
"DOWNLOAD_ROOT = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/\"\n",
"HOUSING_PATH = os.path.join(\"datasets\", \"housing\")\n",
"HOUSING_URL = DOWNLOAD_ROOT + \"datasets/housing/housing.tgz\"\n",
"\n",
"def fetch_housing_data(housing_url=HOUSING_URL, housing_path=HOUSING_PATH):\n",
" if not os.path.isdir(housing_path):\n",
" os.makedirs(housing_path)\n",
" tgz_path = os.path.join(housing_path, \"housing.tgz\")\n",
" urllib.request.urlretrieve(housing_url, tgz_path)\n",
" housing_tgz = tarfile.open(tgz_path)\n",
" housing_tgz.extractall(path=housing_path)\n",
" housing_tgz.close()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"fetch_housing_data()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n", "import pandas as pd\n",
"\n", "\n",
"def load_housing_data(housing_path=HOUSING_PATH):\n", "def load_housing_data():\n",
" csv_path = os.path.join(housing_path, \"housing.csv\")\n", " housing_path = Path() / \"datasets\" / \"housing\"\n",
" return pd.read_csv(csv_path)" " if not (housing_path / \"housing.csv\").is_file():\n",
" housing_path.mkdir(parents=True, exist_ok=True)\n",
" root = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/\"\n",
" url = root + \"datasets/housing/housing.tgz\"\n",
" tgz_path = housing_path / \"housing.tgz\"\n",
" urllib.request.urlretrieve(url, tgz_path)\n",
" housing_tgz = tarfile.open(tgz_path)\n",
" housing_tgz.extractall(path=housing_path)\n",
" housing_tgz.close()\n",
" return pd.read_csv(housing_path / \"housing.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"housing = load_housing_data()"
] ]
}, },
{ {
@ -526,18 +516,19 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 36, "execution_count": 20,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Download the California image\n", "# Download the California image\n",
"images_path = os.path.join(PROJECT_ROOT_DIR, \"images\", \"end_to_end_project\")\n", "images_path = Path() / \"images\" / \"end_to_end_project\"\n",
"os.makedirs(images_path, exist_ok=True)\n",
"DOWNLOAD_ROOT = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/\"\n",
"filename = \"california.png\"\n", "filename = \"california.png\"\n",
"print(\"Downloading\", filename)\n", "if not (images_path / filename).is_file():\n",
"url = DOWNLOAD_ROOT + \"images/end_to_end_project/\" + filename\n", " images_path.mkdir(parents=True, exist_ok=True)\n",
"urllib.request.urlretrieve(url, os.path.join(images_path, filename))" " root = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/\"\n",
" url = root + \"images/end_to_end_project/\" + filename\n",
" print(\"Downloading\", filename)\n",
" urllib.request.urlretrieve(url, images_path / filename)"
] ]
}, },
{ {
@ -547,7 +538,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import matplotlib.image as mpimg\n", "import matplotlib.image as mpimg\n",
"california_img=mpimg.imread(os.path.join(images_path, filename))\n", "\n",
"california_img=mpimg.imread(images_path / filename)\n",
"ax = housing.plot(kind=\"scatter\", x=\"longitude\", y=\"latitude\", figsize=(10,7),\n", "ax = housing.plot(kind=\"scatter\", x=\"longitude\", y=\"latitude\", figsize=(10,7),\n",
" s=housing['population']/100, label=\"Population\",\n", " s=housing['population']/100, label=\"Population\",\n",
" c=\"median_house_value\", cmap=plt.get_cmap(\"jet\"),\n", " c=\"median_house_value\", cmap=plt.get_cmap(\"jet\"),\n",
@ -2342,7 +2334,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -4,8 +4,13 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"**Chapter 3 Classification**\n", "**Chapter 3 Classification**"
"\n", ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_This notebook contains all the sample code and solutions to the exercises in chapter 3._" "_This notebook contains all the sample code and solutions to the exercises in chapter 3._"
] ]
}, },
@ -34,7 +39,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20." "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -43,24 +48,17 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n", "\n",
"# Is this notebook running on Colab or Kaggle?\n", "# Scikit-Learn ≥1.0 is required\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
"IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n",
"\n",
"# Scikit-Learn ≥0.20 is required\n",
"import sklearn\n", "import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n", "assert sklearn.__version__ >= \"1.0\"\n",
"\n", "\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n", "from pathlib import Path\n",
"\n",
"# to make this notebook's output stable across runs\n",
"np.random.seed(42)\n",
"\n", "\n",
"# To plot pretty figures\n", "# To plot pretty figures\n",
"%matplotlib inline\n", "%matplotlib inline\n",
@ -71,14 +69,11 @@
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"classification\"\n",
"CHAPTER_ID = \"classification\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
@ -1466,49 +1461,36 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 100, "execution_count": 13,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "from pathlib import Path\n",
"import pandas as pd\n",
"import urllib.request\n", "import urllib.request\n",
"\n", "\n",
"TITANIC_PATH = os.path.join(\"datasets\", \"titanic\")\n", "def load_titanic_data():\n",
"DOWNLOAD_URL = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/datasets/titanic/\"\n", " titanic_path = Path() / \"datasets\" / \"titanic\"\n",
"\n", " titanic_path.mkdir(parents=True, exist_ok=True)\n",
"def fetch_titanic_data(url=DOWNLOAD_URL, path=TITANIC_PATH):\n", " filenames = (\"train.csv\", \"test.csv\")\n",
" if not os.path.isdir(path):\n", " for filename in filenames:\n",
" os.makedirs(path)\n", " filepath = titanic_path / filename\n",
" for filename in (\"train.csv\", \"test.csv\"):\n", " if filepath.is_file():\n",
" filepath = os.path.join(path, filename)\n", " continue\n",
" if not os.path.isfile(filepath):\n", " root = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/\"\n",
" url = root + \"/datasets/titanic/\" + filename\n",
" print(\"Downloading\", filename)\n", " print(\"Downloading\", filename)\n",
" urllib.request.urlretrieve(url + filename, filepath)\n", " urllib.request.urlretrieve(url, filepath)\n",
"\n", " return [pd.read_csv(titanic_path / filename) for filename in filenames]"
"fetch_titanic_data() "
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 101, "execution_count": 14,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import pandas as pd\n", "train_data, test_data = load_titanic_data()"
"\n",
"def load_titanic_data(filename, titanic_path=TITANIC_PATH):\n",
" csv_path = os.path.join(titanic_path, filename)\n",
" return pd.read_csv(csv_path)"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {},
"outputs": [],
"source": [
"train_data = load_titanic_data(\"train.csv\")\n",
"test_data = load_titanic_data(\"test.csv\")"
] ]
}, },
{ {
@ -1966,38 +1948,40 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 125, "execution_count": 38,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "from pathlib import Path\n",
"import tarfile\n", "import tarfile\n",
"import urllib.request\n", "import urllib.request\n",
"\n", "\n",
"DOWNLOAD_ROOT = \"http://spamassassin.apache.org/old/publiccorpus/\"\n", "def fetch_spam_data():\n",
"HAM_URL = DOWNLOAD_ROOT + \"20030228_easy_ham.tar.bz2\"\n", " root = \"http://spamassassin.apache.org/old/publiccorpus/\"\n",
"SPAM_URL = DOWNLOAD_ROOT + \"20030228_spam.tar.bz2\"\n", " ham_url = root + \"20030228_easy_ham.tar.bz2\"\n",
"SPAM_PATH = os.path.join(\"datasets\", \"spam\")\n", " spam_url = root + \"20030228_spam.tar.bz2\"\n",
"\n", "\n",
"def fetch_spam_data(ham_url=HAM_URL, spam_url=SPAM_URL, spam_path=SPAM_PATH):\n", " spam_path = Path() / \"datasets\" / \"spam\"\n",
" if not os.path.isdir(spam_path):\n", " spam_path.mkdir(parents=True, exist_ok=True)\n",
" os.makedirs(spam_path)\n", " for dir_name, tar_name, url in ((\"easy_ham\", \"ham\", ham_url),\n",
" for filename, url in ((\"ham.tar.bz2\", ham_url), (\"spam.tar.bz2\", spam_url)):\n", " (\"spam\", \"spam\", spam_url)):\n",
" path = os.path.join(spam_path, filename)\n", " if not (spam_path / dir_name).is_dir():\n",
" if not os.path.isfile(path):\n", " path = (spam_path / tar_name).with_suffix(\".tar.bz2\")\n",
" print(\"Downloading\", path)\n",
" urllib.request.urlretrieve(url, path)\n", " urllib.request.urlretrieve(url, path)\n",
" tar_bz2_file = tarfile.open(path)\n", " tar_bz2_file = tarfile.open(path)\n",
" tar_bz2_file.extractall(path=spam_path)\n", " tar_bz2_file.extractall(path=spam_path)\n",
" tar_bz2_file.close()" " tar_bz2_file.close()\n",
" return [spam_path / dir_name for dir_name in (\"easy_ham\", \"spam\")]"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 126, "execution_count": 39,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"fetch_spam_data()" "ham_dir, spam_dir = fetch_spam_data()"
] ]
}, },
{ {
@ -2009,19 +1993,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 127, "execution_count": 40,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"HAM_DIR = os.path.join(SPAM_PATH, \"easy_ham\")\n", "ham_filenames = [f for f in sorted(ham_dir.iterdir()) if len(f.name) > 20]\n",
"SPAM_DIR = os.path.join(SPAM_PATH, \"spam\")\n", "spam_filenames = [f for f in sorted(spam_dir.iterdir()) if len(f.name) > 20]"
"ham_filenames = [name for name in sorted(os.listdir(HAM_DIR)) if len(name) > 20]\n",
"spam_filenames = [name for name in sorted(os.listdir(SPAM_DIR)) if len(name) > 20]"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 128, "execution_count": 41,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2030,7 +2012,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 129, "execution_count": 42,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2046,27 +2028,26 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 130, "execution_count": 45,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import email\n", "import email\n",
"import email.policy\n", "import email.policy\n",
"\n", "\n",
"def load_email(is_spam, filename, spam_path=SPAM_PATH):\n", "def load_email(filepath):\n",
" directory = \"spam\" if is_spam else \"easy_ham\"\n", " with open(filepath, \"rb\") as f:\n",
" with open(os.path.join(spam_path, directory, filename), \"rb\") as f:\n",
" return email.parser.BytesParser(policy=email.policy.default).parse(f)" " return email.parser.BytesParser(policy=email.policy.default).parse(f)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 131, "execution_count": 48,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"ham_emails = [load_email(is_spam=False, filename=name) for name in ham_filenames]\n", "ham_emails = [load_email(filepath) for filepath in ham_filenames]\n",
"spam_emails = [load_email(is_spam=True, filename=name) for name in spam_filenames]" "spam_emails = [load_email(filepath) for filepath in spam_filenames]"
] ]
}, },
{ {
@ -2078,7 +2059,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 132, "execution_count": 49,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2087,7 +2068,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 133, "execution_count": 50,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2103,7 +2084,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 134, "execution_count": 51,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2122,7 +2103,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 135, "execution_count": 52,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2138,7 +2119,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 136, "execution_count": 53,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2147,7 +2128,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 137, "execution_count": 54,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2170,7 +2151,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 138, "execution_count": 55,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2187,7 +2168,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 139, "execution_count": 56,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2203,7 +2184,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 140, "execution_count": 57,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2225,7 +2206,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 141, "execution_count": 58,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2249,7 +2230,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 142, "execution_count": 59,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2268,7 +2249,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 143, "execution_count": 60,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2284,7 +2265,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 144, "execution_count": 61,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2308,7 +2289,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 145, "execution_count": 62,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2326,7 +2307,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 146, "execution_count": 63,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2352,10 +2333,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 147, "execution_count": 67,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Is this notebook running on Colab or Kaggle?\n",
"IS_COLAB = \"google.colab\" in sys.modules\n",
"IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n",
"\n",
"# if running this notebook on Colab or Kaggle, we just pip install urlextract\n", "# if running this notebook on Colab or Kaggle, we just pip install urlextract\n",
"if IS_COLAB or IS_KAGGLE:\n", "if IS_COLAB or IS_KAGGLE:\n",
" %pip install -q -U urlextract" " %pip install -q -U urlextract"
@ -2370,7 +2355,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 148, "execution_count": 68,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2393,7 +2378,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 149, "execution_count": 69,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2445,7 +2430,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 150, "execution_count": 70,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2470,7 +2455,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 151, "execution_count": 71,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2501,7 +2486,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 152, "execution_count": 72,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2512,7 +2497,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 153, "execution_count": 73,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2528,7 +2513,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 154, "execution_count": 74,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2544,7 +2529,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 155, "execution_count": 75,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2567,14 +2552,14 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 156, "execution_count": 76,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from sklearn.linear_model import LogisticRegression\n", "from sklearn.linear_model import LogisticRegression\n",
"from sklearn.model_selection import cross_val_score\n", "from sklearn.model_selection import cross_val_score\n",
"\n", "\n",
"log_clf = LogisticRegression(solver=\"lbfgs\", max_iter=1000, random_state=42)\n", "log_clf = LogisticRegression(max_iter=1000, random_state=42)\n",
"score = cross_val_score(log_clf, X_train_transformed, y_train, cv=3, verbose=3)\n", "score = cross_val_score(log_clf, X_train_transformed, y_train, cv=3, verbose=3)\n",
"score.mean()" "score.mean()"
] ]
@ -2590,7 +2575,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 157, "execution_count": 78,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2598,7 +2583,7 @@
"\n", "\n",
"X_test_transformed = preprocess_pipeline.transform(X_test)\n", "X_test_transformed = preprocess_pipeline.transform(X_test)\n",
"\n", "\n",
"log_clf = LogisticRegression(solver=\"lbfgs\", max_iter=1000, random_state=42)\n", "log_clf = LogisticRegression(max_iter=1000, random_state=42)\n",
"log_clf.fit(X_train_transformed, y_train)\n", "log_clf.fit(X_train_transformed, y_train)\n",
"\n", "\n",
"y_pred = log_clf.predict(X_test_transformed)\n", "y_pred = log_clf.predict(X_test_transformed)\n",
@ -2617,7 +2602,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -39,7 +39,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20." "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -48,17 +48,17 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n", "\n",
"# Scikit-Learn ≥0.20 is required\n", "# Scikit-Learn ≥1.0 is required\n",
"import sklearn\n", "import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n", "assert sklearn.__version__ >= \"1.0\"\n",
"\n", "\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n", "from pathlib import Path\n",
"\n", "\n",
"# to make this notebook's output stable across runs\n", "# to make this notebook's output stable across runs\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
@ -72,14 +72,11 @@
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"training_linear_models\"\n",
"CHAPTER_ID = \"training_linear_models\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
@ -1829,7 +1826,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -4,8 +4,13 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"**Chapter 5 Support Vector Machines**\n", "**Chapter 5 Support Vector Machines**"
"\n", ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_This notebook contains all the sample code and solutions to the exercises in chapter 5._" "_This notebook contains all the sample code and solutions to the exercises in chapter 5._"
] ]
}, },
@ -34,7 +39,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20." "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -43,17 +48,17 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n", "\n",
"# Scikit-Learn ≥0.20 is required\n", "# Scikit-Learn ≥1.0 is required\n",
"import sklearn\n", "import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n", "assert sklearn.__version__ >= \"1.0\"\n",
"\n", "\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n", "from pathlib import Path\n",
"\n", "\n",
"# to make this notebook's output stable across runs\n", "# to make this notebook's output stable across runs\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
@ -67,14 +72,11 @@
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"svm\"\n",
"CHAPTER_ID = \"svm\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
@ -1942,7 +1944,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -39,7 +39,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20." "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -48,17 +48,17 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n", "\n",
"# Scikit-Learn ≥0.20 is required\n", "# Scikit-Learn ≥1.0 is required\n",
"import sklearn\n", "import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n", "assert sklearn.__version__ >= \"1.0\"\n",
"\n", "\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n", "from pathlib import Path\n",
"\n", "\n",
"# to make this notebook's output stable across runs\n", "# to make this notebook's output stable across runs\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
@ -72,14 +72,11 @@
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"decision_trees\"\n",
"CHAPTER_ID = \"decision_trees\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
@ -127,14 +124,14 @@
"\n", "\n",
"export_graphviz(\n", "export_graphviz(\n",
" tree_clf,\n", " tree_clf,\n",
" out_file=os.path.join(IMAGES_PATH, \"iris_tree.dot\"),\n", " out_file=IMAGES_PATH / \"iris_tree.dot\",\n",
" feature_names=iris.feature_names[2:],\n", " feature_names=iris.feature_names[2:],\n",
" class_names=iris.target_names,\n", " class_names=iris.target_names,\n",
" rounded=True,\n", " rounded=True,\n",
" filled=True\n", " filled=True\n",
" )\n", " )\n",
"\n", "\n",
"Source.from_file(os.path.join(IMAGES_PATH, \"iris_tree.dot\"))" "Source.from_file(IMAGES_PATH / \"iris_tree.dot\")"
] ]
}, },
{ {
@ -485,7 +482,7 @@
"source": [ "source": [
"export_graphviz(\n", "export_graphviz(\n",
" tree_reg1,\n", " tree_reg1,\n",
" out_file=os.path.join(IMAGES_PATH, \"regression_tree.dot\"),\n", " out_file=IMAGES_PATH / \"regression_tree.dot\",\n",
" feature_names=[\"x1\"],\n", " feature_names=[\"x1\"],\n",
" rounded=True,\n", " rounded=True,\n",
" filled=True\n", " filled=True\n",
@ -498,7 +495,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"Source.from_file(os.path.join(IMAGES_PATH, \"regression_tree.dot\"))" "Source.from_file(IMAGES_PATH / \"regression_tree.dot\")"
] ]
}, },
{ {
@ -797,7 +794,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -39,7 +39,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20." "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -48,17 +48,17 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n", "\n",
"# Scikit-Learn ≥0.20 is required\n", "# Scikit-Learn ≥1.0 is required\n",
"import sklearn\n", "import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n", "assert sklearn.__version__ >= \"1.0\"\n",
"\n", "\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n", "from pathlib import Path\n",
"\n", "\n",
"# to make this notebook's output stable across runs\n", "# to make this notebook's output stable across runs\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
@ -72,14 +72,11 @@
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"ensembles\"\n",
"CHAPTER_ID = \"ensembles\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
@ -1502,7 +1499,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -4,8 +4,13 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"**Chapter 8 Dimensionality Reduction**\n", "**Chapter 8 Dimensionality Reduction**"
"\n", ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_This notebook contains all the sample code and solutions to the exercises in chapter 8._" "_This notebook contains all the sample code and solutions to the exercises in chapter 8._"
] ]
}, },
@ -34,7 +39,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20." "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -43,20 +48,17 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n",
"# Scikit-Learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n",
"\n", "\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n", "from pathlib import Path\n",
"\n", "\n",
"# to make this notebook's output stable across runs\n", "# Scikit-Learn ≥1.0 is required\n",
"np.random.seed(42)\n", "import sklearn\n",
"assert sklearn.__version__ >= \"1.0\"\n",
"\n", "\n",
"# To plot pretty figures\n", "# To plot pretty figures\n",
"%matplotlib inline\n", "%matplotlib inline\n",
@ -67,14 +69,11 @@
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"dim_reduction\"\n",
"CHAPTER_ID = \"dim_reduction\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
@ -2369,7 +2368,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -4,9 +4,14 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"**Chapter 9 Unsupervised Learning**\n", "**Chapter 9 Unsupervised Learning**"
"\n", ]
"_This notebook contains all the sample code in chapter 9._" },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_This notebook contains all the sample code and solutions to the exercises in chapter 9._"
] ]
}, },
{ {
@ -34,7 +39,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20." "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -43,20 +48,17 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n",
"# Scikit-Learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n",
"\n", "\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n", "from pathlib import Path\n",
"\n", "\n",
"# to make this notebook's output stable across runs\n", "# Scikit-Learn ≥1.0 is required\n",
"np.random.seed(42)\n", "import sklearn\n",
"assert sklearn.__version__ >= \"1.0\"\n",
"\n", "\n",
"# To plot pretty figures\n", "# To plot pretty figures\n",
"%matplotlib inline\n", "%matplotlib inline\n",
@ -67,14 +69,11 @@
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"unsupervised_learning\"\n",
"CHAPTER_ID = \"unsupervised_learning\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
@ -1504,13 +1503,12 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Download the ladybug image\n", "# Download the ladybug image\n",
"images_path = os.path.join(PROJECT_ROOT_DIR, \"images\", \"unsupervised_learning\")\n", "root = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/\"\n",
"os.makedirs(images_path, exist_ok=True)\n",
"DOWNLOAD_ROOT = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/\"\n",
"filename = \"ladybug.png\"\n", "filename = \"ladybug.png\"\n",
"print(\"Downloading\", filename)\n", "if not (images_path / filename).is_file():\n",
"url = DOWNLOAD_ROOT + \"images/unsupervised_learning/\" + filename\n", " print(\"Downloading\", filename)\n",
"urllib.request.urlretrieve(url, os.path.join(images_path, filename))" " url = root + \"images/unsupervised_learning/\" + filename\n",
" urllib.request.urlretrieve(url, os.path.join(images_path, filename))"
] ]
}, },
{ {
@ -3847,7 +3845,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -4,8 +4,13 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"**Chapter 10 Introduction to Artificial Neural Networks with Keras**\n", "**Chapter 10 Introduction to Artificial Neural Networks with Keras**"
"\n", ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_This notebook contains all the sample code and solutions to the exercises in chapter 10._" "_This notebook contains all the sample code and solutions to the exercises in chapter 10._"
] ]
}, },
@ -34,7 +39,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20 and TensorFlow ≥2.0." "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -43,30 +48,25 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n",
"# Scikit-Learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n",
"\n",
"try:\n",
" # %tensorflow_version only exists in Colab.\n",
" %tensorflow_version 2.x\n",
"except Exception:\n",
" pass\n",
"\n",
"# TensorFlow ≥2.0 is required\n",
"import tensorflow as tf\n",
"assert tf.__version__ >= \"2.0\"\n",
"\n", "\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n", "from pathlib import Path\n",
"\n",
"# Scikit-Learn ≥1.0 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"1.0\"\n",
"\n",
"# TensorFlow ≥2.6 is required\n",
"import tensorflow as tf\n",
"assert tf.__version__ >= \"2.6\"\n",
"\n", "\n",
"# to make this notebook's output stable across runs\n", "# to make this notebook's output stable across runs\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"\n", "\n",
"# To plot pretty figures\n", "# To plot pretty figures\n",
"%matplotlib inline\n", "%matplotlib inline\n",
@ -77,17 +77,14 @@
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"ann\"\n",
"CHAPTER_ID = \"ann\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)\n" " plt.savefig(path, format=fig_extension, dpi=resolution)"
] ]
}, },
{ {
@ -1271,23 +1268,23 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 82, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"root_logdir = os.path.join(os.curdir, \"my_logs\")" "root_logdir = Path() / \"my_logs\""
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 83, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def get_run_logdir():\n", "def get_run_logdir():\n",
" import time\n", " import time\n",
" run_id = time.strftime(\"run_%Y_%m_%d-%H_%M_%S\")\n", " run_id = time.strftime(\"run_%Y_%m_%d-%H_%M_%S\")\n",
" return os.path.join(root_logdir, run_id)\n", " return root_logdir / run_id\n",
"\n", "\n",
"run_logdir = get_run_logdir()\n", "run_logdir = get_run_logdir()\n",
"run_logdir" "run_logdir"
@ -1357,7 +1354,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 88, "execution_count": 5,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1936,12 +1933,12 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 125, "execution_count": 6,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"run_index = 1 # increment this at every run\n", "run_index = 1 # increment this at every run\n",
"run_logdir = os.path.join(os.curdir, \"my_mnist_logs\", \"run_{:03d}\".format(run_index))\n", "run_logdir = Path() / \"my_mnist_logs\" / \"run_{:03d}\".format(run_index)\n",
"run_logdir" "run_logdir"
] ]
}, },
@ -1996,7 +1993,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -39,7 +39,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20 and TensorFlow ≥2.0." "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -48,33 +48,28 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n",
"# Scikit-Learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n",
"\n",
"try:\n",
" # %tensorflow_version only exists in Colab.\n",
" %tensorflow_version 2.x\n",
"except Exception:\n",
" pass\n",
"\n",
"# TensorFlow ≥2.0 is required\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"assert tf.__version__ >= \"2.0\"\n",
"\n",
"%load_ext tensorboard\n",
"\n", "\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n", "from pathlib import Path\n",
"\n",
"# Scikit-Learn ≥1.0 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"1.0\"\n",
"\n",
"# TensorFlow ≥2.6 is required\n",
"import tensorflow as tf\n",
"assert tf.__version__ >= \"2.6\"\n",
"\n",
"# Load the Jupyter extension for TensorBoard\n",
"%load_ext tensorboard\n",
"\n", "\n",
"# to make this notebook's output stable across runs\n", "# to make this notebook's output stable across runs\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"\n", "\n",
"# To plot pretty figures\n", "# To plot pretty figures\n",
"%matplotlib inline\n", "%matplotlib inline\n",
@ -85,14 +80,11 @@
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"deep\"\n",
"CHAPTER_ID = \"deep\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
@ -2271,7 +2263,7 @@
"early_stopping_cb = keras.callbacks.EarlyStopping(patience=20)\n", "early_stopping_cb = keras.callbacks.EarlyStopping(patience=20)\n",
"model_checkpoint_cb = keras.callbacks.ModelCheckpoint(\"my_cifar10_model.h5\", save_best_only=True)\n", "model_checkpoint_cb = keras.callbacks.ModelCheckpoint(\"my_cifar10_model.h5\", save_best_only=True)\n",
"run_index = 1 # increment every time you train the model\n", "run_index = 1 # increment every time you train the model\n",
"run_logdir = os.path.join(os.curdir, \"my_cifar10_logs\", \"run_{:03d}\".format(run_index))\n", "run_logdir = Path() / \"my_cifar10_logs\" / \"run_{:03d}\".format(run_index)\n",
"tensorboard_cb = keras.callbacks.TensorBoard(run_logdir)\n", "tensorboard_cb = keras.callbacks.TensorBoard(run_logdir)\n",
"callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]" "callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]"
] ]
@ -2359,7 +2351,7 @@
"early_stopping_cb = keras.callbacks.EarlyStopping(patience=20)\n", "early_stopping_cb = keras.callbacks.EarlyStopping(patience=20)\n",
"model_checkpoint_cb = keras.callbacks.ModelCheckpoint(\"my_cifar10_bn_model.h5\", save_best_only=True)\n", "model_checkpoint_cb = keras.callbacks.ModelCheckpoint(\"my_cifar10_bn_model.h5\", save_best_only=True)\n",
"run_index = 1 # increment every time you train the model\n", "run_index = 1 # increment every time you train the model\n",
"run_logdir = os.path.join(os.curdir, \"my_cifar10_logs\", \"run_bn_{:03d}\".format(run_index))\n", "run_logdir = Path() / \"my_cifar10_logs\" / \"run_bn_{:03d}\".format(run_index)\n",
"tensorboard_cb = keras.callbacks.TensorBoard(run_logdir)\n", "tensorboard_cb = keras.callbacks.TensorBoard(run_logdir)\n",
"callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]\n", "callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]\n",
"\n", "\n",
@ -2416,7 +2408,7 @@
"early_stopping_cb = keras.callbacks.EarlyStopping(patience=20)\n", "early_stopping_cb = keras.callbacks.EarlyStopping(patience=20)\n",
"model_checkpoint_cb = keras.callbacks.ModelCheckpoint(\"my_cifar10_selu_model.h5\", save_best_only=True)\n", "model_checkpoint_cb = keras.callbacks.ModelCheckpoint(\"my_cifar10_selu_model.h5\", save_best_only=True)\n",
"run_index = 1 # increment every time you train the model\n", "run_index = 1 # increment every time you train the model\n",
"run_logdir = os.path.join(os.curdir, \"my_cifar10_logs\", \"run_selu_{:03d}\".format(run_index))\n", "run_logdir = Path() / \"my_cifar10_logs\" / \"run_selu_{:03d}\".format(run_index)\n",
"tensorboard_cb = keras.callbacks.TensorBoard(run_logdir)\n", "tensorboard_cb = keras.callbacks.TensorBoard(run_logdir)\n",
"callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]\n", "callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]\n",
"\n", "\n",
@ -2487,7 +2479,7 @@
"early_stopping_cb = keras.callbacks.EarlyStopping(patience=20)\n", "early_stopping_cb = keras.callbacks.EarlyStopping(patience=20)\n",
"model_checkpoint_cb = keras.callbacks.ModelCheckpoint(\"my_cifar10_alpha_dropout_model.h5\", save_best_only=True)\n", "model_checkpoint_cb = keras.callbacks.ModelCheckpoint(\"my_cifar10_alpha_dropout_model.h5\", save_best_only=True)\n",
"run_index = 1 # increment every time you train the model\n", "run_index = 1 # increment every time you train the model\n",
"run_logdir = os.path.join(os.curdir, \"my_cifar10_logs\", \"run_alpha_dropout_{:03d}\".format(run_index))\n", "run_logdir = Path() / \"my_cifar10_logs\" / \"run_alpha_dropout_{:03d}\".format(run_index)\n",
"tensorboard_cb = keras.callbacks.TensorBoard(run_logdir)\n", "tensorboard_cb = keras.callbacks.TensorBoard(run_logdir)\n",
"callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]\n", "callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]\n",
"\n", "\n",
@ -2704,7 +2696,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -11,7 +11,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"_This notebook contains all the sample code in chapter 12._" "_This notebook contains all the sample code and solutions to the exercises in chapter 12._"
] ]
}, },
{ {
@ -39,7 +39,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20 and TensorFlow ≥2.0." "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -48,29 +48,21 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n",
"# Scikit-Learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n",
"\n",
"try:\n",
" # %tensorflow_version only exists in Colab.\n",
" %tensorflow_version 2.x\n",
"except Exception:\n",
" pass\n",
"\n",
"# TensorFlow ≥2.4 is required in this notebook\n",
"# Earlier 2.x versions will mostly work the same, but with a few bugs\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"assert tf.__version__ >= \"2.4\"\n",
"\n", "\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n", "from pathlib import Path\n",
"\n",
"# Scikit-Learn ≥1.0 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"1.0\"\n",
"\n",
"# TensorFlow ≥2.6 is required\n",
"import tensorflow as tf\n",
"assert tf.__version__ >= \"2.6\"\n",
"\n", "\n",
"# to make this notebook's output stable across runs\n", "# to make this notebook's output stable across runs\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
@ -82,20 +74,7 @@
"import matplotlib.pyplot as plt\n", "import matplotlib.pyplot as plt\n",
"mpl.rc('axes', labelsize=14)\n", "mpl.rc('axes', labelsize=14)\n",
"mpl.rc('xtick', labelsize=12)\n", "mpl.rc('xtick', labelsize=12)\n",
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)"
"\n",
"# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n",
"CHAPTER_ID = \"deep\"\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)"
] ]
}, },
{ {
@ -2726,7 +2705,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 203, "execution_count": 203,
"metadata": {}, "metadata": {
"scrolled": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"for epoch in range(1, n_epochs + 1):\n", "for epoch in range(1, n_epochs + 1):\n",
@ -3980,7 +3961,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -4,8 +4,13 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"**Chapter 13 Loading and Preprocessing Data with TensorFlow**\n", "**Chapter 13 Loading and Preprocessing Data with TensorFlow**"
"\n", ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_This notebook contains all the sample code and solutions to the exercises in chapter 13._" "_This notebook contains all the sample code and solutions to the exercises in chapter 13._"
] ]
}, },
@ -34,7 +39,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20 and TensorFlow ≥2.0." "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -43,33 +48,36 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n", "\n",
"# Is this notebook running on Colab or Kaggle?\n", "# Is this notebook running on Colab or Kaggle?\n",
"IS_COLAB = \"google.colab\" in sys.modules\n", "IS_COLAB = \"google.colab\" in sys.modules\n",
"IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n", "IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n",
"\n", "\n",
"if IS_COLAB or IS_KAGGLE:\n", "if IS_COLAB or IS_KAGGLE:\n",
" %pip install -q -U tfx==0.21.2\n", " %pip install -q -U tfx\n",
" print(\"You can safely ignore the package incompatibility errors.\")\n", " print(\"You can safely ignore the package incompatibility errors.\")\n",
"\n", "\n",
"# Scikit-Learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n",
"\n",
"# TensorFlow ≥2.0 is required\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"assert tf.__version__ >= \"2.0\"\n",
"\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n", "from pathlib import Path\n",
"\n",
"# Scikit-Learn ≥1.0 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"1.0\"\n",
"\n",
"# TensorFlow ≥2.6 is required\n",
"import tensorflow as tf\n",
"assert tf.__version__ >= \"2.6\"\n",
"\n",
"# Load the Jupyter extension for TensorBoard\n",
"%load_ext tensorboard\n",
"\n", "\n",
"# to make this notebook's output stable across runs\n", "# to make this notebook's output stable across runs\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"\n", "\n",
"# To plot pretty figures\n", "# To plot pretty figures\n",
"%matplotlib inline\n", "%matplotlib inline\n",
@ -80,14 +88,11 @@
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"data\"\n",
"CHAPTER_ID = \"data\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
@ -264,9 +269,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"def save_to_multiple_csv_files(data, name_prefix, header=None, n_parts=10):\n", "def save_to_multiple_csv_files(data, name_prefix, header=None, n_parts=10):\n",
" housing_dir = os.path.join(\"datasets\", \"housing\")\n", " housing_dir = Path() / \"datasets\" / \"housing\"\n",
" os.makedirs(housing_dir, exist_ok=True)\n", " housing_dir.mkdir(parents=True, exist_ok=True)\n",
" path_format = os.path.join(housing_dir, \"my_{}_{:02d}.csv\")\n", " path_format = housing_dir / \"my_{}_{:02d}.csv\"\n",
"\n", "\n",
" filepaths = []\n", " filepaths = []\n",
" m = len(data)\n", " m = len(data)\n",
@ -1431,21 +1436,23 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n", "from pathlib import Path\n",
"import tarfile\n", "import tarfile\n",
"import urllib.request\n", "import urllib.request\n",
"import pandas as pd\n",
"\n", "\n",
"DOWNLOAD_ROOT = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/\"\n", "def load_housing_data():\n",
"HOUSING_PATH = os.path.join(\"datasets\", \"housing\")\n", " housing_path = Path() / \"datasets\" / \"housing\"\n",
"HOUSING_URL = DOWNLOAD_ROOT + \"datasets/housing/housing.tgz\"\n", " if not (housing_path / \"housing.csv\").is_file():\n",
"\n", " housing_path.mkdir(parents=True, exist_ok=True)\n",
"def fetch_housing_data(housing_url=HOUSING_URL, housing_path=HOUSING_PATH):\n", " root = \"https://raw.githubusercontent.com/ageron/handson-ml2/master/\"\n",
" os.makedirs(housing_path, exist_ok=True)\n", " url = root + \"datasets/housing/housing.tgz\"\n",
" tgz_path = os.path.join(housing_path, \"housing.tgz\")\n", " tgz_path = housing_path / \"housing.tgz\"\n",
" urllib.request.urlretrieve(housing_url, tgz_path)\n", " urllib.request.urlretrieve(url, tgz_path)\n",
" housing_tgz = tarfile.open(tgz_path)\n", " housing_tgz = tarfile.open(tgz_path)\n",
" housing_tgz.extractall(path=housing_path)\n", " housing_tgz.extractall(path=housing_path)\n",
" housing_tgz.close()" " housing_tgz.close()\n",
" return pd.read_csv(housing_path / \"housing.csv\")"
] ]
}, },
{ {
@ -1453,31 +1460,9 @@
"execution_count": 86, "execution_count": 86,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [
"fetch_housing_data()"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"def load_housing_data(housing_path=HOUSING_PATH):\n",
" csv_path = os.path.join(housing_path, \"housing.csv\")\n",
" return pd.read_csv(csv_path)"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [],
"source": [ "source": [
"housing = load_housing_data()\n", "housing = load_housing_data()\n",
"housing.head()" "head()"
] ]
}, },
{ {
@ -2104,8 +2089,8 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"from datetime import datetime\n", "from datetime import datetime\n",
"logs = os.path.join(os.curdir, \"my_logs\",\n", "\n",
" \"run_\" + datetime.now().strftime(\"%Y%m%d_%H%M%S\"))\n", "logs = Path() / \"my_logs\" / \"run_\" + datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
"\n", "\n",
"tensorboard_cb = tf.keras.callbacks.TensorBoard(\n", "tensorboard_cb = tf.keras.callbacks.TensorBoard(\n",
" log_dir=logs, histogram_freq=1, profile_batch=10)\n", " log_dir=logs, histogram_freq=1, profile_batch=10)\n",
@ -2144,33 +2129,56 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 131, "execution_count": 71,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from pathlib import Path\n", "from pathlib import Path\n",
"\n", "\n",
"DOWNLOAD_ROOT = \"http://ai.stanford.edu/~amaas/data/sentiment/\"\n", "root = \"http://ai.stanford.edu/~amaas/data/sentiment/\"\n",
"FILENAME = \"aclImdb_v1.tar.gz\"\n", "filename = \"aclImdb_v1.tar.gz\"\n",
"filepath = keras.utils.get_file(FILENAME, DOWNLOAD_ROOT + FILENAME, extract=True)\n", "filepath = keras.utils.get_file(filename, root + filename, extract=True)\n",
"path = Path(filepath).parent / \"aclImdb\"\n", "path = Path(filepath).with_name(\"aclImdb\")\n",
"path" "path"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's define a `tree()` function to view the structure of the `aclImdb` directory:"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 132, "execution_count": 76,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"for name, subdirs, files in os.walk(path):\n", "def tree(path, level=0, indent=4, max_files=3):\n",
" indent = len(Path(name).parts) - len(path.parts)\n", " if level == 0:\n",
" print(\" \" * indent + Path(name).parts[-1] + os.sep)\n", " print(f\"{path}/\")\n",
" for index, filename in enumerate(sorted(files)):\n", " level += 1\n",
" if index == 3:\n", " sub_paths = sorted(path.iterdir())\n",
" print(\" \" * (indent + 1) + \"...\")\n", " sub_dirs = [sub_path for sub_path in sub_paths if sub_path.is_dir()]\n",
" break\n", " filepaths = [sub_path for sub_path in sub_paths if not sub_path in sub_dirs]\n",
" print(\" \" * (indent + 1) + filename)" " indent_str = \" \" * indent * level\n",
" for sub_dir in sub_dirs:\n",
" print(f\"{indent_str}{sub_dir.name}/\")\n",
" tree(sub_dir, level + 1, indent)\n",
" for filepath in filepaths[:max_files]:\n",
" print(f\"{indent_str}{filepath.name}\")\n",
" if len(filepaths) > max_files:\n",
" print(f\"{indent_str}...\")"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
"tree(path)"
] ]
}, },
{ {
@ -2771,7 +2779,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -11,7 +11,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"_This notebook contains all the sample code in chapter 14._" "_This notebook contains all the sample code and solutions to the exercises in chapter 14._"
] ]
}, },
{ {
@ -39,7 +39,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20 and TensorFlow ≥2.0." "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -48,38 +48,37 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n", "\n",
"# Is this notebook running on Colab or Kaggle?\n", "# Is this notebook running on Colab or Kaggle?\n",
"IS_COLAB = \"google.colab\" in sys.modules\n", "IS_COLAB = \"google.colab\" in sys.modules\n",
"IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n", "IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n",
"\n", "\n",
"# Scikit-Learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n",
"\n",
"# TensorFlow ≥2.0 is required\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"assert tf.__version__ >= \"2.0\"\n",
"\n",
"if not tf.config.list_physical_devices('GPU'):\n",
" print(\"No GPU was detected. CNNs can be very slow without a GPU.\")\n",
" if IS_COLAB:\n",
" print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
" if IS_KAGGLE:\n",
" print(\"Go to Settings > Accelerator and select GPU.\")\n",
"\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n", "from pathlib import Path\n",
"\n",
"# Scikit-Learn ≥1.0 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"1.0\"\n",
"\n",
"# TensorFlow ≥2.6 is required\n",
"import tensorflow as tf\n",
"assert tf.__version__ >= \"2.6\"\n",
"\n", "\n",
"# to make this notebook's output stable across runs\n", "# to make this notebook's output stable across runs\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
"tf.random.set_seed(42)\n", "tf.random.set_seed(42)\n",
"\n", "\n",
"if not tf.config.list_physical_devices('GPU'):\n",
" print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n",
" if IS_COLAB:\n",
" print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
" if IS_KAGGLE:\n",
" print(\"Go to Settings > Accelerator and select GPU.\")\n",
"\n",
"# To plot pretty figures\n", "# To plot pretty figures\n",
"%matplotlib inline\n", "%matplotlib inline\n",
"import matplotlib as mpl\n", "import matplotlib as mpl\n",
@ -89,14 +88,11 @@
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"cnn\"\n",
"CHAPTER_ID = \"cnn\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
@ -1446,7 +1442,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -11,7 +11,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"_This notebook contains all the sample code in chapter 15._" "_This notebook contains all the sample code and solutions to the exercises in chapter 15._"
] ]
}, },
{ {
@ -39,7 +39,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20 and TensorFlow ≥2.0." "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -48,39 +48,37 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n", "\n",
"# Is this notebook running on Colab or Kaggle?\n", "# Is this notebook running on Colab or Kaggle?\n",
"IS_COLAB = \"google.colab\" in sys.modules\n", "IS_COLAB = \"google.colab\" in sys.modules\n",
"IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n", "IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n",
"\n", "\n",
"# Scikit-Learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n",
"\n",
"# TensorFlow ≥2.0 is required\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"assert tf.__version__ >= \"2.0\"\n",
"\n",
"if not tf.config.list_physical_devices('GPU'):\n",
" print(\"No GPU was detected. LSTMs and CNNs can be very slow without a GPU.\")\n",
" if IS_COLAB:\n",
" print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
" if IS_KAGGLE:\n",
" print(\"Go to Settings > Accelerator and select GPU.\")\n",
"\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n",
"from pathlib import Path\n", "from pathlib import Path\n",
"\n", "\n",
"# Scikit-Learn ≥1.0 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"1.0\"\n",
"\n",
"# TensorFlow ≥2.6 is required\n",
"import tensorflow as tf\n",
"assert tf.__version__ >= \"2.6\"\n",
"\n",
"# to make this notebook's output stable across runs\n", "# to make this notebook's output stable across runs\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
"tf.random.set_seed(42)\n", "tf.random.set_seed(42)\n",
"\n", "\n",
"if not tf.config.list_physical_devices('GPU'):\n",
" print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n",
" if IS_COLAB:\n",
" print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
" if IS_KAGGLE:\n",
" print(\"Go to Settings > Accelerator and select GPU.\")\n",
"\n",
"# To plot pretty figures\n", "# To plot pretty figures\n",
"%matplotlib inline\n", "%matplotlib inline\n",
"import matplotlib as mpl\n", "import matplotlib as mpl\n",
@ -90,14 +88,11 @@
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"rnn\"\n",
"CHAPTER_ID = \"rnn\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
@ -2019,7 +2014,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -11,7 +11,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"_This notebook contains all the sample code in chapter 16._" "_This notebook contains all the sample code and solutions to the exercises in chapter 16._"
] ]
}, },
{ {
@ -39,7 +39,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20 and TensorFlow ≥2.0." "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -48,9 +48,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n", "\n",
"# Is this notebook running on Colab or Kaggle?\n", "# Is this notebook running on Colab or Kaggle?\n",
"IS_COLAB = \"google.colab\" in sys.modules\n", "IS_COLAB = \"google.colab\" in sys.modules\n",
@ -60,17 +60,20 @@
" %pip install -q -U tensorflow-addons\n", " %pip install -q -U tensorflow-addons\n",
" %pip install -q -U transformers\n", " %pip install -q -U transformers\n",
"\n", "\n",
"# Scikit-Learn ≥0.20 is required\n", "# Common imports\n",
"import sklearn\n", "import numpy as np\n",
"assert sklearn.__version__ >= \"0.20\"\n", "from pathlib import Path\n",
"\n", "\n",
"# TensorFlow ≥2.0 is required\n", "# Scikit-Learn ≥1.0 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"1.0\"\n",
"\n",
"# TensorFlow ≥2.6 is required\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"from tensorflow import keras\n", "assert tf.__version__ >= \"2.6\"\n",
"assert tf.__version__ >= \"2.0\"\n",
"\n", "\n",
"if not tf.config.list_physical_devices('GPU'):\n", "if not tf.config.list_physical_devices('GPU'):\n",
" print(\"No GPU was detected. LSTMs and CNNs can be very slow without a GPU.\")\n", " print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n",
" if IS_COLAB:\n", " if IS_COLAB:\n",
" print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n", " print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
" if IS_KAGGLE:\n", " if IS_KAGGLE:\n",
@ -78,11 +81,7 @@
"\n", "\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n", "from pathlib import Path\n",
"\n",
"# to make this notebook's output stable across runs\n",
"np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"\n", "\n",
"# To plot pretty figures\n", "# To plot pretty figures\n",
"%matplotlib inline\n", "%matplotlib inline\n",
@ -93,14 +92,11 @@
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"nlp\"\n",
"CHAPTER_ID = \"nlp\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
@ -911,7 +907,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 60, "execution_count": 7,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -920,17 +916,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 61, "execution_count": 10,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"TFHUB_CACHE_DIR = os.path.join(os.curdir, \"my_tfhub_cache\")\n", "tfhub_cache_dir = Path() / \"my_tfhub_cache\"\n",
"os.environ[\"TFHUB_CACHE_DIR\"] = TFHUB_CACHE_DIR" "os.environ[\"TFHUB_CACHE_DIR\"] = str(tfhub_cache_dir)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 62, "execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -946,15 +942,41 @@
" metrics=[\"accuracy\"])" " metrics=[\"accuracy\"])"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's define a `tree()` function to view the structure of the cache directory TF Hub just created:"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 63, "execution_count": 33,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"for dirpath, dirnames, filenames in os.walk(TFHUB_CACHE_DIR):\n", "def tree(path, level=0, indent=4):\n",
" for filename in filenames:\n", " if level == 0:\n",
" print(os.path.join(dirpath, filename))" " print(f\"{path}/\")\n",
" level += 1\n",
" sub_paths = sorted(path.iterdir())\n",
" sub_dirs = [sub_path for sub_path in sub_paths if sub_path.is_dir()]\n",
" filepaths = [sub_path for sub_path in sub_paths if not sub_path in sub_dirs]\n",
" indent_str = \" \" * indent * level\n",
" for sub_dir in sub_dirs:\n",
" print(f\"{indent_str}{sub_dir.name}/\")\n",
" tree(sub_dir, level + 1, indent)\n",
" for filepath in filepaths:\n",
" print(f\"{indent_str}{filepath.name}\")"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"tree(tfhub_cache_dir)"
] ]
}, },
{ {
@ -2743,7 +2765,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -11,7 +11,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"_This notebook contains all the sample code in chapter 17._" "_This notebook contains all the sample code and solutions to the exercises in chapter 17._"
] ]
}, },
{ {
@ -39,7 +39,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20 and TensorFlow ≥2.0." "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -48,38 +48,37 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n", "\n",
"# Is this notebook running on Colab or Kaggle?\n", "# Is this notebook running on Colab or Kaggle?\n",
"IS_COLAB = \"google.colab\" in sys.modules\n", "IS_COLAB = \"google.colab\" in sys.modules\n",
"IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n", "IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n",
"\n", "\n",
"# Scikit-Learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n",
"\n",
"# TensorFlow ≥2.0 is required\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"assert tf.__version__ >= \"2.0\"\n",
"\n",
"if not tf.config.list_physical_devices('GPU'):\n",
" print(\"No GPU was detected. LSTMs and CNNs can be very slow without a GPU.\")\n",
" if IS_COLAB:\n",
" print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
" if IS_KAGGLE:\n",
" print(\"Go to Settings > Accelerator and select GPU.\")\n",
"\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n", "from pathlib import Path\n",
"\n",
"# Scikit-Learn ≥1.0 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"1.0\"\n",
"\n",
"# TensorFlow ≥2.6 is required\n",
"import tensorflow as tf\n",
"assert tf.__version__ >= \"2.6\"\n",
"\n", "\n",
"# to make this notebook's output stable across runs\n", "# to make this notebook's output stable across runs\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
"tf.random.set_seed(42)\n", "tf.random.set_seed(42)\n",
"\n", "\n",
"if not tf.config.list_physical_devices('GPU'):\n",
" print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n",
" if IS_COLAB:\n",
" print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
" if IS_KAGGLE:\n",
" print(\"Go to Settings > Accelerator and select GPU.\")\n",
"\n",
"# To plot pretty figures\n", "# To plot pretty figures\n",
"%matplotlib inline\n", "%matplotlib inline\n",
"import matplotlib as mpl\n", "import matplotlib as mpl\n",
@ -89,14 +88,11 @@
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"autoencoders\"\n",
"CHAPTER_ID = \"autoencoders\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
@ -1761,7 +1757,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -11,7 +11,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"_This notebook contains all the sample code in chapter 18_." "_This notebook contains all the sample code and solutions to the exercises in chapter 18._"
] ]
}, },
{ {
@ -32,8 +32,14 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Setup\n", "# Setup"
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20 and TensorFlow ≥2.0." ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -42,9 +48,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n", "\n",
"# Is this notebook running on Colab or Kaggle?\n", "# Is this notebook running on Colab or Kaggle?\n",
"IS_COLAB = \"google.colab\" in sys.modules\n", "IS_COLAB = \"google.colab\" in sys.modules\n",
@ -55,30 +61,29 @@
" %pip install -q -U tf-agents pyvirtualdisplay gym[box2d]\n", " %pip install -q -U tf-agents pyvirtualdisplay gym[box2d]\n",
" %pip install -q -U atari_py==0.2.5\n", " %pip install -q -U atari_py==0.2.5\n",
"\n", "\n",
"# Scikit-Learn ≥0.20 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"0.20\"\n",
"\n",
"# TensorFlow ≥2.0 is required\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"assert tf.__version__ >= \"2.0\"\n",
"\n",
"if not tf.config.list_physical_devices('GPU'):\n",
" print(\"No GPU was detected. CNNs can be very slow without a GPU.\")\n",
" if IS_COLAB:\n",
" print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
" if IS_KAGGLE:\n",
" print(\"Go to Settings > Accelerator and select GPU.\")\n",
"\n",
"# Common imports\n", "# Common imports\n",
"import numpy as np\n", "import numpy as np\n",
"import os\n", "from pathlib import Path\n",
"\n",
"# Scikit-Learn ≥1.0 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"1.0\"\n",
"\n",
"# TensorFlow ≥2.6 is required\n",
"import tensorflow as tf\n",
"assert tf.__version__ >= \"2.6\"\n",
"\n", "\n",
"# to make this notebook's output stable across runs\n", "# to make this notebook's output stable across runs\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
"tf.random.set_seed(42)\n", "tf.random.set_seed(42)\n",
"\n", "\n",
"if not tf.config.list_physical_devices('GPU'):\n",
" print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n",
" if IS_COLAB:\n",
" print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
" if IS_KAGGLE:\n",
" print(\"Go to Settings > Accelerator and select GPU.\")\n",
"\n",
"# To plot pretty figures\n", "# To plot pretty figures\n",
"%matplotlib inline\n", "%matplotlib inline\n",
"import matplotlib as mpl\n", "import matplotlib as mpl\n",
@ -92,14 +97,11 @@
"mpl.rc('animation', html='jshtml')\n", "mpl.rc('animation', html='jshtml')\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"rl\"\n",
"CHAPTER_ID = \"rl\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
@ -2574,7 +2576,7 @@
"source": [ "source": [
"import PIL\n", "import PIL\n",
"\n", "\n",
"image_path = os.path.join(\"images\", \"rl\", \"breakout.gif\")\n", "image_path = Path() / \"images\" / \"rl\" / \"breakout.gif\"\n",
"frame_images = [PIL.Image.fromarray(frame) for frame in frames[:150]]\n", "frame_images = [PIL.Image.fromarray(frame) for frame in frames[:150]]\n",
"frame_images[0].save(image_path, format='GIF',\n", "frame_images[0].save(image_path, format='GIF',\n",
" append_images=frame_images[1:],\n", " append_images=frame_images[1:],\n",
@ -3208,7 +3210,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View File

@ -11,7 +11,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"_This notebook contains all the sample code in chapter 19._" "_This notebook contains all the sample code and solutions to the exercises in chapter 19._"
] ]
}, },
{ {
@ -32,8 +32,14 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Setup\n", "# Setup"
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20 and TensorFlow ≥2.0.\n" ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures."
] ]
}, },
{ {
@ -42,9 +48,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"# Python ≥3.5 is required\n", "# Python ≥3.7 is required\n",
"import sys\n", "import sys\n",
"assert sys.version_info >= (3, 5)\n", "assert sys.version_info >= (3, 7)\n",
"\n", "\n",
"# Is this notebook running on Colab or Kaggle?\n", "# Is this notebook running on Colab or Kaggle?\n",
"IS_COLAB = \"google.colab\" in sys.modules\n", "IS_COLAB = \"google.colab\" in sys.modules\n",
@ -56,30 +62,31 @@
" !apt update && apt-get install -y tensorflow-model-server\n", " !apt update && apt-get install -y tensorflow-model-server\n",
" %pip install -q -U tensorflow-serving-api\n", " %pip install -q -U tensorflow-serving-api\n",
"\n", "\n",
"# Scikit-Learn ≥0.20 is required\n", "# Common imports\n",
"import sklearn\n", "import os\n",
"assert sklearn.__version__ >= \"0.20\"\n", "import numpy as np\n",
"from pathlib import Path\n",
"\n", "\n",
"# TensorFlow ≥2.0 is required\n", "# Scikit-Learn ≥1.0 is required\n",
"import sklearn\n",
"assert sklearn.__version__ >= \"1.0\"\n",
"\n",
"# TensorFlow ≥2.6 is required\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"from tensorflow import keras\n", "from tensorflow import keras\n",
"assert tf.__version__ >= \"2.0\"\n", "assert tf.__version__ >= \"2.6\"\n",
"\n",
"if not tf.config.list_physical_devices('GPU'):\n",
" print(\"No GPU was detected. CNNs can be very slow without a GPU.\")\n",
" if IS_COLAB:\n",
" print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
" if IS_KAGGLE:\n",
" print(\"Go to Settings > Accelerator and select GPU.\")\n",
"\n",
"# Common imports\n",
"import numpy as np\n",
"import os\n",
"\n", "\n",
"# to make this notebook's output stable across runs\n", "# to make this notebook's output stable across runs\n",
"np.random.seed(42)\n", "np.random.seed(42)\n",
"tf.random.set_seed(42)\n", "tf.random.set_seed(42)\n",
"\n", "\n",
"if not tf.config.list_physical_devices('GPU'):\n",
" print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n",
" if IS_COLAB:\n",
" print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
" if IS_KAGGLE:\n",
" print(\"Go to Settings > Accelerator and select GPU.\")\n",
"\n",
"# To plot pretty figures\n", "# To plot pretty figures\n",
"%matplotlib inline\n", "%matplotlib inline\n",
"import matplotlib as mpl\n", "import matplotlib as mpl\n",
@ -89,14 +96,11 @@
"mpl.rc('ytick', labelsize=12)\n", "mpl.rc('ytick', labelsize=12)\n",
"\n", "\n",
"# Where to save the figures\n", "# Where to save the figures\n",
"PROJECT_ROOT_DIR = \".\"\n", "IMAGES_PATH = Path() / \"images\" / \"deploy\"\n",
"CHAPTER_ID = \"deploy\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
"\n", "\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n", "def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n", " path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" print(\"Saving figure\", fig_id)\n",
" if tight_layout:\n", " if tight_layout:\n",
" plt.tight_layout()\n", " plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)" " plt.savefig(path, format=fig_extension, dpi=resolution)"
@ -168,7 +172,7 @@
"source": [ "source": [
"model_version = \"0001\"\n", "model_version = \"0001\"\n",
"model_name = \"my_mnist_model\"\n", "model_name = \"my_mnist_model\"\n",
"model_path = os.path.join(model_name, model_version)\n", "model_path = Path() / model_name / model_version\n",
"model_path" "model_path"
] ]
}, },
@ -187,7 +191,14 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"tf.saved_model.save(model, model_path)" "tf.saved_model.save(model, str(model_path))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's define a `tree()` function to view the structure of the `my_mnist_model` directory:"
] ]
}, },
{ {
@ -196,20 +207,19 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"for root, dirs, files in os.walk(model_name):\n", "def tree(path, level=0, indent=4):\n",
" indent = ' ' * root.count(os.sep)\n", " if level == 0:\n",
" print('{}{}/'.format(indent, os.path.basename(root)))\n", " print(f\"{path}/\")\n",
" for filename in files:\n", " level += 1\n",
" print('{}{}'.format(indent + ' ', filename))" " sub_paths = sorted(path.iterdir())\n",
] " sub_dirs = [sub_path for sub_path in sub_paths if sub_path.is_dir()]\n",
}, " filepaths = [sub_path for sub_path in sub_paths if not sub_path in sub_dirs]\n",
{ " indent_str = \" \" * indent * level\n",
"cell_type": "code", " for sub_dir in sub_dirs:\n",
"execution_count": 9, " print(f\"{indent_str}{sub_dir.name}/\")\n",
"metadata": {}, " tree(sub_dir, level + 1, indent)\n",
"outputs": [], " for filepath in filepaths:\n",
"source": [ " print(f\"{indent_str}{filepath.name}\")"
"!saved_model_cli show --dir {model_path}"
] ]
}, },
{ {
@ -218,7 +228,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"!saved_model_cli show --dir {model_path} --tag_set serve" "tree(model_path.parent)"
] ]
}, },
{ {
@ -226,6 +236,24 @@
"execution_count": 11, "execution_count": 11,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [
"!saved_model_cli show --dir {model_path}"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"!saved_model_cli show --dir {model_path} --tag_set serve"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [ "source": [
"!saved_model_cli show --dir {model_path} --tag_set serve \\\n", "!saved_model_cli show --dir {model_path} --tag_set serve \\\n",
" --signature_def serving_default" " --signature_def serving_default"
@ -233,7 +261,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 14,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -249,7 +277,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 15,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -258,7 +286,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 16,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -275,7 +303,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 17,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -286,16 +314,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 19,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"np.round([[1.1347984e-04, 1.5187356e-07, 9.7032893e-04, 2.7640699e-03, 3.7826971e-06,\n", "np.round(\n",
" 7.6876910e-05, 3.9140293e-08, 9.9559116e-01, 5.3502394e-05, 4.2665208e-04],\n", " [[1.14172166e-04, 1.51857336e-07, 9.79080913e-04, 2.77538411e-03,\n",
" [8.2443521e-04, 3.5493889e-05, 9.8826385e-01, 7.0466995e-03, 1.2957400e-07,\n", " 3.75553282e-06, 7.66718149e-05, 3.91490929e-08, 9.95566308e-01,\n",
" 2.3389691e-04, 2.5639210e-03, 9.5886099e-10, 1.0314899e-03, 8.7952529e-08],\n", " 5.34432293e-05, 4.30987304e-04],\n",
" [4.4693781e-05, 9.7028232e-01, 9.0526715e-03, 2.2641101e-03, 4.8766597e-04,\n", " [8.14584550e-04, 3.54881959e-05, 9.88290966e-01, 7.04165967e-03,\n",
" 2.8800720e-03, 2.2714981e-03, 8.3753867e-03, 4.0439744e-03, 2.9759688e-04]], 2)" " 1.27466748e-07, 2.31963830e-04, 2.55776616e-03, 9.73469416e-10,\n",
" 1.02734682e-03, 8.74494361e-08],\n",
" [4.42889832e-05, 9.70350444e-01, 9.02883708e-03, 2.26117787e-03,\n",
" 4.85437602e-04, 2.87237833e-03, 2.26676138e-03, 8.35481752e-03,\n",
" 4.03870409e-03, 2.97143910e-04]], 2)"
] ]
}, },
{ {
@ -332,11 +364,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 30,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"os.environ[\"MODEL_DIR\"] = os.path.split(os.path.abspath(model_path))[0]" "os.environ[\"MODEL_DIR\"] = str(model_path.absolute().parent)"
] ]
}, },
{ {
@ -519,7 +551,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30, "execution_count": 32,
"metadata": { "metadata": {
"scrolled": true "scrolled": true
}, },
@ -542,36 +574,32 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 31, "execution_count": 33,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"model_version = \"0002\"\n", "model_version = \"0002\"\n",
"model_name = \"my_mnist_model\"\n", "model_name = \"my_mnist_model\"\n",
"model_path = os.path.join(model_name, model_version)\n", "model_path = Path() / model_name / model_version\n",
"model_path" "model_path"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 32, "execution_count": 35,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"tf.saved_model.save(model, model_path)" "tf.saved_model.save(model, str(model_path))"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 33, "execution_count": 36,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"for root, dirs, files in os.walk(model_name):\n", "tree(model_path.parent)"
" indent = ' ' * root.count(os.sep)\n",
" print('{}{}/'.format(indent, os.path.basename(root)))\n",
" for filename in files:\n",
" print('{}{}'.format(indent + ' ', filename))"
] ]
}, },
{ {
@ -955,7 +983,6 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import os\n",
"import json\n", "import json\n",
"\n", "\n",
"os.environ[\"TF_CONFIG\"] = json.dumps({\n", "os.environ[\"TF_CONFIG\"] = json.dumps({\n",
@ -1028,7 +1055,6 @@
"source": [ "source": [
"%%writefile my_mnist_multiworker_task.py\n", "%%writefile my_mnist_multiworker_task.py\n",
"\n", "\n",
"import os\n",
"import numpy as np\n", "import numpy as np\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"from tensorflow import keras\n", "from tensorflow import keras\n",
@ -1042,9 +1068,9 @@
"\n", "\n",
"# Only worker #0 will write checkpoints and log to TensorBoard\n", "# Only worker #0 will write checkpoints and log to TensorBoard\n",
"if resolver.task_id == 0:\n", "if resolver.task_id == 0:\n",
" root_logdir = os.path.join(os.curdir, \"my_mnist_multiworker_logs\")\n", " root_logdir = Path() / \"my_mnist_multiworker_logs\"\n",
" run_id = time.strftime(\"run_%Y_%m_%d-%H_%M_%S\")\n", " run_id = time.strftime(\"run_%Y_%m_%d-%H_%M_%S\")\n",
" run_dir = os.path.join(root_logdir, run_id)\n", " run_dir = root_logdir / run_id\n",
" callbacks = [\n", " callbacks = [\n",
" keras.callbacks.TensorBoard(run_dir),\n", " keras.callbacks.TensorBoard(run_dir),\n",
" keras.callbacks.ModelCheckpoint(\"my_mnist_multiworker_model.h5\",\n", " keras.callbacks.ModelCheckpoint(\"my_mnist_multiworker_model.h5\",\n",
@ -1240,7 +1266,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },