handson-ml/10_neural_nets_with_keras.i...

3808 lines
463 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Chapter 10 Introduction to Artificial Neural Networks with Keras**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_This notebook contains all the sample code and solutions to the exercises in chapter 10._"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/10_neural_nets_with_keras.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/10_neural_nets_with_keras.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 11:03:20 +01:00
"This project requires Python 3.7 or above:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
2022-02-19 11:03:20 +01:00
"assert sys.version_info >= (3, 7)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It also requires Scikit-Learn ≥ 1.0.1:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from packaging import version\n",
"import sklearn\n",
"\n",
"assert version.parse(sklearn.__version__) >= version.parse(\"1.0.1\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-28 23:41:27 +01:00
"And TensorFlow ≥ 2.8:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"assert version.parse(tf.__version__) >= version.parse(\"2.8.0\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we did in previous chapters, let's define the default font sizes to make the figures prettier:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"plt.rc('font', size=14)\n",
"plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n",
"plt.rc('xtick', labelsize=10)\n",
"plt.rc('ytick', labelsize=10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And let's create the `images/ann` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"IMAGES_PATH = Path() / \"images\" / \"ann\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# From Biological to Artificial Neurons\n",
"## The Perceptron"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.linear_model import Perceptron\n",
"\n",
"iris = load_iris(as_frame=True)\n",
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = (iris.target == 0) # Iris setosa\n",
"\n",
"per_clf = Perceptron(random_state=42)\n",
"per_clf.fit(X, y)\n",
"\n",
"X_new = [[2, 0.5], [3, 1]]\n",
"y_pred = per_clf.predict(X_new) # predicts True and False for these 2 flowers"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([ True, False])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `Perceptron` is equivalent to a `SGDClassifier` with `loss=\"perceptron\"`, no regularization, and a constant learning rate equal to 1:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# extra code shows how to build and train a Perceptron\n",
"\n",
"from sklearn.linear_model import SGDClassifier\n",
"\n",
"sgd_clf = SGDClassifier(loss=\"perceptron\", penalty=None,\n",
" learning_rate=\"constant\", eta0=1, random_state=42)\n",
"sgd_clf.fit(X, y)\n",
"assert (sgd_clf.coef_ == per_clf.coef_).all()\n",
"assert (sgd_clf.intercept_ == per_clf.intercept_).all()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When the Perceptron finds a decision boundary that properly separates the classes, it stops learning. This means that the decision boundary is often quite close to one class:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAccAAADYCAYAAACeCyhIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA+RElEQVR4nO3deZzNZfvA8c81YxjbkJjJMnaSNZHlIfRkSxSP9NMilSUVLUpJpcxgpKwT2bVrkxayPCWEMDMMKk9RZM2Ssi8zXL8/zpnVLMecbWZc79frvMy5v9t1TnLN/f3e132LqmKMMcaYFAH+DsAYY4zJbSw5GmOMMelYcjTGGGPSseRojDHGpGPJ0RhjjEnHkqMxxhiTjs+So4iEi8h3IrJNRH4SkScy2EdEZLKI7BCRLSJyQ6ptHUXkF+e2ob6K2xhjzJXHlz3HROBpVb0OaAY8JiK10+1zK1DD+eoPvAkgIoHAFOf22sDdGRxrjDHGeITPkqOqHlDVjc6fTwDbgPLpdrsDeEcd1gElRaQs0ATYoaq/q+p54EPnvsYYY4zH+eWZo4hUBhoC69NtKg/sSfV+r7Mts3ZjjDHG4wr4+oIiUgyYDzypqsfTb87gEM2iPaPz98dxS5ZChYo2Cgur5Ua0xhhjcrvduzPfVrFi+u27UD2SUU5Jw6fJUUSCcCTG91X1swx22QuEp3pfAdgPFMyk/RKqOgOYAVCpUmMdNizWA5EbY4zJrQYMyHzbsGHptzd26Zy+HK0qwGxgm6qOz2S3L4H7naNWmwHHVPUAEAPUEJEqIlIQ6Onc1xhjjPE4X/YcWwC9gK0iEu9sGwZUBFDVacDXQCdgB3AaeNC5LVFEBgJLgUBgjqr+5MPYjTHG5FIhIXA8/UM6Z3tW27Mi+XnJKrutaowxJrUBAyROVbO9t+rzATnGGGOMLz37bOqeY6NGrhxj08cZY4zJ1y73lipYcjTGGGMuYcnRGGOMSceSozHGGJOOJUdjjDEmHUuOxhhj8rWkesfLYaUcxhhj8rWxY1N+HjAgLs6VY/J1cjxyZCcnT/5FsWJX+zsUY4wxbshq/lQRyGg+GxF4882cXS9f31Y9ffooERF1iI//3N+hGGOM8ZLMJnpzZwK4fJ0cAY4fP8i0ad2YPfseTp484u9wjDHG5AH5OjkGBQUm/xwTM48RI+qwaVNGK2UZY4wxKXy5ZNUcETkkIj9msn2IiMQ7Xz+KyAURKeXctktEtjq3uTyTeO3aFenV6+bk9ydOHGL69O7MmtWTEycOu/2ZjDHG5E++7Dm+BXTMbKOqvqaq16vq9cDzwEpVPZpql5ud211bqRIIDAxg9uwn+PzzFylXrlRye2zsR0RE1CEu7tPL/hDGGGPyP58lR1VdBRzNdkeHu4F5nrp2p06NiY+fTO/etyS3nThxmJkzezBjxl0cP37IU5cyxhjjYyKX1+7SOX25nqOIVAYWqmrdLPYpAuwFqif1HEVkJ/A3oMB0VZ2RxfH9gf4AFSuWabRjx8w025csieORR6ayb99fyW3FipWmZ88pNG58V04/mjHGmDzA1fUcc2Ny/D/gPlXtkqqtnKruF5FQ4L/AIGdPNEuNGlXXdevGXdJ+7Ngpnn12LnPnfpOmvWHD7tx99xRCQsJc/kzGGGPSr5mYIiQkbRG+P6SNrTGqsdn2KXPjaNWepLulqqr7nX8eAhYATdy5QIkSRZk+fSALF75MhQopEwRs2jSfESPqEBPzIb78pcEYY/K6zNZMzMlaip6W59dzFJESQGvgi1RtRUWkeNLPQHsgwxGvl6t9+4Zs2jSZhx5ql9x26tRfzJ59N9Ond+fYsT89cRljjDF5jC9LOeYBPwDXisheEekjIgNEJPWkQN2AZap6KlVbGLBaRDYDG4BFqrrEU3GVKFGUadMeY9GilwkPL53cHh+/gIiIOmzY8IH1Io0x5grjs7lVVfVuF/Z5C0fJR+q234EG3okqRbt2jl7k0KFvMWvWMgBOnTrKnDn3Ehf3Cffc8yYlSlzj7TCMMcbkArnqtqq/hYQUYerUR1m8eASVKpVJbt+8+XNGjKjN+vXvWy/SGGOuAJYcM3DLLQ3YuHEy/funzFlw+vTfzJ17H2++2ZVjxw74MTpjjMl9MlszMSdrKXpaTmLwaSmHr2VWynE5li/fzMMPv8Eff6RMN1ekSEnuumsyTZveh7hTZWqMMcancmWdo695IjkCnDhxhhdeeIdp0xanaa9XrzP33judkiXLuX0NY4zJy9ytc/RVnaSrydFuq7qgePHCTJ78MMuWRVKlSsoEAVu3LiQiog4//PC2PYs0xlzR3K1zzG11kpYcL0ObNvWIi5vIo492Sm47ffof3n77AaZM6czff+/zY3TGGGM8xZLjZSpWrDATJ/bnm29GUrVqSi/yxx+/JiKiDmvXzrVepDHG5HGWHHOoVau6xMVNYuDAzsltZ84c4513HuKNNzpx9OgeP0ZnjDHGHZYc3VC0aDDjx/fl229HUa1aygQBP/20hIiIuqxZM9t6kcYYkwdZcvSAm26qQ1zcJB5/vEtyacfZs8d5992+REd35OjR3X6O0BhjvMvdOsfcVidppRwetmbNz/Tr9wY7duxPbgsOLk737uNo2bKv1UUaY4wf5bo6RxGZA3QGDmW0nqOItMGxGsdOZ9Nnqhrh3NYRmAQEArNUdYwr1/RHcgQ4ffocr7zyPpMmfZXmtup117XjvvtmcvXVlXwekzHGeLOWcMCAzLdNm5b9td2J7ZFHIKNUJgJvvpn713N8C+iYzT7fq+r1zldSYgwEpgC3ArWBu0WktlcjdVORIoUYO/YhVqyIokaNlAkCtm37LxERdVm1aro9izTG+Jw/awmzu7Y7sWX2z2lSe65ez1FVVwFHc3BoE2CHqv6uqueBD4E7PBqclzRvXovY2AkMHtw1+XbquXMn+eCDAUya1I4jR3b5N0BjjDEZym0DcpqLyGYRWSwidZxt5YHUdRF7nW0ZEpH+IhIrIrFHjvh/CerChQsxZswDrFwZRc2aKWH/73/fEhlZj5Ur3+TixYt+jNAYY0x6uSk5bgQqqWoDIBr43Nme0b3hTO9JquoMVW2sqo1Ll84F08E7NWtWi5iY8Qwe3JWAAMfXfu7cSebNe5RJk9py5MjObM5gjDHGV3JNclTV46p60vnz10CQiJTG0VMMT7VrBWB/BqfI9ZJ6katWjeHaayskt//yy3dERtZjxYqp1os0xphcINckRxG5RpwP5kSkCY7Y/gJigBoiUkVECgI9gS/9F6n7mjSpSUzMeJ555j+pepGn+PDDx5g48RYOH/7dzxEaY/Ijf9YSZndtd2LLrEIuqT1Xr+coIvOANkBp4CDwMhAEoKrTRGQg8AiQCJwBBqvqWuexnYCJOEo55qjqKFeu6a9SjssRE/MrfftGs21bymPVggWL0K3bq7Ru/Why8jTGGOM+j9Y5ikgw8ARwCxBKuh6nqtbPYZxelReSI8DZs+cZOfIjXn99QZrbqjVqtOL+++dQpkw1P0ZnjDH5h6eT4xygG/AJjud9aQ5S1RE5jNOr8kpyTBIbu52+faP5+eeU6eaCggrTrdsY2rQZaL1IY/I5Xy34mxPZFdpnJ7vPdnmF/Jce7ypXk2MBF8/XFeihqt+4HoK5XI0b12D9+nGMGvUxr702nwsXLpKQcIaPP36CjRs/pVev2YSF1fB3mMYYL8ltC/6mll2hfXay+2w5LeT31nfjalfkNGlrDY2XFCoURETEvaxZM5Y6dSomt+/Y8T0jRzbg228ncvHiBT9GaIwx+Z+ryXEsMFhE7L6ej9xwQ3XWrx/HsGF3ERjo+NoTEs7wySdPMW5caw4e/NXPERpjTP6VabITkS+TXkBb4P+Anc7Za75Mt914QcGCQbzyyj2sXfsa9epVTm7/7bc1jBzZgG++GW+9SGOM8YKseoJ/pXstAJYDf2awzXhRw4bV+OGH13jxxf+jQIFAABISzvLpp0/z+us38eefv/g5QmOMyV8yHZCjqg/6MhCTtYIFgxg+/G5uv70pfftOZsuWXQD8/vsPjBp1PV26RNK27VMEBAT6N1B
"text/plain": [
"<Figure size 504x216 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code plots the decision boundary of a Perceptron on the iris dataset\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.colors import ListedColormap\n",
"\n",
"a = -per_clf.coef_[0, 0] / per_clf.coef_[0, 1]\n",
"b = -per_clf.intercept_ / per_clf.coef_[0, 1]\n",
"axes = [0, 5, 0, 2]\n",
"x0, x1 = np.meshgrid(\n",
" np.linspace(axes[0], axes[1], 500).reshape(-1, 1),\n",
" np.linspace(axes[2], axes[3], 200).reshape(-1, 1),\n",
")\n",
"X_new = np.c_[x0.ravel(), x1.ravel()]\n",
"y_predict = per_clf.predict(X_new)\n",
"zz = y_predict.reshape(x0.shape)\n",
"custom_cmap = ListedColormap(['#9898ff', '#fafab0'])\n",
"\n",
"plt.figure(figsize=(7, 3))\n",
"plt.plot(X[y == 0, 0], X[y == 0, 1], \"bs\", label=\"Not Iris setosa\")\n",
"plt.plot(X[y == 1, 0], X[y == 1, 1], \"yo\", label=\"Iris setosa\")\n",
"plt.plot([axes[0], axes[1]], [a * axes[0] + b, a * axes[1] + b], \"k-\",\n",
" linewidth=3)\n",
"plt.contourf(x0, x1, zz, cmap=custom_cmap)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"lower right\")\n",
"plt.axis(axes)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Activation functions**"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwgAAADPCAYAAABRAPaGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABpV0lEQVR4nO3dd3gU1frA8e/ZTe89JCT0GqQXEURCFyxY0CuKilLEckWxIPK7lmvvqNgQFRWuvWNBRCKIIE16ryEkEBIS0uue3x+zWVIhCZvdDbyf55lndnfOnHl3N9mZd+bMOUprjRBCCCGEEEIAmJwdgBBCCCGEEMJ1SIIghBBCCCGEsJEEQQghhBBCCGEjCYIQQgghhBDCRhIEIYQQQgghhI0kCEIIIYQQQggbSRCEQymlWiiltFKqlwO2laCUmu2A7TRRSv2qlMpVSjm932Cl1AGl1P3OjkMIIRobpdR4pVSOg7allVJjHLEtIepKEgRxSkqp7kqpUqXUinqsW90B+iEgCthgj/is26npB/0qYIa9tnMK9wPRQDeM9+YQSqnHlFJbqlnUG3jTUXEIIYSjKKXmWQ+stVKqWCmVqpRaqpS6UynlbodNfAa0skM9NtaYF1azKAr4wZ7bEsJeJEEQpzMJ42DzPKVUxzOtTGtdqrU+orUuOfPQTrut41rr7IbeDtAGWKe13q21PuKA7Z2S1vqY1jrP2XEIIUQD+Q3j4LoFMBzjIPtxYLlSyre+lSql3LXW+VrrVLtEeRrWfWGhI7YlRF1JgiBqpJTyBq4H3gW+BCZUU6avUup3a/OaE0qpJUqpaKXUPGAgcGe5sz0tyjcxUkqZlFJJSql/V6qznbVMd+vzaUqpTdZtHFZKzVVKBVmXxQMfAL7ltvOYdVmFKxhKqWCl1IdKqQylVL5S6jelVKdyy8crpXKUUkOUUlus21uqlGp5is/oADAauMm67XnW16tcOq7c9MdaZrJS6gvrtvYppcZVWidaKbVAKZWulMpTSm1QSg1SSo0HHgU6lXvf42vYTjOl1DdKqWzr9LVSKqbc8ses7/c6pdRea5lvlVJh5cp0tn63WdblG5VSg2r6XIQQogEVWg+uD2utN2itXwbigR7AgwBKKQ+l1HPWfUyuUmqNUmpEWQVKqXjr7+YopdRqpVQRMKL8Fely+6LO5Tdu/d1OU0q5K6XMSqn3lFL7rfuV3UqpB5VSJmvZx4CbgUvK/VbHW5fZ9hNKqZVKqZcqbSfAWueVtXxP7kqp15RSyUqpQqXUIaXUs/b84MW5QxIEcSpjgINa603AxxgHwbZLuEqprsBSYA/QH+gLfA64AVOBlRgH71HW6VD5yrXWFuAT4IZK270B2Ka1/sf63ALcA3TCSFj6AK9bl/1lXZZXbjsv1vB+5gHnYxzQ97Gu84syEqEynhjNkm4FLgCCgLdrqA+M5jy/Wd93lPV918UjwHdAV4xL2+8rpZoDKONM2B8YZ8muBDoD/7Wu9xnwErCTk+/7s8qVK6UU8C0QCQwGBmE0h/rWuqxMC+Bf1u0MB7oDT5Vb/j8gBeNz6w48BhTU8b0KIUSD0FpvAX4Brra+9AHGSarrMX47PwR+sO63ynsO+D+gA/B3pTp3AWupfh/1mda6GOM46jBwLdARmAk8DNxiLfsixv6h7KpHFMZ+q7L5wHVliYXV1UA+8GMt39PdGL/h1wFtMX7Td1azLSFOT2stk0zVThgHp/dbHyvgAHB1ueULgFWnWD8BmF3ptRaABnpZn3exPm9TrsxuYMYp6r0YKARM1ufjgZxTbR/jx1IDF5VbHgicACaWq0cD7cuVuQEoKttWDfEsBOZVek0DYyq9dqDs8yxX5plyz90wkpZx1ueTgGwgrIbtPgZsqeZ123aAYUAp0KLc8lYYSdfQcvUUAIHlyswE9pR7ngXc7Oy/SZlkkuncnjBO9CysYdmz1t/Q1tbfuGaVln8LvGl9HG/9Db66UpkK+xOMkz4HAWV9Hmut+4JTxPgs8NvpYi6/nwBCrfuaIeWW/wa8Y31cm/f0GrCkLFaZZDqTSa4giGoppdpgXBX4H4DWWmMkBBPLFeuO8WNUb9q4OrEZ44wISqnzMX4I/1culsFKqcXWy6rZwNeAB9CkDpvqiPHjurLctk9Ytx1Xrlyh1rr8GZdkwB3jSkJD2FQunhLgGBBhfak7sElrnXYG9XcEkrXWB8ptZx/G+yr/vg9aP48yyeXiAHgZmKuM5mQzlVIdziAmIYRoCArjoLuH9fE2a7PRHGuzoUsw9i/lrT1NnZ9gXHUdYH1+PbBPa23blyilpiil1iqljlm3cy/QrC6Ba63TgUVYr1YopaIwrvjOtxapzXuah9FZxi6l1BtKqUsqXZEQotbkD0fUZCJgBhKVUiVKqRLgIWC4UirWWkbVuHbdLODkJdwbgOVa64MA1uY2PwLbgWuAnhjNf8BIEmrrVLGW75q08s3TZcvq+r+iq9lmdT1sFFezXtm27PH5lu0wq1P+9VPFgdb6MYyE4lugH7BJKXUrQgjhOuKAfRi/XRqjCWi3clNHTu4/yuSeqkJt3LD8GxX3UQvKliul/gXMwjg4H2HdzpvUbf9UZj5wtVLKCxiL0Sz3T+uy074nrfV6jKv0D1vLfwgsliRB1If80YgqlFJuGDdVzaDiD1FXjDPeZW0r12O0a69JEUaScToLgDZKqb4YbSbnl1vWC+OH9l6t9UpttAmNrsd2tmH8vV9Q9oJSKgCjHee2WsRYV8co1+WpUiqSuneBuh7oUv5m4Upq+76bKqValIulFcZnWKf3rY1eml7TWl8CvEfFq0lCCOE0SqnzMJqffgn8g3FypInWek+l6XA9qp8PXKOU6omxzyi/j7oQ+FtrPVtrvV5rvYeqVylquy/8zjq/FGsiYr16T23fk9Y6W2v9hdb6doyrC4MxetoTok4kQRDVuQQIA97VWm8pPwGfArdaz0i8AHRXSs1RSnVVSrVXSk1USpVdWj0A9FFGz0VhNZ3F0FonAcswbgYOBL4ot3g3xt/pPUqplkqpsRg3JZd3APBSSg2zbsenmm3sxvjxfUcpNcDaK8V8jLb1/6tc3g5+x+jBqZcyemOaR91v6v0fkIpxQ/EA6/u/vFzvQQeA5kqpHtb37VlNHb8BG4EFSqmeyhigbgFG8vF7bYJQSnlbL1fHW7/L8zF2ig2RWAkhxOl4KmOAymjrvmcaxj1n64AXrSeSFgDzlFJjlFKtrL/F9yulrqrH9r7BuAL8HrDauj8pswvooZQaqZRqq5T6D8aNxOUdwOgqvL31t7ra8Rq01gUYTWj/D6NJ0fxyy077npTR499YpVRHazPh6zH2cUn1eM/iHCcJgqjOBGCptU1kZV8AzTFucN0ADMXo/WEVRg8Q13GyucqLGGdOtmGcUT9Vm8yPMa5Q/Ki1zix70XqPwlRgmrWeiRgDk1GuzF8YycUn1u08WMM2bgFWA99b5z7AxVrr/FPEVV/3YVzqTsA4ozUX42C/1rTWuRg7msMY/Xxvxejru+yM0lfATxj3gRzDuCRduQ4NXGFdnoDR69QR4IpyZ6ZOpxQIxrhcvRNjZ7kS4zsRQghHG4rRq1oixu/f5Ri/jRdZfzfB+L3/AHge2IHRmcRFGDcc14k2xpX5BmMfNb/S4ncwein6H7AGo4nPS5XKvIvRTHYtxm9x/1NsrmxfuF5rvb3SstO9p2zgAYz923qMK/8jtYyLI+pB1f4YQQghhBBCCHG2kysIQgghhBBCCBtJEIQQQtSLUup9pVSqUmpLDctvUMYo6JuUUn9VM0iVEEIIFyQJghBCiPqah9FzTE32AwO11l2AJ4A5jghKCCHEmXFzdgBCCCEaJ631svJd6Faz/K9yT1cBMQ0elBBCiDMmVxCEEEI4wgTgZ2cHIYQQ4vSccgUhLCxMt2jRwi515ebm4uvra5e67M1VY3PVuMB1Y3PVuMB1Y6tPXCUnSsjfmw8aPCI98IypbmgH58TmCPaOa926dWla63C7VVhP1rE7JmCMn1FTmcnAZABvb++esbGxNRWtE4vFgsnkmufCJLa6c9W4wHVjc9W4wHVjc9W
"text/plain": [
"<Figure size 792x223.2 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 108\n",
"\n",
"from scipy.special import expit as sigmoid\n",
"\n",
"def relu(z):\n",
" return np.maximum(0, z)\n",
"\n",
"def derivative(f, z, eps=0.000001):\n",
" return (f(z + eps) - f(z - eps))/(2 * eps)\n",
"\n",
"max_z = 4.5\n",
"z = np.linspace(-max_z, max_z, 200)\n",
"\n",
"plt.figure(figsize=(11, 3.1))\n",
"\n",
"plt.subplot(121)\n",
"plt.plot([-max_z, 0], [0, 0], \"r-\", linewidth=2, label=\"Heaviside\")\n",
"plt.plot(z, relu(z), \"m-.\", linewidth=2, label=\"ReLU\")\n",
"plt.plot([0, 0], [0, 1], \"r-\", linewidth=0.5)\n",
"plt.plot([0, max_z], [1, 1], \"r-\", linewidth=2)\n",
"plt.plot(z, sigmoid(z), \"g--\", linewidth=2, label=\"Sigmoid\")\n",
"plt.plot(z, np.tanh(z), \"b-\", linewidth=1, label=\"Tanh\")\n",
"plt.grid(True)\n",
"plt.title(\"Activation functions\")\n",
"plt.axis([-max_z, max_z, -1.65, 2.4])\n",
"plt.gca().set_yticks([-1, 0, 1, 2])\n",
"plt.legend(loc=\"lower right\", fontsize=13)\n",
"\n",
"plt.subplot(122)\n",
"plt.plot(z, derivative(np.sign, z), \"r-\", linewidth=2, label=\"Heaviside\")\n",
"plt.plot(0, 0, \"ro\", markersize=5)\n",
"plt.plot(0, 0, \"rx\", markersize=10)\n",
"plt.plot(z, derivative(sigmoid, z), \"g--\", linewidth=2, label=\"Sigmoid\")\n",
"plt.plot(z, derivative(np.tanh, z), \"b-\", linewidth=1, label=\"Tanh\")\n",
"plt.plot([-max_z, 0], [0, 0], \"m-.\", linewidth=2)\n",
"plt.plot([0, max_z], [1, 1], \"m-.\", linewidth=2)\n",
"plt.plot([0, 0], [0, 1], \"m-.\", linewidth=1.2)\n",
"plt.plot(0, 1, \"mo\", markersize=5)\n",
"plt.plot(0, 1, \"mx\", markersize=10)\n",
"plt.grid(True)\n",
"plt.title(\"Derivatives\")\n",
"plt.axis([-max_z, max_z, -0.2, 1.2])\n",
"\n",
"save_fig(\"activation_functions_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Regression MLPs"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_california_housing\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.neural_network import MLPRegressor\n",
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"housing = fetch_california_housing()\n",
"X_train_full, X_test, y_train_full, y_test = train_test_split(\n",
" housing.data, housing.target, random_state=42)\n",
"X_train, X_valid, y_train, y_valid = train_test_split(\n",
" X_train_full, y_train_full, random_state=42)\n",
"\n",
"mlp_reg = MLPRegressor(hidden_layer_sizes=[50, 50, 50], random_state=42)\n",
"pipeline = make_pipeline(StandardScaler(), mlp_reg)\n",
"pipeline.fit(X_train, y_train)\n",
"y_pred = pipeline.predict(X_valid)\n",
"rmse = mean_squared_error(y_valid, y_pred, squared=False)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.5053326657968465"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rmse"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Classification MLPs"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code this was left as an exercise for the reader\n",
"\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.neural_network import MLPClassifier\n",
"\n",
"iris = load_iris()\n",
"X_train_full, X_test, y_train_full, y_test = train_test_split(\n",
" iris.data, iris.target, test_size=0.1, random_state=42)\n",
"X_train, X_valid, y_train, y_valid = train_test_split(\n",
" X_train_full, y_train_full, test_size=0.1, random_state=42)\n",
"\n",
"mlp_clf = MLPClassifier(hidden_layer_sizes=[5], max_iter=10_000,\n",
" random_state=42)\n",
"pipeline = make_pipeline(StandardScaler(), mlp_clf)\n",
"pipeline.fit(X_train, y_train)\n",
"accuracy = pipeline.score(X_valid, y_valid)\n",
"accuracy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Implementing MLPs with Keras\n",
"## Building an Image Classifier Using the Sequential API\n",
"### Using Keras to load the dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's start by loading the fashion MNIST dataset. Keras has a number of functions to load popular datasets in `tf.keras.datasets`. The dataset is already split for you between a training set (60,000 images) and a test set (10,000 images), but it can be useful to split the training set further to have a validation set. We'll use 55,000 images for training, and 5,000 for validation."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"fashion_mnist = tf.keras.datasets.fashion_mnist.load_data()\n",
"(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist\n",
"X_train, y_train = X_train_full[:-5000], y_train_full[:-5000]\n",
"X_valid, y_valid = X_train_full[-5000:], y_train_full[-5000:]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The training set contains 60,000 grayscale images, each 28x28 pixels:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(55000, 28, 28)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each pixel intensity is represented as a byte (0 to 255):"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"dtype('uint8')"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.dtype"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's scale the pixel intensities down to the 0-1 range and convert them to floats, by dividing by 255:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_valid, X_test = X_train / 255., X_valid / 255., X_test / 255."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can plot an image using Matplotlib's `imshow()` function, with a `'binary'`\n",
" color map:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAKRElEQVR4nO3dy2/N3R/F8d3HpbSSaqXu1bgNOqiIqNAhISoxMDc1MiZh4C8wNxFMS4iRSCUGNKQuIQYI4hZxJ6rUtTyDX36/Ub9rPTkn/VnN834Nu7JPz6UrJ+kne++G379/FwB5/vrTTwDA+CgnEIpyAqEoJxCKcgKhppqcf+UCE69hvB/yzQmEopxAKMoJhKKcQCjKCYSinEAoygmEopxAKMoJhKKcQCjKCYSinEAoygmEopxAKMoJhKKcQCjKCYSinEAoygmEopxAKMoJhKKcQCh3NCb+z9zFUg0N456i+I+NjIzIfHBwsDLr6+ur63e71zY2NlaZTZ36Z/9U67nwq9bPjG9OIBTlBEJRTiAU5QRCUU4gFOUEQlFOIBRzzjC/fv2S+ZQpU2T+4MEDmR8+fFjmM2fOrMyam5vl2hkzZsh83bp1Mq9nlunmkO59devreW5qfltK9WfKNycQinICoSgnEIpyAqEoJxCKcgKhKCcQijlnmFpnYv91/vx5mZ87d07mHR0dldm3b9/k2tHRUZkPDAzIfNeuXZXZvHnz5Fq3Z9K9b86nT58qs7/+0t9xTU1NNf1OvjmBUJQTCEU5gVCUEwhFOYFQlBMIRTmBUMw5w0yfPr2u9VevXpX548ePZa72Pbo9kVu2bJH5jRs3ZL53797KbO3atXJtd3e3zLu6umR+5coVmav3tbe3V67dsGGDzFtaWsb9Od+cQCjKCYSinEAoygmEopxAKMoJhGowRwLWfu8ZKqn33G19clu+1DiilFI+fPgg82nTplVmbmuU09PTI/MVK1ZUZm7E5I62fPnypczd0ZfqWM8TJ07Itbt375b5xo0bx/3Q+eYEQlFOIBTlBEJRTiAU5QRCUU4gFOUEQjHnrIGbqdXDzTnXr18vc7clzFGvzR0v2djYWNfvVlcIuvdlzZo1Ml+5cqXM3Ws7e/ZsZfbw4UO59vnz5zIvpTDnBCYTygmEopxAKMoJhKKcQCjKCYSinEAojsasgZu5TaTW1laZv3jxQuYzZ86Uubrm78ePH3KtuiavFD3HLKWUL1++VGbuPR8cHJT5pUuXZO5m169evarMtm7dKtfWim9OIBTlBEJRTiAU5QRCUU4gFOUEQlFOIBRzzklmdHRU5mNjYzJ31/ipOej8+fPl2jlz5sjc7TVV5+K6OaR73WqG6n53KXq/57Nnz+TaWvHNCYSinEAoygmEopxAKMoJhKKcQCjKCYRizlkDN3Nzs0Q1M3N7It0ZqO7sWHfP5ffv32t+7ObmZpkPDw/LXM1J3XxXPe9SSpk1a5bMP378KPPu7u7K7PPnz3LttWvXZL527dpxf843JxCKcgKhKCcQinICoSgnEIpyAqEYpdTAHdPoti+pUUp/f79c646+bG9vl7nbOqWemxsZPH36VObTpk2TuTqWc+pU/afqju10r/vt27cy3717d2V28+ZNufbnz58yr8I3JxCKcgKhKCcQinICoSgnEIpyAqEoJxCqwWx/0nuj/qXc3MrN5JShoSGZb9u2Tebuir96ZrD1XvHX1tYmc/W+ujmmm8G6qxMd9dr27Nkj1+7cudM9/LiDc745gVCUEwhFOYFQlBMIRTmBUJQTCEU5gVATup9TzVDrvarOHU+p9g66696ceuaYTl9fn8zdEY9uzumOkFTcXlE3//369avM3bGdivtM3Gfu/h5v3bpVmbW0tMi1teKbEwhFOYFQlBMIRTmBUJQTCEU5gVCUEwhV18Cunr2BEzkrnGgXLlyQ+cmTJ2U+ODhYmTU1Ncm16pq8UvTZr6X4M3fV5+Kem/t7cM9NzUHd83bXDzpu/qse/9SpU3Lt9u3ba3pOfHMCoSgnEIpyAqEoJxCKcgKhKCcQinICoWLPrX3//r3Mnz9/LvN79+7VvNbNrdRjl1JKY2OjzNVeVben0d0zuXDhQpm7eZ46H9bdYele9+joqMx7e3srs5GREbn24sWLMnf7Od2eTPW+zZ8/X669c+eOzAvn1gKTC+UEQlFOIBTlBEJRTiAU5QRC1TVKuXz5snzwAwcOVGZv3ryRaz98+CBz969xNa6YPXu2XKu2upXiRwJupKDec3e0ZVdXl8z7+/tl3tPTI/OPHz9WZu4zefz4scydpUuXVmbu+kF3ZKjbUuY+U3XF4PDwsFzrxl+FUQowuVBOIBTlBEJRTiAU5QRCUU4gFOUEQsk559jYmJxzbtiwQT642ppV75Vt9RyF6K6qc7PGeqm52Lt37+TaY8eOyXxgYEDmhw4dkvmCBQsqsxkzZsi1ak5ZSinLly+X+f379ysz976oKx9L8Z+5mu+WorfSubn4kydPZF6YcwKTC+UEQlFOIBTlBEJRTiAU5QRCUU4glJxzHjlyRM459+3bJx982bJllZnaH1eKPwrRXSenuJmX25+3ePFimS9atEjmai+r2odaSikvX76U+enTp2WurtkrpZRHjx5VZu4zu379el25ukKwnuNGS/FHgjqqJ+6xh4aGZN7R0cGcE5hMKCcQinICoSgnEIpyAqEoJxCKcgKh5KbKuXPnysVu3qdmlW5utWTJkpofuxS9/87t3Wtra5N5Z2enzN1zU/si3Z5Jt3dwx44dMu/u7pa5OnvW7al0n6k7L1jtyXSv212d6GaRbv+wmnOas5/tlZEdHR3jPye5CsAfQzmBUJQTCEU5gVCUEwhFOYFQcpTiRiXu389V/yIuxW8/clcEun/Lt7e315SV4reUue1qbr3atuWuulPbqkopZc6cOTK/ffu2zNVVem681draKnO3XU19Lu4oVXc0plvvrulTW/VaWlrk2ps3b8p806ZN4/6cb04gFOUEQlFOIBTlBEJRTiAU5QRCUU4glBz+rF69Wi5225OOHj1amS1cuFCuddfFua1Val7otg+5mZfajlaKn3Oq5+7WNjSMe4ri/zQ1NclcXfFXip5du21b7rm72XQ9WwzdY7vcbTlTc1R1nGgppcybN0/mVfjmBEJRTiAU5QRCUU4gFOUEQlFOIBTlBELJKwBLKfrMP+PMmTOV2cGDB+Xa169fy9ztyVRzLbcP1V0n5/Zzuj2Xah7ojll0c043a3QzXpW7x3bP3VHr3TGtjptNu78JtZ9z1apVcu3x48dlXkrhCkBgMqGcQCjKCYSinEAoygmEopxAKMoJhJJzzl+/fsnBlZsN1eP8+fMy379/v8xfvXpVmQ0PD8u1bl7n5phupqbOUHW/28373By0nrOI1Zm2pfj3pR5uv6Xbx+pm15s3b5Z5V1dXZdbb2yvX/gPMOYHJhHICoSgnEIpyAqEoJxCKcgKhKCcQakL3c6a6e/euzN3doO4eymfPnsm8s7OzMnPzPHeeLyYl5pzAZEI5gVCUEwhFOYFQlBMIRTmBUP/KUQoQhlEKMJlQTiAU5QRCUU4gFOUEQlFOIBTlBEJRTiAU5QRCUU4gFOUEQlFOIBTlBEJRTiAU5QRCVd9F9x/6PjkAE4ZvTiAU5QRCUU4gFOUEQlFOIBTlBEL9DRgW8qPu1lMTAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code\n",
"\n",
"plt.imshow(X_train[0], cmap=\"binary\")\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The labels are the class IDs (represented as uint8), from 0 to 9:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([9, 0, 0, ..., 9, 0, 2], dtype=uint8)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here are the corresponding class names:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"class_names = [\"T-shirt/top\", \"Trouser\", \"Pullover\", \"Dress\", \"Coat\",\n",
" \"Sandal\", \"Shirt\", \"Sneaker\", \"Bag\", \"Ankle boot\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So the first image in the training set is an ankle boot:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"'Ankle boot'"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_names[y_train[0]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's take a look at a sample of the images in the dataset:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0AAAAFJCAYAAACy802jAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAADmp0lEQVR4nOydd5gdVfnHP2c3m00nhZDQk1BD702Q0HtHRaSJCD8VpSgoCIiKioIFLKCgIiCiSEd6B0NHeighBEhCGiFs2m62zO+Pme+Z2XPv3b7ZW97P8+xz996ZO3fmnXdOedtxURRhGIZhGIZhGIZRCVT19QkYhmEYhmEYhmGsKGwCZBiGYRiGYRhGxWATIMMwDMMwDMMwKgabABmGYRiGYRiGUTHYBMgwDMMwDMMwjIrBJkCGYRiGYRiGYVQMK2wC5Jyb5JyLnHMrt7FP5Jw7spu/0+1jlBrOuXHJdW/TnX2MGJOnUcyYfvY+zrkTnHOLC703+gbn3KPOud/19Xn0NaafxYlz7kjnXKfWljGdLkxvy7PDEyDn3JbOuWbn3H87czLlSHcHF8l32/q7podPGeBDYFXgpXbO7ULn3GttbH/TOXdIRya0KwqTZ/HhnLsmI/9G59xc59wjzrlvOOdq+vr8ViSmnyuePPo3zTl3qXNucF+fWyninBvtnPuDc266c67BOTfHOfeQc26vvj63UsT0s3cxfe1ZylWe/Tqx71eBPwDHOecmRlE0pZfOqRJYNfP/gcBVwWfLevoHoyhqBma3tU97A1Pn3AbAWsADwHY9d3bdxuRZnDwIHAtUA6OB3YEfAsc65/aIomhJ+AXnXP8oipav2NPsdUw/+wbpXw2wC3A1MBj4Wl+eVHdwztVEUdTYBz99MzAI+AowFVgF2BUY1Qfn0mM45/oBzVHfrAhv+tl7lKW+9iHlKc8oitr9AwYCC4HNgD8DlwbbxwERcARxZ7kUeAPYK7PPpGSflZP3tcCtwIvAKslnEXBk5jurAzcCnyR//wHWa+dcI+DUZN+lwPvAMcE+mxI3PsuABcA1wEqZ7VXA+cRW1AbgVeCQ4Deyf492RI4FzvfI+Da0u9+awO3J+S4F3gSO6oT8tc82wf3YH3gWWJ7ILby2EzLHODs5h3F59rsmc19/A8wB6oGngZ3z6MGBxNbpeuAFYOuuytDk2Xvy7MZ9uAa4K8/nmySy+WHyfjpwIfAX4jbmpuTznYDHEtnPBK4AhmWO89lEFouBT4FngE2SbSsB1wFzE3lMA07vS3mYfq5Y/cynf8QTz48SfXst2HYCsLij75PPTiEeDCxPXr+a2fYP4OZg/yriPuWM5L1LZPwucV/0Kpm+KiP3LwIPJ/uc2gc6Ozw5jz3b2Gc6cB7wR6AOmAGcFeyzEvAn4udyEfHzvU1m+6hEbjOSa30d+HJwjEeB32Xe70HcbpySvG9zzKB7n9zPd4FmYEgfyNT0s2/19RjguUQP5wI3Aatntk9KjrEHcd+yFHge2Co4znHEY8ylwF3AN8i078A6xG3obGAJ8Xj3wLZ0utj+ylmeHRXAscDLmQuZC9TkeRDeBA4C1gP+BnxM0rhkBLAyMAx4hLgBzA5qIpIJEPFs823ihmIzYENiC8n7wKA2zjVKfvcUYH3g+0AL6UBgEPGA6jbiidCuye/cnDnGGcSN+NHJMX5E3FBukWzfNvmdfYCxwMhuKFdHB0R3Eg92NgfGA/sC+3ZC/tonHBC9CuwNTCAedF2aHGds8jcwcw6TgROJLfqHJ9/fKNlvpWSfy4gb8QOAicSN+mJg1eB330zktwnxwzK7rftq8uwbeXbjPlxDnglQsu0Okg6eeOBUR9zRrpvIetPkGr+dvN8eeAr4d/KdfsSDm0uJG8QNiZ/Vicn23xIPtrdL7tMk4HN9JQvTzxWvn/n0D7gcmE8PDDCBw4BG4knk+sA3k/cHJdsPIJ7sDc98ZzegCRibvP8J8FZyb8cnOrwEOCC4h9MTvRkPrNEHOtuPeGBzOTCgwD7TE/07lfg5/mZy7jsm2x3wJPGEZLtknx8TP/vSldWBs4AtEv08mXjwvkfmdx4lGdwQGwDqgM8n79sdMyT3fglwP7BVoo/9+kCmpp99q68nEhuDJiT6+AjweGb7pOTank3ksiFwHzAFcMk+2xOPLb+fyPgU4mcgyhxnc+D/iPu0dZN9lwMb5tPpYvwrZ3l2VACPAd9J/neJwh+R2a4H4ZTMZ6snn+0cCGAisQXwjlCYtJ4AnQi8I+Ekn1UnAvl8G+caAVcFnz0IXJ/8/1Vii/HQPDdn3eT9TOCC4BiPZo6h692m0Hl0Qrk6OiB6BfhBgW0dkX+rc85c8xHBsS4kaHyTz8cQN6Cjg++vnNlncKKMxwX37F3gouB7X8rsM4TYineSybO45NmN+3ANhSdAFwNLk/+nA3cG268F/hx8tkVynasAI5P/dy1w/DuAv/bVtZt+9r1+hvpH3CnPB/6Z75rp/ADzv8Bf8vzmk8n//YgNhV/JbL8auC8ju2XALsExfgPcHdzDbxeB3h5B7I2sJzZGXApsn9k+HfhH8J13gPOS/3cnnhgPDPZ5CTi7jd+9Ebg68/5R4HfEk6NPgb0z29odMyT3vhEY08fyNP3sQ33Ns/+GybWskbyflLzfJ7PPZ4J9bgAeCI5zNe2078Re8vNCne5rmVWiPNstguCcWzc50RvQmcDfgZPy7P5K5v9ZyesqwT73E7u4D4+iqL6Nn96a2KKwyDm3OKlw8ikwgtjq2xZP5Xm/UfL/ROCVKIoWZbZPJp55buScGwasRtyAZHkyc4xeRdeb/F2ZfHwZcJ5z7inn3EXOua3zfLUj8g95voOndRDwdBRF89rYZx3ieGYvuyjORcjKXzyV2WcxsWW6V+Rr8iw6HHHDJ0KZbQ0ck71vpDJYJ4oiha3e55z7j3PuTOfcmpnvXwF83jn3cpJYvGsvXUePYPrZa+ybyFSd9uPElvCeYCJt9BFRFDURD2a/BOCcqyUeRFyf7LsRMAC4N9Dzr5Hbv3X0nvUaURTdTNwvHgTcQxyi+rRz7tzMbq8EX5tFqo9bE3to5gXXuwnJ9Trnqp1z33fOveKc+zjZfjhxHlqWQ4DfE3tE78983tExw4woiuZ0QQw9jelnL9GevjrntnLO3e6ce985t4j0GkJda6uNnUj+sabHOTfYOfcL59wbzrlPEhluk+d3ippylWdHiiCcRGxF+cA5588jOZk1oyj6MLOvT36LoihK9g8nWXcBnyd2Yf2vjd+tIrYOHZVn24IOnHchwsFXlqjA/2191htskfm/DiCKoj875+4jdjPuCUx2zv0siqILM/t2RP4hOYnoBTiUOPayLaQgfSm7fGyR+d/k2fdsRJyXI0KZVRFbfn6d57szAaIo+rJz7jfE4RkHAz9xzh0aRdF9URTd45xbG9iPOOb4P865m6Io+nIPX0dPsUXmf9PPnuNxYk9BIzArSpKznXMtpOcuulKZsL3rvp74vq1OHN7RnzjvFdL7dBDwQXCMMIm8o/esV0kMlg8kfz9yzl0NXOicuzTZJTzviPQ6q4jzxHbJc+i65PU7xGGvpxFPkBcDPyV3Uv9KcuyvOOeeToyy+o2XaH/MUBTyxPSzV2lDX39PHH6lIhRziVMzniCWQZbstWb1DHLvUT4uJe6jvkPsnVxKHOEQ/k7RU47ybLOzTCqkHA+cQ9xJ629z4kaoKwOK84ErgQedc1u0sd+LxDF+86Momhr8tTcB2iHPe1WtewPY3Dk3NLN9J2JZTImiqI54ZrpzcIydk+9CHPYB8cSwxwmudW7m8xlRFP0piqLPAxcQN549zXKC63Jxac49iPOmsvsR7KuESy8751w1sCOp7MQOmX0GE1sCe6WyoMmzeHDObULcgP27jd1eBDbO89xPjaLIV0yLoujlKIp+HkXRJGK39/GZbfOjKLouiqITiCvXHJ9YOYsO089eY2ki0/ej1pWp5gFjXMaiR+tJaEeYQtt9BFEUPUMcDvhFYkv7bYn3i2S/BmDtPDr+fifPpa94g9iIOqAD+75IHFbZkud6pfM7E4fDXhdF0UvEsls/z7H
"text/plain": [
"<Figure size 864x345.6 with 40 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 1010\n",
"\n",
"n_rows = 4\n",
"n_cols = 10\n",
"plt.figure(figsize=(n_cols * 1.2, n_rows * 1.2))\n",
"for row in range(n_rows):\n",
" for col in range(n_cols):\n",
" index = n_cols * row + col\n",
" plt.subplot(n_rows, n_cols, index + 1)\n",
" plt.imshow(X_train[index], cmap=\"binary\", interpolation=\"nearest\")\n",
" plt.axis('off')\n",
" plt.title(class_names[y_train[index]])\n",
"plt.subplots_adjust(wspace=0.2, hspace=0.5)\n",
"\n",
"save_fig(\"fashion_mnist_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating the model using the Sequential API"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42)\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.InputLayer(input_shape=[28, 28]))\n",
"model.add(tf.keras.layers.Flatten())\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.Dense(300, activation=\"relu\"))\n",
"model.add(tf.keras.layers.Dense(100, activation=\"relu\"))\n",
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"# extra code clear the session to reset the name counters\n",
2021-10-17 04:04:08 +02:00
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)\n",
"\n",
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(300, activation=\"relu\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" flatten (Flatten) (None, 784) 0 \n",
" \n",
" dense (Dense) (None, 300) 235500 \n",
" \n",
" dense_1 (Dense) (None, 100) 30100 \n",
" \n",
" dense_2 (Dense) (None, 10) 1010 \n",
" \n",
"=================================================================\n",
"Total params: 266,610\n",
"Trainable params: 266,610\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcEAAAIECAIAAADuFAIlAAAABmJLR0QA/wD/AP+gvaeTAAAgAElEQVR4nOy9fzyUWf/4f4ZBIaO12rSkDeu3Vt2FrLU/UFKRtEqljdBm27vctVuPfny6Pah73/Xe2LfaVKKNLCGlX6vWr61Ii5JfdZeEIqFhRowfc33/ON+97uueMWNcjJnR6/nXnNc51+t6nTPm5brOeZ3XYRAEgQAAAABaKMnaAAAAAAUGfCgAAAB9wIcCAADQB3woAAAAfZjUQmFh4Y8//igrUwAAAOQfBweHsLAwsvhfz6ENDQ1paWljbtJbQWNjI4wtbdLS0hobG2VthbSAvw0FoqioqLCwkCphUGObUlNTfX19IdpJGsDYjgQGg5GSkvLll1/K2hCpAH8bCsSKFSsQQufOnSMlMB8KAABAH/ChAAAA9AEfCgAAQB/woQAAAPQBHwoAAEAf5tBNAEABqa2tjYiICA8P19fXl7Uto0NdXR0ZVfPhhx/OmTOHrOrv7y8uLuZyuW1tbQghMzMzW1tbspbNZl+9epUsLly4cPLkyWNl9f8Ph8M5e/bs06dPjY2N/fz81NXVqbVcLjc1NbWurs7e3t7V1VVFRUVKOktLS3V0dAwNDclmtbW1d+7cwZ9NTU1nz5497L4RFFJSUgQkwGgBYzsSEEIpKSnDugRHn1y5ckVKJo0iEv5tJCYmIoSSk5Obmpo6OztJOZvN3r9/f2dnJ5fL3bt3L0KIxWI9fPiQbMDn80tKSqytrS0sLHJzc/l8vlS6IZqampqpU6eamJioqqoihIyMjJqamqi1xsbGly9fxj5x+vTp+fn5UtLZ19e3ceNGqn4ul1tXV/fHH3+oqKhs3bp1yPv6+Pj4+PhQJeBDxwgY25FAw4cSBPHq1StpGEPl9OnTI1cyLB/KZrOpwsbGxiVLllCF2KeYm5tT/SxBEPipfOTW0sDd3f3+/fsEQbS0tGzYsAEhFBAQQK0NDAwki+vWrXNycpKezv7+fnd39/LycgGFM2bMoOdDYT4UGLe8++67UtWfk5Ozc+dOqd5iSMLCwpYtW8ZisUiJsbGxm5tbdXW1v78/QYnb19HR0dbWHnsLS0pKVq9ebWNjgxDS1dUNDw9XUlK6ffs22aCpqamyspIsqqmp8Xg86elUVlYOCwsLDg4ejc4hBGtKwHiFz+fn5ubevXsXFxsaGqKjo/l8fkVFRWRk5JkzZ/h8Ptm4sbHx6NGjBEHk5eXt3LkzJiamu7sbIZSVlRUVFXXy5EmEEIfDOXLkSFRUFH5szM3N9fLy4nK5sbGxWVlZCKHW1tYDBw68fPlyzPpYXFx8+fJlHx8fqpDJZP76669GRkaZmZkRERGkXElJSUnpv37vHA4nJSVl3759cXFxDQ0NpFz8WCGEXrx4cerUqfDw8N9//31II2fMmOHn50cW9fT05syZQ52Q9fb2Lioqwk/ZXC73/PnzW7ZskapOFxcXDoeTkZExpPESQX0ohfdN6QFjOxLQMN/lKysrsWf5+eefCYK4ePGirq4uQujw4cPr169fvHgxQmj//v24cWJi4uTJkydOnLhx48aAgIBFixYhhObOndvb20sQhKWlpb6+Pm7Z2dmppaXl4OBAEERZWZmjo6Ourm5ubm5ZWRlBECdOnEAI/fTTT8PtHe13+eXLl7u4uAg0s7GxIQjiwYMHmpqaDAYjKysLy2NjY2NiYshm9+7ds7a2Tk9Pb2lpOXTokKamJp6XED9WBEHk5OQEBQWVlpampqZqampu2rRpuP2dOnUqdVahubnZ1NQUIbR161Y3N7eMjIzhKqShMzg42NbWliqh/S4PPnSMgLEdCcP1oQRBlJeXkz6UIIgdO3YghG7cuIGLs2fPnjNnDtl4zZo1DAajoqICF/fs2YMQOnbsGEEQPj4+pA/FF2IfShCEl5eXgYEBWcXlcs+ePSswCykJtH2oiYkJfmGngn0oQRDp6ekMBoNcX6L6UB6PZ2ZmtnfvXvIqPz8/VVXVyspKQuxYcTicmTNncrlcXAwMDEQIFRYWSt7Z/Px8fX19DodDFba0tBgZGSGEHBwcmpubJddGW2d0dDSTyeTxeKQE5kMB4L9QU1OjFidOnIgQMjMzw0ULC4v6+nqyVkNDg8lkWlpa4uKOHTuYTGZBQcGQd2EwGFQlq1atmjRp0siNl4Te3t7a2lo9PT1RDby9vXft2tXR0eHl5cXhcKhV165dq6mpsbe3JyULFizo7e2Ni4tDYscqOTm5u7v7u+++Cw0NDQ0NbWpqMjIyevz4sYQ2DwwM7N279+LFi5qamlR5XFycs7NzQEBAYWGhnZ0d9auRkk4Wi9Xf3y+55WKA+FDgbURZWZkQnSdJXV1dX1//1atXQ+qh+tAxpr29fWBgAPs7UYSHh9+/fz8rK8vf33/hwoWkvKqqCiFEdTpOTk4IoerqamEl1LGqrKzU09M7cuQIPZu3bdsWFhZGjV1FCMXHx6ekpNy9e5fJZDo6OoaEhISGhuIpZunpxH1vbGy0sLCg1xcSeA4FAEF4PF5zc/PMmTOHbClDHzp16lRtbW2BB0wBGAxGYmKimZlZZmZmdHQ0KX/nnXcQQtQ8mIaGhioqKkPG3isrKz98+LCvr4+GwcePH7e1tV26dKmA/PTp0+7u7kwmEyEUEBAQFBSUnZ3NZrOlqvP169cIIQMDAxodEQB8KAAIUlRU1NPTg5dTmExmT0/PoM0YDMbAwMDYmvZfWFpatrS0UCUEQbx584Yq0dLSyszMZLFY1GdMOzs7hBB1sqKioqKvr8/BwUH8HWfNmtXV1XXs2DFSwmazjx49OqSp58+fJwjC39+flOTn5+MP5eXlVO/m6enZ29srSXjDSHQ2NTUxGIwPPvhgyLsMCfhQYHyC4wFbW1txsbOzEyHU29uLi62trXg9gWzf399Pepm0tDRnZ2fsQ93c3FpbW+Pj47u6uuLj49va2mpra/FTjJ6eXnNzc21t7ZMnT7q6ukpKSubNm5eXlzdmfXRycnrw4AFV0tTU9Pz5cwGnb2pqmpSURA1smjVr1rp16woKCshZwps3b5qYmOCoSTFj5evra2BgsG3btoMHD1ZXV6empgYHB69duxa3DA4OXrRokbD7u3Hjxg8//NDX1xcTExMTExMdHR0SEoIX/RBCXl5e58+fJ8OnioqKbGxsTExMpKcTIVRXV+fm5jZhwgQJhnkoqAtMsHYsPWBsRwIa5rp8UVERjm2ysrK6dOlSXl4efjHfsGFDU1NTcnKylpYWQmjfvn19fX0EQYSEhCgrK3/zzTfbt29fuXLlkiVLyOV1DoeD117Mzc0zMjK8vb0XLFhw4sQJgiByc3OZTKa2tjaOZ8Lr4LhqWNBel29vb58yZcrjx49x8dy5c5988glCyNXVNScnR+DyyMhIamxTd3d3aGiopaVlQkLCyZMnPTw86uvrCYIYcqyqqqo+/PBD7D0sLS1LS0tJnXgd/NChQ9T7lpSUaGhoCLidCRMmtLW14QZdXV2BgYFWVlZRUVEbNmxYunRpbW2tVHXyeDwdHZ3r169TdUJsk7wDYzsShutDh0tISIiKigpBEPX19R0dHcINWlpa8Ifu7m6qnM1mU4OZBr12SEay1/PYsWOhoaES3ujly5cCEjabfevWrYaGBgk1kNTV1T179kxA2NPTk5KScuHCheFqIwiiq6urqqqqvb19DHSmpqZ6enoKCCG2CQBGAQMDA/zYJQAOO0cICbz9sVgsajDToNeOLgL7IIOCgtra2srKyiS5dsqUKQISFos1f/58GqmtDA0Np0+fLmxbYWEh3qQwXNTV1c3NzYUXtUZdZ01NTVJSUnJyskBj2lPbENsEAOjNmzf9/f1cLlcgxlCuUFFR0dLS2rBhg4ODw9y5c11cXBBCSkpKCQkJmzdvDgoKmjt3rmwtLC4u3r9/P14Nl0+dz549O3DgwKlTp8iYsIqKimvXrtXX13d2dtKbHqVvWVVV1dWrVx89emRvb6+lpcVkMj09PWlrkwYKl0GyoKDg+fPnZFFbW9vd3V3aN83OzsZJJzE2NjZkqPlbQlJ
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code another way to display the model's architecture\n",
"tf.keras.utils.plot_model(model, \"my_fashion_mnist_model.png\", show_shapes=True)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[<keras.layers.core.flatten.Flatten at 0x7fb220f4d430>,\n",
" <keras.layers.core.dense.Dense at 0x7fb282285af0>,\n",
" <keras.layers.core.dense.Dense at 0x7fb282365b50>,\n",
" <keras.layers.core.dense.Dense at 0x7fb282365fa0>]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.layers"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"'dense'"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hidden1 = model.layers[1]\n",
"hidden1.name"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.get_layer('dense') is hidden1"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.02448617, -0.00877795, -0.02189048, ..., -0.02766046,\n",
" 0.03859074, -0.06889391],\n",
" [ 0.00476504, -0.03105379, -0.0586676 , ..., 0.00602964,\n",
" -0.02763776, -0.04165364],\n",
" [-0.06189284, -0.06901957, 0.07102345, ..., -0.04238207,\n",
" 0.07121518, -0.07331658],\n",
" ...,\n",
" [-0.03048757, 0.02155137, -0.05400612, ..., -0.00113463,\n",
" 0.00228987, 0.05581069],\n",
" [ 0.07061854, -0.06960931, 0.07038955, ..., -0.00384101,\n",
" 0.00034875, 0.02878492],\n",
" [-0.06022581, 0.01577859, -0.02585464, ..., -0.00527829,\n",
" 0.00272203, -0.06793761]], dtype=float32)"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"weights, biases = hidden1.get_weights()\n",
"weights"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(784, 300)"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"weights.shape"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"biases"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(300,)"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"biases.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Compiling the model"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=\"sgd\",\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is equivalent to:"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"# extra code this cell is equivalent to the previous cell\n",
"model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy,\n",
" optimizer=tf.keras.optimizers.SGD(),\n",
" metrics=[tf.keras.metrics.sparse_categorical_accuracy])"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code shows how to convert class ids to one-hot vectors\n",
"tf.keras.utils.to_categorical([0, 5, 1, 0], num_classes=10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: it's important to set `num_classes` when the number of classes is greater than the maximum class id in the sample."
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0, 5, 1, 0])"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code shows how to convert one-hot vectors to class ids\n",
"np.argmax(\n",
" [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],\n",
" axis=1\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training and evaluating the model"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.7220 - sparse_categorical_accuracy: 0.7649 - val_loss: 0.4959 - val_sparse_categorical_accuracy: 0.8332\n",
"Epoch 2/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4825 - sparse_categorical_accuracy: 0.8332 - val_loss: 0.4567 - val_sparse_categorical_accuracy: 0.8384\n",
"Epoch 3/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4369 - sparse_categorical_accuracy: 0.8480 - val_loss: 0.4228 - val_sparse_categorical_accuracy: 0.8542\n",
"Epoch 4/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4122 - sparse_categorical_accuracy: 0.8558 - val_loss: 0.3966 - val_sparse_categorical_accuracy: 0.8624\n",
"Epoch 5/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3910 - sparse_categorical_accuracy: 0.8631 - val_loss: 0.3890 - val_sparse_categorical_accuracy: 0.8632\n",
"Epoch 6/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3751 - sparse_categorical_accuracy: 0.8686 - val_loss: 0.3912 - val_sparse_categorical_accuracy: 0.8600\n",
"Epoch 7/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3628 - sparse_categorical_accuracy: 0.8710 - val_loss: 0.3723 - val_sparse_categorical_accuracy: 0.8698\n",
"Epoch 8/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3514 - sparse_categorical_accuracy: 0.8755 - val_loss: 0.3767 - val_sparse_categorical_accuracy: 0.8612\n",
"Epoch 9/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3406 - sparse_categorical_accuracy: 0.8795 - val_loss: 0.3513 - val_sparse_categorical_accuracy: 0.8726\n",
"Epoch 10/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3306 - sparse_categorical_accuracy: 0.8812 - val_loss: 0.3539 - val_sparse_categorical_accuracy: 0.8738\n",
"Epoch 11/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3223 - sparse_categorical_accuracy: 0.8860 - val_loss: 0.3606 - val_sparse_categorical_accuracy: 0.8712\n",
"Epoch 12/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3146 - sparse_categorical_accuracy: 0.8869 - val_loss: 0.3472 - val_sparse_categorical_accuracy: 0.8742\n",
"Epoch 13/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3071 - sparse_categorical_accuracy: 0.8900 - val_loss: 0.3284 - val_sparse_categorical_accuracy: 0.8800\n",
"Epoch 14/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3001 - sparse_categorical_accuracy: 0.8922 - val_loss: 0.3413 - val_sparse_categorical_accuracy: 0.8780\n",
"Epoch 15/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2938 - sparse_categorical_accuracy: 0.8945 - val_loss: 0.3376 - val_sparse_categorical_accuracy: 0.8822\n",
"Epoch 16/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2867 - sparse_categorical_accuracy: 0.8971 - val_loss: 0.3272 - val_sparse_categorical_accuracy: 0.8796\n",
"Epoch 17/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2822 - sparse_categorical_accuracy: 0.8978 - val_loss: 0.3317 - val_sparse_categorical_accuracy: 0.8796\n",
"Epoch 18/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2757 - sparse_categorical_accuracy: 0.9001 - val_loss: 0.3240 - val_sparse_categorical_accuracy: 0.8824\n",
"Epoch 19/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2711 - sparse_categorical_accuracy: 0.9030 - val_loss: 0.3484 - val_sparse_categorical_accuracy: 0.8720\n",
"Epoch 20/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2662 - sparse_categorical_accuracy: 0.9045 - val_loss: 0.3209 - val_sparse_categorical_accuracy: 0.8800\n",
"Epoch 21/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2613 - sparse_categorical_accuracy: 0.9046 - val_loss: 0.3178 - val_sparse_categorical_accuracy: 0.8862\n",
"Epoch 22/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2563 - sparse_categorical_accuracy: 0.9069 - val_loss: 0.3122 - val_sparse_categorical_accuracy: 0.8848\n",
"Epoch 23/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2520 - sparse_categorical_accuracy: 0.9098 - val_loss: 0.3480 - val_sparse_categorical_accuracy: 0.8716\n",
"Epoch 24/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2469 - sparse_categorical_accuracy: 0.9113 - val_loss: 0.3202 - val_sparse_categorical_accuracy: 0.8878\n",
"Epoch 25/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2428 - sparse_categorical_accuracy: 0.9123 - val_loss: 0.3152 - val_sparse_categorical_accuracy: 0.8856\n",
"Epoch 26/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2393 - sparse_categorical_accuracy: 0.9143 - val_loss: 0.3102 - val_sparse_categorical_accuracy: 0.8852\n",
"Epoch 27/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2341 - sparse_categorical_accuracy: 0.9147 - val_loss: 0.3200 - val_sparse_categorical_accuracy: 0.8850\n",
"Epoch 28/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2313 - sparse_categorical_accuracy: 0.9169 - val_loss: 0.3100 - val_sparse_categorical_accuracy: 0.8900\n",
"Epoch 29/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2268 - sparse_categorical_accuracy: 0.9185 - val_loss: 0.3215 - val_sparse_categorical_accuracy: 0.8864\n",
"Epoch 30/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2235 - sparse_categorical_accuracy: 0.9200 - val_loss: 0.3056 - val_sparse_categorical_accuracy: 0.8894\n"
]
}
],
"source": [
"history = model.fit(X_train, y_train, epochs=30,\n",
" validation_data=(X_valid, y_valid))"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"{'verbose': 1, 'epochs': 30, 'steps': 1719}"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"history.params"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]\n"
]
}
],
"source": [
"print(history.epoch)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAFYCAYAAABNvsbFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABwKElEQVR4nO3dd3gUVdsG8Puk9xAISSBAqAm9iUgUBAEFBIwIWAABUREUxVdREQtir2D5AEUELFhRBCyIlFAUpPfQIZTQa3rb8/3xZLK7ySZZ0jYb7t91zbW7M7MzZ8+WefZUpbUGERERkTNxcXQCiIiIiK4WAxgiIiJyOgxgiIiIyOkwgCEiIiKnwwCGiIiInA4DGCIiInI6RQYwSqlZSqkzSqmdBWxXSqmPlVIHlFLblVJtSz+ZRERERGb2lMDMAdCzkO29ADTKWUYCmF7yZBEREREVrMgARmu9CsCFQnaJAfCVFusAVFFK1SitBBIRERHlVRptYMIBHLN4fDxnHREREVGZcCuFYygb62zOT6CUGgmpZoKXl9d1derUKYXTU2FMJhNcXNhWu6wxn8sP87p8MJ/LB/O5cPv27Tunta5ua1tpBDDHAdS2eFwLQIKtHbXWMwDMAICoqCi9d+/eUjg9FSY2NhZdunRxdDIqPeZz+WFelw/mc/lgPhdOKRVf0LbSCPsWAhia0xupA4DLWuuTpXBcIiIiIpuKLIFRSn0HoAuAYKXUcQATAbgDgNb6UwB/ALgdwAEAKQAeKKvEEhEREQF2BDBa6/uK2K4BPFZqKSIiIiIqAlsOERERkdNhAENEREROhwEMEREROR0GMEREROR0GMAQERGR02EAQ0RERE6HAQwRERE5HQYwRERE5HQYwBAREZHTYQBDRERETocBDBERETkdBjBERETkdBjAEBERkdNhAENEREROhwEMEREROR0GMEREROR0GMAQERFRxbN2LcKBsII2u5VnWoiIiKiCWrsWiI0FunQBoqNLfpzOnYGoKCAjw3qpWhWoUUPur1plXp+eLretWwNJSUC3bggDwgs6DQMYIiIiR1i7FnXmzgU8PUsnYOjSBWjfHsjMBLy8ZNuRI8Dly0BqKpCWJreBgcCNN8r22bOBs2eBffuAL78EsrMBNzdg5UpJ04gRwIULElwYAUaPHsBLL8nzGzUCkpPNAUhqKqAUoDXg4SGP85owAXjjDeDKFeDWW/Nvf/NNuc3IKPRlM4AhIqKKrbRLBiyPo7VctLOyrJegIMDVFbh4ETh/3nqfzZuBkyeBW24B/P2BPXusA4SMDGDcODn+3LlSymBsS00F3N2BZ54BunVDvdRU4IsvgHr1JOjQGggJkXQCwEMPAcuXy3qTSZaGDYEVK2R7x47AP/9Yv8727YH//pP7MTHA9u3W27t2BZYtk/uvvw4cOmS9PStLzh8dLYFNYqIEWZ6ekkYjOAKklEUp2ebhIXmzerWkMyMD6NNHFg8P89KkiTy3ShXZ13KbhwcQHAzExQEeHtCpqbqgt5MBDBERlQ17Ao/MTODcOakySEyU26QkoG1bICwM+OknYPBguai6ugJ33SUXuLFjgchIYM0aYPJk6yqIjAwpWWjcGPjuO/nHn5gogYjh+++Be+6R5xrBhqUTJ4CaNYGPPgImTcq/3dVVLrYxMXKsvJ56CnBxATZuBBYskIu+t7fcBgVJvmRkQAESnHh7S3pdXIBq1czHadxYXpeLiyxKSfWLISDAXOKhlAQnQ4aYt7/7rpSQeHubF8vjb9wor2PLFuC22yTvPDzkPQMkfwszc6b147VrgW7dzMeZMKHg997NTQIwW6KjgWXLcPrGGxMKOjUDGCIiRyvLEoZiHie3aqN9eyAhAUhJMZcgpKRIiUH9+lIN8O235vXGPg0bAk8/LRdfAGjWTC7ARoDy0UcSQKxdK//i8/rlF6BfP+CPPyTIASSI+fVXuWjfc48EMImJwIED1v/gAwLMx6lRQ46/e7dUhRgX+u3b5Rg33QS89ppcTC0X4xj9+gENGpjXL1ggQVF2tlykIyLkWEZwYgQJSsnzp0yRxdZ75eEBU3o6XDw9gc8/t/2e2QquLL30Um4wBA8PeS2Wx+nRo/DnBwXJbceOUipT0s9PTuBRKp/D6GicAE4VtFlpXWDpTJmKiorSe/fudci5ryWxsbHoYkTSVGaYz+WnQuW1vQFDQoJcPC1LGHx85OJiXMCNEoaYGGnk2LAh8Oyz8vyHHwaOHrVuCNmpk5QeANLo8ehRqe4w9O0LLFwo90ND5V84IBdwQKomPvpIivp9fc3rc6pKNADl7Q3Mnw/07Jn/Nb36qlw8jx0D6tQxr3d1lQt4t27Ab7/J8QAJdlq0APz8ZBk2TPLs9GkJSoz1xhIZKRfX1asln4wL9LJlxbso5i0ZcPRxco51aNYs1B8xomIErhWQUmqT1rqdrW0sgSEi51FajR7//RdYuhTo3l0aM54+LYtlCUNmpgQTAPDnn8C2bdbbz5+Xf+MZGfJvu1Ej2deoCgkPB3bulHWDBkmjSEtt2siFOTbWuoThzz+lkaVl0fr589IQ08NDgoPAQPM/Z0CqDf75B9iwwVzCYNlO4YEH5BxGqYBSQIcO5vtPPGG+v3YtsHo1lNby2jZskNIBHx9z6YKPj5TAAFLNcvKkeZu7u/k4S5aYL/TffGP7PQsNBR55pOD3qlOnilUyUMolDEfT01G/pEFHdHSlC1zswRKYSq5C/VutxCptPpfWP7vVq+Vi1rat1OmnpEiJQKtWcjHet08CCmO9sbzwglwgf/1V2iFs2watNZRSQN26ctzwcGD6dOCdd+TfvuWyd6/U90+caN6elWVOl7e3XIzmzgWmTrVOs7u7uRfE8OHSQ8N4jre3VCcYjTuVkhKT1q3NJQjh4cBzz8lz/v4buHRJ1vv7y21QkFQ/VLSSgZzj5FZtlLCEobKWDJSWivTbcfIkcO+9wA8/SPMjRzt5EqhZs2mS1rv9bW1nCQxRZVTYhcP4Z52aKhc6Hx/pIbFzp7mEITVVShzeekv+ubu7S8PA6tWt2zqMHi3H37QJGDPGeltKirQV8PeXkg5bXSJXrgRuvhlYvx547DHzeuNf/qhREsBoLSUQWpsbPfr4SHUFANSuLa/V1VUWFxdzI0tA0jh2rKxbt07yxsiH2FjpKnrLLdYNHb29zemZNk2CJC8vcylG3oDhyy8Lvkjb6ipqqGglA9HROPnDKtw50h8LPk9EWLTN0nv701SBApfSukCX5nHGjm2Nv/6qGOl57TVps/vqq/KRrwjpAXz8CtxBa+2QJTIyUlPZW7FihaOTUPn9+68++NBDWv/7b4mPo99803wck0nrjAzz/c2btV65UuvfftP6u++0njFD61WrZHtqqtajRmk9eLDWnTpp7eKiNaC1u7sc7/RprYODtfbxMW8DtH73XXn+vn3mdZaLUnLr6iq3bm5aBwRoHRqqdd26Ws+fL8/fvl3r227TOiZG63vv1fqBB7R+7DFZ/+abOkHV1DcjVp9UNbS+5x6tFy7UetkyrS9elOcnJ2t96pTWiYlaZ2cXnD/e3jrbxUVrb+/i53fOcbSra8mOo7VOWLhB31w3Xp9ctKHYx6iIRo/WWimTHj26ZMdJSND65pu1PnmydNJVUg89JB//hx+Wr1RxjR4txylp/pRWPtuTHpNJ68uXtT54UOt167RetEjr2bPlJ8D4euddXF21fvttradP1/rbb7X+/XetV6+Wr3V8vNaXLtn+uhaWnuxsra9ckc/Gvn1ab9pk/ln7/nutZ86UnxlzOq7TuoA4glVIlVxFKp50emlpUmVw7pz59vBhYNIk6PR0KFdX4PbbpYunMV5E48bAiy/K8x9+GDh+3HqsiRtvlKqNtWulzYPJJPsapSJDhpirLjw985dijB4tf5WysqSkwt8fSEnByVPAvfgeP6j7EPbG49LG4Zln8pcw3Hwz0K6dlJYsXWrdzmHvXuDBB80lDH/9Je0RrtbatXi00w58lv0gHnH9AtNWtyj2v/KTizaaSwb6FL9k4OSijbj3iRD88MmZEh3n0UeBzz6TJhwl+cdaWq72n29
"text/plain": [
"<Figure size 576x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"pd.DataFrame(history.history).plot(\n",
" figsize=(8, 5), xlim=[0, 29], ylim=[0, 1], grid=True, xlabel=\"Epoch\",\n",
" style=[\"r--\", \"r--.\", \"b-\", \"b-*\"])\n",
"plt.legend(loc=\"lower left\") # extra code\n",
"save_fig(\"keras_learning_curves_plot\") # extra code\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeMAAAFFCAYAAADfMoXLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABnV0lEQVR4nO3dd3gU1frA8e/ZlE0PJRAgoQtIL1IVAQUFRAERFek2BNGr/sQGKmBDr4qKVyl6wQaKckHEgvSmoFQNEES69E4akHZ+f5xMdjd1Awm7Sd7P88yzu1N2zpwt75wyZ5TWGiGEEEJ4js3TCRBCCCFKOwnGQgghhIdJMBZCCCE8TIKxEEII4WESjIUQQggPk2AshBBCeFi+wVgpNV0pdVwptTWX5UopNUkptUsp9adSqkXhJ1MIIYQoudwpGX8CdMtjeXegTsY0DJh8+ckSQgghSo98g7HWehVwOo9VegGfaWMdUEYpVbmwEiiEEEKUdIXRZhwF/OP0+mDGPCGEEEK4wbcQ3kPlMC/HMTaVUsMwVdkEBgZeU7Vq1VzfND09HZtN+pe5Q/KqYCS/Ckbyy32SVwVTGvNr586dJ7XWFbLOL4xgfBBwjqrRwOGcVtRaTwOmAbRs2VJv2LAh1zddsWIFnTp1KoTklXySVwUj+VUwkl/uk7wqmNKYX0qp/TnNL4xTku+AwRm9qtsC57TWRwrhfYUQQohSId+SsVLqS6ATEKGUOgiMBfwAtNZTgB+BW4BdQBJwb1ElVgghhCiJ8g3GWut78lmugZGFliIhhBCilCldLedCCCGEF5JgLIQQQniYBGMhhBDCwyQYCyGEEB4mwVgIIYTwMAnGQgghhIdJMBZCCCE8TIKxEEII4WESjIUQQggPk2AshBBCeJgEYyGEEMLDJBgLIYQQHibBWAghhPAwCcZCCCGEh0kwFkIIITxMgrEQQghhWbsWJkwwj1dwW9+C700IIYRww9q1sGIFdOoE7doV3bZaw8WLkJRkpl9+gV274MYboXJl8/zCBTNdvGgeBw8GPz9YsgTWrDHz9+yBuXMhLQ0CAmDpUtiwAb7/3iy3Jj8/+O03s+9hw+Drr+H8eUhOBh8f8Pc32xbgmCUYCyFESVZYAbFtWxNszp83AS8sDEJC4Nw52LzZMT8pyTwvXx4GDTLBy2aDoUMhKsoEzvR0GDLE7OPPP+GLL8w8azp8GBYsgJQU8PWF1q0hMNDx/klJJkDWrg2TJsHjj5v3debjA6++CgMGwMcfZz+2O+6AMmVg0SJ4800TQJWC1FSzPDnZHDuYY7TbITQUIiLMo+Xaa03aNm82QT0tzbGtBGMhhPAyBQmK6ekmENnt5vV//wvr1kGDBlCrlglwtWpBy5bmz/+990wAcC69depkAkfnziY4+viY/YaFmYAzZAj07w9HjkCvXmZ/qalmSkmBfv1g4kTzXunpJlA5B7wZM0yA3bYNbrgh+zH072/SZAVY54Bos0GHDqaEuXs3/Oc/Zp5S5jE52bGt1vD331CjBgQFmQAaFGSCNJg8eP55My8oyOTx/PmOoFi2LKxebfIyIMDxaAXU116D1183+1271uRXcrIJztZn9dxzuX9WQ4eaKadtC0CCsRCi+LlS1Z/ubpuQAAkJBBw6BDExkJho/vCbNTPLX30Vxo83gc7HxwS/jh3h0UfN8q5d4cABiIuD+HgzDRoEn31m9vnAA9nT8tBDJhApBU8+6Zjv62sCTnCwKbElJ5v5aWmmujY62qxz8aJj/YgI8+jnZx59feHQIUdAVMqkt0sXR9Czjr9hQ1i2zOzLWhYYaALovHmO4LR4sSlFKuVI64oVcPvtpqSbNZ+dA9vcubl/VtdeayZLq1awcKFj29tvz/tz9nUKg+3amerlS/l+XM62SDAWQniCOwHRqqJUygSGfftMsNq4EV56yQQ2q13v9GnTTnjxoqOEqDV89JF5r7ffNn/QJ06YYJmebkpCa9aY/T/1lAkoVslPa9PW+OOP5vXw4SaY7N1rlikFTZuaqkkwbZPr19PWOf3XXWfeH0zJNSXFPE9NhR9+MGm3gnGlShAebkprYWHmsUULs2zFCpNWK80jRpj0RESY5TYbnDljArDdbl4757O/f95BrUIFx3Fm/Yxmz3Zs+9prOX9W4eE5l4yjojwT2C4zKNKuXcG3KYRtJRgLUZoVVimxeXM4edK1TS8xEdq0McFl2zbzB5mYCDt3mjbCtDQTPJYtMyW2d95xlAzj4kwnmxMnTND58EMTDLKy2uaOHjXrWAHJ39+U0CxWeuLiTFADE1Stdr1y5aBKFTPfKrlVrOjYvlIlEzydq2mtYAimZHr6NLEHDlD/mmtMqdR5+//+F+66ywTknDr3fPpp7vncqZM5JisoDhgAjRq5rlOmTM7bejKoWe/hgcB2Wdt6iARjIbzBpQZFrU1b4tKlUK8eXHVVZpUpCQmm9Fa3rukQ88EHrsv++Qc2bTIlNV9fU31pt5sgaU1TpjiCR79+Zl2rLc7qmGO3wzPPwLhx2dP3+++m2nDtWnjssezLU1LMcderZ0pS9es7SodhYWbfYHq+duxolu3ZY6ptrcBmdS56773c8+mFF8yUW7teXm2CYI6ta1fXbV96ybH87rsBOLZiBfVzaiu87TZz0uGFJb0jR8xHO3u2OecoyLbFUZ7HewX2DfXr5bRMgrEQhSVrQE1LM8HL6oSzfbspnSUmOkqP1aubdTp3NiVBHx/TxhUWZpb36GFKQmfPmvd1LnUmJcF995lSldXJJqv33jPB+MwZeOMN0/vVmhISTECztvP3N8HQx8cxlStn1qtUyZTsrPkbNphq4fR0E5zi42HaNEebYVCQKR1efbV57/79zXEFB5sTgC5dsneS6dMn97ytV89MYNatVcvrqj+PHIHHHmvGzz/n8ifvpUHx5ZdNbfpLL5nKhYK4nMCWb34V0X49dbzWviEoJKdlEoxFyXM5pczVq03bYLNmppTp4+OoEpw507RdxsWZSx3i4kzwevZZs8/27R2BzdfXBNn+/c12YC7PSEx03eewYaaXaHKy2X9qqrmko1w5E9CaNzfr2e0mcFtBzno8edLRycZmM/sbOtQRcKOizB/Iww2Y/U8KlSo7dZ7JWkr8739zzq8VK0wnnQ8+yH3bO+7IO6+tAA2ms00+ATHfPz0vrP58+WWIiQm/pD95a/tLDRKXIjDQnP9ZJk82k91uKlKCgsxz5/5WhZnmy8mv3PartTlvPX4cjh1znSZMMOfHFut4bTbzMyxb1kxlyuT8GB7uut8PPjCd1OPiHJPVwpJ13r//7brvnCid9dqsK6Rly5Z6w4YNuS5fsWIFnQrYNby0KnF5lVcwTUsznXVOnDC/uPh4UwUI5vKIOXNMQLV6gLZp4xgNZ8gQWL6ci4mJ2G02UyqsV89x8X6jRqZt01m7dvDrr+Z548awdav59YaFmV/nTTeZTkITJsCYMY7OPddfb0p/TZpAr14muHQ5yezntlApyscRnCpWNNWuzoEtS3tinoEpa1DMYaCBhx+GqVNN59tsf3r5nLgcOQLdup3l55/L5LzvywmoecgzzUXInTSfO2dq+A8eNI8PP+y4NNWZzWa+Hv7+jmZs69H5+b//nfP2AQHmz76wxcWZ881vvjFXADkH5KyUcq3ssKbNm3OuiPHxgYEDzbE7X6nkPE2enHNg8vExTe9K5T5NmJBzXillWjmOH3d0Hndms5nLnlNSHN0GbDYTaMPDzd/I2bM5v/fl8vU1rSvJyaYyS+uWaL0h2ymOlIxF0Vi71rSRtW1relrabCaIHjniGAnHmqx/rPXrzb/gpEmOS0BuuAF+/tn82h5/HN5/nyPpFenHV8ymH5WC4h2lzb/+gthY1w46Vg9WMMHUx4dTx49TpXp18yuJinIsv/pqU5WstUnvPffAyJGZi498uYJ+w8sw+xubawkTTEAKCHAExddfdwlQL7wAq2MjeGZxFyZNMqUSP7+MUkdkZJ4lxbxKAQmN2nHqizWcWvYHJ2u34dTeBpzaYArMr76ac0nAx8ecN4SEQHBwO0Ki2hFyBEIWudZih4TAiy+a0sv
"text/plain": [
"<Figure size 576x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code shows how to shift the training curve by -1/2 epoch\n",
"plt.figure(figsize=(8, 5))\n",
"for key, style in zip(history.history, [\"r--\", \"r--.\", \"b-\", \"b-*\"]):\n",
" epochs = np.array(history.epoch) + (0 if key.startswith(\"val_\") else -0.5)\n",
" plt.plot(epochs, history.history[key], style, label=key)\n",
"plt.xlabel(\"Epoch\")\n",
"plt.axis([-0.5, 29, 0., 1])\n",
"plt.legend(loc=\"lower left\")\n",
"plt.grid()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"313/313 [==============================] - 0s 867us/step - loss: 0.3243 - sparse_categorical_accuracy: 0.8864\n"
]
},
{
"data": {
"text/plain": [
"[0.32431697845458984, 0.8863999843597412]"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.evaluate(X_test, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using the model to make predictions"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0. , 0. , 0. , 0. , 0. , 0.01, 0. , 0.02, 0. , 0.97],\n",
" [0. , 0. , 0.99, 0. , 0.01, 0. , 0. , 0. , 0. , 0. ],\n",
" [0. , 1. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ]],\n",
" dtype=float32)"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_new = X_test[:3]\n",
"y_proba = model.predict(X_new)\n",
"y_proba.round(2)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([9, 2, 1])"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred = y_proba.argmax(axis=-1)\n",
"y_pred"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array(['Ankle boot', 'Pullover', 'Trouser'], dtype='<U11')"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.array(class_names)[y_pred]"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([9, 2, 1], dtype=uint8)"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_new = y_test[:3]\n",
"y_new"
]
},
2019-07-13 11:31:17 +02:00
{
"cell_type": "code",
"execution_count": 48,
2019-07-13 11:31:17 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAACVCAYAAAAe9i7zAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAYOUlEQVR4nO3de7BV1X0H8O8PBXmDgAIivRQBIT4AY7HgC4OpDzAGYqO2wdhGSnVSOzFOmZqRaqYZZtDOqH84NomvTCJ2RAwmqQGxPsIzNihCeAtcUOT9unIFQVb/2Jv0rO9enL3v4cI69/L9zNzh/s7ZZ+3N2eucdff67bWWOecgIiJysrWIfQAiInJqUgMkIiJRqAESEZEo1ACJiEgUaoBERCQKNUAiIhJFk2+AzGykmTkz61ZmG2dmtxznfo67DGkazOxOM/v0WLGINI7oDZCZDTWzL8xsXuxjic3M+qQN3aWxj6UpM7Pn0vfRmdkhM1tnZo+aWbvYxybVraTeHOvnudjH2JycHvsAAEwA8CSAO8xskHNuRewDkmZhDoDxAFoCuBLATwG0A3B3zIM6HmbW0jl3KPZxNHM9S34fA+An9NhnpRtX8zmp5mM7KuoVkJm1AfA3SE7ydADfoeePXhF8w8xeN7N6M1tuZl8tU+YZZvaKmS02s7OPsU0vM3vRzHanP78xs/4FDrlHum29mdWa2beo3IvMbI6ZfWZmu9K/xDuVPN/CzB40s01mdtDMlprZzSVFrE//fTf9f79V4Jgk7KBzbotzbpNz7gUAvwDwdTN7yMyWlW5YSRebmU00s7Vm9nn674SS56aZ2cu0fYv0vH8vjc3M/sXMPkzry9LS+lRS9283s/8xs88ATKzgfZAGSOvMFufcFgB7Sh8D0BrAHj4neZ/rY/VscLe+mU1Ov1cOmtkWM/tZyXPNs74456L9IPkLdUn6+0gA2wC0LHm+DwAHYCWAmwD0B/A8gJ0A2pe8zgHoBqAjgDcBvA2gY0k5DsAt6e9tAawG8ByAiwEMRPLXcS2AtmWO1aX7nQhgAIAfADgC4NKScj8G8EsAFwG4Ot3PyyVlfA/APiSN7gAAPwTwBYAh6fN/ke7nOgA9AHSJeX6a6k96bn9Njz0BYAeAhwAso+fuBPBpA+KxAA4B+G56Hv8pjW9Knx8N4ACAziWvuQbAYQA90vhHAFYBuB7An6d1Yj+A0VT3NwC4Jd3m3Njv7an0k77vriQOnpMCn+ujr7uUyi/9XvpGWsZoAH8G4FIA3y3ZtlnWl9gn+G0A96e/W/rmfSNwwieWPNYrfeyKNB6ZxoMA/AHAqwBalznRfw9gDQAref40JI3LN8scqwPwE3psDoCfp79PALAXQIeS548eW780/hjAZCrjrZIyghVVPw2uV8+hpAECMAxJ4/NfaJwGaB6AZwL7nJv+fjqSP6a+U/L8TwHMSn9vh6Qr50oq4zEA/0114fux389T9QfHboC+T9tV9Lmm76X7kDQwLQPH0WzrS7QuODPrB+ByAC8A6VlOuknuCmz+Qcnvm9N/uXttNoCPAIxzzh0os+svI/nroM7MPk27XvYCOBPAeTmHvSAQfyn9fRCAD5xzdSXPz0dylfQlM+sI4BwkX16l5paUIY3n+vT8HkBynt5BcqXSGAahzHl0zh1G0tj9LZB0CyP5C/fn6bZfQtKd89ujdTCth3cjWwf/t5GOWRrPn85JI36uX0JSJ9ab2dNm9tdpvQGacX2JeRPCXUiuPDaa2dHHDADMrLdzblPJtn9KpDnnXLo9N56/BvBNJN1f75XZbwsA7wO4LfDcruKHn2FI/gIJccf4vdxjcnzeAfAPSOrOZpcmY83sCNJ6VqJlBeXnncefA5hvZr0AXAagFYBX0ueO1t2bAGykMjhpvL+CY5MTK3ROytWHI+m///9FZ+bVOefcJjM7H8AoANcC+A8A/2Zml6EZ15coV0BmdjqAbwP4VwBDSn4GI7na+bsKin0QwFMA5pjZkDLbLQbQD8AO59xa+slrgP4yEB+9a285gMFm1qHk+RFI3uMVzrl9SK7erqAyrkhfCwCfp/+elnMckq8+Pae1zr8TaDuA7lbyVw+SutcQK1D+PMI5twjAhwBuR3Il9Evn3NEbHZYDOAigJlAHaxt4LBJRwc/19vTf0rvphgTKOuCc+41z7ntI8sEXIOklarb1JdYV0GgkNw38xDm3s/QJM3sRwN1m9u8NLdQ594P0i2WOmY1yzi0JbPYLAPcDmGlmk5H8RdEbwM0AnnLOrSmzi3Fm9i6S/t1bkPy1cllJuQ8D+Fla7pkA/hPADOfc2nSbRwD80MzWIMlXfQvJLcJfTp/fhqSv9zoz2wDggHNubwPfBinvLQBdADyQ1rWRSM5lQzwC4CUz+wOSrt/rkTQy42i7o13KfZDcuAAAcM7VmdmjAB5N6+s7ANoj+YPmiHPuxw08Homr7OfaOfeZmS0EMMnMPgTQCcCU0gLM7E4k38eLAHwK4FYkVzdrmnV9iZF4QnKjwOxjPNcXyaXrX6FY8m5kGncreX4KkqTzYN4+jbsDeBbJF/5BJLc/P1NaRuC4HJK7nn6LpJHYCODbtM1FAN5In9+NJDHdqeT5Fkiu1DYhudpZCuDrVMZdadlfAHgrxvlp6j8I3AVHz09EctfjfgAvAvhnNOAmhPSxfwSwFsmXxFoAEwL7OS+tN1sBnE7PGZKc1NG/brcDeB3AV9Png3VfPye1Hh3rJgT+PiryuT6aN6xPn7+Svse+jiRXuSetl+8CGNPc64ulBy8iInJSRZ+KR0RETk1qgEREJAo1QCIiEoUaIBERiUINkIiIRJE3Dki3yDVfPBtAY2oS9aauri7z2O9//3svHjVq1HHvZ/HixV7cvn17Lx4wYMBx7+Mkavb1hu8M9scsA2+88UbmNU888YQXDxkyxIu3bNnixf369cuU8emn/oTsu3fv9uLTT/e/rtevX58p45VXXsk8ViWC9UZXQCIiEoUaIBERiSJvIGpVXBLLCdHsulIOHPAnQX/ssce8eNq0aV7MXRwAsH37di9u06ZN7mvytG7dumzMXSsAcNVVV3nxhAkTvPj6669v8HE0kmZXb9iRI0e8uEUL/+/0K67gad+AefN4MuzyOnbsmHmsvr7eiw8fPuzFXBc/+8xbnBUA8Ktf/cqLx4wZ06DjOoHUBSciItVDDZCIiEShBkhERKJQDujU1aT78idNmpR57Mc/9mel37dvnxe3bdvWi7lPHcjmY7if/dAhf/2vL774IlPGGWec4cW8H/7MHTx4MFMG75f3M3z4cC9+5513MmWcIE263jSGDh06ZB5r2dJf0/Css87y4v37/XXiQvWGc4NcJtebtWvXgj3yyCNefP/992e2iUQ5IBERqR5qgEREJAo1QCIiEoUaIBERiSJvLjiRqsA3GEydOjWzTY8ePby4Xbt2XsxzeoVuwOGbDPIGkXKZQHbgIg8oZFwmkJ0v7rTTTvNiHvh40003ZcrgQYnSOHjONgDo1q2bF/MNMDy4lW9UCW3D+wm9hm3atCl3m2qiKyAREYlCDZCIiEShBkhERKJQDkiahAcffNCLQ5M5cj6GB/vxmiwhnTt39uK8iUND+QCeFLVr165ljys0GSkPTuV8Vffu3b04NBB1x44dXsx5Cilm69atudvwOQzlBkuF8oI88JTzflxm6DOwbdu2svutNroCEhGRKNQAiYhIFGqAREQkCuWApEnYu3evF4fGRHCehHM+d999txdPnDgxU8Yll1zixTyW6KOPPvLi0MSUNTU1Xsw5BD52LhMAevXqVfY1dXV1XhxanGzdunVerBxQZZYtW5a7TatWrbyYzwfnc0J5Px4HxPW5yFgizvtVO10BiYhIFGqAREQkCjVAIiIShXJA0iTwuJjQ/Gk5iytiypQpXtypU6fMNtzPXl9f78UjR4704jfffLPsPgFg0KBBXrxy5Uov5nnDAODxxx/3Yh4HxQuehRY4mzt3rhcPGzYs91gla8mSJV7M+R4gWx+53vDYMM5pAtnxYnlzF4YWMuScZbXTFZCIiEShBkhERKJQAyQiIlGoARIRkSh
"text/plain": [
"<Figure size 518.4x172.8 with 3 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2019-07-13 11:31:17 +02:00
"source": [
"# extra code this cell generates and saves Figure 1012\n",
2019-07-13 11:31:17 +02:00
"plt.figure(figsize=(7.2, 2.4))\n",
"for index, image in enumerate(X_new):\n",
" plt.subplot(1, 3, index + 1)\n",
" plt.imshow(image, cmap=\"binary\", interpolation=\"nearest\")\n",
" plt.axis('off')\n",
" plt.title(class_names[y_test[index]])\n",
2019-07-13 11:31:17 +02:00
"plt.subplots_adjust(wspace=0.2, hspace=0.5)\n",
"save_fig('fashion_mnist_images_plot', tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Building a Regression MLP Using the Sequential API"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's load, split and scale the California housing dataset (the original one, not the modified one as in chapter 2):"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"# extra code load and split the California housing dataset, like earlier\n",
"housing = fetch_california_housing()\n",
"X_train_full, X_test, y_train_full, y_test = train_test_split(\n",
" housing.data, housing.target, random_state=42)\n",
"X_train, X_valid, y_train, y_valid = train_test_split(\n",
" X_train_full, y_train_full, random_state=42)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 50,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.9051 - root_mean_squared_error: 0.9514 - val_loss: 0.4030 - val_root_mean_squared_error: 0.6348\n",
"Epoch 2/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3843 - root_mean_squared_error: 0.6199 - val_loss: 0.8436 - val_root_mean_squared_error: 0.9185\n",
"Epoch 3/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3609 - root_mean_squared_error: 0.6007 - val_loss: 0.3744 - val_root_mean_squared_error: 0.6119\n",
"Epoch 4/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3416 - root_mean_squared_error: 0.5844 - val_loss: 0.4343 - val_root_mean_squared_error: 0.6590\n",
"Epoch 5/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3301 - root_mean_squared_error: 0.5746 - val_loss: 0.3085 - val_root_mean_squared_error: 0.5554\n",
"Epoch 6/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3168 - root_mean_squared_error: 0.5629 - val_loss: 0.4544 - val_root_mean_squared_error: 0.6741\n",
"Epoch 7/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3162 - root_mean_squared_error: 0.5623 - val_loss: 0.2941 - val_root_mean_squared_error: 0.5423\n",
"Epoch 8/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3045 - root_mean_squared_error: 0.5518 - val_loss: 0.3333 - val_root_mean_squared_error: 0.5773\n",
"Epoch 9/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2974 - root_mean_squared_error: 0.5453 - val_loss: 0.3446 - val_root_mean_squared_error: 0.5870\n",
"Epoch 10/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2921 - root_mean_squared_error: 0.5404 - val_loss: 0.2874 - val_root_mean_squared_error: 0.5361\n",
"Epoch 11/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2863 - root_mean_squared_error: 0.5351 - val_loss: 0.4141 - val_root_mean_squared_error: 0.6435\n",
"Epoch 12/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2942 - root_mean_squared_error: 0.5424 - val_loss: 1.0956 - val_root_mean_squared_error: 1.0467\n",
"Epoch 13/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2864 - root_mean_squared_error: 0.5352 - val_loss: 0.3063 - val_root_mean_squared_error: 0.5534\n",
"Epoch 14/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2804 - root_mean_squared_error: 0.5295 - val_loss: 0.2709 - val_root_mean_squared_error: 0.5205\n",
"Epoch 15/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2784 - root_mean_squared_error: 0.5276 - val_loss: 0.3680 - val_root_mean_squared_error: 0.6066\n",
"Epoch 16/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2757 - root_mean_squared_error: 0.5250 - val_loss: 0.2730 - val_root_mean_squared_error: 0.5225\n",
"Epoch 17/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2739 - root_mean_squared_error: 0.5234 - val_loss: 0.3668 - val_root_mean_squared_error: 0.6056\n",
"Epoch 18/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2694 - root_mean_squared_error: 0.5191 - val_loss: 0.4188 - val_root_mean_squared_error: 0.6472\n",
"Epoch 19/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2677 - root_mean_squared_error: 0.5174 - val_loss: 0.9663 - val_root_mean_squared_error: 0.9830\n",
"Epoch 20/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2755 - root_mean_squared_error: 0.5249 - val_loss: 0.2978 - val_root_mean_squared_error: 0.5457\n",
"162/162 [==============================] - 0s 508us/step - loss: 0.2806 - root_mean_squared_error: 0.5297\n"
]
}
],
"source": [
"tf.random.set_seed(42)\n",
"norm_layer = tf.keras.layers.Normalization(input_shape=X_train.shape[1:])\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" norm_layer,\n",
" tf.keras.layers.Dense(50, activation=\"relu\"),\n",
" tf.keras.layers.Dense(50, activation=\"relu\"),\n",
" tf.keras.layers.Dense(50, activation=\"relu\"),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.Dense(1)\n",
"])\n",
"optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
"model.compile(loss=\"mse\", optimizer=optimizer, metrics=[\"RootMeanSquaredError\"])\n",
"norm_layer.adapt(X_train)\n",
"history = model.fit(X_train, y_train, epochs=20,\n",
" validation_data=(X_valid, y_valid))\n",
"mse_test, rmse_test = model.evaluate(X_test, y_test)\n",
"X_new = X_test[:3]\n",
"y_pred = model.predict(X_new)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 51,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.5297096967697144"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rmse_test"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 52,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0.4969182],\n",
" [1.195265 ],\n",
" [4.9428763]], dtype=float32)"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Building Complex Models Using the Functional API"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not all neural network models are simply sequential. Some may have complex topologies. Some may have multiple inputs and/or multiple outputs. For example, a Wide & Deep neural network (see [paper](https://ai.google/research/pubs/pub45413)) connects all or part of the inputs directly to the output layer."
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"# extra code reset the name counters and make the code reproducible\n",
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"normalization_layer = tf.keras.layers.Normalization()\n",
"hidden_layer1 = tf.keras.layers.Dense(30, activation=\"relu\")\n",
"hidden_layer2 = tf.keras.layers.Dense(30, activation=\"relu\")\n",
"concat_layer = tf.keras.layers.Concatenate()\n",
"output_layer = tf.keras.layers.Dense(1)\n",
"\n",
2021-10-17 04:04:08 +02:00
"input_ = tf.keras.layers.Input(shape=X_train.shape[1:])\n",
"normalized = normalization_layer(input_)\n",
"hidden1 = hidden_layer1(normalized)\n",
"hidden2 = hidden_layer2(hidden1)\n",
"concat = concat_layer([input_, hidden2])\n",
"output = output_layer(concat)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Model(inputs=[input_], outputs=[output])"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 55,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" input_1 (InputLayer) [(None, 8)] 0 [] \n",
" \n",
" normalization (Normalization) (None, 8) 17 ['input_1[0][0]'] \n",
" \n",
" dense (Dense) (None, 30) 270 ['normalization[0][0]'] \n",
" \n",
" dense_1 (Dense) (None, 30) 930 ['dense[0][0]'] \n",
" \n",
" concatenate (Concatenate) (None, 38) 0 ['input_1[0][0]', \n",
" 'dense_1[0][0]'] \n",
" \n",
" dense_2 (Dense) (None, 1) 39 ['concatenate[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 1,256\n",
"Trainable params: 1,239\n",
"Non-trainable params: 17\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 56,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"363/363 [==============================] - 1s 1ms/step - loss: 122.3226 - root_mean_squared_error: 11.0600 - val_loss: 305.9134 - val_root_mean_squared_error: 17.4904\n",
"Epoch 2/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 5.5425 - root_mean_squared_error: 2.3543 - val_loss: 183.4622 - val_root_mean_squared_error: 13.5448\n",
"Epoch 3/20\n",
"363/363 [==============================] - 0s 979us/step - loss: 3.0631 - root_mean_squared_error: 1.7502 - val_loss: 87.2228 - val_root_mean_squared_error: 9.3393\n",
"Epoch 4/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 1.5796 - root_mean_squared_error: 1.2568 - val_loss: 35.3699 - val_root_mean_squared_error: 5.9473\n",
"Epoch 5/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.9536 - root_mean_squared_error: 0.9765 - val_loss: 12.3882 - val_root_mean_squared_error: 3.5197\n",
"Epoch 6/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.6322 - root_mean_squared_error: 0.7951 - val_loss: 4.1676 - val_root_mean_squared_error: 2.0415\n",
"Epoch 7/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.5069 - root_mean_squared_error: 0.7120 - val_loss: 1.2937 - val_root_mean_squared_error: 1.1374\n",
"Epoch 8/20\n",
"363/363 [==============================] - 0s 980us/step - loss: 0.4525 - root_mean_squared_error: 0.6727 - val_loss: 0.4837 - val_root_mean_squared_error: 0.6955\n",
"Epoch 9/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4293 - root_mean_squared_error: 0.6552 - val_loss: 0.4343 - val_root_mean_squared_error: 0.6590\n",
"Epoch 10/20\n",
"363/363 [==============================] - 0s 962us/step - loss: 0.4120 - root_mean_squared_error: 0.6419 - val_loss: 0.3996 - val_root_mean_squared_error: 0.6321\n",
"Epoch 11/20\n",
"363/363 [==============================] - 0s 988us/step - loss: 0.4203 - root_mean_squared_error: 0.6483 - val_loss: 0.4149 - val_root_mean_squared_error: 0.6441\n",
"Epoch 12/20\n",
"363/363 [==============================] - 0s 952us/step - loss: 0.3916 - root_mean_squared_error: 0.6257 - val_loss: 0.4569 - val_root_mean_squared_error: 0.6759\n",
"Epoch 13/20\n",
"363/363 [==============================] - 0s 957us/step - loss: 0.4147 - root_mean_squared_error: 0.6440 - val_loss: 0.3736 - val_root_mean_squared_error: 0.6113\n",
"Epoch 14/20\n",
"363/363 [==============================] - 0s 949us/step - loss: 0.3824 - root_mean_squared_error: 0.6184 - val_loss: 0.4550 - val_root_mean_squared_error: 0.6745\n",
"Epoch 15/20\n",
"363/363 [==============================] - 0s 982us/step - loss: 0.4003 - root_mean_squared_error: 0.6327 - val_loss: 0.8553 - val_root_mean_squared_error: 0.9248\n",
"Epoch 16/20\n",
"363/363 [==============================] - 0s 960us/step - loss: 0.4245 - root_mean_squared_error: 0.6516 - val_loss: 1.9204 - val_root_mean_squared_error: 1.3858\n",
"Epoch 17/20\n",
"363/363 [==============================] - 0s 987us/step - loss: 0.4580 - root_mean_squared_error: 0.6767 - val_loss: 2.0632 - val_root_mean_squared_error: 1.4364\n",
"Epoch 18/20\n",
"363/363 [==============================] - 0s 961us/step - loss: 0.4692 - root_mean_squared_error: 0.6850 - val_loss: 3.5730 - val_root_mean_squared_error: 1.8902\n",
"Epoch 19/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4367 - root_mean_squared_error: 0.6608 - val_loss: 3.9989 - val_root_mean_squared_error: 1.9997\n",
"Epoch 20/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4683 - root_mean_squared_error: 0.6843 - val_loss: 2.2966 - val_root_mean_squared_error: 1.5155\n",
"162/162 [==============================] - 0s 612us/step - loss: 0.5723 - root_mean_squared_error: 0.7565\n"
]
}
],
"source": [
"optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
"model.compile(loss=\"mse\", optimizer=optimizer, metrics=[\"RootMeanSquaredError\"])\n",
"normalization_layer.adapt(X_train)\n",
"history = model.fit(X_train, y_train, epochs=20,\n",
" validation_data=(X_valid, y_valid))\n",
"mse_test = model.evaluate(X_test, y_test)\n",
"y_pred = model.predict(X_new)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What if you want to send different subsets of input features through the wide or deep paths? We will send 5 features (features 0 to 4), and 6 through the deep path (features 2 to 7). Note that 3 features will go through both (features 2, 3 and 4)."
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42) # extra code"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"input_wide = tf.keras.layers.Input(shape=[5]) # features 0 to 4\n",
"input_deep = tf.keras.layers.Input(shape=[6]) # features 2 to 7\n",
"norm_layer_wide = tf.keras.layers.Normalization()\n",
"norm_layer_deep = tf.keras.layers.Normalization()\n",
"norm_wide = norm_layer_wide(input_wide)\n",
"norm_deep = norm_layer_deep(input_deep)\n",
"hidden1 = tf.keras.layers.Dense(30, activation=\"relu\")(norm_deep)\n",
2021-10-17 04:04:08 +02:00
"hidden2 = tf.keras.layers.Dense(30, activation=\"relu\")(hidden1)\n",
"concat = tf.keras.layers.concatenate([norm_wide, hidden2])\n",
"output = tf.keras.layers.Dense(1)(concat)\n",
"model = tf.keras.Model(inputs=[input_wide, input_deep], outputs=[output])"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 59,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"363/363 [==============================] - 1s 2ms/step - loss: 1.2768 - root_mean_squared_error: 1.1300 - val_loss: 0.9497 - val_root_mean_squared_error: 0.9745\n",
"Epoch 2/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4767 - root_mean_squared_error: 0.6904 - val_loss: 1.4311 - val_root_mean_squared_error: 1.1963\n",
"Epoch 3/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4433 - root_mean_squared_error: 0.6658 - val_loss: 0.4258 - val_root_mean_squared_error: 0.6525\n",
"Epoch 4/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4057 - root_mean_squared_error: 0.6370 - val_loss: 0.4016 - val_root_mean_squared_error: 0.6338\n",
"Epoch 5/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3940 - root_mean_squared_error: 0.6277 - val_loss: 1.4914 - val_root_mean_squared_error: 1.2212\n",
"Epoch 6/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3873 - root_mean_squared_error: 0.6224 - val_loss: 2.6759 - val_root_mean_squared_error: 1.6358\n",
"Epoch 7/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3914 - root_mean_squared_error: 0.6257 - val_loss: 3.0592 - val_root_mean_squared_error: 1.7490\n",
"Epoch 8/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3735 - root_mean_squared_error: 0.6112 - val_loss: 3.3043 - val_root_mean_squared_error: 1.8178\n",
"Epoch 9/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3712 - root_mean_squared_error: 0.6093 - val_loss: 2.1298 - val_root_mean_squared_error: 1.4594\n",
"Epoch 10/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3693 - root_mean_squared_error: 0.6077 - val_loss: 1.7402 - val_root_mean_squared_error: 1.3192\n",
"Epoch 11/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3578 - root_mean_squared_error: 0.5982 - val_loss: 0.6127 - val_root_mean_squared_error: 0.7827\n",
"Epoch 12/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3605 - root_mean_squared_error: 0.6005 - val_loss: 1.3970 - val_root_mean_squared_error: 1.1819\n",
"Epoch 13/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3527 - root_mean_squared_error: 0.5939 - val_loss: 0.9449 - val_root_mean_squared_error: 0.9721\n",
"Epoch 14/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3436 - root_mean_squared_error: 0.5861 - val_loss: 0.7757 - val_root_mean_squared_error: 0.8807\n",
"Epoch 15/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3421 - root_mean_squared_error: 0.5849 - val_loss: 0.8920 - val_root_mean_squared_error: 0.9445\n",
"Epoch 16/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3405 - root_mean_squared_error: 0.5835 - val_loss: 0.9334 - val_root_mean_squared_error: 0.9661\n",
"Epoch 17/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3394 - root_mean_squared_error: 0.5826 - val_loss: 1.3433 - val_root_mean_squared_error: 1.1590\n",
"Epoch 18/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3384 - root_mean_squared_error: 0.5817 - val_loss: 2.6406 - val_root_mean_squared_error: 1.6250\n",
"Epoch 19/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3459 - root_mean_squared_error: 0.5881 - val_loss: 2.2482 - val_root_mean_squared_error: 1.4994\n",
"Epoch 20/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3503 - root_mean_squared_error: 0.5919 - val_loss: 1.4407 - val_root_mean_squared_error: 1.2003\n",
"162/162 [==============================] - 0s 672us/step - loss: 0.3388 - root_mean_squared_error: 0.5821\n"
]
}
],
"source": [
"optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
"model.compile(loss=\"mse\", optimizer=optimizer, metrics=[\"RootMeanSquaredError\"])\n",
"\n",
"X_train_wide, X_train_deep = X_train[:, :5], X_train[:, 2:]\n",
"X_valid_wide, X_valid_deep = X_valid[:, :5], X_valid[:, 2:]\n",
"X_test_wide, X_test_deep = X_test[:, :5], X_test[:, 2:]\n",
"X_new_wide, X_new_deep = X_test_wide[:3], X_test_deep[:3]\n",
"\n",
"norm_layer_wide.adapt(X_train_wide)\n",
"norm_layer_deep.adapt(X_train_deep)\n",
"history = model.fit((X_train_wide, X_train_deep), y_train, epochs=20,\n",
" validation_data=((X_valid_wide, X_valid_deep), y_valid))\n",
"mse_test = model.evaluate((X_test_wide, X_test_deep), y_test)\n",
"y_pred = model.predict((X_new_wide, X_new_deep))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adding an auxiliary output for regularization:"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"input_wide = tf.keras.layers.Input(shape=[5]) # features 0 to 4\n",
"input_deep = tf.keras.layers.Input(shape=[6]) # features 2 to 7\n",
"norm_layer_wide = tf.keras.layers.Normalization()\n",
"norm_layer_deep = tf.keras.layers.Normalization()\n",
"norm_wide = norm_layer_wide(input_wide)\n",
"norm_deep = norm_layer_deep(input_deep)\n",
"hidden1 = tf.keras.layers.Dense(30, activation=\"relu\")(norm_deep)\n",
2021-10-17 04:04:08 +02:00
"hidden2 = tf.keras.layers.Dense(30, activation=\"relu\")(hidden1)\n",
"concat = tf.keras.layers.concatenate([norm_wide, hidden2])\n",
"output = tf.keras.layers.Dense(1)(concat)\n",
"aux_output = tf.keras.layers.Dense(1)(hidden2)\n",
"model = tf.keras.Model(inputs=[input_wide, input_deep],\n",
" outputs=[output, aux_output])"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
"model.compile(loss=(\"mse\", \"mse\"), loss_weights=(0.9, 0.1), optimizer=optimizer,\n",
" metrics=[\"RootMeanSquaredError\"])"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 63,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"363/363 [==============================] - 1s 2ms/step - loss: 1.3490 - dense_2_loss: 1.2742 - dense_3_loss: 2.0215 - dense_2_root_mean_squared_error: 1.1288 - dense_3_root_mean_squared_error: 1.4218 - val_loss: 1.5415 - val_dense_2_loss: 0.9593 - val_dense_3_loss: 6.7806 - val_dense_2_root_mean_squared_error: 0.9795 - val_dense_3_root_mean_squared_error: 2.6040\n",
"Epoch 2/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.5101 - dense_2_loss: 0.4785 - dense_3_loss: 0.7952 - dense_2_root_mean_squared_error: 0.6917 - dense_3_root_mean_squared_error: 0.8917 - val_loss: 1.3624 - val_dense_2_loss: 1.0094 - val_dense_3_loss: 4.5401 - val_dense_2_root_mean_squared_error: 1.0047 - val_dense_3_root_mean_squared_error: 2.1307\n",
"Epoch 3/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4618 - dense_2_loss: 0.4404 - dense_3_loss: 0.6546 - dense_2_root_mean_squared_error: 0.6636 - dense_3_root_mean_squared_error: 0.8091 - val_loss: 0.5361 - val_dense_2_loss: 0.3975 - val_dense_3_loss: 1.7837 - val_dense_2_root_mean_squared_error: 0.6305 - val_dense_3_root_mean_squared_error: 1.3356\n",
"Epoch 4/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4252 - dense_2_loss: 0.4059 - dense_3_loss: 0.5985 - dense_2_root_mean_squared_error: 0.6371 - dense_3_root_mean_squared_error: 0.7736 - val_loss: 0.5182 - val_dense_2_loss: 0.4590 - val_dense_3_loss: 1.0517 - val_dense_2_root_mean_squared_error: 0.6775 - val_dense_3_root_mean_squared_error: 1.0255\n",
"Epoch 5/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4106 - dense_2_loss: 0.3931 - dense_3_loss: 0.5690 - dense_2_root_mean_squared_error: 0.6269 - dense_3_root_mean_squared_error: 0.7543 - val_loss: 0.4049 - val_dense_2_loss: 0.3588 - val_dense_3_loss: 0.8196 - val_dense_2_root_mean_squared_error: 0.5990 - val_dense_3_root_mean_squared_error: 0.9053\n",
"Epoch 6/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3944 - dense_2_loss: 0.3780 - dense_3_loss: 0.5424 - dense_2_root_mean_squared_error: 0.6148 - dense_3_root_mean_squared_error: 0.7365 - val_loss: 0.4168 - val_dense_2_loss: 0.3934 - val_dense_3_loss: 0.6275 - val_dense_2_root_mean_squared_error: 0.6272 - val_dense_3_root_mean_squared_error: 0.7921\n",
"Epoch 7/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3837 - dense_2_loss: 0.3694 - dense_3_loss: 0.5126 - dense_2_root_mean_squared_error: 0.6078 - dense_3_root_mean_squared_error: 0.7160 - val_loss: 0.3661 - val_dense_2_loss: 0.3430 - val_dense_3_loss: 0.5747 - val_dense_2_root_mean_squared_error: 0.5856 - val_dense_3_root_mean_squared_error: 0.7581\n",
"Epoch 8/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3731 - dense_2_loss: 0.3608 - dense_3_loss: 0.4840 - dense_2_root_mean_squared_error: 0.6007 - dense_3_root_mean_squared_error: 0.6957 - val_loss: 0.8555 - val_dense_2_loss: 0.8704 - val_dense_3_loss: 0.7218 - val_dense_2_root_mean_squared_error: 0.9330 - val_dense_3_root_mean_squared_error: 0.8496\n",
"Epoch 9/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3672 - dense_2_loss: 0.3567 - dense_3_loss: 0.4624 - dense_2_root_mean_squared_error: 0.5972 - dense_3_root_mean_squared_error: 0.6800 - val_loss: 2.6877 - val_dense_2_loss: 2.9011 - val_dense_3_loss: 0.7675 - val_dense_2_root_mean_squared_error: 1.7033 - val_dense_3_root_mean_squared_error: 0.8761\n",
"Epoch 10/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3837 - dense_2_loss: 0.3765 - dense_3_loss: 0.4481 - dense_2_root_mean_squared_error: 0.6136 - dense_3_root_mean_squared_error: 0.6694 - val_loss: 3.6017 - val_dense_2_loss: 3.8004 - val_dense_3_loss: 1.8132 - val_dense_2_root_mean_squared_error: 1.9495 - val_dense_3_root_mean_squared_error: 1.3466\n",
"Epoch 11/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3728 - dense_2_loss: 0.3656 - dense_3_loss: 0.4377 - dense_2_root_mean_squared_error: 0.6046 - dense_3_root_mean_squared_error: 0.6616 - val_loss: 0.6115 - val_dense_2_loss: 0.6325 - val_dense_3_loss: 0.4226 - val_dense_2_root_mean_squared_error: 0.7953 - val_dense_3_root_mean_squared_error: 0.6501\n",
"Epoch 12/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3750 - dense_2_loss: 0.3688 - dense_3_loss: 0.4303 - dense_2_root_mean_squared_error: 0.6073 - dense_3_root_mean_squared_error: 0.6560 - val_loss: 0.9371 - val_dense_2_loss: 0.9545 - val_dense_3_loss: 0.7799 - val_dense_2_root_mean_squared_error: 0.9770 - val_dense_3_root_mean_squared_error: 0.8831\n",
"Epoch 13/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3570 - dense_2_loss: 0.3499 - dense_3_loss: 0.4203 - dense_2_root_mean_squared_error: 0.5915 - dense_3_root_mean_squared_error: 0.6483 - val_loss: 0.4224 - val_dense_2_loss: 0.4245 - val_dense_3_loss: 0.4039 - val_dense_2_root_mean_squared_error: 0.6515 - val_dense_3_root_mean_squared_error: 0.6355\n",
"Epoch 14/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3493 - dense_2_loss: 0.3421 - dense_3_loss: 0.4148 - dense_2_root_mean_squared_error: 0.5849 - dense_3_root_mean_squared_error: 0.6440 - val_loss: 0.3410 - val_dense_2_loss: 0.3221 - val_dense_3_loss: 0.5105 - val_dense_2_root_mean_squared_error: 0.5676 - val_dense_3_root_mean_squared_error: 0.7145\n",
"Epoch 15/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3496 - dense_2_loss: 0.3432 - dense_3_loss: 0.4076 - dense_2_root_mean_squared_error: 0.5858 - dense_3_root_mean_squared_error: 0.6384 - val_loss: 0.6461 - val_dense_2_loss: 0.6671 - val_dense_3_loss: 0.4570 - val_dense_2_root_mean_squared_error: 0.8168 - val_dense_3_root_mean_squared_error: 0.6760\n",
"Epoch 16/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3435 - dense_2_loss: 0.3370 - dense_3_loss: 0.4022 - dense_2_root_mean_squared_error: 0.5805 - dense_3_root_mean_squared_error: 0.6342 - val_loss: 0.6875 - val_dense_2_loss: 0.6841 - val_dense_3_loss: 0.7182 - val_dense_2_root_mean_squared_error: 0.8271 - val_dense_3_root_mean_squared_error: 0.8475\n",
"Epoch 17/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3458 - dense_2_loss: 0.3393 - dense_3_loss: 0.4037 - dense_2_root_mean_squared_error: 0.5825 - dense_3_root_mean_squared_error: 0.6354 - val_loss: 1.1564 - val_dense_2_loss: 1.2129 - val_dense_3_loss: 0.6483 - val_dense_2_root_mean_squared_error: 1.1013 - val_dense_3_root_mean_squared_error: 0.8052\n",
"Epoch 18/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3446 - dense_2_loss: 0.3385 - dense_3_loss: 0.3994 - dense_2_root_mean_squared_error: 0.5818 - dense_3_root_mean_squared_error: 0.6320 - val_loss: 3.9325 - val_dense_2_loss: 4.0947 - val_dense_3_loss: 2.4722 - val_dense_2_root_mean_squared_error: 2.0235 - val_dense_3_root_mean_squared_error: 1.5723\n",
"Epoch 19/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3563 - dense_2_loss: 0.3511 - dense_3_loss: 0.4029 - dense_2_root_mean_squared_error: 0.5925 - dense_3_root_mean_squared_error: 0.6347 - val_loss: 1.4560 - val_dense_2_loss: 1.5433 - val_dense_3_loss: 0.6697 - val_dense_2_root_mean_squared_error: 1.2423 - val_dense_3_root_mean_squared_error: 0.8183\n",
"Epoch 20/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3546 - dense_2_loss: 0.3498 - dense_3_loss: 0.3981 - dense_2_root_mean_squared_error: 0.5914 - dense_3_root_mean_squared_error: 0.6310 - val_loss: 1.1709 - val_dense_2_loss: 1.1945 - val_dense_3_loss: 0.9589 - val_dense_2_root_mean_squared_error: 1.0929 - val_dense_3_root_mean_squared_error: 0.9792\n"
]
}
],
"source": [
"norm_layer_wide.adapt(X_train_wide)\n",
"norm_layer_deep.adapt(X_train_deep)\n",
"history = model.fit(\n",
" (X_train_wide, X_train_deep), (y_train, y_train), epochs=20,\n",
" validation_data=((X_valid_wide, X_valid_deep), (y_valid, y_valid))\n",
")"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 64,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"162/162 [==============================] - 0s 778us/step - loss: 0.3446 - dense_2_loss: 0.3381 - dense_3_loss: 0.4031 - dense_2_root_mean_squared_error: 0.5815 - dense_3_root_mean_squared_error: 0.6349\n"
]
}
],
"source": [
"eval_results = model.evaluate((X_test_wide, X_test_deep), (y_test, y_test))\n",
"weighted_sum_of_losses, main_loss, aux_loss, main_rmse, aux_rmse = eval_results"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7fb250e69310> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n"
]
}
],
"source": [
"y_pred_main, y_pred_aux = model.predict((X_new_wide, X_new_deep))"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"y_pred_tuple = model.predict((X_new_wide, X_new_deep))\n",
"y_pred = dict(zip(model.output_names, y_pred_tuple))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using the Subclassing API to Build Dynamic Models"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"class WideAndDeepModel(tf.keras.Model):\n",
" def __init__(self, units=30, activation=\"relu\", **kwargs):\n",
" super().__init__(**kwargs) # needed to support naming the model\n",
" self.norm_layer_wide = tf.keras.layers.Normalization()\n",
" self.norm_layer_deep = tf.keras.layers.Normalization()\n",
2021-10-17 04:04:08 +02:00
" self.hidden1 = tf.keras.layers.Dense(units, activation=activation)\n",
" self.hidden2 = tf.keras.layers.Dense(units, activation=activation)\n",
" self.main_output = tf.keras.layers.Dense(1)\n",
" self.aux_output = tf.keras.layers.Dense(1)\n",
" \n",
" def call(self, inputs):\n",
" input_wide, input_deep = inputs\n",
" norm_wide = self.norm_layer_wide(input_wide)\n",
" norm_deep = self.norm_layer_deep(input_deep)\n",
" hidden1 = self.hidden1(norm_deep)\n",
" hidden2 = self.hidden2(hidden1)\n",
" concat = tf.keras.layers.concatenate([norm_wide, hidden2])\n",
" output = self.main_output(concat)\n",
" aux_output = self.aux_output(hidden2)\n",
" return output, aux_output\n",
"\n",
"tf.random.set_seed(42) # extra code just for reproducibility\n",
"model = WideAndDeepModel(30, activation=\"relu\", name=\"my_cool_model\")"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"363/363 [==============================] - 1s 2ms/step - loss: 1.3490 - output_1_loss: 1.2742 - output_2_loss: 2.0215 - output_1_root_mean_squared_error: 1.1288 - output_2_root_mean_squared_error: 1.4218 - val_loss: 1.5415 - val_output_1_loss: 0.9593 - val_output_2_loss: 6.7806 - val_output_1_root_mean_squared_error: 0.9795 - val_output_2_root_mean_squared_error: 2.6040\n",
"Epoch 2/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.5101 - output_1_loss: 0.4785 - output_2_loss: 0.7952 - output_1_root_mean_squared_error: 0.6917 - output_2_root_mean_squared_error: 0.8917 - val_loss: 1.3624 - val_output_1_loss: 1.0094 - val_output_2_loss: 4.5401 - val_output_1_root_mean_squared_error: 1.0047 - val_output_2_root_mean_squared_error: 2.1307\n",
"Epoch 3/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4618 - output_1_loss: 0.4404 - output_2_loss: 0.6546 - output_1_root_mean_squared_error: 0.6636 - output_2_root_mean_squared_error: 0.8091 - val_loss: 0.5361 - val_output_1_loss: 0.3975 - val_output_2_loss: 1.7837 - val_output_1_root_mean_squared_error: 0.6305 - val_output_2_root_mean_squared_error: 1.3356\n",
"Epoch 4/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4252 - output_1_loss: 0.4059 - output_2_loss: 0.5985 - output_1_root_mean_squared_error: 0.6371 - output_2_root_mean_squared_error: 0.7736 - val_loss: 0.5182 - val_output_1_loss: 0.4590 - val_output_2_loss: 1.0517 - val_output_1_root_mean_squared_error: 0.6775 - val_output_2_root_mean_squared_error: 1.0255\n",
"Epoch 5/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4106 - output_1_loss: 0.3931 - output_2_loss: 0.5690 - output_1_root_mean_squared_error: 0.6269 - output_2_root_mean_squared_error: 0.7543 - val_loss: 0.4049 - val_output_1_loss: 0.3588 - val_output_2_loss: 0.8196 - val_output_1_root_mean_squared_error: 0.5990 - val_output_2_root_mean_squared_error: 0.9053\n",
"Epoch 6/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3944 - output_1_loss: 0.3780 - output_2_loss: 0.5424 - output_1_root_mean_squared_error: 0.6148 - output_2_root_mean_squared_error: 0.7365 - val_loss: 0.4168 - val_output_1_loss: 0.3934 - val_output_2_loss: 0.6275 - val_output_1_root_mean_squared_error: 0.6272 - val_output_2_root_mean_squared_error: 0.7921\n",
"Epoch 7/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3837 - output_1_loss: 0.3694 - output_2_loss: 0.5126 - output_1_root_mean_squared_error: 0.6078 - output_2_root_mean_squared_error: 0.7160 - val_loss: 0.3661 - val_output_1_loss: 0.3430 - val_output_2_loss: 0.5747 - val_output_1_root_mean_squared_error: 0.5856 - val_output_2_root_mean_squared_error: 0.7581\n",
"Epoch 8/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3731 - output_1_loss: 0.3608 - output_2_loss: 0.4840 - output_1_root_mean_squared_error: 0.6007 - output_2_root_mean_squared_error: 0.6957 - val_loss: 0.8555 - val_output_1_loss: 0.8704 - val_output_2_loss: 0.7218 - val_output_1_root_mean_squared_error: 0.9330 - val_output_2_root_mean_squared_error: 0.8496\n",
"Epoch 9/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3672 - output_1_loss: 0.3567 - output_2_loss: 0.4624 - output_1_root_mean_squared_error: 0.5972 - output_2_root_mean_squared_error: 0.6800 - val_loss: 2.6877 - val_output_1_loss: 2.9011 - val_output_2_loss: 0.7675 - val_output_1_root_mean_squared_error: 1.7033 - val_output_2_root_mean_squared_error: 0.8761\n",
"Epoch 10/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3837 - output_1_loss: 0.3765 - output_2_loss: 0.4481 - output_1_root_mean_squared_error: 0.6136 - output_2_root_mean_squared_error: 0.6694 - val_loss: 3.6017 - val_output_1_loss: 3.8004 - val_output_2_loss: 1.8132 - val_output_1_root_mean_squared_error: 1.9495 - val_output_2_root_mean_squared_error: 1.3466\n",
"162/162 [==============================] - 0s 781us/step - loss: 0.3652 - output_1_loss: 0.3570 - output_2_loss: 0.4387 - output_1_root_mean_squared_error: 0.5975 - output_2_root_mean_squared_error: 0.6624\n",
"WARNING:tensorflow:6 out of the last 7 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7fb250b9d820> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n"
]
}
],
"source": [
"optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
"model.compile(loss=\"mse\", loss_weights=[0.9, 0.1], optimizer=optimizer,\n",
" metrics=[\"RootMeanSquaredError\"])\n",
"model.norm_layer_wide.adapt(X_train_wide)\n",
"model.norm_layer_deep.adapt(X_train_deep)\n",
"history = model.fit(\n",
" (X_train_wide, X_train_deep), (y_train, y_train), epochs=10,\n",
" validation_data=((X_valid_wide, X_valid_deep), (y_valid, y_valid)))\n",
"eval_results = model.evaluate((X_test_wide, X_test_deep), (y_test, y_test))\n",
"weighted_sum_of_losses, main_loss, aux_loss, main_rmse, aux_rmse = eval_results\n",
"y_pred_main, y_pred_aux = model.predict((X_new_wide, X_new_deep))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Saving and Restoring a Model"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
"# extra code delete the directory, in case it already exists\n",
"\n",
"import shutil\n",
"\n",
"shutil.rmtree(\"my_keras_model\", ignore_errors=True)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 70,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_keras_model/assets\n"
]
}
],
"source": [
2022-09-12 01:46:42 +02:00
"model.save(\"my_keras_model\", save_format=\"tf\")"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 71,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"my_keras_model/assets\n",
"my_keras_model/keras_metadata.pb\n",
"my_keras_model/saved_model.pb\n",
"my_keras_model/variables\n",
"my_keras_model/variables/variables.data-00000-of-00001\n",
"my_keras_model/variables/variables.index\n"
]
}
],
"source": [
"# extra code show the contents of the my_keras_model/ directory\n",
"for path in sorted(Path(\"my_keras_model\").glob(\"**/*\")):\n",
" print(path)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"model = tf.keras.models.load_model(\"my_keras_model\")\n",
"y_pred_main, y_pred_aux = model.predict((X_new_wide, X_new_deep))"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"model.save_weights(\"my_weights\")"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 74,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fb210482490>"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.load_weights(\"my_weights\")"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 75,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"my_weights.data-00000-of-00001\n",
"my_weights.index\n"
]
}
],
"source": [
"# extra code show the list of my_weights.* files\n",
"for path in sorted(Path().glob(\"my_weights.*\")):\n",
" print(path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using Callbacks"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"shutil.rmtree(\"my_checkpoints\", ignore_errors=True) # extra code"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 77,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"363/363 [==============================] - 1s 2ms/step - loss: 0.3775 - output_1_loss: 0.3706 - output_2_loss: 0.4402 - output_1_root_mean_squared_error: 0.6088 - output_2_root_mean_squared_error: 0.6635 - val_loss: 0.3369 - val_output_1_loss: 0.3234 - val_output_2_loss: 0.4587 - val_output_1_root_mean_squared_error: 0.5687 - val_output_2_root_mean_squared_error: 0.6773\n",
"Epoch 2/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3556 - output_1_loss: 0.3480 - output_2_loss: 0.4242 - output_1_root_mean_squared_error: 0.5899 - output_2_root_mean_squared_error: 0.6513 - val_loss: 0.4940 - val_output_1_loss: 0.4650 - val_output_2_loss: 0.7551 - val_output_1_root_mean_squared_error: 0.6819 - val_output_2_root_mean_squared_error: 0.8689\n",
"Epoch 3/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3612 - output_1_loss: 0.3547 - output_2_loss: 0.4198 - output_1_root_mean_squared_error: 0.5956 - output_2_root_mean_squared_error: 0.6480 - val_loss: 0.3443 - val_output_1_loss: 0.3355 - val_output_2_loss: 0.4241 - val_output_1_root_mean_squared_error: 0.5792 - val_output_2_root_mean_squared_error: 0.6512\n",
"Epoch 4/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3493 - output_1_loss: 0.3425 - output_2_loss: 0.4110 - output_1_root_mean_squared_error: 0.5852 - output_2_root_mean_squared_error: 0.6411 - val_loss: 0.4676 - val_output_1_loss: 0.4635 - val_output_2_loss: 0.5046 - val_output_1_root_mean_squared_error: 0.6808 - val_output_2_root_mean_squared_error: 0.7104\n",
"Epoch 5/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3525 - output_1_loss: 0.3465 - output_2_loss: 0.4069 - output_1_root_mean_squared_error: 0.5886 - output_2_root_mean_squared_error: 0.6379 - val_loss: 1.3020 - val_output_1_loss: 1.3842 - val_output_2_loss: 0.5623 - val_output_1_root_mean_squared_error: 1.1765 - val_output_2_root_mean_squared_error: 0.7499\n",
"Epoch 6/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3512 - output_1_loss: 0.3453 - output_2_loss: 0.4039 - output_1_root_mean_squared_error: 0.5876 - output_2_root_mean_squared_error: 0.6356 - val_loss: 1.6719 - val_output_1_loss: 1.7502 - val_output_2_loss: 0.9670 - val_output_1_root_mean_squared_error: 1.3230 - val_output_2_root_mean_squared_error: 0.9833\n",
"Epoch 7/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3533 - output_1_loss: 0.3477 - output_2_loss: 0.4038 - output_1_root_mean_squared_error: 0.5897 - output_2_root_mean_squared_error: 0.6355 - val_loss: 0.6855 - val_output_1_loss: 0.7149 - val_output_2_loss: 0.4210 - val_output_1_root_mean_squared_error: 0.8455 - val_output_2_root_mean_squared_error: 0.6488\n",
"Epoch 8/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3409 - output_1_loss: 0.3348 - output_2_loss: 0.3965 - output_1_root_mean_squared_error: 0.5786 - output_2_root_mean_squared_error: 0.6297 - val_loss: 2.0126 - val_output_1_loss: 1.9280 - val_output_2_loss: 2.7742 - val_output_1_root_mean_squared_error: 1.3885 - val_output_2_root_mean_squared_error: 1.6656\n",
"Epoch 9/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3441 - output_1_loss: 0.3375 - output_2_loss: 0.4028 - output_1_root_mean_squared_error: 0.5810 - output_2_root_mean_squared_error: 0.6347 - val_loss: 1.6894 - val_output_1_loss: 1.8009 - val_output_2_loss: 0.6859 - val_output_1_root_mean_squared_error: 1.3420 - val_output_2_root_mean_squared_error: 0.8282\n",
"Epoch 10/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3517 - output_1_loss: 0.3468 - output_2_loss: 0.3962 - output_1_root_mean_squared_error: 0.5889 - output_2_root_mean_squared_error: 0.6294 - val_loss: 1.2969 - val_output_1_loss: 1.3365 - val_output_2_loss: 0.9407 - val_output_1_root_mean_squared_error: 1.1561 - val_output_2_root_mean_squared_error: 0.9699\n"
]
}
],
"source": [
"checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\"my_checkpoints\",\n",
" save_weights_only=True)\n",
"history = model.fit(\n",
" (X_train_wide, X_train_deep), (y_train, y_train), epochs=10,\n",
" validation_data=((X_valid_wide, X_valid_deep), (y_valid, y_valid)),\n",
" callbacks=[checkpoint_cb])"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 78,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3405 - output_1_loss: 0.3349 - output_2_loss: 0.3910 - output_1_root_mean_squared_error: 0.5787 - output_2_root_mean_squared_error: 0.6253 - val_loss: 0.6245 - val_output_1_loss: 0.6502 - val_output_2_loss: 0.3937 - val_output_1_root_mean_squared_error: 0.8063 - val_output_2_root_mean_squared_error: 0.6275\n",
"Epoch 2/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3400 - output_1_loss: 0.3344 - output_2_loss: 0.3900 - output_1_root_mean_squared_error: 0.5783 - output_2_root_mean_squared_error: 0.6245 - val_loss: 0.9552 - val_output_1_loss: 0.9508 - val_output_2_loss: 0.9947 - val_output_1_root_mean_squared_error: 0.9751 - val_output_2_root_mean_squared_error: 0.9974\n",
"Epoch 3/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3442 - output_1_loss: 0.3389 - output_2_loss: 0.3921 - output_1_root_mean_squared_error: 0.5821 - output_2_root_mean_squared_error: 0.6262 - val_loss: 0.3574 - val_output_1_loss: 0.3552 - val_output_2_loss: 0.3766 - val_output_1_root_mean_squared_error: 0.5960 - val_output_2_root_mean_squared_error: 0.6137\n",
"Epoch 4/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3347 - output_1_loss: 0.3289 - output_2_loss: 0.3865 - output_1_root_mean_squared_error: 0.5735 - output_2_root_mean_squared_error: 0.6217 - val_loss: 0.4521 - val_output_1_loss: 0.4401 - val_output_2_loss: 0.5609 - val_output_1_root_mean_squared_error: 0.6634 - val_output_2_root_mean_squared_error: 0.7489\n",
"Epoch 5/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3363 - output_1_loss: 0.3311 - output_2_loss: 0.3832 - output_1_root_mean_squared_error: 0.5754 - output_2_root_mean_squared_error: 0.6190 - val_loss: 0.4903 - val_output_1_loss: 0.5018 - val_output_2_loss: 0.3869 - val_output_1_root_mean_squared_error: 0.7084 - val_output_2_root_mean_squared_error: 0.6220\n",
"Epoch 6/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3300 - output_1_loss: 0.3245 - output_2_loss: 0.3801 - output_1_root_mean_squared_error: 0.5696 - output_2_root_mean_squared_error: 0.6165 - val_loss: 0.8351 - val_output_1_loss: 0.8434 - val_output_2_loss: 0.7602 - val_output_1_root_mean_squared_error: 0.9184 - val_output_2_root_mean_squared_error: 0.8719\n",
"Epoch 7/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3324 - output_1_loss: 0.3270 - output_2_loss: 0.3814 - output_1_root_mean_squared_error: 0.5718 - output_2_root_mean_squared_error: 0.6176 - val_loss: 0.6880 - val_output_1_loss: 0.7171 - val_output_2_loss: 0.4259 - val_output_1_root_mean_squared_error: 0.8468 - val_output_2_root_mean_squared_error: 0.6526\n",
"Epoch 8/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3286 - output_1_loss: 0.3231 - output_2_loss: 0.3774 - output_1_root_mean_squared_error: 0.5684 - output_2_root_mean_squared_error: 0.6143 - val_loss: 4.4284 - val_output_1_loss: 4.2604 - val_output_2_loss: 5.9404 - val_output_1_root_mean_squared_error: 2.0641 - val_output_2_root_mean_squared_error: 2.4373\n",
"Epoch 9/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3378 - output_1_loss: 0.3322 - output_2_loss: 0.3886 - output_1_root_mean_squared_error: 0.5764 - output_2_root_mean_squared_error: 0.6234 - val_loss: 1.7043 - val_output_1_loss: 1.7984 - val_output_2_loss: 0.8578 - val_output_1_root_mean_squared_error: 1.3410 - val_output_2_root_mean_squared_error: 0.9262\n",
"Epoch 10/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3401 - output_1_loss: 0.3354 - output_2_loss: 0.3824 - output_1_root_mean_squared_error: 0.5792 - output_2_root_mean_squared_error: 0.6184 - val_loss: 0.6170 - val_output_1_loss: 0.6282 - val_output_2_loss: 0.5169 - val_output_1_root_mean_squared_error: 0.7926 - val_output_2_root_mean_squared_error: 0.7190\n",
"Epoch 11/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3230 - output_1_loss: 0.3177 - output_2_loss: 0.3706 - output_1_root_mean_squared_error: 0.5637 - output_2_root_mean_squared_error: 0.6088 - val_loss: 0.3558 - val_output_1_loss: 0.3490 - val_output_2_loss: 0.4170 - val_output_1_root_mean_squared_error: 0.5907 - val_output_2_root_mean_squared_error: 0.6457\n",
"Epoch 12/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3253 - output_1_loss: 0.3201 - output_2_loss: 0.3727 - output_1_root_mean_squared_error: 0.5658 - output_2_root_mean_squared_error: 0.6105 - val_loss: 0.4612 - val_output_1_loss: 0.4597 - val_output_2_loss: 0.4745 - val_output_1_root_mean_squared_error: 0.6780 - val_output_2_root_mean_squared_error: 0.6888\n",
"Epoch 13/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3221 - output_1_loss: 0.3167 - output_2_loss: 0.3699 - output_1_root_mean_squared_error: 0.5628 - output_2_root_mean_squared_error: 0.6082 - val_loss: 0.3120 - val_output_1_loss: 0.3056 - val_output_2_loss: 0.3694 - val_output_1_root_mean_squared_error: 0.5528 - val_output_2_root_mean_squared_error: 0.6078\n",
"Epoch 14/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3204 - output_1_loss: 0.3149 - output_2_loss: 0.3695 - output_1_root_mean_squared_error: 0.5612 - output_2_root_mean_squared_error: 0.6078 - val_loss: 0.4120 - val_output_1_loss: 0.4013 - val_output_2_loss: 0.5076 - val_output_1_root_mean_squared_error: 0.6335 - val_output_2_root_mean_squared_error: 0.7124\n",
"Epoch 15/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3196 - output_1_loss: 0.3144 - output_2_loss: 0.3662 - output_1_root_mean_squared_error: 0.5607 - output_2_root_mean_squared_error: 0.6052 - val_loss: 0.3304 - val_output_1_loss: 0.3269 - val_output_2_loss: 0.3619 - val_output_1_root_mean_squared_error: 0.5718 - val_output_2_root_mean_squared_error: 0.6016\n",
"Epoch 16/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3166 - output_1_loss: 0.3113 - output_2_loss: 0.3639 - output_1_root_mean_squared_error: 0.5579 - output_2_root_mean_squared_error: 0.6032 - val_loss: 0.4455 - val_output_1_loss: 0.4414 - val_output_2_loss: 0.4819 - val_output_1_root_mean_squared_error: 0.6644 - val_output_2_root_mean_squared_error: 0.6942\n",
"Epoch 17/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3186 - output_1_loss: 0.3134 - output_2_loss: 0.3650 - output_1_root_mean_squared_error: 0.5599 - output_2_root_mean_squared_error: 0.6041 - val_loss: 0.3255 - val_output_1_loss: 0.3212 - val_output_2_loss: 0.3643 - val_output_1_root_mean_squared_error: 0.5667 - val_output_2_root_mean_squared_error: 0.6035\n",
"Epoch 18/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3143 - output_1_loss: 0.3091 - output_2_loss: 0.3611 - output_1_root_mean_squared_error: 0.5560 - output_2_root_mean_squared_error: 0.6009 - val_loss: 1.6360 - val_output_1_loss: 1.6925 - val_output_2_loss: 1.1276 - val_output_1_root_mean_squared_error: 1.3010 - val_output_2_root_mean_squared_error: 1.0619\n",
"Epoch 19/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3169 - output_1_loss: 0.3122 - output_2_loss: 0.3601 - output_1_root_mean_squared_error: 0.5587 - output_2_root_mean_squared_error: 0.6001 - val_loss: 1.2441 - val_output_1_loss: 1.3093 - val_output_2_loss: 0.6572 - val_output_1_root_mean_squared_error: 1.1442 - val_output_2_root_mean_squared_error: 0.8107\n",
"Epoch 20/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3245 - output_1_loss: 0.3201 - output_2_loss: 0.3641 - output_1_root_mean_squared_error: 0.5658 - output_2_root_mean_squared_error: 0.6034 - val_loss: 1.5466 - val_output_1_loss: 1.5582 - val_output_2_loss: 1.4424 - val_output_1_root_mean_squared_error: 1.2483 - val_output_2_root_mean_squared_error: 1.2010\n",
"Epoch 21/100\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3202 - output_1_loss: 0.3153 - output_2_loss: 0.3640 - output_1_root_mean_squared_error: 0.5615 - output_2_root_mean_squared_error: 0.6033 - val_loss: 0.6704 - val_output_1_loss: 0.6907 - val_output_2_loss: 0.4873 - val_output_1_root_mean_squared_error: 0.8311 - val_output_2_root_mean_squared_error: 0.6980\n",
"Epoch 22/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3150 - output_1_loss: 0.3103 - output_2_loss: 0.3573 - output_1_root_mean_squared_error: 0.5570 - output_2_root_mean_squared_error: 0.5978 - val_loss: 0.4909 - val_output_1_loss: 0.4955 - val_output_2_loss: 0.4493 - val_output_1_root_mean_squared_error: 0.7039 - val_output_2_root_mean_squared_error: 0.6703\n",
"Epoch 23/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3104 - output_1_loss: 0.3054 - output_2_loss: 0.3552 - output_1_root_mean_squared_error: 0.5526 - output_2_root_mean_squared_error: 0.5960 - val_loss: 0.3845 - val_output_1_loss: 0.3803 - val_output_2_loss: 0.4228 - val_output_1_root_mean_squared_error: 0.6167 - val_output_2_root_mean_squared_error: 0.6502\n"
]
}
],
"source": [
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=10,\n",
" restore_best_weights=True)\n",
"history = model.fit(\n",
" (X_train_wide, X_train_deep), (y_train, y_train), epochs=100,\n",
" validation_data=((X_valid_wide, X_valid_deep), (y_valid, y_valid)),\n",
" callbacks=[checkpoint_cb, early_stopping_cb])"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
"class PrintValTrainRatioCallback(tf.keras.callbacks.Callback):\n",
" def on_epoch_end(self, epoch, logs):\n",
" ratio = logs[\"val_loss\"] / logs[\"loss\"]\n",
" print(f\"Epoch={epoch}, val/train={ratio:.2f}\")"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 80,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch=0, val/train=2.29\n",
"Epoch=1, val/train=1.03\n",
"Epoch=2, val/train=2.07\n",
"Epoch=3, val/train=1.76\n",
"Epoch=4, val/train=3.56\n",
"Epoch=5, val/train=1.86\n",
"Epoch=6, val/train=2.45\n",
"Epoch=7, val/train=7.86\n",
"Epoch=8, val/train=11.20\n",
"Epoch=9, val/train=1.14\n"
]
}
],
"source": [
"val_train_ratio_cb = PrintValTrainRatioCallback()\n",
"history = model.fit(\n",
" (X_train_wide, X_train_deep), (y_train, y_train), epochs=10,\n",
" validation_data=((X_valid_wide, X_valid_deep), (y_valid, y_valid)),\n",
" callbacks=[val_train_ratio_cb], verbose=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using TensorBoard for Visualization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"TensorBoard is preinstalled on Colab, but not the `tensorboard-plugin-profile`, so let's install it:"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
"if \"google.colab\" in sys.modules: # extra code\n",
" %pip install -q -U tensorboard-plugin-profile"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"shutil.rmtree(\"my_logs\", ignore_errors=True)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from time import strftime\n",
"\n",
"def get_run_logdir(root_logdir=\"my_logs\"):\n",
" return Path(root_logdir) / strftime(\"run_%Y_%m_%d_%H_%M_%S\")\n",
"\n",
"run_logdir = get_run_logdir()"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
"# extra code builds the first regression model we used earlier\n",
2021-10-17 04:04:08 +02:00
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)\n",
"norm_layer = tf.keras.layers.Normalization(input_shape=X_train.shape[1:])\n",
"model = tf.keras.Sequential([\n",
" norm_layer,\n",
" tf.keras.layers.Dense(30, activation=\"relu\"),\n",
" tf.keras.layers.Dense(30, activation=\"relu\"),\n",
" tf.keras.layers.Dense(1)\n",
"])\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)\n",
"model.compile(loss=\"mse\", optimizer=optimizer, metrics=[\"RootMeanSquaredError\"])\n",
"norm_layer.adapt(X_train)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 85,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-08-01 17:25:59.099970: I tensorflow/core/profiler/lib/profiler_session.cc:110] Profiler session initializing.\n",
"2022-08-01 17:25:59.099982: I tensorflow/core/profiler/lib/profiler_session.cc:125] Profiler session started.\n",
"2022-08-01 17:25:59.100137: I tensorflow/core/profiler/lib/profiler_session.cc:143] Profiler session tear down.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"261/363 [====================>.........] - ETA: 0s - loss: 2.3165 - root_mean_squared_error: 1.5220"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-08-01 17:25:59.430946: I tensorflow/core/profiler/lib/profiler_session.cc:110] Profiler session initializing.\n",
"2022-08-01 17:25:59.430962: I tensorflow/core/profiler/lib/profiler_session.cc:125] Profiler session started.\n",
"2022-08-01 17:25:59.510100: I tensorflow/core/profiler/lib/profiler_session.cc:67] Profiler session collecting data.\n",
"2022-08-01 17:25:59.524969: I tensorflow/core/profiler/lib/profiler_session.cc:143] Profiler session tear down.\n",
"2022-08-01 17:25:59.539451: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00\n",
"\n",
"2022-08-01 17:25:59.549606: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for trace.json.gz to my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00/my_computer.trace.json.gz\n",
"2022-08-01 17:25:59.558338: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00\n",
"\n",
"2022-08-01 17:25:59.558474: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for memory_profile.json.gz to my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00/my_computer.memory_profile.json.gz\n",
"2022-08-01 17:25:59.559618: I tensorflow/core/profiler/rpc/client/capture_profile.cc:251] Creating directory: my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00\n",
"Dumped tool data for xplane.pb to my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00/my_computer.xplane.pb\n",
"Dumped tool data for overview_page.pb to my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00/my_computer.overview_page.pb\n",
"Dumped tool data for input_pipeline.pb to my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00/my_computer.input_pipeline.pb\n",
"Dumped tool data for tensorflow_stats.pb to my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00/my_computer.tensorflow_stats.pb\n",
"Dumped tool data for kernel_stats.pb to my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00/my_computer.kernel_stats.pb\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"363/363 [==============================] - 1s 1ms/step - loss: 1.8866 - root_mean_squared_error: 1.3736 - val_loss: 0.7126 - val_root_mean_squared_error: 0.8442\n",
"Epoch 2/20\n",
"363/363 [==============================] - 0s 907us/step - loss: 0.6577 - root_mean_squared_error: 0.8110 - val_loss: 0.6880 - val_root_mean_squared_error: 0.8295\n",
"Epoch 3/20\n",
"363/363 [==============================] - 0s 836us/step - loss: 0.5934 - root_mean_squared_error: 0.7703 - val_loss: 0.5803 - val_root_mean_squared_error: 0.7618\n",
"Epoch 4/20\n",
"363/363 [==============================] - 0s 832us/step - loss: 0.5557 - root_mean_squared_error: 0.7455 - val_loss: 0.5166 - val_root_mean_squared_error: 0.7188\n",
"Epoch 5/20\n",
"363/363 [==============================] - 0s 985us/step - loss: 0.5272 - root_mean_squared_error: 0.7261 - val_loss: 0.4895 - val_root_mean_squared_error: 0.6997\n",
"Epoch 6/20\n",
"363/363 [==============================] - 0s 887us/step - loss: 0.5033 - root_mean_squared_error: 0.7094 - val_loss: 0.4951 - val_root_mean_squared_error: 0.7036\n",
"Epoch 7/20\n",
"363/363 [==============================] - 0s 894us/step - loss: 0.4854 - root_mean_squared_error: 0.6967 - val_loss: 0.4862 - val_root_mean_squared_error: 0.6973\n",
"Epoch 8/20\n",
"363/363 [==============================] - 0s 868us/step - loss: 0.4709 - root_mean_squared_error: 0.6862 - val_loss: 0.4554 - val_root_mean_squared_error: 0.6748\n",
"Epoch 9/20\n",
"363/363 [==============================] - 0s 780us/step - loss: 0.4578 - root_mean_squared_error: 0.6766 - val_loss: 0.4413 - val_root_mean_squared_error: 0.6643\n",
"Epoch 10/20\n",
"363/363 [==============================] - 0s 819us/step - loss: 0.4474 - root_mean_squared_error: 0.6689 - val_loss: 0.4379 - val_root_mean_squared_error: 0.6617\n",
"Epoch 11/20\n",
"363/363 [==============================] - 0s 795us/step - loss: 0.4393 - root_mean_squared_error: 0.6628 - val_loss: 0.4396 - val_root_mean_squared_error: 0.6630\n",
"Epoch 12/20\n",
"363/363 [==============================] - 0s 852us/step - loss: 0.4318 - root_mean_squared_error: 0.6571 - val_loss: 0.4505 - val_root_mean_squared_error: 0.6712\n",
"Epoch 13/20\n",
"363/363 [==============================] - 0s 910us/step - loss: 0.4260 - root_mean_squared_error: 0.6527 - val_loss: 0.3997 - val_root_mean_squared_error: 0.6322\n",
"Epoch 14/20\n",
"363/363 [==============================] - 0s 796us/step - loss: 0.4202 - root_mean_squared_error: 0.6482 - val_loss: 0.3956 - val_root_mean_squared_error: 0.6290\n",
"Epoch 15/20\n",
"363/363 [==============================] - 0s 816us/step - loss: 0.4155 - root_mean_squared_error: 0.6446 - val_loss: 0.3916 - val_root_mean_squared_error: 0.6257\n",
"Epoch 16/20\n",
"363/363 [==============================] - 0s 759us/step - loss: 0.4112 - root_mean_squared_error: 0.6412 - val_loss: 0.3937 - val_root_mean_squared_error: 0.6275\n",
"Epoch 17/20\n",
"363/363 [==============================] - 0s 826us/step - loss: 0.4077 - root_mean_squared_error: 0.6385 - val_loss: 0.3809 - val_root_mean_squared_error: 0.6172\n",
"Epoch 18/20\n",
"363/363 [==============================] - 0s 832us/step - loss: 0.4039 - root_mean_squared_error: 0.6356 - val_loss: 0.3793 - val_root_mean_squared_error: 0.6159\n",
"Epoch 19/20\n",
"363/363 [==============================] - 0s 747us/step - loss: 0.4004 - root_mean_squared_error: 0.6328 - val_loss: 0.3850 - val_root_mean_squared_error: 0.6205\n",
"Epoch 20/20\n",
"363/363 [==============================] - 0s 755us/step - loss: 0.3980 - root_mean_squared_error: 0.6308 - val_loss: 0.3809 - val_root_mean_squared_error: 0.6172\n"
]
}
],
"source": [
"tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir,\n",
" profile_batch=(100, 200))\n",
"history = model.fit(X_train, y_train, epochs=20,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[tensorboard_cb])"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 86,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"my_logs\n",
" run_2022_08_01_17_25_59\n",
" events.out.tfevents.1638910166.my_computer.profile-empty\n",
" plugins\n",
" profile\n",
" 2022_08_01_17_26_00\n",
" my_computer.input_pipeline.pb\n",
" my_computer.kernel_stats.pb\n",
" my_computer.memory_profile.json.gz\n",
" my_computer.overview_page.pb\n",
" my_computer.tensorflow_stats.pb\n",
" my_computer.trace.json.gz\n",
" my_computer.xplane.pb\n",
" train\n",
" events.out.tfevents.1638910166.my_computer.22294.0.v2\n",
" validation\n",
" events.out.tfevents.1638910166.my_computer.22294.1.v2\n"
]
}
],
"source": [
"print(\"my_logs\")\n",
"for path in sorted(Path(\"my_logs\").glob(\"**/*\")):\n",
" print(\" \" * (len(path.parts) - 1) + path.parts[-1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's load the `tensorboard` Jupyter extension and start the TensorBoard server: "
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 87,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe id=\"tensorboard-frame-18d562db9bb9706a\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
" </iframe>\n",
" <script>\n",
" (function() {\n",
" const frame = document.getElementById(\"tensorboard-frame-18d562db9bb9706a\");\n",
" const url = new URL(\"/\", window.location);\n",
" const port = 6006;\n",
" if (port) {\n",
" url.port = port;\n",
" }\n",
" frame.src = url;\n",
" })();\n",
" </script>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%load_ext tensorboard\n",
"%tensorboard --logdir=./my_logs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note**: if you prefer to access TensorBoard in a separate tab, click the \"localhost:6006\" link below:"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"<a href=\"http://localhost:6006/\">http://localhost:6006/</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# extra code\n",
"\n",
"if \"google.colab\" in sys.modules:\n",
" from google.colab import output\n",
"\n",
" output.serve_kernel_port_as_window(6006)\n",
"else:\n",
" from IPython.display import display, HTML\n",
"\n",
" display(HTML('<a href=\"http://localhost:6006/\">http://localhost:6006/</a>'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can use also visualize histograms, images, text, and even listen to audio using TensorBoard:"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 89,
"metadata": {},
"outputs": [],
"source": [
"test_logdir = get_run_logdir()\n",
"writer = tf.summary.create_file_writer(str(test_logdir))\n",
"with writer.as_default():\n",
" for step in range(1, 1000 + 1):\n",
" tf.summary.scalar(\"my_scalar\", np.sin(step / 10), step=step)\n",
" \n",
" data = (np.random.randn(100) + 2) * step / 100 # gets larger\n",
" tf.summary.histogram(\"my_hist\", data, buckets=50, step=step)\n",
" \n",
" images = np.random.rand(2, 32, 32, 3) * step / 1000 # gets brighter\n",
" tf.summary.image(\"my_images\", images, step=step)\n",
" \n",
" texts = [\"The step is \" + str(step), \"Its square is \" + str(step ** 2)]\n",
" tf.summary.text(\"my_text\", texts, step=step)\n",
" \n",
" sine_wave = tf.math.sin(tf.range(12000) / 48000 * 2 * np.pi * step)\n",
" audio = tf.reshape(tf.cast(sine_wave, tf.float32), [1, -1, 1])\n",
" tf.summary.audio(\"my_audio\", audio, sample_rate=48000, step=step)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can share your TensorBoard logs with the world by uploading them to https://tensorboard.dev/. For this, you can run the `tensorboard dev upload` command, with the `--logdir` and `--one_shot` options, and optionally the `--name` and `--description` options. The first time, it will ask you to accept Google's Terms of Service, and to authenticate.\n",
"\n",
"Notes:\n",
"* Authenticating requires user input. Colab supports user input from shell commands, but the main other Jupyter environments do not, so for them we use a hackish workaround (alternatively, you could run the command in a terminal window, after you make sure to activate this project's conda environment and move to this notebook's directory).\n",
"* If you get an authentication related error (such as *invalid_grant: Bad Request*), it's likely that your login session has expired. In this case, try running the command `tensorboard dev auth revoke` to logout, and try again."
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 90,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"New experiment created. View your TensorBoard at: https://tensorboard.dev/experiment/il3YO6KgQHeQuMpX8cxgYw/\n",
"\n",
"\u001b[1m[2022-08-01T17:26:19]\u001b[0m Started scanning logdir.\n",
"\u001b[1m[2022-08-01T17:26:25]\u001b[0m Total uploaded: 1120 scalars, 2000 tensors (1.2 MB), 1 binary objects (29.2 kB)\n",
"\u001b[1m[2022-08-01T17:26:25]\u001b[0m Done scanning logdir.\n",
"\n",
"\n",
"Done. View your TensorBoard at https://tensorboard.dev/experiment/il3YO6KgQHeQuMpX8cxgYw/\n"
]
}
],
"source": [
"# extra code\n",
"\n",
"if \"google.colab\" in sys.modules:\n",
" !tensorboard dev upload --logdir ./my_logs --one_shot \\\n",
" --name \"Quick test\" --description \"This is a test\" \n",
"else:\n",
" from tensorboard.main import run_main\n",
"\n",
" argv = \"tensorboard dev upload --logdir ./my_logs --one_shot\".split()\n",
" argv += [\"--name\", \"Quick test\", \"--description\", \"This is a test\"]\n",
" try:\n",
" original_sys_argv_and_sys_exit = sys.argv, sys.exit\n",
" sys.argv, sys.exit = argv, lambda status: None\n",
" run_main()\n",
" finally:\n",
" sys.argv, sys.exit = original_sys_argv_and_sys_exit"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can get list your published experiments:"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 91,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"https://tensorboard.dev/experiment/il3YO6KgQHeQuMpX8cxgYw/\n",
"\tName Quick test\n",
"\tDescription This is a test\n",
"\tId il3YO6KgQHeQuMpX8cxgYw\n",
"\tCreated 2022-08-01 17:26:19 (20 seconds ago)\n",
"\tUpdated 2022-08-01 17:26:25 (14 seconds ago)\n",
"\tRuns 3\n",
"\tTags 8\n",
"\tScalars 1120\n",
"\tTensor bytes 1421436\n",
"\tBinary object bytes 30096\n",
"Total: 1 experiment(s)\n"
]
}
],
"source": [
"!tensorboard dev list"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To delete an experiment, use the following command:\n",
"\n",
"```python\n",
"!tensorboard dev delete --experiment_id <experiment_id>\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When you stop this Jupyter kernel (a.k.a. Runtime), it will automatically stop the TensorBoard server as well. Another way to stop the TensorBoard server is to kill it, if you are running on Linux or MacOSX. First, you need to find its process ID:"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 92,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Known TensorBoard instances:\n",
" - port 6006: logdir ./my_logs (started 0:00:31 ago; pid 22701)\n"
]
}
],
"source": [
"# extra code lists all running TensorBoard server instances\n",
"\n",
"from tensorboard import notebook\n",
"\n",
"notebook.list()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next you can use the following command on Linux or MacOSX, replacing `<pid>` with the pid listed above:\n",
"\n",
" !kill <pid>\n",
"\n",
"On Windows:\n",
"\n",
" !taskkill /F /PID <pid>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fine-Tuning Neural Network Hyperparameters"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this section we'll use the Fashion MNIST dataset again:"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
"(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist\n",
"X_train, y_train = X_train_full[:-5000], y_train_full[:-5000]\n",
"X_valid, y_valid = X_train_full[-5000:], y_train_full[-5000:]"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 94,
"metadata": {},
"outputs": [],
"source": [
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 95,
"metadata": {},
"outputs": [],
"source": [
"if \"google.colab\" in sys.modules:\n",
" %pip install -q -U keras_tuner"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 96,
"metadata": {},
"outputs": [],
"source": [
"import keras_tuner as kt\n",
"\n",
"def build_model(hp):\n",
" n_hidden = hp.Int(\"n_hidden\", min_value=0, max_value=8, default=2)\n",
" n_neurons = hp.Int(\"n_neurons\", min_value=16, max_value=256)\n",
" learning_rate = hp.Float(\"learning_rate\", min_value=1e-4, max_value=1e-2,\n",
" sampling=\"log\")\n",
" optimizer = hp.Choice(\"optimizer\", values=[\"sgd\", \"adam\"])\n",
" if optimizer == \"sgd\":\n",
" optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)\n",
" else:\n",
" optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n",
"\n",
" model = tf.keras.Sequential()\n",
" model.add(tf.keras.layers.Flatten())\n",
" for _ in range(n_hidden):\n",
" model.add(tf.keras.layers.Dense(n_neurons, activation=\"relu\"))\n",
" model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
" model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
" return model"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 97,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial 5 Complete [00h 00m 24s]\n",
"val_accuracy: 0.8736000061035156\n",
"\n",
"Best val_accuracy So Far: 0.8736000061035156\n",
"Total elapsed time: 00h 01m 43s\n",
"INFO:tensorflow:Oracle triggered exit\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"I1208 09:51:50.359315 4451454400 1158129808.py:4] Oracle triggered exit\n"
]
}
],
"source": [
"random_search_tuner = kt.RandomSearch(\n",
" build_model, objective=\"val_accuracy\", max_trials=5, overwrite=True,\n",
" directory=\"my_fashion_mnist\", project_name=\"my_rnd_search\", seed=42)\n",
"random_search_tuner.search(X_train, y_train, epochs=10,\n",
" validation_data=(X_valid, y_valid))"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
"top3_models = random_search_tuner.get_best_models(num_models=3)\n",
"best_model = top3_models[0]"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 99,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"{'n_hidden': 5,\n",
" 'n_neurons': 70,\n",
" 'learning_rate': 0.00041268008323824807,\n",
" 'optimizer': 'adam'}"
]
},
"execution_count": 99,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"top3_params = random_search_tuner.get_best_hyperparameters(num_trials=3)\n",
"top3_params[0].values # best hyperparameter values"
2020-03-12 00:16:18 +01:00
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 100,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial summary\n",
"Hyperparameters:\n",
"n_hidden: 5\n",
"n_neurons: 70\n",
"learning_rate: 0.00041268008323824807\n",
"optimizer: adam\n",
"Score: 0.8736000061035156\n"
]
}
],
"source": [
"best_trial = random_search_tuner.oracle.get_best_trials(num_trials=1)[0]\n",
"best_trial.summary()"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 101,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.8736000061035156"
]
},
"execution_count": 101,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"best_trial.metrics.get_last_value(\"val_accuracy\")"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 102,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1875/1875 [==============================] - 3s 1ms/step - loss: 0.3274 - accuracy: 0.8799\n",
"Epoch 2/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.3155 - accuracy: 0.8827\n",
"Epoch 3/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.3049 - accuracy: 0.8867\n",
"Epoch 4/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2962 - accuracy: 0.8914\n",
"Epoch 5/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2886 - accuracy: 0.8931\n",
"Epoch 6/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2831 - accuracy: 0.8935\n",
"Epoch 7/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2795 - accuracy: 0.8962\n",
"Epoch 8/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2701 - accuracy: 0.8999: 0s - loss: 0\n",
"Epoch 9/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2661 - accuracy: 0.9009\n",
"Epoch 10/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2628 - accuracy: 0.9012\n",
"313/313 [==============================] - 0s 744us/step - loss: 0.3625 - accuracy: 0.8753\n"
]
}
],
"source": [
"best_model.fit(X_train_full, y_train_full, epochs=10)\n",
"test_loss, test_accuracy = best_model.evaluate(X_test, y_test)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 103,
"metadata": {},
"outputs": [],
"source": [
"class MyClassificationHyperModel(kt.HyperModel):\n",
" def build(self, hp):\n",
" return build_model(hp)\n",
"\n",
" def fit(self, hp, model, X, y, **kwargs):\n",
" if hp.Boolean(\"normalize\"):\n",
" norm_layer = tf.keras.layers.Normalization()\n",
" X = norm_layer(X)\n",
" return model.fit(X, y, **kwargs)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 104,
"metadata": {},
"outputs": [],
"source": [
"hyperband_tuner = kt.Hyperband(\n",
" MyClassificationHyperModel(), objective=\"val_accuracy\", seed=42,\n",
" max_epochs=10, factor=3, hyperband_iterations=2,\n",
" overwrite=True, directory=\"my_fashion_mnist\", project_name=\"hyperband\")"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 105,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial 60 Complete [00h 00m 18s]\n",
"val_accuracy: 0.819599986076355\n",
"\n",
"Best val_accuracy So Far: 0.8704000115394592\n",
"Total elapsed time: 00h 08m 44s\n",
"INFO:tensorflow:Oracle triggered exit\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"I1208 10:00:59.856360 4451454400 3169670597.py:4] Oracle triggered exit\n"
]
}
],
"source": [
"root_logdir = Path(hyperband_tuner.project_dir) / \"tensorboard\"\n",
"tensorboard_cb = tf.keras.callbacks.TensorBoard(root_logdir)\n",
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=2)\n",
"hyperband_tuner.search(X_train, y_train, epochs=10,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[early_stopping_cb, tensorboard_cb])"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 106,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial 10 Complete [00h 00m 13s]\n",
"val_accuracy: 0.7228000164031982\n",
"\n",
"Best val_accuracy So Far: 0.8636000156402588\n",
"Total elapsed time: 00h 02m 10s\n",
"INFO:tensorflow:Oracle triggered exit\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"I1208 10:03:10.004801 4451454400 1918178380.py:5] Oracle triggered exit\n"
]
}
],
"source": [
"bayesian_opt_tuner = kt.BayesianOptimization(\n",
" MyClassificationHyperModel(), objective=\"val_accuracy\", seed=42,\n",
" max_trials=10, alpha=1e-4, beta=2.6,\n",
" overwrite=True, directory=\"my_fashion_mnist\", project_name=\"bayesian_opt\")\n",
"bayesian_opt_tuner.search(X_train, y_train, epochs=10,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[early_stopping_cb])"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe id=\"tensorboard-frame-e1aedbefe0e1f220\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
" </iframe>\n",
" <script>\n",
" (function() {\n",
" const frame = document.getElementById(\"tensorboard-frame-e1aedbefe0e1f220\");\n",
" const url = new URL(\"/\", window.location);\n",
" const port = 6007;\n",
" if (port) {\n",
" url.port = port;\n",
" }\n",
" frame.src = url;\n",
" })();\n",
" </script>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%tensorboard --logdir {root_logdir}"
]
},
{
"cell_type": "markdown",
2020-04-06 09:13:12 +02:00
"metadata": {},
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. to 9."
]
},
{
"cell_type": "markdown",
2020-04-06 09:13:12 +02:00
"metadata": {},
"source": [
"1. Visit the [TensorFlow Playground](https://playground.tensorflow.org/) and play around with it, as described in this exercise.\n",
"2. Here is a neural network based on the original artificial neurons that computes _A_ ⊕ _B_ (where ⊕ represents the exclusive OR), using the fact that _A_ ⊕ _B_ = (_A_ ∧ ¬ _B_) (¬ _A_ ∧ _B_). There are other solutions—for example, using the fact that _A_ ⊕ _B_ = (_A_ _B_) ∧ ¬(_A_ ∧ _B_), or the fact that _A_ ⊕ _B_ = (_A_ _B_) ∧ (¬ _A_ ¬ _B_), and so on.<br /><img width=\"70%\" src=\"images/ann/exercise2.png\" />\n",
"3. A classical Perceptron will converge only if the dataset is linearly separable, and it won't be able to estimate class probabilities. In contrast, a Logistic Regression classifier will generally converge to a reasonably good solution even if the dataset is not linearly separable, and it will output class probabilities. If you change the Perceptron's activation function to the sigmoid activation function (or the softmax activation function if there are multiple neurons), and if you train it using Gradient Descent (or some other optimization algorithm minimizing the cost function, typically cross entropy), then it becomes equivalent to a Logistic Regression classifier.\n",
"4. The sigmoid activation function was a key ingredient in training the first MLPs because its derivative is always nonzero, so Gradient Descent can always roll down the slope. When the activation function is a step function, Gradient Descent cannot move, as there is no slope at all.\n",
"5. Popular activation functions include the step function, the sigmoid function, the hyperbolic tangent (tanh) function, and the Rectified Linear Unit (ReLU) function (see Figure 10-8). See Chapter 11 for other examples, such as ELU and variants of the ReLU function.\n",
"6. Considering the MLP described in the question, composed of one input layer with 10 passthrough neurons, followed by one hidden layer with 50 artificial neurons, and finally one output layer with 3 artificial neurons, where all artificial neurons use the ReLU activation function:\n",
" * The shape of the input matrix **X** is _m_ × 10, where _m_ represents the training batch size.\n",
" * The shape of the hidden layer's weight matrix **W**<sub>_h_</sub> is 10 × 50, and the length of its bias vector **b**<sub>_h_</sub> is 50.\n",
" * The shape of the output layer's weight matrix **W**<sub>_o_</sub> is 50 × 3, and the length of its bias vector **b**<sub>_o_</sub> is 3.\n",
" * The shape of the network's output matrix **Y** is _m_ × 3.\n",
" * **Y** = ReLU(ReLU(**X** **W**<sub>_h_</sub> + **b**<sub>_h_</sub>) **W**<sub>_o_</sub> + **b**<sub>_o_</sub>). Recall that the ReLU function just sets every negative number in the matrix to zero. Also note that when you are adding a bias vector to a matrix, it is added to every single row in the matrix, which is called _broadcasting_.\n",
"7. To classify email into spam or ham, you just need one neuron in the output layer of a neural network—for example, indicating the probability that the email is spam. You would typically use the sigmoid activation function in the output layer when estimating a probability. If instead you want to tackle MNIST, you need 10 neurons in the output layer, and you must replace the sigmoid function with the softmax activation function, which can handle multiple classes, outputting one probability per class. If you want your neural network to predict housing prices like in Chapter 2, then you need one output neuron, using no activation function at all in the output layer. Note: when the values to predict can vary by many orders of magnitude, you may want to predict the logarithm of the target value rather than the target value directly. Simply computing the exponential of the neural network's output will give you the estimated value (since exp(log _v_) = _v_).\n",
"8. Backpropagation is a technique used to train artificial neural networks. It first computes the gradients of the cost function with regard to every model parameter (all the weights and biases), then it performs a Gradient Descent step using these gradients. This backpropagation step is typically performed thousands or millions of times, using many training batches, until the model parameters converge to values that (hopefully) minimize the cost function. To compute the gradients, backpropagation uses reverse-mode autodiff (although it wasn't called that when backpropagation was invented, and it has been reinvented several times). Reverse-mode autodiff performs a forward pass through a computation graph, computing every node's value for the current training batch, and then it performs a reverse pass, computing all the gradients at once (see Appendix B for more details). So what's the difference? Well, backpropagation refers to the whole process of training an artificial neural network using multiple backpropagation steps, each of which computes gradients and uses them to perform a Gradient Descent step. In contrast, reverse-mode autodiff is just a technique to compute gradients efficiently, and it happens to be used by backpropagation.\n",
"9. Here is a list of all the hyperparameters you can tweak in a basic MLP: the number of hidden layers, the number of neurons in each hidden layer, and the activation function used in each hidden layer and in the output layer. In general, the ReLU activation function (or one of its variants; see Chapter 11) is a good default for the hidden layers. For the output layer, in general you will want the sigmoid activation function for binary classification, the softmax activation function for multiclass classification, or no activation function for regression. If the MLP overfits the training data, you can try reducing the number of hidden layers and reducing the number of neurons per hidden layer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 10."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Exercise: Train a deep MLP on the MNIST dataset (you can load it using `tf.keras.datasets.mnist.load_data()`. See if you can get over 98% accuracy by manually tuning the hyperparameters. Try searching for the optimal learning rate by using the approach presented in this chapter (i.e., by growing the learning rate exponentially, plotting the loss, and finding the point where the loss shoots up). Next, try tuning the hyperparameters using Keras Tuner with all the bells and whistles—save checkpoints, use early stopping, and plot learning curves using TensorBoard.*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**TODO**: update this solution to use Keras Tuner."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's load the dataset:"
]
},
{
"cell_type": "code",
"execution_count": 108,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"(X_train_full, y_train_full), (X_test, y_test) = tf.keras.datasets.mnist.load_data()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Just like for the Fashion MNIST dataset, the MNIST training set contains 60,000 grayscale images, each 28x28 pixels:"
]
},
{
"cell_type": "code",
"execution_count": 109,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(60000, 28, 28)"
]
},
"execution_count": 109,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train_full.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each pixel intensity is also represented as a byte (0 to 255):"
]
},
{
"cell_type": "code",
"execution_count": 110,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"dtype('uint8')"
]
},
"execution_count": 110,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train_full.dtype"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's split the full training set into a validation set and a (smaller) training set. We also scale the pixel intensities down to the 0-1 range and convert them to floats, by dividing by 255, just like we did for Fashion MNIST:"
]
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {},
"outputs": [],
"source": [
"X_valid, X_train = X_train_full[:5000] / 255., X_train_full[5000:] / 255.\n",
"y_valid, y_train = y_train_full[:5000], y_train_full[5000:]\n",
"X_test = X_test / 255."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's plot an image using Matplotlib's `imshow()` function, with a `'binary'`\n",
" color map:"
]
},
{
"cell_type": "code",
"execution_count": 112,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAGHElEQVR4nO3cz4tNfQDH8blPU4Zc42dKydrCpJQaopSxIdlYsLSykDBbO1slJWExSjKRP2GytSEWyvjRGKUkGzYUcp/dU2rO9z7umTv3c++8XkufzpkjvTvl25lGq9UaAvL80+sHABYmTgglTgglTgglTgg13Gb3X7nQfY2F/tCbE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0IN9/oBlqPbt29Xbo1Go3jthg0bivvLly+L+/j4eHHft29fcWfpeHNCKHFCKHFCKHFCKHFCKHFCKHFCqJ6dc967d6+4P3v2rLhPTU0t5uMsqS9fvnR87fBw+Z/sx48fxX1kZKS4r1q1qnIbGxsrXvvgwYPivmnTpuLOn7w5IZQ4IZQ4IZQ4IZQ4IZQ4IZQ4IVSj1WqV9uLYzoULFyq3q1evFq/9/ft3nR9NDxw4cKC4T09PF/fNmzcv5uP0kwU/4vXmhFDihFDihFDihFDihFDihFDihFBdPefcunVr5fbhw4fite2+HVy5cmVHz7QY9u7dW9yPHTu2NA/SgZmZmeJ+586dym1+fr7Wz253Dnr//v3KbcC/BXXOCf1EnBBKnBBKnBBKnBBKnBBKnBCqq+ecr1+/rtxevHhRvHZiYqK4N5vNjp6Jsrm5ucrt8OHDxWtnZ2dr/ezLly9XbpOTk7XuHc45J/QTcUIocUIocUIocUIocUKorh6lMFgePnxY3I8fP17r/hs3bqzcPn/+XOve4RylQD8RJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4Qa7vUDkOX69euV25MnT7r6s79//165PX36tHjtrl27Fvtxes6bE0KJE0KJE0KJE0KJE0KJE0KJE0L5vbU98PHjx8rt7t27xWuvXLmy2I/zh9Kz9dKaNWuK+9evX5foSbrC762FfiJOCCVOCCVOCCVOCCVOCCVOCOV7zg7MzMwU93bfHt68ebNye/fuXUfPNOhOnTrV60dYct6cEEqcEEqcEEqcEEqcEEqcEGpZHqW8efOmuJ8+fbq4P3r0aDEf569s27atuK9bt67W/S9dulS5jYyMFK89c+ZMcX/16lVHzzQ0NDS0ZcuWjq/tV96cEEqcEEqcEEqcEEqcEEqcEEqcEGpgzzlLv0Ly2rVrxWvn5uaK++rVq4v76OhocT9//nzl1u48b8+ePcW93TloN7X7e7fTbDYrtyNHjtS6dz/y5oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQA3vO+fjx48qt3Tnm0aNHi/vk5GRx379/f3HvV8+fPy/u79+/r3X/FStWVG7bt2+vde9+5M0JocQJocQJocQJocQJocQJocQJoQb2nPPGjRuV29jYWPHaixcvLvbjDIS3b98W90+fPtW6/8GDB2tdP2i8OSGUOCGUOCGUOCGUOCGUOCHUwB6lrF+/vnJzVNKZ0md4/8fatWuL+9mzZ2vdf9B4c0IocUIocUIocUIocUIocUIocUKogT3npDM7duyo3GZnZ2vd+9ChQ8V9fHy81v0HjTcnhBInhBInhBInhBInhBInhBInhHLOyR/m5+crt1+/fhWvHR0dLe7nzp3r4ImWL29OCCVOCCVOCCVOCCVOCCVOCCVOCOWcc5mZnp4u7t++favcms1m8dpbt24Vd99r/h1vTgglTgglTgglTgglTgglTgglTgjVaLVapb04kufnz5/Ffffu3cW99LtpT5w4Ubx2amqquFOpsdAfenNCKHFCKHFCKHFCKHFCKHFCKJ+MDZhGY8H/lf/PyZMni/vOnTsrt4mJiU4eiQ55c0IocUIocUIocUIocUIocUIocUIon4xB7/lkDPqJOCGUOCGUOCGUOCGUOCGUOCFUu+85yx8HAl3jzQmhxAmhxAmhxAmhxAmhxAmh/gWlotX4VjU5XgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(X_train[0], cmap=\"binary\")\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The labels are the class IDs (represented as uint8), from 0 to 9. Conveniently, the class IDs correspond to the digits represented in the images, so we don't need a `class_names` array:"
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([7, 3, 4, ..., 5, 6, 8], dtype=uint8)"
]
},
"execution_count": 113,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The validation set contains 5,000 images, and the test set contains 10,000 images:"
]
},
{
"cell_type": "code",
"execution_count": 114,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(5000, 28, 28)"
]
},
"execution_count": 114,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_valid.shape"
]
},
{
"cell_type": "code",
"execution_count": 115,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(10000, 28, 28)"
]
},
"execution_count": 115,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_test.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's take a look at a sample of the images in the dataset:"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqIAAAEkCAYAAAD0AFOrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABl3klEQVR4nO3debxN1fvA8c+iIkNJplSiEoki9S19FQ1Ko6SkwdAgRfMslG+ERmmSSEShUUWpNBkqoZCvMvySkG+mZAhF+/fHvs/a59xz7nzOWfuc87xfLy/3nnvuueuuu/c+az/rWc8ynuehlFJKKaVUqpVy3QCllFJKKZWddCCqlFJKKaWc0IGoUkoppZRyQgeiSimllFLKCR2IKqWUUkopJ3QgqpRSSimlnNCBqFJKKaWUciLlA1FjzM/GGC/Ov8mpbksYGGN6GGMWGGM25/z7yhhzrut2hYEx5r6cY+MZ121xyRhzgDFmtDFmnTFmhzFmkTGmhet2uWCMOcUY864xZnXOsdHFdZtcMsaUNsb0M8Yszzk2lhtj+htj9nDdNlf0GIlljOkecYzMNcac7LpNrhhjKhpjnjTGrDDGbDfGfGmMOd51u1wIy/XDRUT0eOCAiH/HAh7wmoO2hMEq4B78fjgO+BSYaIw52mmrHDPGnAh0BRa4botLxphKwEzAAOcCRwI3AWsdNsulCsBC4BZgu+O2hME9QA/gZqA+fr/0AHq6bJRjeoxEMMZcCgwBBgBNgC+BD4wxtZw2zJ0RwFlAZ6AR8BEw1RhzoNNWuRGK64dxvbOSMaYXcBdQ0/O8P502JiSMMRuBnp7nDXPdFheMMfsC3+IPRO8HFnqed6PbVrlhjBkAtPA879+u2xI2xpitwI2e541y3RZXjDGTgA2e53WOeGw0sL/neee5a1k46DECxphZwALP87pGPLYUeMPzvKy6YTHG7A1sAdp5nvdOxONzgQ88z+vtrHEOhOX64TRH1BhjgGuAsToItWHyDvh39F+6bo9DL+BfJD913ZAQuBCYZYyZYIxZa4yZZ4y5MefcUWoGcKoxpj6AMaYBcBrwvtNWqVAwxuwFNMWP+kX6CDgp9S1ybg+gNLAj1+Pbgeapb45zobh+uM4jagXUwQ+VZy1jTCPgK6AssBVo63ne925b5YYxpitwONDRdVtC4lCgOzAYGAQ0Bp7O+VpW584qAB4GKgKLjDG78a/pD3me95zbZqmQqII/8Pot1+O/AWekvjlueZ63xRjzFdDbGLMQ+B9wGdAMWOa0cW6E4vrheiDaFZjted48x+1wbTH+AKMS0A4YbYxp6XneQpeNSjVjTD38PKaTPc/7y3V7QqIUMCdiCu07Y0xd/DweHYiqS4FOwOXAf/GvI0OMMcs9z3vRZcNUqOTOwTNxHssWHYGR+OszduOngY3DX6eRbUJx/XA2EDXGVAPa4L+hZrWcQZfcjc3JWcF3G37aQjZphn8HvzBi5rk0cIox5nqgvOd5O101zpE1wKJcj/2An1Su1KPAY57njc/5/HtjzCH4iw10IKrW4w+2auR6vBqxUdKs4Hne/wEtjDHlgX08z1tjjJkALHfcNBdCcf1wmSPaBdgJjC/gedmoFFDGdSMcmIi/irFxxL85+MdIYyAbo6QzgXq5HjsCWOGgLSp8yuEPNCLtRmtEK2yQYy5+GlykVmT3OgQ8z9uWMwjdD38V/TsFfU8GCsX1w0lENGehxbXAeM/ztrhoQ1gYYwYBk4GV+LkalwMt8Uv1ZBXP8zYBmyIfM8ZsAzZmW5pChMHAlznVJSbgl1+5GbjPaascMcZUwM8hBv9iWcsY0xj/GPnFWcPceQ+41xizHH9qrQlwO/Cy01Y5pMdIjCeAMcaYb/BvbK8HagLPO22VI8aYs/CPix/xj5NH8dPjXnLZLkdCcf1wUr7JGHMqfr3MEzzP+yblDQgRY8wo4FT8qZM/8OtmPup53ocu2xUWxpjPyeLyTQA5GxwMwI+M/oKfG/q057r2mgPGmJbAZ3G+NNrzvC4pbUwIGGMqAv2AtvjTrWvwZxAe9Dwv98rgrKDHSCxjTHfgbvza3QuB2zzPm+a2VW4YY9oDA4GDgI3Am0Avz/P+cNowB8Jy/XBeR1QppZRSSmUnzSNSSimllFJO6EBUKaWUUko5oQNRpZRSSinlhA5ElVJKKaWUEzoQVUoppZRSThRURzTdl9Sbgp9SJNof0bQ/oml/xNI+iab9EU37I5r2RzTtj2gZ2R8aEVVKKaWUUk7oQFQppZRSSjmhA1GllFJKKeWEk73mlVLF988//3DHHXcA8MwzzwDw1VdfAXDcccc5a5dSSilVVBoRVUoppZRSTmhEVKk0sXbtWgD69OnDCy+8EPW15cuXA9kXEe3atSsAY8eOZebMmQAce+yxLpukQujBBx9k/PjxAEyaNAmAQw891GWTUmrRokUAPPnkkwAMHz6cbt26AfD888+7apYKgbVr1zJ//nwA3nnnHQCmTZvGwoULAbjqqqsAOOywwwC44447KFOmTNRrbNy4kcqVKxe7DRoRVUoppZRSTmhENARWrFgB+HepAA899BDG+OW2PM8vG3bkkUcC0L9/fy666CIHrVSurFmzBoBHHnkEICoaevLJJwNwwgknpL5hIXDIIYcAsGPHDpYuXQpoRBRgxowZDBs2DPCjxbnJcSPXkk6dOpUoohFWGzZsAPxr66pVqwD49ttvgeyJiI4ePZo+ffoA2D4wxvD+++/Hff7YsWNp06YNABUrVkxNI1XKjRgxAoABAwbYMYjwPM+OQUaNGhX1tb333pvbbrst6rHLLruMDz/8sNht0YGoI+vWrQNg4MCBvPLKKwCsX78e8C8SchCIxYsXA35Y/JRTTgGgSpUqqWpu0vz1118AnH766YD/BioqVaoEwIIFCzj44INT3rYw2LVrFw899BAAzz77rH28R48eADzxxBMA7LXXXqlvXAjIQBT8N1yASy+91FVznNm1axcAffv2Bfxj5Y8//gCIuZYATJ8+HQjOt3nz5sW84WQCOSZkAJYN/v77bwA7MLjuuuvsY/kZOnQoADfffDN16tQBoF+/fkBmnVP/93//Z1MUJJ3nhx9+sCkKnTt3dtW0lJBB54ABA6I+B3+QCVChQgV73ZBxyT///APAnXfeyb777gvA1VdfDcCvv/5aojbp1LxSSimllHIiZRHRl156CfDvzvfff3/AvwsBaNasmZ0qynT9+/cHsFMlxhg7/S53ILVq1aJq1apR3yd3JT///LONiEoCejqSSOg111wDREdCL7zwQgDuvfdeAGrWrJnva/32228AVK9ePdHNdK5nz55RkVCAbt262bJNKpCtUWGAXr16AfDoo48C0VNruZ1yyil88cUXUY999NFHbNmyBcis6djPP//cdRNSTmZJevbsmedz6tevzy233BL1mLzH7N69m2XLlgFw/fXX26+na1RUosETJkwA/IinXCvkvJkzZ07WRETlGiGR0L322otLLrkEwE65N2nSxD7/tddeA2DQoEEAzJ8/nx07dkS9ZkHv0QXRiKhSSimllHKiyBHRV199FYDvvvsOgJEjRxbq+zZt2hT80D38HytRsbJly1KuXDkAjj76aCAYheeODKY7KY8g0YrIqEWDBg0A/y4+d/6n5HS1aNHC5oums8cffxyIXUjRo0cPHnvsMcA/Lgpyxx132Gj7/fffD8Ctt96awJa68cADDwDYvgC48cYbgSDioeDtt9+2H1922WUOW5J6khfaq1evmGOifPny3H777QC0bdsW8GdaAPbZZx+b2yX56VWqVLHX5UwgMyySA5gNJPInpXjikVz7F154gebNmxf4mpJn3K1bN+bMmQMEEbWwk/GFzD7KYs+jjjqKwYMHA9CqVSvAzyFeuXIlELzXSr5kppXEGzduXNTnzZs35+WXX87z+e3btwegWrVqQLCeI5IsbiuuQl955KI2ZMgQIEhcLQ45QMSOHTtsqFemUmQaYNy4cRkx5SppCD/++CMQvClUrVrVDjrlzaR3797cd999Uc+T1AWZxodg9fR1112X7OYn1MKFC20SvJDpwCeffLJ
"text/plain": [
"<Figure size 864x345.6 with 40 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"n_rows = 4\n",
"n_cols = 10\n",
"plt.figure(figsize=(n_cols * 1.2, n_rows * 1.2))\n",
"for row in range(n_rows):\n",
" for col in range(n_cols):\n",
" index = n_cols * row + col\n",
" plt.subplot(n_rows, n_cols, index + 1)\n",
" plt.imshow(X_train[index], cmap=\"binary\", interpolation=\"nearest\")\n",
" plt.axis('off')\n",
" plt.title(y_train[index])\n",
"plt.subplots_adjust(wspace=0.2, hspace=0.5)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's build a simple dense network and find the optimal learning rate. We will need a callback to grow the learning rate at each iteration. It will also record the learning rate and the loss at each iteration:"
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"K = tf.keras.backend\n",
"\n",
2021-10-17 04:04:08 +02:00
"class ExponentialLearningRate(tf.keras.callbacks.Callback):\n",
" def __init__(self, factor):\n",
" self.factor = factor\n",
" self.rates = []\n",
" self.losses = []\n",
" def on_batch_end(self, batch, logs):\n",
" self.rates.append(K.get_value(self.model.optimizer.learning_rate))\n",
" self.losses.append(logs[\"loss\"])\n",
" K.set_value(self.model.optimizer.learning_rate, self.model.optimizer.learning_rate * self.factor)"
]
},
{
"cell_type": "code",
"execution_count": 118,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"tf.keras.backend.clear_session()\n",
"np.random.seed(42)\n",
"tf.random.set_seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(300, activation=\"relu\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will start with a small learning rate of 1e-3, and grow it by 0.5% at each iteration:"
]
},
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"expon_lr = ExponentialLearningRate(factor=1.005)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's train the model for just 1 epoch:"
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1719/1719 [==============================] - 3s 2ms/step - loss: nan - accuracy: 0.5843 - val_loss: nan - val_accuracy: 0.0958\n"
]
}
],
"source": [
"history = model.fit(X_train, y_train, epochs=1,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[expon_lr])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now plot the loss as a functionof the learning rate:"
]
},
{
"cell_type": "code",
"execution_count": 122,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'Loss')"
]
},
"execution_count": 122,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAEOCAYAAACNY7BQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAl1ElEQVR4nO3dd3hUZd7/8fd3UkgCKUAgJKH3XkMvxg4oggXL2teVde3rurqu+zzr+lu3Wh67YtdVURQFFcHVNdKkVxFxg7RQRJAWirT798cMJGcMkGByJpl8Xtc1FzPnnDn5zk3gM/e5z7mPOecQERE5LBDpAkREpHJRMIiIiIeCQUREPBQMIiLioWAQEREPBYOIiHjERrqAnyomKdV1bNOCuBhlXEXatWsXNWvWjHQZUU/t7I/wds7fVEhMwGiWXn3aft68eZudc/VKWlflgyE2tT5vfJhH2wYpkS4lquXl5ZGbmxvpMqKe2tkf4e084vHppCTG8fLPe0WuKJ+Z2eqjrfPta7aZNTKzT81smZktNbNbStgm18y2m9nC0ON/S7PvPfsOln/BIlJt6DJfLz97DAeA3zjn5ptZMjDPzP7tnPsybLupzrmzy7LjPfsVDCLy01ikC6hEfOsxOOc2OOfmh57vBJYB2eWx770KBhH5KTQ1kEdExhjMrCnQDZhVwuq+ZrYIWA/c7pxbWsL7RwGjAOIbtOT2MfN4MDepAiuWwsJC8vLyIl1G1FM7+yO8nXfs3IPba2r7EN+DwcxqAW8DtzrndoStng80cc4VmtlQ4F2gVfg+nHOjgdEANTJbue/3Oq6atIv5/3M6dWrGV+wHqKY0KOoPtbM/wtv5oS+mUbtmPLm51Wfw+Vh8PcfTzOIIhsKrzrlx4eudczucc4Wh5xOBODNLP9Y+M1ISjjzv/v/+Xb4Fi0i1oANJXn6elWTAc8Ay59yDR9mmQWg7zKxXqL4tx9pvvVo16N2szpHX0/M3s3XXPqZ8/R2HDumvW0RKR4PPRfw8lNQfuBxYYmYLQ8t+DzQGcM49BVwA/MrMDgB7gIvdcW4YYQZv/LIv67bt4ZT787j02aJhi4Gt0nnsku6kJsWV/6cRkaihsWcv34LBOTeN44Syc+4x4LET2X92WiIf33YSuffncfCQo3PDVKb+dzNd7v2Iv5/fiSGdMkmuEUuoQyIi4qH/G4pU+Sufi2tUJ4n8+4ZwyEFMwJixYjM3vDqfO99ewp1vL6FFvZrcclprzu6USSCgXwIRkZJEVTBAMPVjQv/n92uRzpy7T2P2qu/5+MtNTFi0jptfX8Bz01ZyRZ8m9GpWh0Z1dJqrSHXnNPzsEXXBEC42JkC/Fun0a5HO3We1Y/zCdTzw0df8ZuwiAPo2r8uNp7SkX4u66kqKVGP6118k6oOhuJiAcV73hozoms1XG3fy6fJNvPL5ai59dhYt69diQMt0ft6/GY3rqhchUp1o8NmrWgXDYYGA0T4rhfZZKVwzoBlvzl3LxCUbeG3WGv41czWDWtfj/O4NOa19fWrExkS6XBHxgQ4YFKmWwVBcQlwMV/RtyhV9m/Ltjr08P20lExat54bX5pOWFMfwLlmMzGlEx+zUSJcqIuKLah8MxWWkJHDX0HbcMbgt0/M38+bctbw+Zy0vfb6adpkpDOuSycCW9WiflUKMzmoSiRo6lOSlYChBTMAY1Loeg1rXY/vu/YxftI635xXwj0nL+QfLqZ0Ux1mdMzm3WzbdG9fWoLVIVNC/48MUDMeRmhR35FDTpp17+XzFFj5etom35hXwr5lryEipwfndGzK4YwM6ZqXq+giRKkgdBi8FQxnUT05geNdshnfNZufe/Uz6YiMTl2zg6Snf8ETeCtKS4ujfMp0hHRtwUut6JCdoKg6RqkId/yIKhhOUnBDHyJxGjMxpxPehSfum5W8mb/kmPli8gfiYALlt6jEypxG5beoRF+PrRLYiIidMwVAO6tSMZ0S3bEZ0y+bgIce81VuZ9MVGJixax0dffkt6rRqc2y2LEd2yaZ+ZojEJkUrmOHN1VjsKhnIWEzB6NatDr2Z1uGtoW/KWf8fYuWt5Yfoqnpm6klb1azEypyEjumVTPznh+DsUEV/o61oRBUMFiosJcHr7DE5vn8GWwh+YtHQjb80r4C8Tv+Lvk5bTv2U653fP5oz2DUiM14V0IlI5KBh8UrdWDS7t3YRLezchf9NOxs1fx/iF67llzEJq1YhlSMcGnNe9Ib2b1dGZTSIRoCO8RRQMEdCyfjJ3DG7L7We0Yfaq7xk3v4CJSzYydl4B2WmJjOiWxfndG9K8Xq1Ilyoi1ZCCIYICAaNP87r0aV6XP53TkY++3Mi4+et4Mm8Fj3+6gkGt63F1/6ac1KqeehEiFUhjz14KhkoiMT7myDUSm3bsZcyctbwyczVXvzCH5uk1ubJfU87v0ZBaNfRXJlIRTMPPR+jk+kqofkoCN5/aiul3nsLDF3clOTGOP05YSt+/fMK9733J6i27Il2iSFTRjXq89PWzEouPDRzpRSxYs5UXpq/i5c9X8cKMlZzatj6/PKkFPZvWiXSZIlFBg89FFAxVRLfGtenWuDZ3n9WOV2eu5tVZaxj51Of0b1mXW09rrYAQkXKjQ0lVTEZKAred0YZpd57CH85qx/KNOxn51Odc+uxM5qz6PtLliVRJGnz2UjBUUYnxMfxiYHOm3uENiMuencXCtdsiXZ5IlaNDSUUUDFVceEB8uWEHIx6fzi9fmcvX3+6MdHkiVYI6DF4KhihxOCCm3HEyvz6tNdPztzD4/6Zw17jFfLfzh0iXJ1Lp6XTVIgqGKFOrRiy3nNaKqXeczJX9mjJ2bgEn35/Hk3kr2Lv/YKTLE5EqQMEQpWrXjOePwzow+deD6NO8Dn+f9BWnP/QZE5ds0BTDImH0b8JLwRDlWtSrxbNX9uRf1/QmKS6W61+dz0VPz2RJwfZIlyZSuehI0hEKhmpiQKt0Prh5APed25EV3xVyzuPTuH3sIr7ftS/SpYlEnPoLXgqGaiQ2JsClvZvw6W9zGTWwOeMXruOMh6bwybJvI12aSMSpw1BEwVANpSTEcdfQdoy/YQDpteK55qW5/HbsIrbtVu9BRBQM1Vr7rBTG39if63NbMG7BOk578DMmLFqvgTipfvQr76FgqOZqxMZwx+C2vHfjALLTErn59QVc+/I8NmzfE+nSRHxluvT5CN+CwcwamdmnZrbMzJaa2S0lbGNm9oiZ5ZvZYjPr7ld91V37rBTGXd+fu4e2Y1r+d5zx4BTGL1wX6bJEfKEOg5efPYYDwG+cc+2APsANZtY+bJshQKvQYxTwpI/1VXsxAePaQc2ZfOsg2jRI5pYxC/n1GwvZuXd/pEsTqXDqLxTxLRiccxucc/NDz3cCy4DssM2GAy+7oJlAmpll+lWjBDWpW5Mxo/pw62mtGL9wHUMfmUr+Vl01LVJdROR+DGbWFOgGzApblQ2sLfa6ILRsQ9j7RxHsUZCRkUFeXl5FlVqtdY2Fu3ol8PTivfxl1iG+3voRg5vFEdCx2ApTWFio32cfhLfz7t272bRpr9o+xPdgMLNawNvArc65HeGrS3jLjw7/OedGA6MBcnJyXG5ubnmXKSG5wEVD9vPzJz/hza/3syVQmwcv7EpqUlykS4tKeXl56Pe54oW3c+KcT8nISCM3t1vkiqpEfD0rycziCIbCq865cSVsUgA0Kva6IbDej9rk6FIS4rihaw3uGdaeKf/9jmGPTePL9eGZLlJ1afDZy8+zkgx4DljmnHvwKJtNAK4InZ3UB9junNtwlG3FR2bGVf2bMWZUX344cJDznpzOe4uU2RI9dIC0iJ89hv7A5cApZrYw9BhqZteZ2XWhbSYC3wD5wDPA9T7WJ6XQo0lt3r9pIJ2yU7np9QX87cOvOHhI37dEoolvYwzOuWkcJ5Rd8JLbG/ypSE5UveQavPqLPvzpvaU89dkKvtq4g4cv7kZqosYdpGrSxf5euvJZTkh8bID7zu3Efed2ZHr+ZkY8Pp38TbqVqFRduvK5iIJBfpJLezfhtWv7sHPvfkY8PoOPv9RMrVL1OA0/eygY5Cf
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(expon_lr.rates, expon_lr.losses)\n",
"plt.gca().set_xscale('log')\n",
"plt.hlines(min(expon_lr.losses), min(expon_lr.rates), max(expon_lr.rates))\n",
"plt.axis([min(expon_lr.rates), max(expon_lr.rates), 0, expon_lr.losses[0]])\n",
"plt.grid()\n",
"plt.xlabel(\"Learning rate\")\n",
"plt.ylabel(\"Loss\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The loss starts shooting back up violently when the learning rate goes over 6e-1, so let's try using half of that, at 3e-1:"
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"tf.keras.backend.clear_session()\n",
"np.random.seed(42)\n",
"tf.random.set_seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 124,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(300, activation=\"relu\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 125,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(learning_rate=3e-1)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "code",
"execution_count": 126,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"PosixPath('my_mnist_logs/run_001')"
]
},
"execution_count": 126,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"run_index = 1 # increment this at every run\n",
"run_logdir = Path() / \"my_mnist_logs\" / \"run_{:03d}\".format(run_index)\n",
"run_logdir"
]
},
{
"cell_type": "code",
"execution_count": 127,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2363 - accuracy: 0.9264 - val_loss: 0.0972 - val_accuracy: 0.9720\n",
"Epoch 2/100\n",
"1719/1719 [==============================] - 2s 997us/step - loss: 0.0948 - accuracy: 0.9702 - val_loss: 0.1035 - val_accuracy: 0.9706\n",
"Epoch 3/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0667 - accuracy: 0.9792 - val_loss: 0.0783 - val_accuracy: 0.9770\n",
"Epoch 4/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0463 - accuracy: 0.9848 - val_loss: 0.0827 - val_accuracy: 0.9766\n",
"Epoch 5/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0359 - accuracy: 0.9881 - val_loss: 0.0698 - val_accuracy: 0.9826\n",
"Epoch 6/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0297 - accuracy: 0.9908 - val_loss: 0.1048 - val_accuracy: 0.9758\n",
"Epoch 7/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0245 - accuracy: 0.9917 - val_loss: 0.0932 - val_accuracy: 0.9794\n",
"Epoch 8/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0239 - accuracy: 0.9922 - val_loss: 0.0816 - val_accuracy: 0.9798\n",
"Epoch 9/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0154 - accuracy: 0.9952 - val_loss: 0.0775 - val_accuracy: 0.9838\n",
"Epoch 10/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0126 - accuracy: 0.9960 - val_loss: 0.0805 - val_accuracy: 0.9812\n",
"Epoch 11/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0111 - accuracy: 0.9964 - val_loss: 0.0962 - val_accuracy: 0.9804\n",
"Epoch 12/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0118 - accuracy: 0.9963 - val_loss: 0.1044 - val_accuracy: 0.9774\n",
"Epoch 13/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0114 - accuracy: 0.9961 - val_loss: 0.1055 - val_accuracy: 0.9802\n",
"Epoch 14/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0150 - accuracy: 0.9948 - val_loss: 0.0993 - val_accuracy: 0.9826\n",
"Epoch 15/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0054 - accuracy: 0.9981 - val_loss: 0.0955 - val_accuracy: 0.9822\n",
"Epoch 16/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0046 - accuracy: 0.9984 - val_loss: 0.0982 - val_accuracy: 0.9822\n",
"Epoch 17/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0055 - accuracy: 0.9983 - val_loss: 0.0908 - val_accuracy: 0.9844\n",
"Epoch 18/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0070 - accuracy: 0.9978 - val_loss: 0.0883 - val_accuracy: 0.9840\n",
"Epoch 19/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0025 - accuracy: 0.9992 - val_loss: 0.0978 - val_accuracy: 0.9838\n",
"Epoch 20/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0058 - accuracy: 0.9983 - val_loss: 0.1011 - val_accuracy: 0.9830\n",
"Epoch 21/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0039 - accuracy: 0.9989 - val_loss: 0.0991 - val_accuracy: 0.9840\n",
"Epoch 22/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 9.2480e-04 - accuracy: 0.9998 - val_loss: 0.0963 - val_accuracy: 0.9840\n",
"Epoch 23/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 1.2642e-04 - accuracy: 1.0000 - val_loss: 0.0970 - val_accuracy: 0.9846\n",
"Epoch 24/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 6.9068e-05 - accuracy: 1.0000 - val_loss: 0.0970 - val_accuracy: 0.9854\n",
"Epoch 25/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 5.1481e-05 - accuracy: 1.0000 - val_loss: 0.0977 - val_accuracy: 0.9850\n"
]
}
],
"source": [
2021-10-17 04:04:08 +02:00
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=20)\n",
2022-09-12 01:46:42 +02:00
"checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\"my_mnist_model\", save_best_only=True)\n",
2021-10-17 04:04:08 +02:00
"tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)\n",
"\n",
"history = model.fit(X_train, y_train, epochs=100,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[checkpoint_cb, early_stopping_cb, tensorboard_cb])"
]
},
{
"cell_type": "code",
"execution_count": 128,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"313/313 [==============================] - 0s 908us/step - loss: 0.0708 - accuracy: 0.9799\n"
]
},
{
"data": {
"text/plain": [
"[0.07079131156206131, 0.9799000024795532]"
]
},
"execution_count": 128,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2022-09-12 01:46:42 +02:00
"model = tf.keras.models.load_model(\"my_mnist_model\") # rollback to best model\n",
"model.evaluate(X_test, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We got over 98% accuracy. Finally, let's look at the learning curves using TensorBoard:"
]
},
{
"cell_type": "code",
"execution_count": 129,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe id=\"tensorboard-frame-9f95f24bb0151492\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
" </iframe>\n",
" <script>\n",
" (function() {\n",
" const frame = document.getElementById(\"tensorboard-frame-9f95f24bb0151492\");\n",
" const url = new URL(\"/\", window.location);\n",
" const port = 6008;\n",
" if (port) {\n",
" url.port = port;\n",
" }\n",
" frame.src = url;\n",
" })();\n",
" </script>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%tensorboard --logdir=./my_mnist_logs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"nav_menu": {
"height": "264px",
"width": "369px"
},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
2020-04-06 09:13:12 +02:00
"nbformat_minor": 4
}