handson-ml/10_neural_nets_with_keras.i...

3718 lines
460 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Chapter 10 Introduction to Artificial Neural Networks with Keras**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_This notebook contains all the sample code and solutions to the exercises in chapter 10._"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/10_neural_nets_with_keras.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/10_neural_nets_with_keras.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 11:03:20 +01:00
"This project requires Python 3.7 or above:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
2022-02-19 11:03:20 +01:00
"assert sys.version_info >= (3, 7)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It also requires Scikit-Learn ≥ 1.0.1:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from packaging import version\n",
"import sklearn\n",
"\n",
"assert version.parse(sklearn.__version__) >= version.parse(\"1.0.1\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-28 23:41:27 +01:00
"And TensorFlow ≥ 2.8:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"assert version.parse(tf.__version__) >= version.parse(\"2.8.0\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As we did in previous chapters, let's define the default font sizes to make the figures prettier:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"plt.rc('font', size=14)\n",
"plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n",
"plt.rc('xtick', labelsize=10)\n",
"plt.rc('ytick', labelsize=10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And let's create the `images/ann` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"IMAGES_PATH = Path() / \"images\" / \"ann\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# From Biological to Artificial Neurons\n",
"## The Perceptron"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.linear_model import Perceptron\n",
"\n",
"iris = load_iris(as_frame=True)\n",
"X = iris.data[[\"petal length (cm)\", \"petal width (cm)\"]].values\n",
"y = (iris.target == 0) # Iris setosa\n",
"\n",
"per_clf = Perceptron(random_state=42)\n",
"per_clf.fit(X, y)\n",
"\n",
"X_new = [[2, 0.5], [3, 1]]\n",
"y_pred = per_clf.predict(X_new) # predicts True and False for these 2 flowers"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([ True, False])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `Perceptron` is equivalent to a `SGDClassifier` with `loss=\"perceptron\"`, no regularization, and a constant learning rate equal to 1:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# extra code shows how to build and train a Perceptron\n",
"\n",
"from sklearn.linear_model import SGDClassifier\n",
"\n",
"sgd_clf = SGDClassifier(loss=\"perceptron\", penalty=None,\n",
" learning_rate=\"constant\", eta0=1, random_state=42)\n",
"sgd_clf.fit(X, y)\n",
"assert (sgd_clf.coef_ == per_clf.coef_).all()\n",
"assert (sgd_clf.intercept_ == per_clf.intercept_).all()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When the Perceptron finds a decision boundary that properly separates the classes, it stops learning. This means that the decision boundary is often quite close to one class:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAccAAADYCAYAAACeCyhIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA+RElEQVR4nO3deZzNZfvA8c81YxjbkJjJMnaSNZHlIfRkSxSP9NMilSUVLUpJpcxgpKwT2bVrkxayPCWEMDMMKk9RZM2Ssi8zXL8/zpnVLMecbWZc79frvMy5v9t1TnLN/f3e132LqmKMMcaYFAH+DsAYY4zJbSw5GmOMMelYcjTGGGPSseRojDHGpGPJ0RhjjEnHkqMxxhiTjs+So4iEi8h3IrJNRH4SkScy2EdEZLKI7BCRLSJyQ6ptHUXkF+e2ob6K2xhjzJXHlz3HROBpVb0OaAY8JiK10+1zK1DD+eoPvAkgIoHAFOf22sDdGRxrjDHGeITPkqOqHlDVjc6fTwDbgPLpdrsDeEcd1gElRaQs0ATYoaq/q+p54EPnvsYYY4zH+eWZo4hUBhoC69NtKg/sSfV+r7Mts3ZjjDHG4wr4+oIiUgyYDzypqsfTb87gEM2iPaPz98dxS5ZChYo2Cgur5Ua0xhhjcrvduzPfVrFi+u27UD2SUU5Jw6fJUUSCcCTG91X1swx22QuEp3pfAdgPFMyk/RKqOgOYAVCpUmMdNizWA5EbY4zJrQYMyHzbsGHptzd26Zy+HK0qwGxgm6qOz2S3L4H7naNWmwHHVPUAEAPUEJEqIlIQ6Onc1xhjjPE4X/YcWwC9gK0iEu9sGwZUBFDVacDXQCdgB3AaeNC5LVFEBgJLgUBgjqr+5MPYjTHG5FIhIXA8/UM6Z3tW27Mi+XnJKrutaowxJrUBAyROVbO9t+rzATnGGGOMLz37bOqeY6NGrhxj08cZY4zJ1y73lipYcjTGGGMuYcnRGGOMSceSozHGGJOOJUdjjDEmHUuOxhhj8rWkesfLYaUcxhhj8rWxY1N+HjAgLs6VY/J1cjxyZCcnT/5FsWJX+zsUY4wxbshq/lQRyGg+GxF4882cXS9f31Y9ffooERF1iI//3N+hGGOM8ZLMJnpzZwK4fJ0cAY4fP8i0ad2YPfseTp484u9wjDHG5AH5OjkGBQUm/xwTM48RI+qwaVNGK2UZY4wxKXy5ZNUcETkkIj9msn2IiMQ7Xz+KyAURKeXctktEtjq3uTyTeO3aFenV6+bk9ydOHGL69O7MmtWTEycOu/2ZjDHG5E++7Dm+BXTMbKOqvqaq16vq9cDzwEpVPZpql5ud211bqRIIDAxg9uwn+PzzFylXrlRye2zsR0RE1CEu7tPL/hDGGGPyP58lR1VdBRzNdkeHu4F5nrp2p06NiY+fTO/etyS3nThxmJkzezBjxl0cP37IU5cyxhjjYyKX1+7SOX25nqOIVAYWqmrdLPYpAuwFqif1HEVkJ/A3oMB0VZ2RxfH9gf4AFSuWabRjx8w025csieORR6ayb99fyW3FipWmZ88pNG58V04/mjHGmDzA1fUcc2Ny/D/gPlXtkqqtnKruF5FQ4L/AIGdPNEuNGlXXdevGXdJ+7Ngpnn12LnPnfpOmvWHD7tx99xRCQsJc/kzGGGPSr5mYIiQkbRG+P6SNrTGqsdn2KXPjaNWepLulqqr7nX8eAhYATdy5QIkSRZk+fSALF75MhQopEwRs2jSfESPqEBPzIb78pcEYY/K6zNZMzMlaip6W59dzFJESQGvgi1RtRUWkeNLPQHsgwxGvl6t9+4Zs2jSZhx5ql9x26tRfzJ59N9Ond+fYsT89cRljjDF5jC9LOeYBPwDXisheEekjIgNEJPWkQN2AZap6KlVbGLBaRDYDG4BFqrrEU3GVKFGUadMeY9GilwkPL53cHh+/gIiIOmzY8IH1Io0x5grjs7lVVfVuF/Z5C0fJR+q234EG3okqRbt2jl7k0KFvMWvWMgBOnTrKnDn3Ehf3Cffc8yYlSlzj7TCMMcbkArnqtqq/hYQUYerUR1m8eASVKpVJbt+8+XNGjKjN+vXvWy/SGGOuAJYcM3DLLQ3YuHEy/funzFlw+vTfzJ17H2++2ZVjxw74MTpjjMl9MlszMSdrKXpaTmLwaSmHr2VWynE5li/fzMMPv8Eff6RMN1ekSEnuumsyTZveh7hTZWqMMcancmWdo695IjkCnDhxhhdeeIdp0xanaa9XrzP33judkiXLuX0NY4zJy9ytc/RVnaSrydFuq7qgePHCTJ78MMuWRVKlSsoEAVu3LiQiog4//PC2PYs0xlzR3K1zzG11kpYcL0ObNvWIi5vIo492Sm47ffof3n77AaZM6czff+/zY3TGGGM8xZLjZSpWrDATJ/bnm29GUrVqSi/yxx+/JiKiDmvXzrVepDHG5HGWHHOoVau6xMVNYuDAzsltZ84c4513HuKNNzpx9OgeP0ZnjDHGHZYc3VC0aDDjx/fl229HUa1aygQBP/20hIiIuqxZM9t6kcYYkwdZcvSAm26qQ1zcJB5/vEtyacfZs8d5992+REd35OjR3X6O0BhjvMvdOsfcVidppRwetmbNz/Tr9wY7duxPbgsOLk737uNo2bKv1UUaY4wf5bo6RxGZA3QGDmW0nqOItMGxGsdOZ9Nnqhrh3NYRmAQEArNUdYwr1/RHcgQ4ffocr7zyPpMmfZXmtup117XjvvtmcvXVlXwekzHGeLOWcMCAzLdNm5b9td2J7ZFHIKNUJgJvvpn713N8C+iYzT7fq+r1zldSYgwEpgC3ArWBu0WktlcjdVORIoUYO/YhVqyIokaNlAkCtm37LxERdVm1aro9izTG+Jw/awmzu7Y7sWX2z2lSe65ez1FVVwFHc3BoE2CHqv6uqueBD4E7PBqclzRvXovY2AkMHtw1+XbquXMn+eCDAUya1I4jR3b5N0BjjDEZym0DcpqLyGYRWSwidZxt5YHUdRF7nW0ZEpH+IhIrIrFHjvh/CerChQsxZswDrFwZRc2aKWH/73/fEhlZj5Ur3+TixYt+jNAYY0x6uSk5bgQqqWoDIBr43Nme0b3hTO9JquoMVW2sqo1Ll84F08E7NWtWi5iY8Qwe3JWAAMfXfu7cSebNe5RJk9py5MjObM5gjDHGV3JNclTV46p60vnz10CQiJTG0VMMT7VrBWB/BqfI9ZJ6katWjeHaayskt//yy3dERtZjxYqp1os0xphcINckRxG5RpwP5kSkCY7Y/gJigBoiUkVECgI9gS/9F6n7mjSpSUzMeJ555j+pepGn+PDDx5g48RYOH/7dzxEaY/Ijf9YSZndtd2LLrEIuqT1Xr+coIvOANkBp4CDwMhAEoKrTRGQg8AiQCJwBBqvqWuexnYCJOEo55qjqKFeu6a9SjssRE/MrfftGs21bymPVggWL0K3bq7Ru/Why8jTGGOM+j9Y5ikgw8ARwCxBKuh6nqtbPYZxelReSI8DZs+cZOfIjXn99QZrbqjVqtOL+++dQpkw1P0ZnjDH5h6eT4xygG/AJjud9aQ5S1RE5jNOr8kpyTBIbu52+faP5+eeU6eaCggrTrdsY2rQZaL1IY/I5Xy34mxPZFdpnJ7vPdnmF/Jce7ypXk2MBF8/XFeihqt+4HoK5XI0b12D9+nGMGvUxr702nwsXLpKQcIaPP36CjRs/pVev2YSF1fB3mMYYL8ltC/6mll2hfXay+2w5LeT31nfjalfkNGlrDY2XFCoURETEvaxZM5Y6dSomt+/Y8T0jRzbg228ncvHiBT9GaIwx+Z+ryXEsMFhE7L6ej9xwQ3XWrx/HsGF3ERjo+NoTEs7wySdPMW5caw4e/NXPERpjTP6VabITkS+TXkBb4P+Anc7Za75Mt914QcGCQbzyyj2sXfsa9epVTm7/7bc1jBzZgG++GW+9SGOM8YKseoJ/pXstAJYDf2awzXhRw4bV+OGH13jxxf+jQIFAABISzvLpp0/z+us38eefv/g5QmOMyV8yHZCjqg/6MhCTtYIFgxg+/G5uv70pfftOZsuWXQD8/vsPjBp1PV26RNK27VMEBAT6N1B
"text/plain": [
"<Figure size 504x216 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code plots the decision boundary of a Perceptron on the iris dataset\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib.colors import ListedColormap\n",
"\n",
"a = -per_clf.coef_[0, 0] / per_clf.coef_[0, 1]\n",
"b = -per_clf.intercept_ / per_clf.coef_[0, 1]\n",
"axes = [0, 5, 0, 2]\n",
"x0, x1 = np.meshgrid(\n",
" np.linspace(axes[0], axes[1], 500).reshape(-1, 1),\n",
" np.linspace(axes[2], axes[3], 200).reshape(-1, 1),\n",
")\n",
"X_new = np.c_[x0.ravel(), x1.ravel()]\n",
"y_predict = per_clf.predict(X_new)\n",
"zz = y_predict.reshape(x0.shape)\n",
"custom_cmap = ListedColormap(['#9898ff', '#fafab0'])\n",
"\n",
"plt.figure(figsize=(7, 3))\n",
"plt.plot(X[y == 0, 0], X[y == 0, 1], \"bs\", label=\"Not Iris setosa\")\n",
"plt.plot(X[y == 1, 0], X[y == 1, 1], \"yo\", label=\"Iris setosa\")\n",
"plt.plot([axes[0], axes[1]], [a * axes[0] + b, a * axes[1] + b], \"k-\",\n",
" linewidth=3)\n",
"plt.contourf(x0, x1, zz, cmap=custom_cmap)\n",
"plt.xlabel(\"Petal length\")\n",
"plt.ylabel(\"Petal width\")\n",
"plt.legend(loc=\"lower right\")\n",
"plt.axis(axes)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Activation functions**"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAwgAAADPCAYAAABRAPaGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABpV0lEQVR4nO3dd3gU1frA8e/ZTe89JCT0GqQXEURCFyxY0CuKilLEckWxIPK7lmvvqNgQFRWuvWNBRCKIIE16ryEkEBIS0uue3x+zWVIhCZvdDbyf55lndnfOnHl3N9mZd+bMOUprjRBCCCGEEEIAmJwdgBBCCCGEEMJ1SIIghBBCCCGEsJEEQQghhBBCCGEjCYIQQgghhBDCRhIEIYQQQgghhI0kCEIIIYQQQggbSRCEQymlWiiltFKqlwO2laCUmu2A7TRRSv2qlMpVSjm932Cl1AGl1P3OjkMIIRobpdR4pVSOg7allVJjHLEtIepKEgRxSkqp7kqpUqXUinqsW90B+iEgCthgj/is26npB/0qYIa9tnMK9wPRQDeM9+YQSqnHlFJbqlnUG3jTUXEIIYSjKKXmWQ+stVKqWCmVqpRaqpS6UynlbodNfAa0skM9NtaYF1azKAr4wZ7bEsJeJEEQpzMJ42DzPKVUxzOtTGtdqrU+orUuOfPQTrut41rr7IbeDtAGWKe13q21PuKA7Z2S1vqY1jrP2XEIIUQD+Q3j4LoFMBzjIPtxYLlSyre+lSql3LXW+VrrVLtEeRrWfWGhI7YlRF1JgiBqpJTyBq4H3gW+BCZUU6avUup3a/OaE0qpJUqpaKXUPGAgcGe5sz0tyjcxUkqZlFJJSql/V6qznbVMd+vzaUqpTdZtHFZKzVVKBVmXxQMfAL7ltvOYdVmFKxhKqWCl1IdKqQylVL5S6jelVKdyy8crpXKUUkOUUlus21uqlGp5is/oADAauMm67XnW16tcOq7c9MdaZrJS6gvrtvYppcZVWidaKbVAKZWulMpTSm1QSg1SSo0HHgU6lXvf42vYTjOl1DdKqWzr9LVSKqbc8ses7/c6pdRea5lvlVJh5cp0tn63WdblG5VSg2r6XIQQogEVWg+uD2utN2itXwbigR7AgwBKKQ+l1HPWfUyuUmqNUmpEWQVKqXjr7+YopdRqpVQRMKL8Fely+6LO5Tdu/d1OU0q5K6XMSqn3lFL7rfuV3UqpB5VSJmvZx4CbgUvK/VbHW5fZ9hNKqZVKqZcqbSfAWueVtXxP7kqp15RSyUqpQqXUIaXUs/b84MW5QxIEcSpjgINa603AxxgHwbZLuEqprsBSYA/QH+gLfA64AVOBlRgH71HW6VD5yrXWFuAT4IZK270B2Ka1/sf63ALcA3TCSFj6AK9bl/1lXZZXbjsv1vB+5gHnYxzQ97Gu84syEqEynhjNkm4FLgCCgLdrqA+M5jy/Wd93lPV918UjwHdAV4xL2+8rpZoDKONM2B8YZ8muBDoD/7Wu9xnwErCTk+/7s8qVK6UU8C0QCQwGBmE0h/rWuqxMC+Bf1u0MB7oDT5Vb/j8gBeNz6w48BhTU8b0KIUSD0FpvAX4Brra+9AHGSarrMX47PwR+sO63ynsO+D+gA/B3pTp3AWupfh/1mda6GOM46jBwLdARmAk8DNxiLfsixv6h7KpHFMZ+q7L5wHVliYXV1UA+8GMt39PdGL/h1wFtMX7Td1azLSFOT2stk0zVThgHp/dbHyvgAHB1ueULgFWnWD8BmF3ptRaABnpZn3exPm9TrsxuYMYp6r0YKARM1ufjgZxTbR/jx1IDF5VbHgicACaWq0cD7cuVuQEoKttWDfEsBOZVek0DYyq9dqDs8yxX5plyz90wkpZx1ueTgGwgrIbtPgZsqeZ123aAYUAp0KLc8lYYSdfQcvUUAIHlyswE9pR7ngXc7Oy/SZlkkuncnjBO9CysYdmz1t/Q1tbfuGaVln8LvGl9HG/9Db66UpkK+xOMkz4HAWV9Hmut+4JTxPgs8NvpYi6/nwBCrfuaIeWW/wa8Y31cm/f0GrCkLFaZZDqTSa4giGoppdpgXBX4H4DWWmMkBBPLFeuO8WNUb9q4OrEZ44wISqnzMX4I/1culsFKqcXWy6rZwNeAB9CkDpvqiPHjurLctk9Ytx1Xrlyh1rr8GZdkwB3jSkJD2FQunhLgGBBhfak7sElrnXYG9XcEkrXWB8ptZx/G+yr/vg9aP48yyeXiAHgZmKuM5mQzlVIdziAmIYRoCArjoLuH9fE2a7PRHGuzoUsw9i/lrT1NnZ9gXHUdYH1+PbBPa23blyilpiil1iqljlm3cy/QrC6Ba63TgUVYr1YopaIwrvjOtxapzXuah9FZxi6l1BtKqUsqXZEQotbkD0fUZCJgBhKVUiVKqRLgIWC4UirWWkbVuHbdLODkJdwbgOVa64MA1uY2PwLbgWuAnhjNf8BIEmrrVLGW75q08s3TZcvq+r+iq9lmdT1sFFezXtm27PH5lu0wq1P+9VPFgdb6MYyE4lugH7BJKXUrQgjhOuKAfRi/XRqjCWi3clNHTu4/yuSeqkJt3LD8GxX3UQvKliul/gXMwjg4H2HdzpvUbf9UZj5wtVLKCxiL0Sz3T+uy074nrfV6jKv0D1vLfwgsliRB1If80YgqlFJuGDdVzaDiD1FXjDPeZW0r12O0a69JEUaScToLgDZKqb4YbSbnl1vWC+OH9l6t9UpttAmNrsd2tmH8vV9Q9oJSKgCjHee2WsRYV8co1+WpUiqSuneBuh7oUv5m4Upq+76bKqValIulFcZnWKf3rY1eml7TWl8CvEfFq0lCCOE0SqnzMJqffgn8g3FypInWek+l6XA9qp8PXKOU6omxzyi/j7oQ+FtrPVtrvV5rvYeqVylquy/8zjq/FGsiYr16T23fk9Y6W2v9hdb6doyrC4MxetoTok4kQRDVuQQIA97VWm8pPwGfArdaz0i8AHRXSs1RSnVVSrVXSk1USpVdWj0A9FFGz0VhNZ3F0FonAcswbgYOBL4ot3g3xt/pPUqplkqpsRg3JZd3APBSSg2zbsenmm3sxvjxfUcpNcDaK8V8jLb1/6tc3g5+x+jBqZcyemOaR91v6v0fkIpxQ/EA6/u/vFzvQQeA5kqpHtb37VlNHb8BG4EFSqmeyhigbgFG8vF7bYJQSnlbL1fHW7/L8zF2ig2RWAkhxOl4KmOAymjrvmcaxj1n64AXrSeSFgDzlFJjlFKtrL/F9yulrqrH9r7BuAL8HrDauj8pswvooZQaqZRqq5T6D8aNxOUdwOgqvL31t7ra8Rq01gUYTWj/D6NJ0fxyy077npTR499YpVRHazPh6zH2cUn1eM/iHCcJgqjOBGCptU1kZV8AzTFucN0ADMXo/WEVRg8Q13GyucqLGGdOtmGcUT9Vm8yPMa5Q/Ki1zix70XqPwlRgmrWeiRgDk1GuzF8YycUn1u08WMM2bgFWA99b5z7AxVrr/FPEVV/3YVzqTsA4ozUX42C/1rTWuRg7msMY/Xxvxejru+yM0lfATxj3gRzDuCRduQ4NXGFdnoDR69QR4IpyZ6ZOpxQIxrhcvRNjZ7kS4zsRQghHG4rRq1oixu/f5Ri/jRdZfzfB+L3/AHge2IHRmcRFGDcc14k2xpX5BmMfNb/S4ncwein6H7AGo4nPS5XKvIvRTHYtxm9x/1NsrmxfuF5rvb3SstO9p2zgAYz923qMK/8jtYyLI+pB1f4YQQghhBBCCHG2kysIQgghhBBCCBtJEIQQQtSLUup9pVSqUmpLDctvUMYo6JuUUn9VM0iVEEIIFyQJghBCiPqah9FzTE32AwO11l2AJ4A5jghKCCHEmXFzdgBCCCEaJ631svJd6Faz/K9yT1cBMQ0elBBCiDMmVxCEEEI4wgTgZ2cHIYQQ4vSccgUhLCxMt2jRwi515ebm4uvra5e67M1VY3PVuMB1Y3PVuMB1Y6tPXCUnSsjfmw8aPCI98IypbmgH58TmCPaOa926dWla63C7VVhP1rE7JmCMn1FTmcnAZABvb++esbGxNRWtE4vFgsnkmufCJLa6c9W4wHVjc9W4wHVjc9W
"text/plain": [
"<Figure size 792x223.2 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 108\n",
"\n",
"from scipy.special import expit as sigmoid\n",
"\n",
"def relu(z):\n",
" return np.maximum(0, z)\n",
"\n",
"def derivative(f, z, eps=0.000001):\n",
" return (f(z + eps) - f(z - eps))/(2 * eps)\n",
"\n",
"max_z = 4.5\n",
"z = np.linspace(-max_z, max_z, 200)\n",
"\n",
"plt.figure(figsize=(11, 3.1))\n",
"\n",
"plt.subplot(121)\n",
"plt.plot([-max_z, 0], [0, 0], \"r-\", linewidth=2, label=\"Heaviside\")\n",
"plt.plot(z, relu(z), \"m-.\", linewidth=2, label=\"ReLU\")\n",
"plt.plot([0, 0], [0, 1], \"r-\", linewidth=0.5)\n",
"plt.plot([0, max_z], [1, 1], \"r-\", linewidth=2)\n",
"plt.plot(z, sigmoid(z), \"g--\", linewidth=2, label=\"Sigmoid\")\n",
"plt.plot(z, np.tanh(z), \"b-\", linewidth=1, label=\"Tanh\")\n",
"plt.grid(True)\n",
"plt.title(\"Activation functions\")\n",
"plt.axis([-max_z, max_z, -1.65, 2.4])\n",
"plt.gca().set_yticks([-1, 0, 1, 2])\n",
"plt.legend(loc=\"lower right\", fontsize=13)\n",
"\n",
"plt.subplot(122)\n",
"plt.plot(z, derivative(np.sign, z), \"r-\", linewidth=2, label=\"Heaviside\")\n",
"plt.plot(0, 0, \"ro\", markersize=5)\n",
"plt.plot(0, 0, \"rx\", markersize=10)\n",
"plt.plot(z, derivative(sigmoid, z), \"g--\", linewidth=2, label=\"Sigmoid\")\n",
"plt.plot(z, derivative(np.tanh, z), \"b-\", linewidth=1, label=\"Tanh\")\n",
"plt.plot([-max_z, 0], [0, 0], \"m-.\", linewidth=2)\n",
"plt.plot([0, max_z], [1, 1], \"m-.\", linewidth=2)\n",
"plt.plot([0, 0], [0, 1], \"m-.\", linewidth=1.2)\n",
"plt.plot(0, 1, \"mo\", markersize=5)\n",
"plt.plot(0, 1, \"mx\", markersize=10)\n",
"plt.grid(True)\n",
"plt.title(\"Derivatives\")\n",
"plt.axis([-max_z, max_z, -0.2, 1.2])\n",
"\n",
"save_fig(\"activation_functions_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Regression MLPs"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import fetch_california_housing\n",
"from sklearn.metrics import mean_squared_error\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.neural_network import MLPRegressor\n",
"from sklearn.pipeline import make_pipeline\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"housing = fetch_california_housing()\n",
"X_train_full, X_test, y_train_full, y_test = train_test_split(\n",
" housing.data, housing.target, random_state=42)\n",
"X_train, X_valid, y_train, y_valid = train_test_split(\n",
" X_train_full, y_train_full, random_state=42)\n",
"\n",
"mlp_reg = MLPRegressor(hidden_layer_sizes=[50, 50, 50], random_state=42)\n",
"pipeline = make_pipeline(StandardScaler(), mlp_reg)\n",
"pipeline.fit(X_train, y_train)\n",
"y_pred = pipeline.predict(X_valid)\n",
"rmse = mean_squared_error(y_valid, y_pred, squared=False)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.5053326657968465"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rmse"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Classification MLPs"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code this was left as an exercise for the reader\n",
"\n",
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.neural_network import MLPClassifier\n",
"\n",
"iris = load_iris()\n",
"X_train_full, X_test, y_train_full, y_test = train_test_split(\n",
" iris.data, iris.target, test_size=0.1, random_state=42)\n",
"X_train, X_valid, y_train, y_valid = train_test_split(\n",
" X_train_full, y_train_full, test_size=0.1, random_state=42)\n",
"\n",
"mlp_clf = MLPClassifier(hidden_layer_sizes=[5], max_iter=10_000,\n",
" random_state=42)\n",
"pipeline = make_pipeline(StandardScaler(), mlp_clf)\n",
"pipeline.fit(X_train, y_train)\n",
"accuracy = pipeline.score(X_valid, y_valid)\n",
"accuracy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Implementing MLPs with Keras\n",
"## Building an Image Classifier Using the Sequential API\n",
"### Using Keras to load the dataset"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's start by loading the fashion MNIST dataset. Keras has a number of functions to load popular datasets in `tf.keras.datasets`. The dataset is already split for you between a training set (60,000 images) and a test set (10,000 images), but it can be useful to split the training set further to have a validation set. We'll use 55,000 images for training, and 5,000 for validation."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"fashion_mnist = tf.keras.datasets.fashion_mnist.load_data()\n",
"(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist\n",
"X_train, y_train = X_train_full[:-5000], y_train_full[:-5000]\n",
"X_valid, y_valid = X_train_full[-5000:], y_train_full[-5000:]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The training set contains 60,000 grayscale images, each 28x28 pixels:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(55000, 28, 28)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each pixel intensity is represented as a byte (0 to 255):"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"dtype('uint8')"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.dtype"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's scale the pixel intensities down to the 0-1 range and convert them to floats, by dividing by 255:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_valid, X_test = X_train / 255., X_valid / 255., X_test / 255."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can plot an image using Matplotlib's `imshow()` function, with a `'binary'`\n",
" color map:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAKRElEQVR4nO3dy2/N3R/F8d3HpbSSaqXu1bgNOqiIqNAhISoxMDc1MiZh4C8wNxFMS4iRSCUGNKQuIQYI4hZxJ6rUtTyDX36/Ub9rPTkn/VnN834Nu7JPz6UrJ+kne++G379/FwB5/vrTTwDA+CgnEIpyAqEoJxCKcgKhppqcf+UCE69hvB/yzQmEopxAKMoJhKKcQCjKCYSinEAoygmEopxAKMoJhKKcQCjKCYSinEAoygmEopxAKMoJhKKcQCjKCYSinEAoygmEopxAKMoJhKKcQCh3NCb+z9zFUg0N456i+I+NjIzIfHBwsDLr6+ur63e71zY2NlaZTZ36Z/9U67nwq9bPjG9OIBTlBEJRTiAU5QRCUU4gFOUEQlFOIBRzzjC/fv2S+ZQpU2T+4MEDmR8+fFjmM2fOrMyam5vl2hkzZsh83bp1Mq9nlunmkO59devreW5qfltK9WfKNycQinICoSgnEIpyAqEoJxCKcgKhKCcQijlnmFpnYv91/vx5mZ87d07mHR0dldm3b9/k2tHRUZkPDAzIfNeuXZXZvHnz5Fq3Z9K9b86nT58qs7/+0t9xTU1NNf1OvjmBUJQTCEU5gVCUEwhFOYFQlBMIRTmBUMw5w0yfPr2u9VevXpX548ePZa72Pbo9kVu2bJH5jRs3ZL53797KbO3atXJtd3e3zLu6umR+5coVmav3tbe3V67dsGGDzFtaWsb9Od+cQCjKCYSinEAoygmEopxAKMoJhGowRwLWfu8ZKqn33G19clu+1DiilFI+fPgg82nTplVmbmuU09PTI/MVK1ZUZm7E5I62fPnypczd0ZfqWM8TJ07Itbt375b5xo0bx/3Q+eYEQlFOIBTlBEJRTiAU5QRCUU4gFOUEQjHnrIGbqdXDzTnXr18vc7clzFGvzR0v2djYWNfvVlcIuvdlzZo1Ml+5cqXM3Ws7e/ZsZfbw4UO59vnz5zIvpTDnBCYTygmEopxAKMoJhKKcQCjKCYSinEAojsasgZu5TaTW1laZv3jxQuYzZ86Uubrm78ePH3KtuiavFD3HLKWUL1++VGbuPR8cHJT5pUuXZO5m169evarMtm7dKtfWim9OIBTlBEJRTiAU5QRCUU4gFOUEQlFOIBRzzklmdHRU5mNjYzJ31/ipOej8+fPl2jlz5sjc7TVV5+K6OaR73WqG6n53KXq/57Nnz+TaWvHNCYSinEAoygmEopxAKMoJhKKcQCjKCYRizlkDN3Nzs0Q1M3N7It0ZqO7sWHfP5ffv32t+7ObmZpkPDw/LXM1J3XxXPe9SSpk1a5bMP378KPPu7u7K7PPnz3LttWvXZL527dpxf843JxCKcgKhKCcQinICoSgnEIpyAqEYpdTAHdPoti+pUUp/f79c646+bG9vl7nbOqWemxsZPH36VObTpk2TuTqWc+pU/afqju10r/vt27cy3717d2V28+ZNufbnz58yr8I3JxCKcgKhKCcQinICoSgnEIpyAqEoJxCqwWx/0nuj/qXc3MrN5JShoSGZb9u2Tebuir96ZrD1XvHX1tYmc/W+ujmmm8G6qxMd9dr27Nkj1+7cudM9/LiDc745gVCUEwhFOYFQlBMIRTmBUJQTCEU5gVATup9TzVDrvarOHU+p9g66696ceuaYTl9fn8zdEY9uzumOkFTcXlE3//369avM3bGdivtM3Gfu/h5v3bpVmbW0tMi1teKbEwhFOYFQlBMIRTmBUJQTCEU5gVCUEwhV18Cunr2BEzkrnGgXLlyQ+cmTJ2U+ODhYmTU1Ncm16pq8UvTZr6X4M3fV5+Kem/t7cM9NzUHd83bXDzpu/qse/9SpU3Lt9u3ba3pOfHMCoSgnEIpyAqEoJxCKcgKhKCcQinICoWLPrX3//r3Mnz9/LvN79+7VvNbNrdRjl1JKY2OjzNVeVben0d0zuXDhQpm7eZ46H9bdYele9+joqMx7e3srs5GREbn24sWLMnf7Od2eTPW+zZ8/X669c+eOzAvn1gKTC+UEQlFOIBTlBEJRTiAU5QRC1TVKuXz5snzwAwcOVGZv3ryRaz98+CBz969xNa6YPXu2XKu2upXiRwJupKDec3e0ZVdXl8z7+/tl3tPTI/OPHz9WZu4zefz4scydpUuXVmbu+kF3ZKjbUuY+U3XF4PDwsFzrxl+FUQowuVBOIBTlBEJRTiAU5QRCUU4gFOUEQsk559jYmJxzbtiwQT642ppV75Vt9RyF6K6qc7PGeqm52Lt37+TaY8eOyXxgYEDmhw4dkvmCBQsqsxkzZsi1ak5ZSinLly+X+f379ysz976oKx9L8Z+5mu+WorfSubn4kydPZF6YcwKTC+UEQlFOIBTlBEJRTiAU5QRCUU4glJxzHjlyRM459+3bJx982bJllZnaH1eKPwrRXSenuJmX25+3ePFimS9atEjmai+r2odaSikvX76U+enTp2WurtkrpZRHjx5VZu4zu379el25ukKwnuNGS/FHgjqqJ+6xh4aGZN7R0cGcE5hMKCcQinICoSgnEIpyAqEoJxCKcgKh5KbKuXPnysVu3qdmlW5utWTJkpofuxS9/87t3Wtra5N5Z2enzN1zU/si3Z5Jt3dwx44dMu/u7pa5OnvW7al0n6k7L1jtyXSv212d6GaRbv+wmnOas5/tlZEdHR3jPye5CsAfQzmBUJQTCEU5gVCUEwhFOYFQcpTiRiXu389V/yIuxW8/clcEun/Lt7e315SV4reUue1qbr3atuWuulPbqkopZc6cOTK/ffu2zNVVem681draKnO3XU19Lu4oVXc0plvvrulTW/VaWlrk2ps3b8p806ZN4/6cb04gFOUEQlFOIBTlBEJRTiAU5QRCUU4glBz+rF69Wi5225OOHj1amS1cuFCuddfFua1Val7otg+5mZfajlaKn3Oq5+7WNjSMe4ri/zQ1NclcXfFXip5du21b7rm72XQ9WwzdY7vcbTlTc1R1nGgppcybN0/mVfjmBEJRTiAU5QRCUU4gFOUEQlFOIBTlBELJKwBLKfrMP+PMmTOV2cGDB+Xa169fy9ztyVRzLbcP1V0n5/Zzuj2Xah7ojll0c043a3QzXpW7x3bP3VHr3TGtjptNu78JtZ9z1apVcu3x48dlXkrhCkBgMqGcQCjKCYSinEAoygmEopxAKMoJhJJzzl+/fsnBlZsN1eP8+fMy379/v8xfvXpVmQ0PD8u1bl7n5phupqbOUHW/28373By0nrOI1Zm2pfj3pR5uv6Xbx+pm15s3b5Z5V1dXZdbb2yvX/gPMOYHJhHICoSgnEIpyAqEoJxCKcgKhKCcQakL3c6a6e/euzN3doO4eymfPnsm8s7OzMnPzPHeeLyYl5pzAZEI5gVCUEwhFOYFQlBMIRTmBUP/KUQoQhlEKMJlQTiAU5QRCUU4gFOUEQlFOIBTlBEJRTiAU5QRCUU4gFOUEQlFOIBTlBEJRTiAU5QRCVd9F9x/6PjkAE4ZvTiAU5QRCUU4gFOUEQlFOIBTlBEL9DRgW8qPu1lMTAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code\n",
"\n",
"plt.imshow(X_train[0], cmap=\"binary\")\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The labels are the class IDs (represented as uint8), from 0 to 9:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([9, 0, 0, ..., 9, 0, 2], dtype=uint8)"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here are the corresponding class names:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"class_names = [\"T-shirt/top\", \"Trouser\", \"Pullover\", \"Dress\", \"Coat\",\n",
" \"Sandal\", \"Shirt\", \"Sneaker\", \"Bag\", \"Ankle boot\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So the first image in the training set is an ankle boot:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"'Ankle boot'"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class_names[y_train[0]]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's take a look at a sample of the images in the dataset:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0AAAAFJCAYAAACy802jAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAADmp0lEQVR4nOydd5gdVfnHP2c3m00nhZDQk1BD702Q0HtHRaSJCD8VpSgoCIiKioIFLKCgIiCiSEd6B0NHeighBEhCGiFs2m62zO+Pme+Z2XPv3b7ZW97P8+xz996ZO3fmnXdOedtxURRhGIZhGIZhGIZRCVT19QkYhmEYhmEYhmGsKGwCZBiGYRiGYRhGxWATIMMwDMMwDMMwKgabABmGYRiGYRiGUTHYBMgwDMMwDMMwjIrBJkCGYRiGYRiGYVQMK2wC5Jyb5JyLnHMrt7FP5Jw7spu/0+1jlBrOuXHJdW/TnX2MGJOnUcyYfvY+zrkTnHOLC703+gbn3KPOud/19Xn0NaafxYlz7kjnXKfWljGdLkxvy7PDEyDn3JbOuWbn3H87czLlSHcHF8l32/q7podPGeBDYFXgpXbO7ULn3GttbH/TOXdIRya0KwqTZ/HhnLsmI/9G59xc59wjzrlvOOdq+vr8ViSmnyuePPo3zTl3qXNucF+fWyninBvtnPuDc266c67BOTfHOfeQc26vvj63UsT0s3cxfe1ZylWe/Tqx71eBPwDHOecmRlE0pZfOqRJYNfP/gcBVwWfLevoHoyhqBma3tU97A1Pn3AbAWsADwHY9d3bdxuRZnDwIHAtUA6OB3YEfAsc65/aIomhJ+AXnXP8oipav2NPsdUw/+wbpXw2wC3A1MBj4Wl+eVHdwztVEUdTYBz99MzAI+AowFVgF2BUY1Qfn0mM45/oBzVHfrAhv+tl7lKW+9iHlKc8oitr9AwYCC4HNgD8DlwbbxwERcARxZ7kUeAPYK7PPpGSflZP3tcCtwIvAKslnEXBk5jurAzcCnyR//wHWa+dcI+DUZN+lwPvAMcE+mxI3PsuABcA1wEqZ7VXA+cRW1AbgVeCQ4Deyf492RI4FzvfI+Da0u9+awO3J+S4F3gSO6oT8tc82wf3YH3gWWJ7ILby2EzLHODs5h3F59rsmc19/A8wB6oGngZ3z6MGBxNbpeuAFYOuuytDk2Xvy7MZ9uAa4K8/nmySy+WHyfjpwIfAX4jbmpuTznYDHEtnPBK4AhmWO89lEFouBT4FngE2SbSsB1wFzE3lMA07vS3mYfq5Y/cynf8QTz48SfXst2HYCsLij75PPTiEeDCxPXr+a2fYP4OZg/yriPuWM5L1LZPwucV/0Kpm+KiP3LwIPJ/uc2gc6Ozw5jz3b2Gc6cB7wR6AOmAGcFeyzEvAn4udyEfHzvU1m+6hEbjOSa30d+HJwjEeB32Xe70HcbpySvG9zzKB7n9zPd4FmYEgfyNT0s2/19RjguUQP5wI3Aatntk9KjrEHcd+yFHge2Co4znHEY8ylwF3AN8i078A6xG3obGAJ8Xj3wLZ0utj+ylmeHRXAscDLmQuZC9TkeRDeBA4C1gP+BnxM0rhkBLAyMAx4hLgBzA5qIpIJEPFs823ihmIzYENiC8n7wKA2zjVKfvcUYH3g+0AL6UBgEPGA6jbiidCuye/cnDnGGcSN+NHJMX5E3FBukWzfNvmdfYCxwMhuKFdHB0R3Eg92NgfGA/sC+3ZC/tonHBC9CuwNTCAedF2aHGds8jcwcw6TgROJLfqHJ9/fKNlvpWSfy4gb8QOAicSN+mJg1eB330zktwnxwzK7rftq8uwbeXbjPlxDnglQsu0Okg6eeOBUR9zRrpvIetPkGr+dvN8eeAr4d/KdfsSDm0uJG8QNiZ/Vicn23xIPtrdL7tMk4HN9JQvTzxWvn/n0D7gcmE8PDDCBw4BG4knk+sA3k/cHJdsPIJ7sDc98ZzegCRibvP8J8FZyb8cnOrwEOCC4h9MTvRkPrNEHOtuPeGBzOTCgwD7TE/07lfg5/mZy7jsm2x3wJPGEZLtknx8TP/vSldWBs4AtEv08mXjwvkfmdx4lGdwQGwDqgM8n79sdMyT3fglwP7BVoo/9+kCmpp99q68nEhuDJiT6+AjweGb7pOTank3ksiFwHzAFcMk+2xOPLb+fyPgU4mcgyhxnc+D/iPu0dZN9lwMb5tPpYvwrZ3l2VACPAd9J/neJwh+R2a4H4ZTMZ6snn+0cCGAisQXwjlCYtJ4AnQi8I+Ekn1UnAvl8G+caAVcFnz0IXJ/8/1Vii/HQPDdn3eT9TOCC4BiPZo6h692m0Hl0Qrk6OiB6BfhBgW0dkX+rc85c8xHBsS4kaHyTz8cQN6Cjg++vnNlncKKMxwX37F3gouB7X8rsM4TYineSybO45NmN+3ANhSdAFwNLk/+nA3cG268F/hx8tkVynasAI5P/dy1w/DuAv/bVtZt+9r1+hvpH3CnPB/6Z75rp/ADzv8Bf8vzmk8n//YgNhV/JbL8auC8ju2XALsExfgPcHdzDbxeB3h5B7I2sJzZGXApsn9k+HfhH8J13gPOS/3cnnhgPDPZ5CTi7jd+9Ebg68/5R4HfEk6NPgb0z29odMyT3vhEY08fyNP3sQ33Ns/+GybWskbyflLzfJ7PPZ4J9bgAeCI5zNe2078Re8vNCne5rmVWiPNstguCcWzc50RvQmcDfgZPy7P5K5v9ZyesqwT73E7u4D4+iqL6Nn96a2KKwyDm3OKlw8ikwgtjq2xZP5Xm/UfL/ROCVKIoWZbZPJp55buScGwasRtyAZHkyc4xeRdeb/F2ZfHwZcJ5z7inn3EXOua3zfLUj8g95voOndRDwdBRF89rYZx3ieGYvuyjORcjKXzyV2WcxsWW6V+Rr8iw6HHHDJ0KZbQ0ck71vpDJYJ4oiha3e55z7j3PuTOfcmpnvXwF83jn3cpJYvGsvXUePYPrZa+ybyFSd9uPElvCeYCJt9BFRFDURD2a/BOCcqyUeRFyf7LsRMAC4N9Dzr5Hbv3X0nvUaURTdTNwvHgTcQxyi+rRz7tzMbq8EX5tFqo9bE3to5gXXuwnJ9Trnqp1z33fOveKc+zjZfjhxHlqWQ4DfE3tE78983tExw4woiuZ0QQw9jelnL9GevjrntnLO3e6ce985t4j0GkJda6uNnUj+sabHOTfYOfcL59wbzrlPEhluk+d3ippylWdHiiCcRGxF+cA5588jOZk1oyj6MLOvT36LoihK9g8nWXcBnyd2Yf2vjd+tIrYOHZVn24IOnHchwsFXlqjA/2191htskfm/DiCKoj875+4jdjPuCUx2zv0siqILM/t2RP4hOYnoBTiUOPayLaQgfSm7fGyR+d/k2fdsRJyXI0KZVRFbfn6d57szAaIo+rJz7jfE4RkHAz9xzh0aRdF9URTd45xbG9iPOOb4P865m6Io+nIPX0dPsUXmf9PPnuNxYk9BIzArSpKznXMtpOcuulKZsL3rvp74vq1OHN7RnzjvFdL7dBDwQXCMMIm8o/esV0kMlg8kfz9yzl0NXOicuzTZJTzviPQ6q4jzxHbJc+i65PU7xGGvpxFPkBcDPyV3Uv9KcuyvOOeeToyy+o2XaH/MUBTyxPSzV2lDX39PHH6lIhRziVMzniCWQZbstWb1DHLvUT4uJe6jvkPsnVxKHOEQ/k7RU47ybLOzTCqkHA+cQ9xJ629z4kaoKwOK84ErgQedc1u0sd+LxDF+86Momhr8tTcB2iHPe1WtewPY3Dk3NLN9J2JZTImiqI54ZrpzcIydk+9CHPYB8cSwxwmudW7m8xlRFP0piqLPAxcQN549zXKC63Jxac49iPOmsvsR7KuESy8751w1sCOp7MQOmX0GE1sCe6WyoMmzeHDObULcgP27jd1eBDbO89xPjaLIV0yLoujlKIp+HkXRJGK39/GZbfOjKLouiqITiCvXHJ9YOYsO089eY2ki0/ej1pWp5gFjXMaiR+tJaEeYQtt9BFEUPUMcDvhFYkv7bYn3i2S/BmDtPDr+fifPpa94g9iIOqAD+75IHFbZkud6pfM7E4fDXhdF0UvEsls/z7H
"text/plain": [
"<Figure size 864x345.6 with 40 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 1010\n",
"\n",
"n_rows = 4\n",
"n_cols = 10\n",
"plt.figure(figsize=(n_cols * 1.2, n_rows * 1.2))\n",
"for row in range(n_rows):\n",
" for col in range(n_cols):\n",
" index = n_cols * row + col\n",
" plt.subplot(n_rows, n_cols, index + 1)\n",
" plt.imshow(X_train[index], cmap=\"binary\", interpolation=\"nearest\")\n",
" plt.axis('off')\n",
" plt.title(class_names[y_train[index]])\n",
"plt.subplots_adjust(wspace=0.2, hspace=0.5)\n",
"\n",
"save_fig(\"fashion_mnist_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Creating the model using the Sequential API"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42)\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.InputLayer(input_shape=[28, 28]))\n",
"model.add(tf.keras.layers.Flatten())\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.Dense(300, activation=\"relu\"))\n",
"model.add(tf.keras.layers.Dense(100, activation=\"relu\"))\n",
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"# extra code clear the session to reset the name counters\n",
2021-10-17 04:04:08 +02:00
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)\n",
"\n",
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(300, activation=\"relu\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" flatten (Flatten) (None, 784) 0 \n",
" \n",
" dense (Dense) (None, 300) 235500 \n",
" \n",
" dense_1 (Dense) (None, 100) 30100 \n",
" \n",
" dense_2 (Dense) (None, 10) 1010 \n",
" \n",
"=================================================================\n",
"Total params: 266,610\n",
"Trainable params: 266,610\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcEAAAIECAIAAADuFAIlAAAABmJLR0QA/wD/AP+gvaeTAAAgAElEQVR4nOy9fzyUWf/4f4ZBIaO12rSkDeu3Vt2FrLU/UFKRtEqljdBm27vctVuPfny6Pah73/Xe2LfaVKKNLCGlX6vWr61Ii5JfdZeEIqFhRowfc33/ON+97uueMWNcjJnR6/nXnNc51+t6nTPm5brOeZ3XYRAEgQAAAABaKMnaAAAAAAUGfCgAAAB9wIcCAADQB3woAAAAfZjUQmFh4Y8//igrUwAAAOQfBweHsLAwsvhfz6ENDQ1paWljbtJbQWNjI4wtbdLS0hobG2VthbSAvw0FoqioqLCwkCphUGObUlNTfX19IdpJGsDYjgQGg5GSkvLll1/K2hCpAH8bCsSKFSsQQufOnSMlMB8KAABAH/ChAAAA9AEfCgAAQB/woQAAAPQBHwoAAEAf5tBNAEABqa2tjYiICA8P19fXl7Uto0NdXR0ZVfPhhx/OmTOHrOrv7y8uLuZyuW1tbQghMzMzW1tbspbNZl+9epUsLly4cPLkyWNl9f8Ph8M5e/bs06dPjY2N/fz81NXVqbVcLjc1NbWurs7e3t7V1VVFRUVKOktLS3V0dAwNDclmtbW1d+7cwZ9NTU1nz5497L4RFFJSUgQkwGgBYzsSEEIpKSnDugRHn1y5ckVKJo0iEv5tJCYmIoSSk5Obmpo6OztJOZvN3r9/f2dnJ5fL3bt3L0KIxWI9fPiQbMDn80tKSqytrS0sLHJzc/l8vlS6IZqampqpU6eamJioqqoihIyMjJqamqi1xsbGly9fxj5x+vTp+fn5UtLZ19e3ceNGqn4ul1tXV/fHH3+oqKhs3bp1yPv6+Pj4+PhQJeBDxwgY25FAw4cSBPHq1StpGEPl9OnTI1cyLB/KZrOpwsbGxiVLllCF2KeYm5tT/SxBEPipfOTW0sDd3f3+/fsEQbS0tGzYsAEhFBAQQK0NDAwki+vWrXNycpKezv7+fnd39/LycgGFM2bMoOdDYT4UGLe8++67UtWfk5Ozc+dOqd5iSMLCwpYtW8ZisUiJsbGxm5tbdXW1v78/QYnb19HR0dbWHnsLS0pKVq9ebWNjgxDS1dUNDw9XUlK6ffs22aCpqamyspIsqqmp8Xg86elUVlYOCwsLDg4ejc4hBGtKwHiFz+fn5ubevXsXFxsaGqKjo/l8fkVFRWRk5JkzZ/h8Ptm4sbHx6NGjBEHk5eXt3LkzJiamu7sbIZSVlRUVFXXy5EmEEIfDOXLkSFRUFH5szM3N9fLy4nK5sbGxWVlZCKHW1tYDBw68fPlyzPpYXFx8+fJlHx8fqpDJZP76669GRkaZmZkRERGkXElJSUnpv37vHA4nJSVl3759cXFxDQ0NpFz8WCGEXrx4cerUqfDw8N9//31II2fMmOHn50cW9fT05syZQ52Q9fb2Lioqwk/ZXC73/PnzW7ZskapOFxcXDoeTkZExpPESQX0ohfdN6QFjOxLQMN/lKysrsWf5+eefCYK4ePGirq4uQujw4cPr169fvHgxQmj//v24cWJi4uTJkydOnLhx48aAgIBFixYhhObOndvb20sQhKWlpb6+Pm7Z2dmppaXl4OBAEERZWZmjo6Ourm5ubm5ZWRlBECdOnEAI/fTTT8PtHe13+eXLl7u4uAg0s7GxIQjiwYMHmpqaDAYjKysLy2NjY2NiYshm9+7ds7a2Tk9Pb2lpOXTokKamJp6XED9WBEHk5OQEBQWVlpampqZqampu2rRpuP2dOnUqdVahubnZ1NQUIbR161Y3N7eMjIzhKqShMzg42NbWliqh/S4PPnSMgLEdCcP1oQRBlJeXkz6UIIgdO3YghG7cuIGLs2fPnjNnDtl4zZo1DAajoqICF/fs2YMQOnbsGEEQPj4+pA/FF2IfShCEl5eXgYEBWcXlcs+ePSswCykJtH2oiYkJfmGngn0oQRDp6ekMBoNcX6L6UB6PZ2ZmtnfvXvIqPz8/VVXVyspKQuxYcTicmTNncrlcXAwMDEQIFRYWSt7Z/Px8fX19DodDFba0tBgZGSGEHBwcmpubJddGW2d0dDSTyeTxeKQE5kMB4L9QU1OjFidOnIgQMjMzw0ULC4v6+nqyVkNDg8lkWlpa4uKOHTuYTGZBQcGQd2EwGFQlq1atmjRp0siNl4Te3t7a2lo9PT1RDby9vXft2tXR0eHl5cXhcKhV165dq6mpsbe3JyULFizo7e2Ni4tDYscqOTm5u7v7u+++Cw0NDQ0NbWpqMjIyevz4sYQ2DwwM7N279+LFi5qamlR5XFycs7NzQEBAYWGhnZ0d9auRkk4Wi9Xf3y+55WKA+FDgbURZWZkQnSdJXV1dX1//1atXQ+qh+tAxpr29fWBgAPs7UYSHh9+/fz8rK8vf33/hwoWkvKqqCiFEdTpOTk4IoerqamEl1LGqrKzU09M7cuQIPZu3bdsWFhZGjV1FCMXHx6ekpNy9e5fJZDo6OoaEhISGhuIpZunpxH1vbGy0sLCg1xcSeA4FAEF4PF5zc/PMmTOHbClDHzp16lRtbW2BB0wBGAxGYmKimZlZZmZmdHQ0KX/nnXcQQtQ8mIaGhioqKkPG3isrKz98+LCvr4+GwcePH7e1tV26dKmA/PTp0+7u7kwmEyEUEBAQFBSUnZ3NZrOlqvP169cIIQMDAxodEQB8KAAIUlRU1NPTg5dTmExmT0/PoM0YDMbAwMDYmvZfWFpatrS0UCUEQbx584Yq0dLSyszMZLFY1GdMOzs7hBB1sqKioqKvr8/BwUH8HWfNmtXV1XXs2DFSwmazjx49OqSp58+fJwjC39+flOTn5+MP5eXlVO/m6enZ29srSXjDSHQ2NTUxGIwPPvhgyLsMCfhQYHyC4wFbW1txsbOzEyHU29uLi62trXg9gWzf399Pepm0tDRnZ2fsQ93c3FpbW+Pj47u6uuLj49va2mpra/FTjJ6eXnNzc21t7ZMnT7q6ukpKSubNm5eXlzdmfXRycnrw4AFV0tTU9Pz5cwGnb2pqmpSURA1smjVr1rp16woKCshZwps3b5qYmOCoSTFj5evra2BgsG3btoMHD1ZXV6empgYHB69duxa3DA4OXrRokbD7u3Hjxg8//NDX1xcTExMTExMdHR0SEoIX/RBCXl5e58+fJ8OnioqKbGxsTExMpKcTIVRXV+fm5jZhwgQJhnkoqAtMsHYsPWBsRwIa5rp8UVERjm2ysrK6dOlSXl4efjHfsGFDU1NTcnKylpYWQmjfvn19fX0EQYSEhCgrK3/zzTfbt29fuXLlkiVLyOV1DoeD117Mzc0zMjK8vb0XLFhw4sQJgiByc3OZTKa2tjaOZ8Lr4LhqWNBel29vb58yZcrjx49x8dy5c5988glCyNXVNScnR+DyyMhIamxTd3d3aGiopaVlQkLCyZMnPTw86uvrCYIYcqyqqqo+/PBD7D0sLS1LS0tJnXgd/NChQ9T7lpSUaGhoCLidCRMmtLW14QZdXV2BgYFWVlZRUVEbNmxYunRpbW2tVHXyeDwdHZ3r169TdUJsk7wDYzsShutDh0tISIiKigpBEPX19R0dHcINWlpa8Ifu7m6qnM1mU4OZBr12SEay1/PYsWOhoaES3ujly5cCEjabfevWrYaGBgk1kNTV1T179kxA2NPTk5KScuHCheFqIwiiq6urqqqqvb19DHSmpqZ6enoKCCG2CQBGAQMDA/zYJQAOO0cICbz9sVgsajDToNeOLgL7IIOCgtra2srKyiS5dsqUKQISFos1f/58GqmtDA0Np0+fLmxbYWEh3qQwXNTV1c3NzYUXtUZdZ01NTVJSUnJyskBj2lPbENsEAOjNmzf9/f1cLlcgxlCuUFFR0dLS2rBhg4ODw9y5c11cXBBCSkpKCQkJmzdvDgoKmjt3rmwtLC4u3r9/P14Nl0+dz549O3DgwKlTp8iYsIqKimvXrtXX13d2dtKbHqVvWVVV1dWrVx89emRvb6+lpcVkMj09PWlrkwYKl0GyoKDg+fPnZFFbW9vd3V3aN83OzsZJJzE2NjZkqPlbQlJ
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code another way to display the model's architecture\n",
"tf.keras.utils.plot_model(model, \"my_fashion_mnist_model.png\", show_shapes=True)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[<keras.layers.core.flatten.Flatten at 0x7fb220f4d430>,\n",
" <keras.layers.core.dense.Dense at 0x7fb282285af0>,\n",
" <keras.layers.core.dense.Dense at 0x7fb282365b50>,\n",
" <keras.layers.core.dense.Dense at 0x7fb282365fa0>]"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.layers"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"'dense'"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hidden1 = model.layers[1]\n",
"hidden1.name"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.get_layer('dense') is hidden1"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.02448617, -0.00877795, -0.02189048, ..., -0.02766046,\n",
" 0.03859074, -0.06889391],\n",
" [ 0.00476504, -0.03105379, -0.0586676 , ..., 0.00602964,\n",
" -0.02763776, -0.04165364],\n",
" [-0.06189284, -0.06901957, 0.07102345, ..., -0.04238207,\n",
" 0.07121518, -0.07331658],\n",
" ...,\n",
" [-0.03048757, 0.02155137, -0.05400612, ..., -0.00113463,\n",
" 0.00228987, 0.05581069],\n",
" [ 0.07061854, -0.06960931, 0.07038955, ..., -0.00384101,\n",
" 0.00034875, 0.02878492],\n",
" [-0.06022581, 0.01577859, -0.02585464, ..., -0.00527829,\n",
" 0.00272203, -0.06793761]], dtype=float32)"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"weights, biases = hidden1.get_weights()\n",
"weights"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(784, 300)"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"weights.shape"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"biases"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(300,)"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"biases.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Compiling the model"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=\"sgd\",\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is equivalent to:"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"# extra code this cell is equivalent to the previous cell\n",
"model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy,\n",
" optimizer=tf.keras.optimizers.SGD(),\n",
" metrics=[tf.keras.metrics.sparse_categorical_accuracy])"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code shows how to convert class ids to one-hot vectors\n",
"tf.keras.utils.to_categorical([0, 5, 1, 0], num_classes=10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: it's important to set `num_classes` when the number of classes is greater than the maximum class id in the sample."
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0, 5, 1, 0])"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code shows how to convert one-hot vectors to class ids\n",
"np.argmax(\n",
" [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],\n",
" axis=1\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training and evaluating the model"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.7220 - sparse_categorical_accuracy: 0.7649 - val_loss: 0.4959 - val_sparse_categorical_accuracy: 0.8332\n",
"Epoch 2/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4825 - sparse_categorical_accuracy: 0.8332 - val_loss: 0.4567 - val_sparse_categorical_accuracy: 0.8384\n",
"Epoch 3/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4369 - sparse_categorical_accuracy: 0.8480 - val_loss: 0.4228 - val_sparse_categorical_accuracy: 0.8542\n",
"Epoch 4/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4122 - sparse_categorical_accuracy: 0.8558 - val_loss: 0.3966 - val_sparse_categorical_accuracy: 0.8624\n",
"Epoch 5/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3910 - sparse_categorical_accuracy: 0.8631 - val_loss: 0.3890 - val_sparse_categorical_accuracy: 0.8632\n",
"Epoch 6/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3751 - sparse_categorical_accuracy: 0.8686 - val_loss: 0.3912 - val_sparse_categorical_accuracy: 0.8600\n",
"Epoch 7/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3628 - sparse_categorical_accuracy: 0.8710 - val_loss: 0.3723 - val_sparse_categorical_accuracy: 0.8698\n",
"Epoch 8/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3514 - sparse_categorical_accuracy: 0.8755 - val_loss: 0.3767 - val_sparse_categorical_accuracy: 0.8612\n",
"Epoch 9/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3406 - sparse_categorical_accuracy: 0.8795 - val_loss: 0.3513 - val_sparse_categorical_accuracy: 0.8726\n",
"Epoch 10/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3306 - sparse_categorical_accuracy: 0.8812 - val_loss: 0.3539 - val_sparse_categorical_accuracy: 0.8738\n",
"Epoch 11/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3223 - sparse_categorical_accuracy: 0.8860 - val_loss: 0.3606 - val_sparse_categorical_accuracy: 0.8712\n",
"Epoch 12/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3146 - sparse_categorical_accuracy: 0.8869 - val_loss: 0.3472 - val_sparse_categorical_accuracy: 0.8742\n",
"Epoch 13/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3071 - sparse_categorical_accuracy: 0.8900 - val_loss: 0.3284 - val_sparse_categorical_accuracy: 0.8800\n",
"Epoch 14/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3001 - sparse_categorical_accuracy: 0.8922 - val_loss: 0.3413 - val_sparse_categorical_accuracy: 0.8780\n",
"Epoch 15/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2938 - sparse_categorical_accuracy: 0.8945 - val_loss: 0.3376 - val_sparse_categorical_accuracy: 0.8822\n",
"Epoch 16/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2867 - sparse_categorical_accuracy: 0.8971 - val_loss: 0.3272 - val_sparse_categorical_accuracy: 0.8796\n",
"Epoch 17/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2822 - sparse_categorical_accuracy: 0.8978 - val_loss: 0.3317 - val_sparse_categorical_accuracy: 0.8796\n",
"Epoch 18/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2757 - sparse_categorical_accuracy: 0.9001 - val_loss: 0.3240 - val_sparse_categorical_accuracy: 0.8824\n",
"Epoch 19/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2711 - sparse_categorical_accuracy: 0.9030 - val_loss: 0.3484 - val_sparse_categorical_accuracy: 0.8720\n",
"Epoch 20/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2662 - sparse_categorical_accuracy: 0.9045 - val_loss: 0.3209 - val_sparse_categorical_accuracy: 0.8800\n",
"Epoch 21/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2613 - sparse_categorical_accuracy: 0.9046 - val_loss: 0.3178 - val_sparse_categorical_accuracy: 0.8862\n",
"Epoch 22/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2563 - sparse_categorical_accuracy: 0.9069 - val_loss: 0.3122 - val_sparse_categorical_accuracy: 0.8848\n",
"Epoch 23/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2520 - sparse_categorical_accuracy: 0.9098 - val_loss: 0.3480 - val_sparse_categorical_accuracy: 0.8716\n",
"Epoch 24/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2469 - sparse_categorical_accuracy: 0.9113 - val_loss: 0.3202 - val_sparse_categorical_accuracy: 0.8878\n",
"Epoch 25/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2428 - sparse_categorical_accuracy: 0.9123 - val_loss: 0.3152 - val_sparse_categorical_accuracy: 0.8856\n",
"Epoch 26/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2393 - sparse_categorical_accuracy: 0.9143 - val_loss: 0.3102 - val_sparse_categorical_accuracy: 0.8852\n",
"Epoch 27/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2341 - sparse_categorical_accuracy: 0.9147 - val_loss: 0.3200 - val_sparse_categorical_accuracy: 0.8850\n",
"Epoch 28/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2313 - sparse_categorical_accuracy: 0.9169 - val_loss: 0.3100 - val_sparse_categorical_accuracy: 0.8900\n",
"Epoch 29/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2268 - sparse_categorical_accuracy: 0.9185 - val_loss: 0.3215 - val_sparse_categorical_accuracy: 0.8864\n",
"Epoch 30/30\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2235 - sparse_categorical_accuracy: 0.9200 - val_loss: 0.3056 - val_sparse_categorical_accuracy: 0.8894\n"
]
}
],
"source": [
"history = model.fit(X_train, y_train, epochs=30,\n",
" validation_data=(X_valid, y_valid))"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"{'verbose': 1, 'epochs': 30, 'steps': 1719}"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"history.params"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]\n"
]
}
],
"source": [
"print(history.epoch)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAFYCAYAAABNvsbFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABwKElEQVR4nO3dd3gUVdsG8Puk9xAISSBAqAm9iUgUBAEFBIwIWAABUREUxVdREQtir2D5AEUELFhRBCyIlFAUpPfQIZTQa3rb8/3xZLK7ySZZ0jYb7t91zbW7M7MzZ8+WefZUpbUGERERkTNxcXQCiIiIiK4WAxgiIiJyOgxgiIiIyOkwgCEiIiKnwwCGiIiInA4DGCIiInI6RQYwSqlZSqkzSqmdBWxXSqmPlVIHlFLblVJtSz+ZRERERGb2lMDMAdCzkO29ADTKWUYCmF7yZBEREREVrMgARmu9CsCFQnaJAfCVFusAVFFK1SitBBIRERHlVRptYMIBHLN4fDxnHREREVGZcCuFYygb62zOT6CUGgmpZoKXl9d1derUKYXTU2FMJhNcXNhWu6wxn8sP87p8MJ/LB/O5cPv27Tunta5ua1tpBDDHAdS2eFwLQIKtHbXWMwDMAICoqCi9d+/eUjg9FSY2NhZdunRxdDIqPeZz+WFelw/mc/lgPhdOKRVf0LbSCPsWAhia0xupA4DLWuuTpXBcIiIiIpuKLIFRSn0HoAuAYKXUcQATAbgDgNb6UwB/ALgdwAEAKQAeKKvEEhEREQF2BDBa6/uK2K4BPFZqKSIiIiIqAlsOERERkdNhAENEREROhwEMEREROR0GMEREROR0GMAQERGR02EAQ0RERE6HAQwRERE5HQYwRERE5HQYwBAREZHTYQBDRERETocBDBERETkdBjBERETkdBjAEBERkdNhAENEREROhwEMEREROR0GMEREROR0GMAQERFRxbN2LcKBsII2u5VnWoiIiKiCWrsWiI0FunQBoqNLfpzOnYGoKCAjw3qpWhWoUUPur1plXp+eLretWwNJSUC3bggDwgs6DQMYIiIiR1i7FnXmzgU8PUsnYOjSBWjfHsjMBLy8ZNuRI8Dly0BqKpCWJreBgcCNN8r22bOBs2eBffuAL78EsrMBNzdg5UpJ04gRwIULElwYAUaPHsBLL8nzGzUCkpPNAUhqKqAUoDXg4SGP85owAXjjDeDKFeDWW/Nvf/NNuc3IKPRlM4AhIqKKrbRLBiyPo7VctLOyrJegIMDVFbh4ETh/3nqfzZuBkyeBW24B/P2BPXusA4SMDGDcODn+3LlSymBsS00F3N2BZ54BunVDvdRU4IsvgHr1JOjQGggJkXQCwEMPAcuXy3qTSZaGDYEVK2R7x47AP/9Yv8727YH//pP7MTHA9u3W27t2BZYtk/uvvw4cOmS9PStLzh8dLYFNYqIEWZ6ekkYjOAKklEUp2ebhIXmzerWkMyMD6NNHFg8P89KkiTy3ShXZ13KbhwcQHAzExQEeHtCpqbqgt5MBDBERlQ17Ao/MTODcOakySEyU26QkoG1bICwM+OknYPBguai6ugJ33SUXuLFjgchIYM0aYPJk6yqIjAwpWWjcGPjuO/nHn5gogYjh+++Be+6R5xrBhqUTJ4CaNYGPPgImTcq/3dVVLrYxMXKsvJ56CnBxATZuBBYskIu+t7fcBgVJvmRkQAESnHh7S3pdXIBq1czHadxYXpeLiyxKSfWLISDAXOKhlAQnQ4aYt7/7rpSQeHubF8vjb9wor2PLFuC22yTvPDzkPQMkfwszc6b147VrgW7dzMeZMKHg997NTQIwW6KjgWXLcPrGGxMKOjUDGCIiRyvLEoZiHie3aqN9eyAhAUhJMZcgpKRIiUH9+lIN8O235vXGPg0bAk8/LRdfAGjWTC7ARoDy0UcSQKxdK//i8/rlF6BfP+CPPyTIASSI+fVXuWjfc48EMImJwIED1v/gAwLMx6lRQ46/e7dUhRgX+u3b5Rg33QS89ppcTC0X4xj9+gENGpjXL1ggQVF2tlykIyLkWEZwYgQJSsnzp0yRxdZ75eEBU3o6XDw9gc8/t/2e2QquLL30Um4wBA8PeS2Wx+nRo/DnBwXJbceOUipT0s9PTuBRKp/D6GicAE4VtFlpXWDpTJmKiorSe/fudci5ryWxsbHoYkTSVGaYz+WnQuW1vQFDQoJcPC1LGHx85OJiXMCNEoaYGGnk2LAh8Oyz8vyHHwaOHrVuCNmpk5QeANLo8ehRqe4w9O0LLFwo90ND5V84IBdwQKomPvpIivp9fc3rc6pKNADl7Q3Mnw/07Jn/Nb36qlw8jx0D6tQxr3d1lQt4t27Ab7/J8QAJdlq0APz8ZBk2TPLs9GkJSoz1xhIZKRfX1asln4wL9LJlxbso5i0ZcPRxco51aNYs1B8xomIErhWQUmqT1rqdrW0sgSEi51FajR7//RdYuhTo3l0aM54+LYtlCUNmpgQTAPDnn8C2bdbbz5+Xf+MZGfJvu1Ej2deoCgkPB3bulHWDBkmjSEtt2siFOTbWuoThzz+lkaVl0fr589IQ08NDgoPAQPM/Z0CqDf75B9iwwVzCYNlO4YEH5BxGqYBSQIcO5vtPPGG+v3YtsHo1lNby2jZskNIBHx9z6YKPj5TAAFLNcvKkeZu7u/k4S5aYL/TffGP7PQsNBR55pOD3qlOnilUyUMolDEfT01G/pEFHdHSlC1zswRKYSq5C/VutxCptPpfWP7vVq+Vi1rat1OmnpEiJQKtWcjHet08CCmO9sbzwglwgf/1V2iFs2watNZRSQN26ctzwcGD6dOCdd+TfvuWyd6/U90+caN6elWVOl7e3XIzmzgWmTrVOs7u7uRfE8OHSQ8N4jre3VCcYjTuVkhKT1q3NJQjh4cBzz8lz/v4buHRJ1vv7y21QkFQ/VLSSgZzj5FZtlLCEobKWDJSWivTbcfIkcO+9wA8/SPMjRzt5EqhZs2mS1rv9bW1nCQxRZVTYhcP4Z52aKhc6Hx/pIbFzp7mEITVVShzeekv+ubu7S8PA6tWt2zqMHi3H37QJGDPGeltKirQV8PeXkg5bXSJXrgRuvhlYvx547DHzeuNf/qhREsBoLSUQWpsbPfr4SHUFANSuLa/V1VUWFxdzI0tA0jh2rKxbt07yxsiH2FjpKnrLLdYNHb29zemZNk2CJC8vcylG3oDhyy8Lvkjb6ipqqGglA9HROPnDKtw50h8LPk9EWLTN0nv701SBApfSukCX5nHGjm2Nv/6qGOl57TVps/vqq/KRrwjpAXz8CtxBa+2QJTIyUlPZW7FihaOTUPn9+68++NBDWv/7b4mPo99803wck0nrjAzz/c2btV65UuvfftP6u++0njFD61WrZHtqqtajRmk9eLDWnTpp7eKiNaC1u7sc7/RprYODtfbxMW8DtH73XXn+vn3mdZaLUnLr6iq3bm5aBwRoHRqqdd26Ws+fL8/fvl3r227TOiZG63vv1fqBB7R+7DFZ/+abOkHV1DcjVp9UNbS+5x6tFy7UetkyrS9elOcnJ2t96pTWiYlaZ2cXnD/e3jrbxUVrb+/i53fOcbSra8mOo7VOWLhB31w3Xp9ctKHYx6iIRo/WWimTHj26ZMdJSND65pu1PnmydNJVUg89JB//hx+Wr1RxjR4txylp/pRWPtuTHpNJ68uXtT54UOt167RetEjr2bPlJ8D4euddXF21fvttradP1/rbb7X+/XetV6+Wr3V8vNaXLtn+uhaWnuxsra9ckc/Gvn1ab9pk/ln7/nutZ86UnxlzOq7TuoA4glVIlVxFKp50emlpUmVw7pz59vBhYNIk6PR0KFdX4PbbpYunMV5E48bAiy/K8x9+GDh+3HqsiRtvlKqNtWulzYPJJPsapSJDhpirLjw985dijB4tf5WysqSkwt8fSEnByVPAvfgeP6j7EPbG49LG4Zln8pcw3Hwz0K6dlJYsXWrdzmHvXuDBB80lDH/9Je0RrtbatXi00w58lv0gHnH9AtNWtyj2v/KTizaaSwb6FL9k4OSijbj3iRD88MmZEh3n0UeBzz6TJhwl+cdaWq72n29
"text/plain": [
"<Figure size 576x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"\n",
"pd.DataFrame(history.history).plot(\n",
" figsize=(8, 5), xlim=[0, 29], ylim=[0, 1], grid=True, xlabel=\"Epoch\",\n",
" style=[\"r--\", \"r--.\", \"b-\", \"b-*\"])\n",
"plt.legend(loc=\"lower left\") # extra code\n",
"save_fig(\"keras_learning_curves_plot\") # extra code\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeMAAAFFCAYAAADfMoXLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABnV0lEQVR4nO3dd3gU1frA8e/ZlE0PJRAgoQtIL1IVAQUFRAERFek2BNGr/sQGKmBDr4qKVyl6wQaKckHEgvSmoFQNEES69E4akHZ+f5xMdjd1Awm7Sd7P88yzu1N2zpwt75wyZ5TWGiGEEEJ4js3TCRBCCCFKOwnGQgghhIdJMBZCCCE8TIKxEEII4WESjIUQQggPk2AshBBCeFi+wVgpNV0pdVwptTWX5UopNUkptUsp9adSqkXhJ1MIIYQoudwpGX8CdMtjeXegTsY0DJh8+ckSQgghSo98g7HWehVwOo9VegGfaWMdUEYpVbmwEiiEEEKUdIXRZhwF/OP0+mDGPCGEEEK4wbcQ3kPlMC/HMTaVUsMwVdkEBgZeU7Vq1VzfND09HZtN+pe5Q/KqYCS/Ckbyy32SVwVTGvNr586dJ7XWFbLOL4xgfBBwjqrRwOGcVtRaTwOmAbRs2VJv2LAh1zddsWIFnTp1KoTklXySVwUj+VUwkl/uk7wqmNKYX0qp/TnNL4xTku+AwRm9qtsC57TWRwrhfYUQQohSId+SsVLqS6ATEKGUOgiMBfwAtNZTgB+BW4BdQBJwb1ElVgghhCiJ8g3GWut78lmugZGFliIhhBCilCldLedCCCGEF5JgLIQQQniYBGMhhBDCwyQYCyGEEB4mwVgIIYTwMAnGQgghhIdJMBZCCCE8TIKxEEII4WESjIUQQggPk2AshBBCeJgEYyGEEMLDJBgLIYQQHibBWAghhPAwCcZCCCGEh0kwFkIIITxMgrEQQghhWbsWJkwwj1dwW9+C700IIYRww9q1sGIFdOoE7doV3bZaw8WLkJRkpl9+gV274MYboXJl8/zCBTNdvGgeBw8GPz9YsgTWrDHz9+yBuXMhLQ0CAmDpUtiwAb7/3iy3Jj8/+O03s+9hw+Drr+H8eUhOBh8f8Pc32xbgmCUYCyFESVZYAbFtWxNszp83AS8sDEJC4Nw52LzZMT8pyTwvXx4GDTLBy2aDoUMhKsoEzvR0GDLE7OPPP+GLL8w8azp8GBYsgJQU8PWF1q0hMNDx/klJJkDWrg2TJsHjj5v3debjA6++CgMGwMcfZz+2O+6AMmVg0SJ4800TQJWC1FSzPDnZHDuYY7TbITQUIiLMo+Xaa03aNm82QT0tzbGtBGMhhPAyBQmK6ekmENnt5vV//wvr1kGDBlCrlglwtWpBy5bmz/+990wAcC69depkAkfnziY4+viY/YaFmYAzZAj07w9HjkCvXmZ/qalmSkmBfv1g4kTzXunpJlA5B7wZM0yA3bYNbrgh+zH072/SZAVY54Bos0GHDqaEuXs3/Oc/Zp5S5jE52bGt1vD331CjBgQFmQAaFGSCNJg8eP55My8oyOTx/PmOoFi2LKxebfIyIMDxaAXU116D1183+1271uRXcrIJztZn9dxzuX9WQ4eaKadtC0CCsRCi+LlS1Z/ubpuQAAkJBBw6BDExkJho/vCbNTPLX30Vxo83gc7HxwS/jh3h0UfN8q5d4cABiIuD+HgzDRoEn31m9vnAA9nT8tBDJhApBU8+6Zjv62sCTnCwKbElJ5v5aWmmujY62qxz8aJj/YgI8+jnZx59feHQIUdAVMqkt0sXR9Czjr9hQ1i2zOzLWhYYaALovHmO4LR4sSlFKuVI64oVcPvtpqSbNZ+dA9vcubl/VtdeayZLq1awcKFj29tvz/tz9nUKg+3amerlS/l+XM62SDAWQniCOwHRqqJUygSGfftMsNq4EV56yQQ2q13v9GnTTnjxoqOEqDV89JF5r7ffNn/QJ06YYJmebkpCa9aY/T/1lAkoVslPa9PW+OOP5vXw4SaY7N1rlikFTZuaqkkwbZPr19PWOf3XXWfeH0zJNSXFPE9NhR9+MGm3gnGlShAebkprYWHmsUULs2zFCpNWK80jRpj0RESY5TYbnDljArDdbl4757O/f95BrUIFx3Fm/Yxmz3Zs+9prOX9W4eE5l4yjojwT2C4zKNKuXcG3KYRtJRgLUZoVVimxeXM4edK1TS8xEdq0McFl2zbzB5mYCDt3mjbCtDQTPJYtMyW2d95xlAzj4kwnmxMnTND58EMTDLKy2uaOHjXrWAHJ39+U0CxWeuLiTFADE1Stdr1y5aBKFTPfKrlVrOjYvlIlEzydq2mtYAimZHr6NLEHDlD/mmtMqdR5+//+F+66ywTknDr3fPpp7vncqZM5JisoDhgAjRq5rlOmTM7bejKoWe/hgcB2Wdt6iARjIbzBpQZFrU1b4tKlUK8eXHVVZpUpCQmm9Fa3rukQ88EHrsv++Qc2bTIlNV9fU31pt5sgaU1TpjiCR79+Zl2rLc7qmGO3wzPPwLhx2dP3+++m2nDtWnjssezLU1LMcderZ0pS9es7SodhYWbfYHq+duxolu3ZY6ptrcBmdS56773c8+mFF8yUW7teXm2CYI6ta1fXbV96ybH87rsBOLZiBfVzaiu87TZz0uGFJb0jR8xHO3u2OecoyLbFUZ7HewX2DfXr5bRMgrEQhSVrQE1LM8HL6oSzfbspnSUmOkqP1aubdTp3NiVBHx/TxhUWZpb36GFKQmfPmvd1LnUmJcF995lSldXJJqv33jPB+MwZeOMN0/vVmhISTECztvP3N8HQx8cxlStn1qtUyZTsrPkbNphq4fR0E5zi42HaNEebYVCQKR1efbV57/79zXEFB5sTgC5dsneS6dMn97ytV89MYNatVcvrqj+PHIHHHmvGzz/n8ifvpUHx5ZdNbfpLL5nKhYK4nMCWb34V0X49dbzWviEoJKdlEoxFyXM5pczVq03bYLNmppTp4+OoEpw507RdxsWZSx3i4kzwevZZs8/27R2BzdfXBNn+/c12YC7PSEx03eewYaaXaHKy2X9qqrmko1w5E9CaNzfr2e0mcFtBzno8edLRycZmM/sbOtQRcKOizB/Iww2Y/U8KlSo7dZ7JWkr8739zzq8VK0wnnQ8+yH3bO+7IO6+tAA2ms00+ATHfPz0vrP58+WWIiQm/pD95a/tLDRKXIjDQnP9ZJk82k91uKlKCgsxz5/5WhZnmy8mv3PartTlvPX4cjh1znSZMMOfHFut4bTbzMyxb1kxlyuT8GB7uut8PPjCd1OPiHJPVwpJ13r//7brvnCid9dqsK6Rly5Z6w4YNuS5fsWIFnQrYNby0KnF5lVcwTUsznXVOnDC/uPh4UwUI5vKIOXNMQLV6gLZp4xgNZ8gQWL6ci4mJ2G02UyqsV89x8X6jRqZt01m7dvDrr+Z548awdav59YaFmV/nTTeZTkITJsCYMY7OPddfb0p/TZpAr14muHQ5yezntlApyscRnCpWNNWuzoEtS3tinoEpa1DMYaCBhx+GqVNN59tsf3r5nLgcOQLdup3l55/L5LzvywmoecgzzUXInTSfO2dq+A8eNI8PP+y4NNWZzWa+Hv7+jmZs69H5+b//nfP2AQHmz76wxcWZ881vvjFXADkH5KyUcq3ssKbNm3OuiPHxgYEDzbE7X6nkPE2enHNg8vExTe9K5T5NmJBzXillWjmOH3d0Hndms5nLnlNSHN0GbDYTaMPDzd/I2bM5v/fl8vU1rSvJyaYyS+uWaL0h2ymOlIxF0Vi71rSRtW1relrabCaIHjniGAnHmqx/rPXrzb/gpEmOS0BuuAF+/tn82h5/HN5/nyPpFenHV8ymH5WC4h2lzb/+gthY1w46Vg9WMMHUx4dTx49TpXp18yuJinIsv/pqU5WstUnvPffAyJGZi498uYJ+w8sw+xubawkTTEAKCHAExddfdwlQL7wAq2MjeGZxFyZNMqUSP7+MUkdkZJ4lxbxKAQmN2nHqizWcWvYHJ2u34dTeBpzaYArMr76ac0nAx8ecN4SEQHBwO0Ki2hFyBEIWudZih4TAiy+a0sv
"text/plain": [
"<Figure size 576x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code shows how to shift the training curve by -1/2 epoch\n",
"plt.figure(figsize=(8, 5))\n",
"for key, style in zip(history.history, [\"r--\", \"r--.\", \"b-\", \"b-*\"]):\n",
" epochs = np.array(history.epoch) + (0 if key.startswith(\"val_\") else -0.5)\n",
" plt.plot(epochs, history.history[key], style, label=key)\n",
"plt.xlabel(\"Epoch\")\n",
"plt.axis([-0.5, 29, 0., 1])\n",
"plt.legend(loc=\"lower left\")\n",
"plt.grid()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"313/313 [==============================] - 0s 867us/step - loss: 0.3243 - sparse_categorical_accuracy: 0.8864\n"
]
},
{
"data": {
"text/plain": [
"[0.32431697845458984, 0.8863999843597412]"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.evaluate(X_test, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using the model to make predictions"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0. , 0. , 0. , 0. , 0. , 0.01, 0. , 0.02, 0. , 0.97],\n",
" [0. , 0. , 0.99, 0. , 0.01, 0. , 0. , 0. , 0. , 0. ],\n",
" [0. , 1. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ]],\n",
" dtype=float32)"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_new = X_test[:3]\n",
"y_proba = model.predict(X_new)\n",
"y_proba.round(2)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([9, 2, 1])"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred = y_proba.argmax(axis=-1)\n",
"y_pred"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array(['Ankle boot', 'Pullover', 'Trouser'], dtype='<U11')"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.array(class_names)[y_pred]"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([9, 2, 1], dtype=uint8)"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_new = y_test[:3]\n",
"y_new"
]
},
2019-07-13 11:31:17 +02:00
{
"cell_type": "code",
"execution_count": 48,
2019-07-13 11:31:17 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAACVCAYAAAAe9i7zAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAYOUlEQVR4nO3de7BV1X0H8O8PBXmDgAIivRQBIT4AY7HgC4OpDzAGYqO2wdhGSnVSOzFOmZqRaqYZZtDOqH84NomvTCJ2RAwmqQGxPsIzNihCeAtcUOT9unIFQVb/2Jv0rO9enL3v4cI69/L9zNzh/s7ZZ+3N2eucdff67bWWOecgIiJysrWIfQAiInJqUgMkIiJRqAESEZEo1ACJiEgUaoBERCQKNUAiIhJFk2+AzGykmTkz61ZmG2dmtxznfo67DGkazOxOM/v0WLGINI7oDZCZDTWzL8xsXuxjic3M+qQN3aWxj6UpM7Pn0vfRmdkhM1tnZo+aWbvYxybVraTeHOvnudjH2JycHvsAAEwA8CSAO8xskHNuRewDkmZhDoDxAFoCuBLATwG0A3B3zIM6HmbW0jl3KPZxNHM9S34fA+An9NhnpRtX8zmp5mM7KuoVkJm1AfA3SE7ydADfoeePXhF8w8xeN7N6M1tuZl8tU+YZZvaKmS02s7OPsU0vM3vRzHanP78xs/4FDrlHum29mdWa2beo3IvMbI6ZfWZmu9K/xDuVPN/CzB40s01mdtDMlprZzSVFrE//fTf9f79V4Jgk7KBzbotzbpNz7gUAvwDwdTN7yMyWlW5YSRebmU00s7Vm9nn674SS56aZ2cu0fYv0vH8vjc3M/sXMPkzry9LS+lRS9283s/8xs88ATKzgfZAGSOvMFufcFgB7Sh8D0BrAHj4neZ/rY/VscLe+mU1Ov1cOmtkWM/tZyXPNs74456L9IPkLdUn6+0gA2wC0LHm+DwAHYCWAmwD0B/A8gJ0A2pe8zgHoBqAjgDcBvA2gY0k5DsAt6e9tAawG8ByAiwEMRPLXcS2AtmWO1aX7nQhgAIAfADgC4NKScj8G8EsAFwG4Ot3PyyVlfA/APiSN7gAAPwTwBYAh6fN/ke7nOgA9AHSJeX6a6k96bn9Njz0BYAeAhwAso+fuBPBpA+KxAA4B+G56Hv8pjW9Knx8N4ACAziWvuQbAYQA90vhHAFYBuB7An6d1Yj+A0VT3NwC4Jd3m3Njv7an0k77vriQOnpMCn+ujr7uUyi/9XvpGWsZoAH8G4FIA3y3ZtlnWl9gn+G0A96e/W/rmfSNwwieWPNYrfeyKNB6ZxoMA/AHAqwBalznRfw9gDQAref40JI3LN8scqwPwE3psDoCfp79PALAXQIeS548eW780/hjAZCrjrZIyghVVPw2uV8+hpAECMAxJ4/NfaJwGaB6AZwL7nJv+fjqSP6a+U/L8TwHMSn9vh6Qr50oq4zEA/0114fux389T9QfHboC+T9tV9Lmm76X7kDQwLQPH0WzrS7QuODPrB+ByAC8A6VlOuknuCmz+Qcnvm9N/uXttNoCPAIxzzh0os+svI/nroM7MPk27XvYCOBPAeTmHvSAQfyn9fRCAD5xzdSXPz0dylfQlM+sI4BwkX16l5paUIY3n+vT8HkBynt5BcqXSGAahzHl0zh1G0tj9LZB0CyP5C/fn6bZfQtKd89ujdTCth3cjWwf/t5GOWRrPn85JI36uX0JSJ9ab2dNm9tdpvQGacX2JeRPCXUiuPDaa2dHHDADMrLdzblPJtn9KpDnnXLo9N56/BvBNJN1f75XZbwsA7wO4LfDcruKHn2FI/gIJccf4vdxjcnzeAfAPSOrOZpcmY83sCNJ6VqJlBeXnncefA5hvZr0AXAagFYBX0ueO1t2bAGykMjhpvL+CY5MTK3ROytWHI+m///9FZ+bVOefcJjM7H8AoANcC+A8A/2Zml6EZ15coV0BmdjqAbwP4VwBDSn4GI7na+bsKin0QwFMA5pjZkDLbLQbQD8AO59xa+slrgP4yEB+9a285gMFm1qHk+RFI3uMVzrl9SK7erqAyrkhfCwCfp/+elnMckq8+Pae1zr8TaDuA7lbyVw+SutcQK1D+PMI5twjAhwBuR3Il9Evn3NEbHZYDOAigJlAHaxt4LBJRwc/19vTf0rvphgTKOuCc+41z7ntI8sEXIOklarb1JdYV0GgkNw38xDm3s/QJM3sRwN1m9u8NLdQ594P0i2WOmY1yzi0JbPYLAPcDmGlmk5H8RdEbwM0AnnLOrSmzi3Fm9i6S/t1bkPy1cllJuQ8D+Fla7pkA/hPADOfc2nSbRwD80MzWIMlXfQvJLcJfTp/fhqSv9zoz2wDggHNubwPfBinvLQBdADyQ1rWRSM5lQzwC4CUz+wOSrt/rkTQy42i7o13KfZDcuAAAcM7VmdmjAB5N6+s7ANoj+YPmiHPuxw08Homr7OfaOfeZmS0EMMnMPgTQCcCU0gLM7E4k38eLAHwK4FYkVzdrmnV9iZF4QnKjwOxjPNcXyaXrX6FY8m5kGncreX4KkqTzYN4+jbsDeBbJF/5BJLc/P1NaRuC4HJK7nn6LpJHYCODbtM1FAN5In9+NJDHdqeT5Fkiu1DYhudpZCuDrVMZdadlfAHgrxvlp6j8I3AVHz09EctfjfgAvAvhnNOAmhPSxfwSwFsmXxFoAEwL7OS+tN1sBnE7PGZKc1NG/brcDeB3AV9Png3VfPye1Hh3rJgT+PiryuT6aN6xPn7+Svse+jiRXuSetl+8CGNPc64ulBy8iInJSRZ+KR0RETk1qgEREJAo1QCIiEoUaIBERiUINkIiIRJE3Dki3yDVfPBtAY2oS9aauri7z2O9//3svHjVq1HHvZ/HixV7cvn17Lx4wYMBx7+Mkavb1hu8M9scsA2+88UbmNU888YQXDxkyxIu3bNnixf369cuU8emn/oTsu3fv9uLTT/e/rtevX58p45VXXsk8ViWC9UZXQCIiEoUaIBERiSJvIGpVXBLLCdHsulIOHPAnQX/ssce8eNq0aV7MXRwAsH37di9u06ZN7mvytG7dumzMXSsAcNVVV3nxhAkTvPj6669v8HE0kmZXb9iRI0e8uEUL/+/0K67gad+AefN4MuzyOnbsmHmsvr7eiw8fPuzFXBc/+8xbnBUA8Ktf/cqLx4wZ06DjOoHUBSciItVDDZCIiEShBkhERKJQDujU1aT78idNmpR57Mc/9mel37dvnxe3bdvWi7lPHcjmY7if/dAhf/2vL774IlPGGWec4cW8H/7MHTx4MFMG75f3M3z4cC9+5513MmWcIE263jSGDh06ZB5r2dJf0/Css87y4v37/XXiQvWGc4NcJtebtWvXgj3yyCNefP/992e2iUQ5IBERqR5qgEREJAo1QCIiEoUaIBERiSJvLjiRqsA3GEydOjWzTY8ePby4Xbt2XsxzeoVuwOGbDPIGkXKZQHbgIg8oZFwmkJ0v7rTTTvNiHvh40003ZcrgQYnSOHjONgDo1q2bF/MNMDy4lW9UCW3D+wm9hm3atCl3m2qiKyAREYlCDZCIiEShBkhERKJQDkiahAcffNCLQ5M5cj6GB/vxmiwhnTt39uK8iUND+QCeFLVr165ljys0GSkPTuV8Vffu3b04NBB1x44dXsx5Cilm69atudvwOQzlBkuF8oI88JTzflxm6DOwbdu2svutNroCEhGRKNQAiYhIFGqAREQkCuWApEnYu3evF4fGRHCehHM+d999txdPnDgxU8Yll1zixTyW6KOPPvLi0MSUNTU1Xsw5BD52LhMAevXqVfY1dXV1XhxanGzdunVerBxQZZYtW5a7TatWrbyYzwfnc0J5Px4HxPW5yFgizvtVO10BiYhIFGqAREQkCjVAIiIShXJA0iTwuJjQ/Gk5iytiypQpXtypU6fMNtzPXl9f78UjR4704jfffLPsPgFg0KBBXrxy5Uov5nnDAODxxx/3Yh4HxQuehRY4mzt3rhcPGzYs91gla8mSJV7M+R4gWx+53vDYMM5pAtnxYnlzF4YWMuScZbXTFZCIiEShBkhERKJQAyQiIlGoARIRkSh
"text/plain": [
"<Figure size 518.4x172.8 with 3 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2019-07-13 11:31:17 +02:00
"source": [
"# extra code this cell generates and saves Figure 1012\n",
2019-07-13 11:31:17 +02:00
"plt.figure(figsize=(7.2, 2.4))\n",
"for index, image in enumerate(X_new):\n",
" plt.subplot(1, 3, index + 1)\n",
" plt.imshow(image, cmap=\"binary\", interpolation=\"nearest\")\n",
" plt.axis('off')\n",
" plt.title(class_names[y_test[index]])\n",
2019-07-13 11:31:17 +02:00
"plt.subplots_adjust(wspace=0.2, hspace=0.5)\n",
"save_fig('fashion_mnist_images_plot', tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Building a Regression MLP Using the Sequential API"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's load, split and scale the California housing dataset (the original one, not the modified one as in chapter 2):"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"# extra code load and split the California housing dataset, like earlier\n",
"housing = fetch_california_housing()\n",
"X_train_full, X_test, y_train_full, y_test = train_test_split(\n",
" housing.data, housing.target, random_state=42)\n",
"X_train, X_valid, y_train, y_valid = train_test_split(\n",
" X_train_full, y_train_full, random_state=42)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 50,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.9051 - root_mean_squared_error: 0.9514 - val_loss: 0.4030 - val_root_mean_squared_error: 0.6348\n",
"Epoch 2/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3843 - root_mean_squared_error: 0.6199 - val_loss: 0.8436 - val_root_mean_squared_error: 0.9185\n",
"Epoch 3/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3609 - root_mean_squared_error: 0.6007 - val_loss: 0.3744 - val_root_mean_squared_error: 0.6119\n",
"Epoch 4/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3416 - root_mean_squared_error: 0.5844 - val_loss: 0.4343 - val_root_mean_squared_error: 0.6590\n",
"Epoch 5/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3301 - root_mean_squared_error: 0.5746 - val_loss: 0.3085 - val_root_mean_squared_error: 0.5554\n",
"Epoch 6/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3168 - root_mean_squared_error: 0.5629 - val_loss: 0.4544 - val_root_mean_squared_error: 0.6741\n",
"Epoch 7/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3162 - root_mean_squared_error: 0.5623 - val_loss: 0.2941 - val_root_mean_squared_error: 0.5423\n",
"Epoch 8/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3045 - root_mean_squared_error: 0.5518 - val_loss: 0.3333 - val_root_mean_squared_error: 0.5773\n",
"Epoch 9/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2974 - root_mean_squared_error: 0.5453 - val_loss: 0.3446 - val_root_mean_squared_error: 0.5870\n",
"Epoch 10/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2921 - root_mean_squared_error: 0.5404 - val_loss: 0.2874 - val_root_mean_squared_error: 0.5361\n",
"Epoch 11/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2863 - root_mean_squared_error: 0.5351 - val_loss: 0.4141 - val_root_mean_squared_error: 0.6435\n",
"Epoch 12/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2942 - root_mean_squared_error: 0.5424 - val_loss: 1.0956 - val_root_mean_squared_error: 1.0467\n",
"Epoch 13/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2864 - root_mean_squared_error: 0.5352 - val_loss: 0.3063 - val_root_mean_squared_error: 0.5534\n",
"Epoch 14/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2804 - root_mean_squared_error: 0.5295 - val_loss: 0.2709 - val_root_mean_squared_error: 0.5205\n",
"Epoch 15/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2784 - root_mean_squared_error: 0.5276 - val_loss: 0.3680 - val_root_mean_squared_error: 0.6066\n",
"Epoch 16/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2757 - root_mean_squared_error: 0.5250 - val_loss: 0.2730 - val_root_mean_squared_error: 0.5225\n",
"Epoch 17/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2739 - root_mean_squared_error: 0.5234 - val_loss: 0.3668 - val_root_mean_squared_error: 0.6056\n",
"Epoch 18/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2694 - root_mean_squared_error: 0.5191 - val_loss: 0.4188 - val_root_mean_squared_error: 0.6472\n",
"Epoch 19/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2677 - root_mean_squared_error: 0.5174 - val_loss: 0.9663 - val_root_mean_squared_error: 0.9830\n",
"Epoch 20/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.2755 - root_mean_squared_error: 0.5249 - val_loss: 0.2978 - val_root_mean_squared_error: 0.5457\n",
"162/162 [==============================] - 0s 508us/step - loss: 0.2806 - root_mean_squared_error: 0.5297\n"
]
}
],
"source": [
"tf.random.set_seed(42)\n",
"norm_layer = tf.keras.layers.Normalization(input_shape=X_train.shape[1:])\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" norm_layer,\n",
" tf.keras.layers.Dense(50, activation=\"relu\"),\n",
" tf.keras.layers.Dense(50, activation=\"relu\"),\n",
" tf.keras.layers.Dense(50, activation=\"relu\"),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.Dense(1)\n",
"])\n",
"optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
"model.compile(loss=\"mse\", optimizer=optimizer, metrics=[\"RootMeanSquaredError\"])\n",
"norm_layer.adapt(X_train)\n",
"history = model.fit(X_train, y_train, epochs=20,\n",
" validation_data=(X_valid, y_valid))\n",
"mse_test, rmse_test = model.evaluate(X_test, y_test)\n",
"X_new = X_test[:3]\n",
"y_pred = model.predict(X_new)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 51,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.5297096967697144"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rmse_test"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 52,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0.4969182],\n",
" [1.195265 ],\n",
" [4.9428763]], dtype=float32)"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_pred"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Building Complex Models Using the Functional API"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not all neural network models are simply sequential. Some may have complex topologies. Some may have multiple inputs and/or multiple outputs. For example, a Wide & Deep neural network (see [paper](https://ai.google/research/pubs/pub45413)) connects all or part of the inputs directly to the output layer."
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"# extra code reset the name counters and make the code reproducible\n",
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"normalization_layer = tf.keras.layers.Normalization()\n",
"hidden_layer1 = tf.keras.layers.Dense(30, activation=\"relu\")\n",
"hidden_layer2 = tf.keras.layers.Dense(30, activation=\"relu\")\n",
"concat_layer = tf.keras.layers.Concatenate()\n",
"output_layer = tf.keras.layers.Dense(1)\n",
"\n",
2021-10-17 04:04:08 +02:00
"input_ = tf.keras.layers.Input(shape=X_train.shape[1:])\n",
"normalized = normalization_layer(input_)\n",
"hidden1 = hidden_layer1(normalized)\n",
"hidden2 = hidden_layer2(hidden1)\n",
"concat = concat_layer([normalized, hidden2])\n",
"output = output_layer(concat)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Model(inputs=[input_], outputs=[output])"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 55,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" input_1 (InputLayer) [(None, 8)] 0 [] \n",
" \n",
" normalization (Normalization) (None, 8) 17 ['input_1[0][0]'] \n",
" \n",
" dense (Dense) (None, 30) 270 ['normalization[0][0]'] \n",
" \n",
" dense_1 (Dense) (None, 30) 930 ['dense[0][0]'] \n",
" \n",
" concatenate (Concatenate) (None, 38) 0 ['input_1[0][0]', \n",
" 'dense_1[0][0]'] \n",
" \n",
" dense_2 (Dense) (None, 1) 39 ['concatenate[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 1,256\n",
"Trainable params: 1,239\n",
"Non-trainable params: 17\n",
"__________________________________________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 56,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"363/363 [==============================] - 1s 1ms/step - loss: 122.3226 - root_mean_squared_error: 11.0600 - val_loss: 305.9134 - val_root_mean_squared_error: 17.4904\n",
"Epoch 2/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 5.5425 - root_mean_squared_error: 2.3543 - val_loss: 183.4622 - val_root_mean_squared_error: 13.5448\n",
"Epoch 3/20\n",
"363/363 [==============================] - 0s 979us/step - loss: 3.0631 - root_mean_squared_error: 1.7502 - val_loss: 87.2228 - val_root_mean_squared_error: 9.3393\n",
"Epoch 4/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 1.5796 - root_mean_squared_error: 1.2568 - val_loss: 35.3699 - val_root_mean_squared_error: 5.9473\n",
"Epoch 5/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.9536 - root_mean_squared_error: 0.9765 - val_loss: 12.3882 - val_root_mean_squared_error: 3.5197\n",
"Epoch 6/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.6322 - root_mean_squared_error: 0.7951 - val_loss: 4.1676 - val_root_mean_squared_error: 2.0415\n",
"Epoch 7/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.5069 - root_mean_squared_error: 0.7120 - val_loss: 1.2937 - val_root_mean_squared_error: 1.1374\n",
"Epoch 8/20\n",
"363/363 [==============================] - 0s 980us/step - loss: 0.4525 - root_mean_squared_error: 0.6727 - val_loss: 0.4837 - val_root_mean_squared_error: 0.6955\n",
"Epoch 9/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4293 - root_mean_squared_error: 0.6552 - val_loss: 0.4343 - val_root_mean_squared_error: 0.6590\n",
"Epoch 10/20\n",
"363/363 [==============================] - 0s 962us/step - loss: 0.4120 - root_mean_squared_error: 0.6419 - val_loss: 0.3996 - val_root_mean_squared_error: 0.6321\n",
"Epoch 11/20\n",
"363/363 [==============================] - 0s 988us/step - loss: 0.4203 - root_mean_squared_error: 0.6483 - val_loss: 0.4149 - val_root_mean_squared_error: 0.6441\n",
"Epoch 12/20\n",
"363/363 [==============================] - 0s 952us/step - loss: 0.3916 - root_mean_squared_error: 0.6257 - val_loss: 0.4569 - val_root_mean_squared_error: 0.6759\n",
"Epoch 13/20\n",
"363/363 [==============================] - 0s 957us/step - loss: 0.4147 - root_mean_squared_error: 0.6440 - val_loss: 0.3736 - val_root_mean_squared_error: 0.6113\n",
"Epoch 14/20\n",
"363/363 [==============================] - 0s 949us/step - loss: 0.3824 - root_mean_squared_error: 0.6184 - val_loss: 0.4550 - val_root_mean_squared_error: 0.6745\n",
"Epoch 15/20\n",
"363/363 [==============================] - 0s 982us/step - loss: 0.4003 - root_mean_squared_error: 0.6327 - val_loss: 0.8553 - val_root_mean_squared_error: 0.9248\n",
"Epoch 16/20\n",
"363/363 [==============================] - 0s 960us/step - loss: 0.4245 - root_mean_squared_error: 0.6516 - val_loss: 1.9204 - val_root_mean_squared_error: 1.3858\n",
"Epoch 17/20\n",
"363/363 [==============================] - 0s 987us/step - loss: 0.4580 - root_mean_squared_error: 0.6767 - val_loss: 2.0632 - val_root_mean_squared_error: 1.4364\n",
"Epoch 18/20\n",
"363/363 [==============================] - 0s 961us/step - loss: 0.4692 - root_mean_squared_error: 0.6850 - val_loss: 3.5730 - val_root_mean_squared_error: 1.8902\n",
"Epoch 19/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4367 - root_mean_squared_error: 0.6608 - val_loss: 3.9989 - val_root_mean_squared_error: 1.9997\n",
"Epoch 20/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4683 - root_mean_squared_error: 0.6843 - val_loss: 2.2966 - val_root_mean_squared_error: 1.5155\n",
"162/162 [==============================] - 0s 612us/step - loss: 0.5723 - root_mean_squared_error: 0.7565\n"
]
}
],
"source": [
"optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
"model.compile(loss=\"mse\", optimizer=optimizer, metrics=[\"RootMeanSquaredError\"])\n",
"normalization_layer.adapt(X_train)\n",
"history = model.fit(X_train, y_train, epochs=20,\n",
" validation_data=(X_valid, y_valid))\n",
"mse_test = model.evaluate(X_test, y_test)\n",
"y_pred = model.predict(X_new)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What if you want to send different subsets of input features through the wide or deep paths? We will send 5 features (features 0 to 4), and 6 through the deep path (features 2 to 7). Note that 3 features will go through both (features 2, 3 and 4)."
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42) # extra code"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"input_wide = tf.keras.layers.Input(shape=[5]) # features 0 to 4\n",
"input_deep = tf.keras.layers.Input(shape=[6]) # features 2 to 7\n",
"norm_layer_wide = tf.keras.layers.Normalization()\n",
"norm_layer_deep = tf.keras.layers.Normalization()\n",
"norm_wide = norm_layer_wide(input_wide)\n",
"norm_deep = norm_layer_deep(input_deep)\n",
"hidden1 = tf.keras.layers.Dense(30, activation=\"relu\")(norm_deep)\n",
2021-10-17 04:04:08 +02:00
"hidden2 = tf.keras.layers.Dense(30, activation=\"relu\")(hidden1)\n",
"concat = tf.keras.layers.concatenate([norm_wide, hidden2])\n",
"output = tf.keras.layers.Dense(1)(concat)\n",
"model = tf.keras.Model(inputs=[input_wide, input_deep], outputs=[output])"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 59,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"363/363 [==============================] - 1s 2ms/step - loss: 1.2768 - root_mean_squared_error: 1.1300 - val_loss: 0.9497 - val_root_mean_squared_error: 0.9745\n",
"Epoch 2/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4767 - root_mean_squared_error: 0.6904 - val_loss: 1.4311 - val_root_mean_squared_error: 1.1963\n",
"Epoch 3/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4433 - root_mean_squared_error: 0.6658 - val_loss: 0.4258 - val_root_mean_squared_error: 0.6525\n",
"Epoch 4/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4057 - root_mean_squared_error: 0.6370 - val_loss: 0.4016 - val_root_mean_squared_error: 0.6338\n",
"Epoch 5/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3940 - root_mean_squared_error: 0.6277 - val_loss: 1.4914 - val_root_mean_squared_error: 1.2212\n",
"Epoch 6/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3873 - root_mean_squared_error: 0.6224 - val_loss: 2.6759 - val_root_mean_squared_error: 1.6358\n",
"Epoch 7/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3914 - root_mean_squared_error: 0.6257 - val_loss: 3.0592 - val_root_mean_squared_error: 1.7490\n",
"Epoch 8/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3735 - root_mean_squared_error: 0.6112 - val_loss: 3.3043 - val_root_mean_squared_error: 1.8178\n",
"Epoch 9/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3712 - root_mean_squared_error: 0.6093 - val_loss: 2.1298 - val_root_mean_squared_error: 1.4594\n",
"Epoch 10/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3693 - root_mean_squared_error: 0.6077 - val_loss: 1.7402 - val_root_mean_squared_error: 1.3192\n",
"Epoch 11/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3578 - root_mean_squared_error: 0.5982 - val_loss: 0.6127 - val_root_mean_squared_error: 0.7827\n",
"Epoch 12/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3605 - root_mean_squared_error: 0.6005 - val_loss: 1.3970 - val_root_mean_squared_error: 1.1819\n",
"Epoch 13/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3527 - root_mean_squared_error: 0.5939 - val_loss: 0.9449 - val_root_mean_squared_error: 0.9721\n",
"Epoch 14/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3436 - root_mean_squared_error: 0.5861 - val_loss: 0.7757 - val_root_mean_squared_error: 0.8807\n",
"Epoch 15/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3421 - root_mean_squared_error: 0.5849 - val_loss: 0.8920 - val_root_mean_squared_error: 0.9445\n",
"Epoch 16/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3405 - root_mean_squared_error: 0.5835 - val_loss: 0.9334 - val_root_mean_squared_error: 0.9661\n",
"Epoch 17/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3394 - root_mean_squared_error: 0.5826 - val_loss: 1.3433 - val_root_mean_squared_error: 1.1590\n",
"Epoch 18/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3384 - root_mean_squared_error: 0.5817 - val_loss: 2.6406 - val_root_mean_squared_error: 1.6250\n",
"Epoch 19/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3459 - root_mean_squared_error: 0.5881 - val_loss: 2.2482 - val_root_mean_squared_error: 1.4994\n",
"Epoch 20/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3503 - root_mean_squared_error: 0.5919 - val_loss: 1.4407 - val_root_mean_squared_error: 1.2003\n",
"162/162 [==============================] - 0s 672us/step - loss: 0.3388 - root_mean_squared_error: 0.5821\n"
]
}
],
"source": [
"optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
"model.compile(loss=\"mse\", optimizer=optimizer, metrics=[\"RootMeanSquaredError\"])\n",
"\n",
"X_train_wide, X_train_deep = X_train[:, :5], X_train[:, 2:]\n",
"X_valid_wide, X_valid_deep = X_valid[:, :5], X_valid[:, 2:]\n",
"X_test_wide, X_test_deep = X_test[:, :5], X_test[:, 2:]\n",
"X_new_wide, X_new_deep = X_test_wide[:3], X_test_deep[:3]\n",
"\n",
"norm_layer_wide.adapt(X_train_wide)\n",
"norm_layer_deep.adapt(X_train_deep)\n",
"history = model.fit((X_train_wide, X_train_deep), y_train, epochs=20,\n",
" validation_data=((X_valid_wide, X_valid_deep), y_valid))\n",
"mse_test = model.evaluate((X_test_wide, X_test_deep), y_test)\n",
"y_pred = model.predict((X_new_wide, X_new_deep))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Adding an auxiliary output for regularization:"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"input_wide = tf.keras.layers.Input(shape=[5]) # features 0 to 4\n",
"input_deep = tf.keras.layers.Input(shape=[6]) # features 2 to 7\n",
"norm_layer_wide = tf.keras.layers.Normalization()\n",
"norm_layer_deep = tf.keras.layers.Normalization()\n",
"norm_wide = norm_layer_wide(input_wide)\n",
"norm_deep = norm_layer_deep(input_deep)\n",
"hidden1 = tf.keras.layers.Dense(30, activation=\"relu\")(norm_deep)\n",
2021-10-17 04:04:08 +02:00
"hidden2 = tf.keras.layers.Dense(30, activation=\"relu\")(hidden1)\n",
"concat = tf.keras.layers.concatenate([norm_wide, hidden2])\n",
"output = tf.keras.layers.Dense(1)(concat)\n",
"aux_output = tf.keras.layers.Dense(1)(hidden2)\n",
"model = tf.keras.Model(inputs=[input_wide, input_deep],\n",
" outputs=[output, aux_output])"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
"model.compile(loss=(\"mse\", \"mse\"), loss_weights=(0.9, 0.1), optimizer=optimizer,\n",
" metrics=[\"RootMeanSquaredError\"])"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 63,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"363/363 [==============================] - 1s 2ms/step - loss: 1.3490 - dense_2_loss: 1.2742 - dense_3_loss: 2.0215 - dense_2_root_mean_squared_error: 1.1288 - dense_3_root_mean_squared_error: 1.4218 - val_loss: 1.5415 - val_dense_2_loss: 0.9593 - val_dense_3_loss: 6.7806 - val_dense_2_root_mean_squared_error: 0.9795 - val_dense_3_root_mean_squared_error: 2.6040\n",
"Epoch 2/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.5101 - dense_2_loss: 0.4785 - dense_3_loss: 0.7952 - dense_2_root_mean_squared_error: 0.6917 - dense_3_root_mean_squared_error: 0.8917 - val_loss: 1.3624 - val_dense_2_loss: 1.0094 - val_dense_3_loss: 4.5401 - val_dense_2_root_mean_squared_error: 1.0047 - val_dense_3_root_mean_squared_error: 2.1307\n",
"Epoch 3/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4618 - dense_2_loss: 0.4404 - dense_3_loss: 0.6546 - dense_2_root_mean_squared_error: 0.6636 - dense_3_root_mean_squared_error: 0.8091 - val_loss: 0.5361 - val_dense_2_loss: 0.3975 - val_dense_3_loss: 1.7837 - val_dense_2_root_mean_squared_error: 0.6305 - val_dense_3_root_mean_squared_error: 1.3356\n",
"Epoch 4/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4252 - dense_2_loss: 0.4059 - dense_3_loss: 0.5985 - dense_2_root_mean_squared_error: 0.6371 - dense_3_root_mean_squared_error: 0.7736 - val_loss: 0.5182 - val_dense_2_loss: 0.4590 - val_dense_3_loss: 1.0517 - val_dense_2_root_mean_squared_error: 0.6775 - val_dense_3_root_mean_squared_error: 1.0255\n",
"Epoch 5/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4106 - dense_2_loss: 0.3931 - dense_3_loss: 0.5690 - dense_2_root_mean_squared_error: 0.6269 - dense_3_root_mean_squared_error: 0.7543 - val_loss: 0.4049 - val_dense_2_loss: 0.3588 - val_dense_3_loss: 0.8196 - val_dense_2_root_mean_squared_error: 0.5990 - val_dense_3_root_mean_squared_error: 0.9053\n",
"Epoch 6/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3944 - dense_2_loss: 0.3780 - dense_3_loss: 0.5424 - dense_2_root_mean_squared_error: 0.6148 - dense_3_root_mean_squared_error: 0.7365 - val_loss: 0.4168 - val_dense_2_loss: 0.3934 - val_dense_3_loss: 0.6275 - val_dense_2_root_mean_squared_error: 0.6272 - val_dense_3_root_mean_squared_error: 0.7921\n",
"Epoch 7/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3837 - dense_2_loss: 0.3694 - dense_3_loss: 0.5126 - dense_2_root_mean_squared_error: 0.6078 - dense_3_root_mean_squared_error: 0.7160 - val_loss: 0.3661 - val_dense_2_loss: 0.3430 - val_dense_3_loss: 0.5747 - val_dense_2_root_mean_squared_error: 0.5856 - val_dense_3_root_mean_squared_error: 0.7581\n",
"Epoch 8/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3731 - dense_2_loss: 0.3608 - dense_3_loss: 0.4840 - dense_2_root_mean_squared_error: 0.6007 - dense_3_root_mean_squared_error: 0.6957 - val_loss: 0.8555 - val_dense_2_loss: 0.8704 - val_dense_3_loss: 0.7218 - val_dense_2_root_mean_squared_error: 0.9330 - val_dense_3_root_mean_squared_error: 0.8496\n",
"Epoch 9/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3672 - dense_2_loss: 0.3567 - dense_3_loss: 0.4624 - dense_2_root_mean_squared_error: 0.5972 - dense_3_root_mean_squared_error: 0.6800 - val_loss: 2.6877 - val_dense_2_loss: 2.9011 - val_dense_3_loss: 0.7675 - val_dense_2_root_mean_squared_error: 1.7033 - val_dense_3_root_mean_squared_error: 0.8761\n",
"Epoch 10/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3837 - dense_2_loss: 0.3765 - dense_3_loss: 0.4481 - dense_2_root_mean_squared_error: 0.6136 - dense_3_root_mean_squared_error: 0.6694 - val_loss: 3.6017 - val_dense_2_loss: 3.8004 - val_dense_3_loss: 1.8132 - val_dense_2_root_mean_squared_error: 1.9495 - val_dense_3_root_mean_squared_error: 1.3466\n",
"Epoch 11/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3728 - dense_2_loss: 0.3656 - dense_3_loss: 0.4377 - dense_2_root_mean_squared_error: 0.6046 - dense_3_root_mean_squared_error: 0.6616 - val_loss: 0.6115 - val_dense_2_loss: 0.6325 - val_dense_3_loss: 0.4226 - val_dense_2_root_mean_squared_error: 0.7953 - val_dense_3_root_mean_squared_error: 0.6501\n",
"Epoch 12/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3750 - dense_2_loss: 0.3688 - dense_3_loss: 0.4303 - dense_2_root_mean_squared_error: 0.6073 - dense_3_root_mean_squared_error: 0.6560 - val_loss: 0.9371 - val_dense_2_loss: 0.9545 - val_dense_3_loss: 0.7799 - val_dense_2_root_mean_squared_error: 0.9770 - val_dense_3_root_mean_squared_error: 0.8831\n",
"Epoch 13/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3570 - dense_2_loss: 0.3499 - dense_3_loss: 0.4203 - dense_2_root_mean_squared_error: 0.5915 - dense_3_root_mean_squared_error: 0.6483 - val_loss: 0.4224 - val_dense_2_loss: 0.4245 - val_dense_3_loss: 0.4039 - val_dense_2_root_mean_squared_error: 0.6515 - val_dense_3_root_mean_squared_error: 0.6355\n",
"Epoch 14/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3493 - dense_2_loss: 0.3421 - dense_3_loss: 0.4148 - dense_2_root_mean_squared_error: 0.5849 - dense_3_root_mean_squared_error: 0.6440 - val_loss: 0.3410 - val_dense_2_loss: 0.3221 - val_dense_3_loss: 0.5105 - val_dense_2_root_mean_squared_error: 0.5676 - val_dense_3_root_mean_squared_error: 0.7145\n",
"Epoch 15/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3496 - dense_2_loss: 0.3432 - dense_3_loss: 0.4076 - dense_2_root_mean_squared_error: 0.5858 - dense_3_root_mean_squared_error: 0.6384 - val_loss: 0.6461 - val_dense_2_loss: 0.6671 - val_dense_3_loss: 0.4570 - val_dense_2_root_mean_squared_error: 0.8168 - val_dense_3_root_mean_squared_error: 0.6760\n",
"Epoch 16/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3435 - dense_2_loss: 0.3370 - dense_3_loss: 0.4022 - dense_2_root_mean_squared_error: 0.5805 - dense_3_root_mean_squared_error: 0.6342 - val_loss: 0.6875 - val_dense_2_loss: 0.6841 - val_dense_3_loss: 0.7182 - val_dense_2_root_mean_squared_error: 0.8271 - val_dense_3_root_mean_squared_error: 0.8475\n",
"Epoch 17/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3458 - dense_2_loss: 0.3393 - dense_3_loss: 0.4037 - dense_2_root_mean_squared_error: 0.5825 - dense_3_root_mean_squared_error: 0.6354 - val_loss: 1.1564 - val_dense_2_loss: 1.2129 - val_dense_3_loss: 0.6483 - val_dense_2_root_mean_squared_error: 1.1013 - val_dense_3_root_mean_squared_error: 0.8052\n",
"Epoch 18/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3446 - dense_2_loss: 0.3385 - dense_3_loss: 0.3994 - dense_2_root_mean_squared_error: 0.5818 - dense_3_root_mean_squared_error: 0.6320 - val_loss: 3.9325 - val_dense_2_loss: 4.0947 - val_dense_3_loss: 2.4722 - val_dense_2_root_mean_squared_error: 2.0235 - val_dense_3_root_mean_squared_error: 1.5723\n",
"Epoch 19/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3563 - dense_2_loss: 0.3511 - dense_3_loss: 0.4029 - dense_2_root_mean_squared_error: 0.5925 - dense_3_root_mean_squared_error: 0.6347 - val_loss: 1.4560 - val_dense_2_loss: 1.5433 - val_dense_3_loss: 0.6697 - val_dense_2_root_mean_squared_error: 1.2423 - val_dense_3_root_mean_squared_error: 0.8183\n",
"Epoch 20/20\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3546 - dense_2_loss: 0.3498 - dense_3_loss: 0.3981 - dense_2_root_mean_squared_error: 0.5914 - dense_3_root_mean_squared_error: 0.6310 - val_loss: 1.1709 - val_dense_2_loss: 1.1945 - val_dense_3_loss: 0.9589 - val_dense_2_root_mean_squared_error: 1.0929 - val_dense_3_root_mean_squared_error: 0.9792\n"
]
}
],
"source": [
"norm_layer_wide.adapt(X_train_wide)\n",
"norm_layer_deep.adapt(X_train_deep)\n",
"history = model.fit(\n",
" (X_train_wide, X_train_deep), (y_train, y_train), epochs=20,\n",
" validation_data=((X_valid_wide, X_valid_deep), (y_valid, y_valid))\n",
")"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 64,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"162/162 [==============================] - 0s 778us/step - loss: 0.3446 - dense_2_loss: 0.3381 - dense_3_loss: 0.4031 - dense_2_root_mean_squared_error: 0.5815 - dense_3_root_mean_squared_error: 0.6349\n"
]
}
],
"source": [
"eval_results = model.evaluate((X_test_wide, X_test_deep), (y_test, y_test))\n",
"weighted_sum_of_losses, main_loss, aux_loss, main_rmse, aux_rmse = eval_results"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7fb250e69310> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n"
]
}
],
"source": [
"y_pred_main, y_pred_aux = model.predict((X_new_wide, X_new_deep))"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"y_pred_tuple = model.predict((X_new_wide, X_new_deep))\n",
"y_pred = dict(zip(model.output_names, y_pred_tuple))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using the Subclassing API to Build Dynamic Models"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"class WideAndDeepModel(tf.keras.Model):\n",
" def __init__(self, units=30, activation=\"relu\", **kwargs):\n",
" super().__init__(**kwargs) # needed to support naming the model\n",
" self.norm_layer_wide = tf.keras.layers.Normalization()\n",
" self.norm_layer_deep = tf.keras.layers.Normalization()\n",
2021-10-17 04:04:08 +02:00
" self.hidden1 = tf.keras.layers.Dense(units, activation=activation)\n",
" self.hidden2 = tf.keras.layers.Dense(units, activation=activation)\n",
" self.main_output = tf.keras.layers.Dense(1)\n",
" self.aux_output = tf.keras.layers.Dense(1)\n",
" \n",
" def call(self, inputs):\n",
" input_wide, input_deep = inputs\n",
" norm_wide = self.norm_layer_wide(input_wide)\n",
" norm_deep = self.norm_layer_deep(input_deep)\n",
" hidden1 = self.hidden1(norm_deep)\n",
" hidden2 = self.hidden2(hidden1)\n",
" concat = tf.keras.layers.concatenate([norm_wide, hidden2])\n",
" output = self.main_output(concat)\n",
" aux_output = self.aux_output(hidden2)\n",
" return output, aux_output\n",
"\n",
"tf.random.set_seed(42) # extra code just for reproducibility\n",
"model = WideAndDeepModel(30, activation=\"relu\", name=\"my_cool_model\")"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"363/363 [==============================] - 1s 2ms/step - loss: 1.3490 - output_1_loss: 1.2742 - output_2_loss: 2.0215 - output_1_root_mean_squared_error: 1.1288 - output_2_root_mean_squared_error: 1.4218 - val_loss: 1.5415 - val_output_1_loss: 0.9593 - val_output_2_loss: 6.7806 - val_output_1_root_mean_squared_error: 0.9795 - val_output_2_root_mean_squared_error: 2.6040\n",
"Epoch 2/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.5101 - output_1_loss: 0.4785 - output_2_loss: 0.7952 - output_1_root_mean_squared_error: 0.6917 - output_2_root_mean_squared_error: 0.8917 - val_loss: 1.3624 - val_output_1_loss: 1.0094 - val_output_2_loss: 4.5401 - val_output_1_root_mean_squared_error: 1.0047 - val_output_2_root_mean_squared_error: 2.1307\n",
"Epoch 3/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4618 - output_1_loss: 0.4404 - output_2_loss: 0.6546 - output_1_root_mean_squared_error: 0.6636 - output_2_root_mean_squared_error: 0.8091 - val_loss: 0.5361 - val_output_1_loss: 0.3975 - val_output_2_loss: 1.7837 - val_output_1_root_mean_squared_error: 0.6305 - val_output_2_root_mean_squared_error: 1.3356\n",
"Epoch 4/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4252 - output_1_loss: 0.4059 - output_2_loss: 0.5985 - output_1_root_mean_squared_error: 0.6371 - output_2_root_mean_squared_error: 0.7736 - val_loss: 0.5182 - val_output_1_loss: 0.4590 - val_output_2_loss: 1.0517 - val_output_1_root_mean_squared_error: 0.6775 - val_output_2_root_mean_squared_error: 1.0255\n",
"Epoch 5/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.4106 - output_1_loss: 0.3931 - output_2_loss: 0.5690 - output_1_root_mean_squared_error: 0.6269 - output_2_root_mean_squared_error: 0.7543 - val_loss: 0.4049 - val_output_1_loss: 0.3588 - val_output_2_loss: 0.8196 - val_output_1_root_mean_squared_error: 0.5990 - val_output_2_root_mean_squared_error: 0.9053\n",
"Epoch 6/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3944 - output_1_loss: 0.3780 - output_2_loss: 0.5424 - output_1_root_mean_squared_error: 0.6148 - output_2_root_mean_squared_error: 0.7365 - val_loss: 0.4168 - val_output_1_loss: 0.3934 - val_output_2_loss: 0.6275 - val_output_1_root_mean_squared_error: 0.6272 - val_output_2_root_mean_squared_error: 0.7921\n",
"Epoch 7/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3837 - output_1_loss: 0.3694 - output_2_loss: 0.5126 - output_1_root_mean_squared_error: 0.6078 - output_2_root_mean_squared_error: 0.7160 - val_loss: 0.3661 - val_output_1_loss: 0.3430 - val_output_2_loss: 0.5747 - val_output_1_root_mean_squared_error: 0.5856 - val_output_2_root_mean_squared_error: 0.7581\n",
"Epoch 8/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3731 - output_1_loss: 0.3608 - output_2_loss: 0.4840 - output_1_root_mean_squared_error: 0.6007 - output_2_root_mean_squared_error: 0.6957 - val_loss: 0.8555 - val_output_1_loss: 0.8704 - val_output_2_loss: 0.7218 - val_output_1_root_mean_squared_error: 0.9330 - val_output_2_root_mean_squared_error: 0.8496\n",
"Epoch 9/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3672 - output_1_loss: 0.3567 - output_2_loss: 0.4624 - output_1_root_mean_squared_error: 0.5972 - output_2_root_mean_squared_error: 0.6800 - val_loss: 2.6877 - val_output_1_loss: 2.9011 - val_output_2_loss: 0.7675 - val_output_1_root_mean_squared_error: 1.7033 - val_output_2_root_mean_squared_error: 0.8761\n",
"Epoch 10/10\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3837 - output_1_loss: 0.3765 - output_2_loss: 0.4481 - output_1_root_mean_squared_error: 0.6136 - output_2_root_mean_squared_error: 0.6694 - val_loss: 3.6017 - val_output_1_loss: 3.8004 - val_output_2_loss: 1.8132 - val_output_1_root_mean_squared_error: 1.9495 - val_output_2_root_mean_squared_error: 1.3466\n",
"162/162 [==============================] - 0s 781us/step - loss: 0.3652 - output_1_loss: 0.3570 - output_2_loss: 0.4387 - output_1_root_mean_squared_error: 0.5975 - output_2_root_mean_squared_error: 0.6624\n",
"WARNING:tensorflow:6 out of the last 7 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7fb250b9d820> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n"
]
}
],
"source": [
"optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)\n",
"model.compile(loss=\"mse\", loss_weights=[0.9, 0.1], optimizer=optimizer,\n",
" metrics=[\"RootMeanSquaredError\"])\n",
"model.norm_layer_wide.adapt(X_train_wide)\n",
"model.norm_layer_deep.adapt(X_train_deep)\n",
"history = model.fit(\n",
" (X_train_wide, X_train_deep), (y_train, y_train), epochs=10,\n",
" validation_data=((X_valid_wide, X_valid_deep), (y_valid, y_valid)))\n",
"eval_results = model.evaluate((X_test_wide, X_test_deep), (y_test, y_test))\n",
"weighted_sum_of_losses, main_loss, aux_loss, main_rmse, aux_rmse = eval_results\n",
"y_pred_main, y_pred_aux = model.predict((X_new_wide, X_new_deep))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Saving and Restoring a Model"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
"# extra code delete the directory, in case it already exists\n",
"\n",
"import shutil\n",
"\n",
"shutil.rmtree(\"my_keras_model\", ignore_errors=True)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 70,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_keras_model/assets\n"
]
}
],
"source": [
2022-09-12 01:46:42 +02:00
"model.save(\"my_keras_model\", save_format=\"tf\")"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 71,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"my_keras_model/assets\n",
"my_keras_model/keras_metadata.pb\n",
"my_keras_model/saved_model.pb\n",
"my_keras_model/variables\n",
"my_keras_model/variables/variables.data-00000-of-00001\n",
"my_keras_model/variables/variables.index\n"
]
}
],
"source": [
"# extra code show the contents of the my_keras_model/ directory\n",
"for path in sorted(Path(\"my_keras_model\").glob(\"**/*\")):\n",
" print(path)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"model = tf.keras.models.load_model(\"my_keras_model\")\n",
"y_pred_main, y_pred_aux = model.predict((X_new_wide, X_new_deep))"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"model.save_weights(\"my_weights\")"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 74,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fb210482490>"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.load_weights(\"my_weights\")"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 75,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"my_weights.data-00000-of-00001\n",
"my_weights.index\n"
]
}
],
"source": [
"# extra code show the list of my_weights.* files\n",
"for path in sorted(Path().glob(\"my_weights.*\")):\n",
" print(path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using Callbacks"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"shutil.rmtree(\"my_checkpoints\", ignore_errors=True) # extra code"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 77,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"363/363 [==============================] - 1s 2ms/step - loss: 0.3775 - output_1_loss: 0.3706 - output_2_loss: 0.4402 - output_1_root_mean_squared_error: 0.6088 - output_2_root_mean_squared_error: 0.6635 - val_loss: 0.3369 - val_output_1_loss: 0.3234 - val_output_2_loss: 0.4587 - val_output_1_root_mean_squared_error: 0.5687 - val_output_2_root_mean_squared_error: 0.6773\n",
"Epoch 2/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3556 - output_1_loss: 0.3480 - output_2_loss: 0.4242 - output_1_root_mean_squared_error: 0.5899 - output_2_root_mean_squared_error: 0.6513 - val_loss: 0.4940 - val_output_1_loss: 0.4650 - val_output_2_loss: 0.7551 - val_output_1_root_mean_squared_error: 0.6819 - val_output_2_root_mean_squared_error: 0.8689\n",
"Epoch 3/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3612 - output_1_loss: 0.3547 - output_2_loss: 0.4198 - output_1_root_mean_squared_error: 0.5956 - output_2_root_mean_squared_error: 0.6480 - val_loss: 0.3443 - val_output_1_loss: 0.3355 - val_output_2_loss: 0.4241 - val_output_1_root_mean_squared_error: 0.5792 - val_output_2_root_mean_squared_error: 0.6512\n",
"Epoch 4/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3493 - output_1_loss: 0.3425 - output_2_loss: 0.4110 - output_1_root_mean_squared_error: 0.5852 - output_2_root_mean_squared_error: 0.6411 - val_loss: 0.4676 - val_output_1_loss: 0.4635 - val_output_2_loss: 0.5046 - val_output_1_root_mean_squared_error: 0.6808 - val_output_2_root_mean_squared_error: 0.7104\n",
"Epoch 5/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3525 - output_1_loss: 0.3465 - output_2_loss: 0.4069 - output_1_root_mean_squared_error: 0.5886 - output_2_root_mean_squared_error: 0.6379 - val_loss: 1.3020 - val_output_1_loss: 1.3842 - val_output_2_loss: 0.5623 - val_output_1_root_mean_squared_error: 1.1765 - val_output_2_root_mean_squared_error: 0.7499\n",
"Epoch 6/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3512 - output_1_loss: 0.3453 - output_2_loss: 0.4039 - output_1_root_mean_squared_error: 0.5876 - output_2_root_mean_squared_error: 0.6356 - val_loss: 1.6719 - val_output_1_loss: 1.7502 - val_output_2_loss: 0.9670 - val_output_1_root_mean_squared_error: 1.3230 - val_output_2_root_mean_squared_error: 0.9833\n",
"Epoch 7/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3533 - output_1_loss: 0.3477 - output_2_loss: 0.4038 - output_1_root_mean_squared_error: 0.5897 - output_2_root_mean_squared_error: 0.6355 - val_loss: 0.6855 - val_output_1_loss: 0.7149 - val_output_2_loss: 0.4210 - val_output_1_root_mean_squared_error: 0.8455 - val_output_2_root_mean_squared_error: 0.6488\n",
"Epoch 8/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3409 - output_1_loss: 0.3348 - output_2_loss: 0.3965 - output_1_root_mean_squared_error: 0.5786 - output_2_root_mean_squared_error: 0.6297 - val_loss: 2.0126 - val_output_1_loss: 1.9280 - val_output_2_loss: 2.7742 - val_output_1_root_mean_squared_error: 1.3885 - val_output_2_root_mean_squared_error: 1.6656\n",
"Epoch 9/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3441 - output_1_loss: 0.3375 - output_2_loss: 0.4028 - output_1_root_mean_squared_error: 0.5810 - output_2_root_mean_squared_error: 0.6347 - val_loss: 1.6894 - val_output_1_loss: 1.8009 - val_output_2_loss: 0.6859 - val_output_1_root_mean_squared_error: 1.3420 - val_output_2_root_mean_squared_error: 0.8282\n",
"Epoch 10/10\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3517 - output_1_loss: 0.3468 - output_2_loss: 0.3962 - output_1_root_mean_squared_error: 0.5889 - output_2_root_mean_squared_error: 0.6294 - val_loss: 1.2969 - val_output_1_loss: 1.3365 - val_output_2_loss: 0.9407 - val_output_1_root_mean_squared_error: 1.1561 - val_output_2_root_mean_squared_error: 0.9699\n"
]
}
],
"source": [
"checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\"my_checkpoints\",\n",
" save_weights_only=True)\n",
"history = model.fit(\n",
" (X_train_wide, X_train_deep), (y_train, y_train), epochs=10,\n",
" validation_data=((X_valid_wide, X_valid_deep), (y_valid, y_valid)),\n",
" callbacks=[checkpoint_cb])"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 78,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3405 - output_1_loss: 0.3349 - output_2_loss: 0.3910 - output_1_root_mean_squared_error: 0.5787 - output_2_root_mean_squared_error: 0.6253 - val_loss: 0.6245 - val_output_1_loss: 0.6502 - val_output_2_loss: 0.3937 - val_output_1_root_mean_squared_error: 0.8063 - val_output_2_root_mean_squared_error: 0.6275\n",
"Epoch 2/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3400 - output_1_loss: 0.3344 - output_2_loss: 0.3900 - output_1_root_mean_squared_error: 0.5783 - output_2_root_mean_squared_error: 0.6245 - val_loss: 0.9552 - val_output_1_loss: 0.9508 - val_output_2_loss: 0.9947 - val_output_1_root_mean_squared_error: 0.9751 - val_output_2_root_mean_squared_error: 0.9974\n",
"Epoch 3/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3442 - output_1_loss: 0.3389 - output_2_loss: 0.3921 - output_1_root_mean_squared_error: 0.5821 - output_2_root_mean_squared_error: 0.6262 - val_loss: 0.3574 - val_output_1_loss: 0.3552 - val_output_2_loss: 0.3766 - val_output_1_root_mean_squared_error: 0.5960 - val_output_2_root_mean_squared_error: 0.6137\n",
"Epoch 4/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3347 - output_1_loss: 0.3289 - output_2_loss: 0.3865 - output_1_root_mean_squared_error: 0.5735 - output_2_root_mean_squared_error: 0.6217 - val_loss: 0.4521 - val_output_1_loss: 0.4401 - val_output_2_loss: 0.5609 - val_output_1_root_mean_squared_error: 0.6634 - val_output_2_root_mean_squared_error: 0.7489\n",
"Epoch 5/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3363 - output_1_loss: 0.3311 - output_2_loss: 0.3832 - output_1_root_mean_squared_error: 0.5754 - output_2_root_mean_squared_error: 0.6190 - val_loss: 0.4903 - val_output_1_loss: 0.5018 - val_output_2_loss: 0.3869 - val_output_1_root_mean_squared_error: 0.7084 - val_output_2_root_mean_squared_error: 0.6220\n",
"Epoch 6/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3300 - output_1_loss: 0.3245 - output_2_loss: 0.3801 - output_1_root_mean_squared_error: 0.5696 - output_2_root_mean_squared_error: 0.6165 - val_loss: 0.8351 - val_output_1_loss: 0.8434 - val_output_2_loss: 0.7602 - val_output_1_root_mean_squared_error: 0.9184 - val_output_2_root_mean_squared_error: 0.8719\n",
"Epoch 7/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3324 - output_1_loss: 0.3270 - output_2_loss: 0.3814 - output_1_root_mean_squared_error: 0.5718 - output_2_root_mean_squared_error: 0.6176 - val_loss: 0.6880 - val_output_1_loss: 0.7171 - val_output_2_loss: 0.4259 - val_output_1_root_mean_squared_error: 0.8468 - val_output_2_root_mean_squared_error: 0.6526\n",
"Epoch 8/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3286 - output_1_loss: 0.3231 - output_2_loss: 0.3774 - output_1_root_mean_squared_error: 0.5684 - output_2_root_mean_squared_error: 0.6143 - val_loss: 4.4284 - val_output_1_loss: 4.2604 - val_output_2_loss: 5.9404 - val_output_1_root_mean_squared_error: 2.0641 - val_output_2_root_mean_squared_error: 2.4373\n",
"Epoch 9/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3378 - output_1_loss: 0.3322 - output_2_loss: 0.3886 - output_1_root_mean_squared_error: 0.5764 - output_2_root_mean_squared_error: 0.6234 - val_loss: 1.7043 - val_output_1_loss: 1.7984 - val_output_2_loss: 0.8578 - val_output_1_root_mean_squared_error: 1.3410 - val_output_2_root_mean_squared_error: 0.9262\n",
"Epoch 10/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3401 - output_1_loss: 0.3354 - output_2_loss: 0.3824 - output_1_root_mean_squared_error: 0.5792 - output_2_root_mean_squared_error: 0.6184 - val_loss: 0.6170 - val_output_1_loss: 0.6282 - val_output_2_loss: 0.5169 - val_output_1_root_mean_squared_error: 0.7926 - val_output_2_root_mean_squared_error: 0.7190\n",
"Epoch 11/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3230 - output_1_loss: 0.3177 - output_2_loss: 0.3706 - output_1_root_mean_squared_error: 0.5637 - output_2_root_mean_squared_error: 0.6088 - val_loss: 0.3558 - val_output_1_loss: 0.3490 - val_output_2_loss: 0.4170 - val_output_1_root_mean_squared_error: 0.5907 - val_output_2_root_mean_squared_error: 0.6457\n",
"Epoch 12/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3253 - output_1_loss: 0.3201 - output_2_loss: 0.3727 - output_1_root_mean_squared_error: 0.5658 - output_2_root_mean_squared_error: 0.6105 - val_loss: 0.4612 - val_output_1_loss: 0.4597 - val_output_2_loss: 0.4745 - val_output_1_root_mean_squared_error: 0.6780 - val_output_2_root_mean_squared_error: 0.6888\n",
"Epoch 13/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3221 - output_1_loss: 0.3167 - output_2_loss: 0.3699 - output_1_root_mean_squared_error: 0.5628 - output_2_root_mean_squared_error: 0.6082 - val_loss: 0.3120 - val_output_1_loss: 0.3056 - val_output_2_loss: 0.3694 - val_output_1_root_mean_squared_error: 0.5528 - val_output_2_root_mean_squared_error: 0.6078\n",
"Epoch 14/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3204 - output_1_loss: 0.3149 - output_2_loss: 0.3695 - output_1_root_mean_squared_error: 0.5612 - output_2_root_mean_squared_error: 0.6078 - val_loss: 0.4120 - val_output_1_loss: 0.4013 - val_output_2_loss: 0.5076 - val_output_1_root_mean_squared_error: 0.6335 - val_output_2_root_mean_squared_error: 0.7124\n",
"Epoch 15/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3196 - output_1_loss: 0.3144 - output_2_loss: 0.3662 - output_1_root_mean_squared_error: 0.5607 - output_2_root_mean_squared_error: 0.6052 - val_loss: 0.3304 - val_output_1_loss: 0.3269 - val_output_2_loss: 0.3619 - val_output_1_root_mean_squared_error: 0.5718 - val_output_2_root_mean_squared_error: 0.6016\n",
"Epoch 16/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3166 - output_1_loss: 0.3113 - output_2_loss: 0.3639 - output_1_root_mean_squared_error: 0.5579 - output_2_root_mean_squared_error: 0.6032 - val_loss: 0.4455 - val_output_1_loss: 0.4414 - val_output_2_loss: 0.4819 - val_output_1_root_mean_squared_error: 0.6644 - val_output_2_root_mean_squared_error: 0.6942\n",
"Epoch 17/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3186 - output_1_loss: 0.3134 - output_2_loss: 0.3650 - output_1_root_mean_squared_error: 0.5599 - output_2_root_mean_squared_error: 0.6041 - val_loss: 0.3255 - val_output_1_loss: 0.3212 - val_output_2_loss: 0.3643 - val_output_1_root_mean_squared_error: 0.5667 - val_output_2_root_mean_squared_error: 0.6035\n",
"Epoch 18/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3143 - output_1_loss: 0.3091 - output_2_loss: 0.3611 - output_1_root_mean_squared_error: 0.5560 - output_2_root_mean_squared_error: 0.6009 - val_loss: 1.6360 - val_output_1_loss: 1.6925 - val_output_2_loss: 1.1276 - val_output_1_root_mean_squared_error: 1.3010 - val_output_2_root_mean_squared_error: 1.0619\n",
"Epoch 19/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3169 - output_1_loss: 0.3122 - output_2_loss: 0.3601 - output_1_root_mean_squared_error: 0.5587 - output_2_root_mean_squared_error: 0.6001 - val_loss: 1.2441 - val_output_1_loss: 1.3093 - val_output_2_loss: 0.6572 - val_output_1_root_mean_squared_error: 1.1442 - val_output_2_root_mean_squared_error: 0.8107\n",
"Epoch 20/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3245 - output_1_loss: 0.3201 - output_2_loss: 0.3641 - output_1_root_mean_squared_error: 0.5658 - output_2_root_mean_squared_error: 0.6034 - val_loss: 1.5466 - val_output_1_loss: 1.5582 - val_output_2_loss: 1.4424 - val_output_1_root_mean_squared_error: 1.2483 - val_output_2_root_mean_squared_error: 1.2010\n",
"Epoch 21/100\n",
"363/363 [==============================] - 0s 1ms/step - loss: 0.3202 - output_1_loss: 0.3153 - output_2_loss: 0.3640 - output_1_root_mean_squared_error: 0.5615 - output_2_root_mean_squared_error: 0.6033 - val_loss: 0.6704 - val_output_1_loss: 0.6907 - val_output_2_loss: 0.4873 - val_output_1_root_mean_squared_error: 0.8311 - val_output_2_root_mean_squared_error: 0.6980\n",
"Epoch 22/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3150 - output_1_loss: 0.3103 - output_2_loss: 0.3573 - output_1_root_mean_squared_error: 0.5570 - output_2_root_mean_squared_error: 0.5978 - val_loss: 0.4909 - val_output_1_loss: 0.4955 - val_output_2_loss: 0.4493 - val_output_1_root_mean_squared_error: 0.7039 - val_output_2_root_mean_squared_error: 0.6703\n",
"Epoch 23/100\n",
"363/363 [==============================] - 1s 1ms/step - loss: 0.3104 - output_1_loss: 0.3054 - output_2_loss: 0.3552 - output_1_root_mean_squared_error: 0.5526 - output_2_root_mean_squared_error: 0.5960 - val_loss: 0.3845 - val_output_1_loss: 0.3803 - val_output_2_loss: 0.4228 - val_output_1_root_mean_squared_error: 0.6167 - val_output_2_root_mean_squared_error: 0.6502\n"
]
}
],
"source": [
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=10,\n",
" restore_best_weights=True)\n",
"history = model.fit(\n",
" (X_train_wide, X_train_deep), (y_train, y_train), epochs=100,\n",
" validation_data=((X_valid_wide, X_valid_deep), (y_valid, y_valid)),\n",
" callbacks=[checkpoint_cb, early_stopping_cb])"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
"class PrintValTrainRatioCallback(tf.keras.callbacks.Callback):\n",
" def on_epoch_end(self, epoch, logs):\n",
" ratio = logs[\"val_loss\"] / logs[\"loss\"]\n",
" print(f\"Epoch={epoch}, val/train={ratio:.2f}\")"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 80,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch=0, val/train=2.29\n",
"Epoch=1, val/train=1.03\n",
"Epoch=2, val/train=2.07\n",
"Epoch=3, val/train=1.76\n",
"Epoch=4, val/train=3.56\n",
"Epoch=5, val/train=1.86\n",
"Epoch=6, val/train=2.45\n",
"Epoch=7, val/train=7.86\n",
"Epoch=8, val/train=11.20\n",
"Epoch=9, val/train=1.14\n"
]
}
],
"source": [
"val_train_ratio_cb = PrintValTrainRatioCallback()\n",
"history = model.fit(\n",
" (X_train_wide, X_train_deep), (y_train, y_train), epochs=10,\n",
" validation_data=((X_valid_wide, X_valid_deep), (y_valid, y_valid)),\n",
" callbacks=[val_train_ratio_cb], verbose=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Using TensorBoard for Visualization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"TensorBoard is preinstalled on Colab, but not the `tensorboard-plugin-profile`, so let's install it:"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
"if \"google.colab\" in sys.modules: # extra code\n",
" %pip install -q -U tensorboard-plugin-profile"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"shutil.rmtree(\"my_logs\", ignore_errors=True)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from time import strftime\n",
"\n",
"def get_run_logdir(root_logdir=\"my_logs\"):\n",
" return Path(root_logdir) / strftime(\"run_%Y_%m_%d_%H_%M_%S\")\n",
"\n",
"run_logdir = get_run_logdir()"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
"# extra code builds the first regression model we used earlier\n",
2021-10-17 04:04:08 +02:00
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)\n",
"norm_layer = tf.keras.layers.Normalization(input_shape=X_train.shape[1:])\n",
"model = tf.keras.Sequential([\n",
" norm_layer,\n",
" tf.keras.layers.Dense(30, activation=\"relu\"),\n",
" tf.keras.layers.Dense(30, activation=\"relu\"),\n",
" tf.keras.layers.Dense(1)\n",
"])\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)\n",
"model.compile(loss=\"mse\", optimizer=optimizer, metrics=[\"RootMeanSquaredError\"])\n",
"norm_layer.adapt(X_train)"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 85,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-08-01 17:25:59.099970: I tensorflow/core/profiler/lib/profiler_session.cc:110] Profiler session initializing.\n",
"2022-08-01 17:25:59.099982: I tensorflow/core/profiler/lib/profiler_session.cc:125] Profiler session started.\n",
"2022-08-01 17:25:59.100137: I tensorflow/core/profiler/lib/profiler_session.cc:143] Profiler session tear down.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"261/363 [====================>.........] - ETA: 0s - loss: 2.3165 - root_mean_squared_error: 1.5220"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-08-01 17:25:59.430946: I tensorflow/core/profiler/lib/profiler_session.cc:110] Profiler session initializing.\n",
"2022-08-01 17:25:59.430962: I tensorflow/core/profiler/lib/profiler_session.cc:125] Profiler session started.\n",
"2022-08-01 17:25:59.510100: I tensorflow/core/profiler/lib/profiler_session.cc:67] Profiler session collecting data.\n",
"2022-08-01 17:25:59.524969: I tensorflow/core/profiler/lib/profiler_session.cc:143] Profiler session tear down.\n",
"2022-08-01 17:25:59.539451: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00\n",
"\n",
"2022-08-01 17:25:59.549606: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for trace.json.gz to my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00/my_computer.trace.json.gz\n",
"2022-08-01 17:25:59.558338: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00\n",
"\n",
"2022-08-01 17:25:59.558474: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for memory_profile.json.gz to my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00/my_computer.memory_profile.json.gz\n",
"2022-08-01 17:25:59.559618: I tensorflow/core/profiler/rpc/client/capture_profile.cc:251] Creating directory: my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00\n",
"Dumped tool data for xplane.pb to my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00/my_computer.xplane.pb\n",
"Dumped tool data for overview_page.pb to my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00/my_computer.overview_page.pb\n",
"Dumped tool data for input_pipeline.pb to my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00/my_computer.input_pipeline.pb\n",
"Dumped tool data for tensorflow_stats.pb to my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00/my_computer.tensorflow_stats.pb\n",
"Dumped tool data for kernel_stats.pb to my_logs/run_2022_08_01_17_25_59/plugins/profile/2022_08_01_17_26_00/my_computer.kernel_stats.pb\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"363/363 [==============================] - 1s 1ms/step - loss: 1.8866 - root_mean_squared_error: 1.3736 - val_loss: 0.7126 - val_root_mean_squared_error: 0.8442\n",
"Epoch 2/20\n",
"363/363 [==============================] - 0s 907us/step - loss: 0.6577 - root_mean_squared_error: 0.8110 - val_loss: 0.6880 - val_root_mean_squared_error: 0.8295\n",
"Epoch 3/20\n",
"363/363 [==============================] - 0s 836us/step - loss: 0.5934 - root_mean_squared_error: 0.7703 - val_loss: 0.5803 - val_root_mean_squared_error: 0.7618\n",
"Epoch 4/20\n",
"363/363 [==============================] - 0s 832us/step - loss: 0.5557 - root_mean_squared_error: 0.7455 - val_loss: 0.5166 - val_root_mean_squared_error: 0.7188\n",
"Epoch 5/20\n",
"363/363 [==============================] - 0s 985us/step - loss: 0.5272 - root_mean_squared_error: 0.7261 - val_loss: 0.4895 - val_root_mean_squared_error: 0.6997\n",
"Epoch 6/20\n",
"363/363 [==============================] - 0s 887us/step - loss: 0.5033 - root_mean_squared_error: 0.7094 - val_loss: 0.4951 - val_root_mean_squared_error: 0.7036\n",
"Epoch 7/20\n",
"363/363 [==============================] - 0s 894us/step - loss: 0.4854 - root_mean_squared_error: 0.6967 - val_loss: 0.4862 - val_root_mean_squared_error: 0.6973\n",
"Epoch 8/20\n",
"363/363 [==============================] - 0s 868us/step - loss: 0.4709 - root_mean_squared_error: 0.6862 - val_loss: 0.4554 - val_root_mean_squared_error: 0.6748\n",
"Epoch 9/20\n",
"363/363 [==============================] - 0s 780us/step - loss: 0.4578 - root_mean_squared_error: 0.6766 - val_loss: 0.4413 - val_root_mean_squared_error: 0.6643\n",
"Epoch 10/20\n",
"363/363 [==============================] - 0s 819us/step - loss: 0.4474 - root_mean_squared_error: 0.6689 - val_loss: 0.4379 - val_root_mean_squared_error: 0.6617\n",
"Epoch 11/20\n",
"363/363 [==============================] - 0s 795us/step - loss: 0.4393 - root_mean_squared_error: 0.6628 - val_loss: 0.4396 - val_root_mean_squared_error: 0.6630\n",
"Epoch 12/20\n",
"363/363 [==============================] - 0s 852us/step - loss: 0.4318 - root_mean_squared_error: 0.6571 - val_loss: 0.4505 - val_root_mean_squared_error: 0.6712\n",
"Epoch 13/20\n",
"363/363 [==============================] - 0s 910us/step - loss: 0.4260 - root_mean_squared_error: 0.6527 - val_loss: 0.3997 - val_root_mean_squared_error: 0.6322\n",
"Epoch 14/20\n",
"363/363 [==============================] - 0s 796us/step - loss: 0.4202 - root_mean_squared_error: 0.6482 - val_loss: 0.3956 - val_root_mean_squared_error: 0.6290\n",
"Epoch 15/20\n",
"363/363 [==============================] - 0s 816us/step - loss: 0.4155 - root_mean_squared_error: 0.6446 - val_loss: 0.3916 - val_root_mean_squared_error: 0.6257\n",
"Epoch 16/20\n",
"363/363 [==============================] - 0s 759us/step - loss: 0.4112 - root_mean_squared_error: 0.6412 - val_loss: 0.3937 - val_root_mean_squared_error: 0.6275\n",
"Epoch 17/20\n",
"363/363 [==============================] - 0s 826us/step - loss: 0.4077 - root_mean_squared_error: 0.6385 - val_loss: 0.3809 - val_root_mean_squared_error: 0.6172\n",
"Epoch 18/20\n",
"363/363 [==============================] - 0s 832us/step - loss: 0.4039 - root_mean_squared_error: 0.6356 - val_loss: 0.3793 - val_root_mean_squared_error: 0.6159\n",
"Epoch 19/20\n",
"363/363 [==============================] - 0s 747us/step - loss: 0.4004 - root_mean_squared_error: 0.6328 - val_loss: 0.3850 - val_root_mean_squared_error: 0.6205\n",
"Epoch 20/20\n",
"363/363 [==============================] - 0s 755us/step - loss: 0.3980 - root_mean_squared_error: 0.6308 - val_loss: 0.3809 - val_root_mean_squared_error: 0.6172\n"
]
}
],
"source": [
"tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir,\n",
" profile_batch=(100, 200))\n",
"history = model.fit(X_train, y_train, epochs=20,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[tensorboard_cb])"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 86,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"my_logs\n",
" run_2022_08_01_17_25_59\n",
" events.out.tfevents.1638910166.my_computer.profile-empty\n",
" plugins\n",
" profile\n",
" 2022_08_01_17_26_00\n",
" my_computer.input_pipeline.pb\n",
" my_computer.kernel_stats.pb\n",
" my_computer.memory_profile.json.gz\n",
" my_computer.overview_page.pb\n",
" my_computer.tensorflow_stats.pb\n",
" my_computer.trace.json.gz\n",
" my_computer.xplane.pb\n",
" train\n",
" events.out.tfevents.1638910166.my_computer.22294.0.v2\n",
" validation\n",
" events.out.tfevents.1638910166.my_computer.22294.1.v2\n"
]
}
],
"source": [
"print(\"my_logs\")\n",
"for path in sorted(Path(\"my_logs\").glob(\"**/*\")):\n",
" print(\" \" * (len(path.parts) - 1) + path.parts[-1])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's load the `tensorboard` Jupyter extension and start the TensorBoard server: "
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 87,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe id=\"tensorboard-frame-18d562db9bb9706a\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
" </iframe>\n",
" <script>\n",
" (function() {\n",
" const frame = document.getElementById(\"tensorboard-frame-18d562db9bb9706a\");\n",
" const url = new URL(\"/\", window.location);\n",
" const port = 6006;\n",
" if (port) {\n",
" url.port = port;\n",
" }\n",
" frame.src = url;\n",
" })();\n",
" </script>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%load_ext tensorboard\n",
"%tensorboard --logdir=./my_logs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note**: if you prefer to access TensorBoard in a separate tab, click the \"localhost:6006\" link below:"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"<a href=\"http://localhost:6006/\">http://localhost:6006/</a>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# extra code\n",
"\n",
"if \"google.colab\" in sys.modules:\n",
" from google.colab import output\n",
"\n",
" output.serve_kernel_port_as_window(6006)\n",
"else:\n",
" from IPython.display import display, HTML\n",
"\n",
" display(HTML('<a href=\"http://localhost:6006/\">http://localhost:6006/</a>'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can use also visualize histograms, images, text, and even listen to audio using TensorBoard:"
]
},
{
"cell_type": "code",
2019-07-13 11:31:17 +02:00
"execution_count": 89,
"metadata": {},
"outputs": [],
"source": [
"test_logdir = get_run_logdir()\n",
"writer = tf.summary.create_file_writer(str(test_logdir))\n",
"with writer.as_default():\n",
" for step in range(1, 1000 + 1):\n",
" tf.summary.scalar(\"my_scalar\", np.sin(step / 10), step=step)\n",
" \n",
" data = (np.random.randn(100) + 2) * step / 100 # gets larger\n",
" tf.summary.histogram(\"my_hist\", data, buckets=50, step=step)\n",
" \n",
" images = np.random.rand(2, 32, 32, 3) * step / 1000 # gets brighter\n",
" tf.summary.image(\"my_images\", images, step=step)\n",
" \n",
" texts = [\"The step is \" + str(step), \"Its square is \" + str(step ** 2)]\n",
" tf.summary.text(\"my_text\", texts, step=step)\n",
" \n",
" sine_wave = tf.math.sin(tf.range(12000) / 48000 * 2 * np.pi * step)\n",
" audio = tf.reshape(tf.cast(sine_wave, tf.float32), [1, -1, 1])\n",
" tf.summary.audio(\"my_audio\", audio, sample_rate=48000, step=step)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note**: it used to be possible to easily share your TensorBoard logs with the world by uploading them to https://tensorboard.dev/. Sadly, this service will shut down in December 2023, so I have removed the corresponding code examples from this notebook."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When you stop this Jupyter kernel (a.k.a. Runtime), it will automatically stop the TensorBoard server as well. Another way to stop the TensorBoard server is to kill it, if you are running on Linux or MacOSX. First, you need to find its process ID:"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Known TensorBoard instances:\n",
" - port 6006: logdir ./my_logs (started 0:00:31 ago; pid 22701)\n"
]
}
],
"source": [
"# extra code lists all running TensorBoard server instances\n",
"\n",
"from tensorboard import notebook\n",
"\n",
"notebook.list()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next you can use the following command on Linux or MacOSX, replacing `<pid>` with the pid listed above:\n",
"\n",
" !kill <pid>\n",
"\n",
"On Windows:\n",
"\n",
" !taskkill /F /PID <pid>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Fine-Tuning Neural Network Hyperparameters"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this section we'll use the Fashion MNIST dataset again:"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [],
"source": [
"(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist\n",
"X_train, y_train = X_train_full[:-5000], y_train_full[:-5000]\n",
"X_valid, y_valid = X_train_full[-5000:], y_train_full[-5000:]"
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [],
"source": [
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
"if \"google.colab\" in sys.modules:\n",
" %pip install -q -U keras_tuner"
]
},
{
"cell_type": "code",
"execution_count": 94,
"metadata": {},
"outputs": [],
"source": [
"import keras_tuner as kt\n",
"\n",
"def build_model(hp):\n",
" n_hidden = hp.Int(\"n_hidden\", min_value=0, max_value=8, default=2)\n",
" n_neurons = hp.Int(\"n_neurons\", min_value=16, max_value=256)\n",
" learning_rate = hp.Float(\"learning_rate\", min_value=1e-4, max_value=1e-2,\n",
" sampling=\"log\")\n",
" optimizer = hp.Choice(\"optimizer\", values=[\"sgd\", \"adam\"])\n",
" if optimizer == \"sgd\":\n",
" optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)\n",
" else:\n",
" optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n",
"\n",
" model = tf.keras.Sequential()\n",
" model.add(tf.keras.layers.Flatten())\n",
" for _ in range(n_hidden):\n",
" model.add(tf.keras.layers.Dense(n_neurons, activation=\"relu\"))\n",
" model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
" model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial 5 Complete [00h 00m 24s]\n",
"val_accuracy: 0.8736000061035156\n",
"\n",
"Best val_accuracy So Far: 0.8736000061035156\n",
"Total elapsed time: 00h 01m 43s\n",
"INFO:tensorflow:Oracle triggered exit\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"I1208 09:51:50.359315 4451454400 1158129808.py:4] Oracle triggered exit\n"
]
}
],
"source": [
"random_search_tuner = kt.RandomSearch(\n",
" build_model, objective=\"val_accuracy\", max_trials=5, overwrite=True,\n",
" directory=\"my_fashion_mnist\", project_name=\"my_rnd_search\", seed=42)\n",
"random_search_tuner.search(X_train, y_train, epochs=10,\n",
" validation_data=(X_valid, y_valid))"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [],
"source": [
"top3_models = random_search_tuner.get_best_models(num_models=3)\n",
"best_model = top3_models[0]"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"{'n_hidden': 5,\n",
" 'n_neurons': 70,\n",
" 'learning_rate': 0.00041268008323824807,\n",
" 'optimizer': 'adam'}"
]
},
"execution_count": 97,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"top3_params = random_search_tuner.get_best_hyperparameters(num_trials=3)\n",
"top3_params[0].values # best hyperparameter values"
2020-03-12 00:16:18 +01:00
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial summary\n",
"Hyperparameters:\n",
"n_hidden: 5\n",
"n_neurons: 70\n",
"learning_rate: 0.00041268008323824807\n",
"optimizer: adam\n",
"Score: 0.8736000061035156\n"
]
}
],
"source": [
"best_trial = random_search_tuner.oracle.get_best_trials(num_trials=1)[0]\n",
"best_trial.summary()"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.8736000061035156"
]
},
"execution_count": 99,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"best_trial.metrics.get_last_value(\"val_accuracy\")"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1875/1875 [==============================] - 3s 1ms/step - loss: 0.3274 - accuracy: 0.8799\n",
"Epoch 2/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.3155 - accuracy: 0.8827\n",
"Epoch 3/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.3049 - accuracy: 0.8867\n",
"Epoch 4/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2962 - accuracy: 0.8914\n",
"Epoch 5/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2886 - accuracy: 0.8931\n",
"Epoch 6/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2831 - accuracy: 0.8935\n",
"Epoch 7/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2795 - accuracy: 0.8962\n",
"Epoch 8/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2701 - accuracy: 0.8999: 0s - loss: 0\n",
"Epoch 9/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2661 - accuracy: 0.9009\n",
"Epoch 10/10\n",
"1875/1875 [==============================] - 2s 1ms/step - loss: 0.2628 - accuracy: 0.9012\n",
"313/313 [==============================] - 0s 744us/step - loss: 0.3625 - accuracy: 0.8753\n"
]
}
],
"source": [
"best_model.fit(X_train_full, y_train_full, epochs=10)\n",
"test_loss, test_accuracy = best_model.evaluate(X_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [],
"source": [
"class MyClassificationHyperModel(kt.HyperModel):\n",
" def build(self, hp):\n",
" return build_model(hp)\n",
"\n",
" def fit(self, hp, model, X, y, **kwargs):\n",
" if hp.Boolean(\"normalize\"):\n",
" norm_layer = tf.keras.layers.Normalization()\n",
" X = norm_layer(X)\n",
" return model.fit(X, y, **kwargs)"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {},
"outputs": [],
"source": [
"hyperband_tuner = kt.Hyperband(\n",
" MyClassificationHyperModel(), objective=\"val_accuracy\", seed=42,\n",
" max_epochs=10, factor=3, hyperband_iterations=2,\n",
" overwrite=True, directory=\"my_fashion_mnist\", project_name=\"hyperband\")"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial 60 Complete [00h 00m 18s]\n",
"val_accuracy: 0.819599986076355\n",
"\n",
"Best val_accuracy So Far: 0.8704000115394592\n",
"Total elapsed time: 00h 08m 44s\n",
"INFO:tensorflow:Oracle triggered exit\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"I1208 10:00:59.856360 4451454400 3169670597.py:4] Oracle triggered exit\n"
]
}
],
"source": [
"root_logdir = Path(hyperband_tuner.project_dir) / \"tensorboard\"\n",
"tensorboard_cb = tf.keras.callbacks.TensorBoard(root_logdir)\n",
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=2)\n",
"hyperband_tuner.search(X_train, y_train, epochs=10,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[early_stopping_cb, tensorboard_cb])"
]
},
{
"cell_type": "code",
"execution_count": 104,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trial 10 Complete [00h 00m 13s]\n",
"val_accuracy: 0.7228000164031982\n",
"\n",
"Best val_accuracy So Far: 0.8636000156402588\n",
"Total elapsed time: 00h 02m 10s\n",
"INFO:tensorflow:Oracle triggered exit\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"I1208 10:03:10.004801 4451454400 1918178380.py:5] Oracle triggered exit\n"
]
}
],
"source": [
"bayesian_opt_tuner = kt.BayesianOptimization(\n",
" MyClassificationHyperModel(), objective=\"val_accuracy\", seed=42,\n",
" max_trials=10, alpha=1e-4, beta=2.6,\n",
" overwrite=True, directory=\"my_fashion_mnist\", project_name=\"bayesian_opt\")\n",
"bayesian_opt_tuner.search(X_train, y_train, epochs=10,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[early_stopping_cb])"
]
},
{
"cell_type": "code",
"execution_count": 105,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe id=\"tensorboard-frame-e1aedbefe0e1f220\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
" </iframe>\n",
" <script>\n",
" (function() {\n",
" const frame = document.getElementById(\"tensorboard-frame-e1aedbefe0e1f220\");\n",
" const url = new URL(\"/\", window.location);\n",
" const port = 6007;\n",
" if (port) {\n",
" url.port = port;\n",
" }\n",
" frame.src = url;\n",
" })();\n",
" </script>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%tensorboard --logdir {root_logdir}"
]
},
{
"cell_type": "markdown",
2020-04-06 09:13:12 +02:00
"metadata": {},
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. to 9."
]
},
{
"cell_type": "markdown",
2020-04-06 09:13:12 +02:00
"metadata": {},
"source": [
"1. Visit the [TensorFlow Playground](https://playground.tensorflow.org/) and play around with it, as described in this exercise.\n",
"2. Here is a neural network based on the original artificial neurons that computes _A_ ⊕ _B_ (where ⊕ represents the exclusive OR), using the fact that _A_ ⊕ _B_ = (_A_ ∧ ¬ _B_) (¬ _A_ ∧ _B_). There are other solutions—for example, using the fact that _A_ ⊕ _B_ = (_A_ _B_) ∧ ¬(_A_ ∧ _B_), or the fact that _A_ ⊕ _B_ = (_A_ _B_) ∧ (¬ _A_ ¬ _B_), and so on.<br /><img width=\"70%\" src=\"images/ann/exercise2.png\" />\n",
"3. A classical Perceptron will converge only if the dataset is linearly separable, and it won't be able to estimate class probabilities. In contrast, a Logistic Regression classifier will generally converge to a reasonably good solution even if the dataset is not linearly separable, and it will output class probabilities. If you change the Perceptron's activation function to the sigmoid activation function (or the softmax activation function if there are multiple neurons), and if you train it using Gradient Descent (or some other optimization algorithm minimizing the cost function, typically cross entropy), then it becomes equivalent to a Logistic Regression classifier.\n",
"4. The sigmoid activation function was a key ingredient in training the first MLPs because its derivative is always nonzero, so Gradient Descent can always roll down the slope. When the activation function is a step function, Gradient Descent cannot move, as there is no slope at all.\n",
"5. Popular activation functions include the step function, the sigmoid function, the hyperbolic tangent (tanh) function, and the Rectified Linear Unit (ReLU) function (see Figure 10-8). See Chapter 11 for other examples, such as ELU and variants of the ReLU function.\n",
"6. Considering the MLP described in the question, composed of one input layer with 10 passthrough neurons, followed by one hidden layer with 50 artificial neurons, and finally one output layer with 3 artificial neurons, where all artificial neurons use the ReLU activation function:\n",
" * The shape of the input matrix **X** is _m_ × 10, where _m_ represents the training batch size.\n",
" * The shape of the hidden layer's weight matrix **W**<sub>_h_</sub> is 10 × 50, and the length of its bias vector **b**<sub>_h_</sub> is 50.\n",
" * The shape of the output layer's weight matrix **W**<sub>_o_</sub> is 50 × 3, and the length of its bias vector **b**<sub>_o_</sub> is 3.\n",
" * The shape of the network's output matrix **Y** is _m_ × 3.\n",
" * **Y** = ReLU(ReLU(**X** **W**<sub>_h_</sub> + **b**<sub>_h_</sub>) **W**<sub>_o_</sub> + **b**<sub>_o_</sub>). Recall that the ReLU function just sets every negative number in the matrix to zero. Also note that when you are adding a bias vector to a matrix, it is added to every single row in the matrix, which is called _broadcasting_.\n",
"7. To classify email into spam or ham, you just need one neuron in the output layer of a neural network—for example, indicating the probability that the email is spam. You would typically use the sigmoid activation function in the output layer when estimating a probability. If instead you want to tackle MNIST, you need 10 neurons in the output layer, and you must replace the sigmoid function with the softmax activation function, which can handle multiple classes, outputting one probability per class. If you want your neural network to predict housing prices like in Chapter 2, then you need one output neuron, using no activation function at all in the output layer. Note: when the values to predict can vary by many orders of magnitude, you may want to predict the logarithm of the target value rather than the target value directly. Simply computing the exponential of the neural network's output will give you the estimated value (since exp(log _v_) = _v_).\n",
"8. Backpropagation is a technique used to train artificial neural networks. It first computes the gradients of the cost function with regard to every model parameter (all the weights and biases), then it performs a Gradient Descent step using these gradients. This backpropagation step is typically performed thousands or millions of times, using many training batches, until the model parameters converge to values that (hopefully) minimize the cost function. To compute the gradients, backpropagation uses reverse-mode autodiff (although it wasn't called that when backpropagation was invented, and it has been reinvented several times). Reverse-mode autodiff performs a forward pass through a computation graph, computing every node's value for the current training batch, and then it performs a reverse pass, computing all the gradients at once (see Appendix B for more details). So what's the difference? Well, backpropagation refers to the whole process of training an artificial neural network using multiple backpropagation steps, each of which computes gradients and uses them to perform a Gradient Descent step. In contrast, reverse-mode autodiff is just a technique to compute gradients efficiently, and it happens to be used by backpropagation.\n",
"9. Here is a list of all the hyperparameters you can tweak in a basic MLP: the number of hidden layers, the number of neurons in each hidden layer, and the activation function used in each hidden layer and in the output layer. In general, the ReLU activation function (or one of its variants; see Chapter 11) is a good default for the hidden layers. For the output layer, in general you will want the sigmoid activation function for binary classification, the softmax activation function for multiclass classification, or no activation function for regression. If the MLP overfits the training data, you can try reducing the number of hidden layers and reducing the number of neurons per hidden layer."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 10."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"*Exercise: Train a deep MLP on the MNIST dataset (you can load it using `tf.keras.datasets.mnist.load_data()`. See if you can get over 98% accuracy by manually tuning the hyperparameters. Try searching for the optimal learning rate by using the approach presented in this chapter (i.e., by growing the learning rate exponentially, plotting the loss, and finding the point where the loss shoots up). Next, try tuning the hyperparameters using Keras Tuner with all the bells and whistles—save checkpoints, use early stopping, and plot learning curves using TensorBoard.*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**TODO**: update this solution to use Keras Tuner."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's load the dataset:"
]
},
{
"cell_type": "code",
"execution_count": 106,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"(X_train_full, y_train_full), (X_test, y_test) = tf.keras.datasets.mnist.load_data()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Just like for the Fashion MNIST dataset, the MNIST training set contains 60,000 grayscale images, each 28x28 pixels:"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(60000, 28, 28)"
]
},
"execution_count": 107,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train_full.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Each pixel intensity is also represented as a byte (0 to 255):"
]
},
{
"cell_type": "code",
"execution_count": 108,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"dtype('uint8')"
]
},
"execution_count": 108,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train_full.dtype"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's split the full training set into a validation set and a (smaller) training set. We also scale the pixel intensities down to the 0-1 range and convert them to floats, by dividing by 255, just like we did for Fashion MNIST:"
]
},
{
"cell_type": "code",
"execution_count": 109,
"metadata": {},
"outputs": [],
"source": [
"X_valid, X_train = X_train_full[:5000] / 255., X_train_full[5000:] / 255.\n",
"y_valid, y_train = y_train_full[:5000], y_train_full[5000:]\n",
"X_test = X_test / 255."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's plot an image using Matplotlib's `imshow()` function, with a `'binary'`\n",
" color map:"
]
},
{
"cell_type": "code",
"execution_count": 110,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAGHElEQVR4nO3cz4tNfQDH8blPU4Zc42dKydrCpJQaopSxIdlYsLSykDBbO1slJWExSjKRP2GytSEWyvjRGKUkGzYUcp/dU2rO9z7umTv3c++8XkufzpkjvTvl25lGq9UaAvL80+sHABYmTgglTgglTgglTgg13Gb3X7nQfY2F/tCbE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0KJE0IN9/oBlqPbt29Xbo1Go3jthg0bivvLly+L+/j4eHHft29fcWfpeHNCKHFCKHFCKHFCKHFCKHFCKHFCqJ6dc967d6+4P3v2rLhPTU0t5uMsqS9fvnR87fBw+Z/sx48fxX1kZKS4r1q1qnIbGxsrXvvgwYPivmnTpuLOn7w5IZQ4IZQ4IZQ4IZQ4IZQ4IZQ4IVSj1WqV9uLYzoULFyq3q1evFq/9/ft3nR9NDxw4cKC4T09PF/fNmzcv5uP0kwU/4vXmhFDihFDihFDihFDihFDihFDihFBdPefcunVr5fbhw4fite2+HVy5cmVHz7QY9u7dW9yPHTu2NA/SgZmZmeJ+586dym1+fr7Wz253Dnr//v3KbcC/BXXOCf1EnBBKnBBKnBBKnBBKnBBKnBCqq+ecr1+/rtxevHhRvHZiYqK4N5vNjp6Jsrm5ucrt8OHDxWtnZ2dr/ezLly9XbpOTk7XuHc45J/QTcUIocUIocUIocUIocUKorh6lMFgePnxY3I8fP17r/hs3bqzcPn/+XOve4RylQD8RJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4QSJ4Qa7vUDkOX69euV25MnT7r6s79//165PX36tHjtrl27Fvtxes6bE0KJE0KJE0KJE0KJE0KJE0KJE0L5vbU98PHjx8rt7t27xWuvXLmy2I/zh9Kz9dKaNWuK+9evX5foSbrC762FfiJOCCVOCCVOCCVOCCVOCCVOCOV7zg7MzMwU93bfHt68ebNye/fuXUfPNOhOnTrV60dYct6cEEqcEEqcEEqcEEqcEEqcEGpZHqW8efOmuJ8+fbq4P3r0aDEf569s27atuK9bt67W/S9dulS5jYyMFK89c+ZMcX/16lVHzzQ0NDS0ZcuWjq/tV96cEEqcEEqcEEqcEEqcEEqcEEqcEGpgzzlLv0Ly2rVrxWvn5uaK++rVq4v76OhocT9//nzl1u48b8+ePcW93TloN7X7e7fTbDYrtyNHjtS6dz/y5oRQ4oRQ4oRQ4oRQ4oRQ4oRQ4oRQA3vO+fjx48qt3Tnm0aNHi/vk5GRx379/f3HvV8+fPy/u79+/r3X/FStWVG7bt2+vde9+5M0JocQJocQJocQJocQJocQJocQJoQb2nPPGjRuV29jYWPHaixcvLvbjDIS3b98W90+fPtW6/8GDB2tdP2i8OSGUOCGUOCGUOCGUOCGUOCHUwB6lrF+/vnJzVNKZ0md4/8fatWuL+9mzZ2vdf9B4c0IocUIocUIocUIocUIocUIocUKogT3npDM7duyo3GZnZ2vd+9ChQ8V9fHy81v0HjTcnhBInhBInhBInhBInhBInhBInhHLOyR/m5+crt1+/fhWvHR0dLe7nzp3r4ImWL29OCCVOCCVOCCVOCCVOCCVOCCVOCOWcc5mZnp4u7t++favcms1m8dpbt24Vd99r/h1vTgglTgglTgglTgglTgglTgglTgjVaLVapb04kufnz5/Ffffu3cW99LtpT5w4Ubx2amqquFOpsdAfenNCKHFCKHFCKHFCKHFCKHFCKJ+MDZhGY8H/lf/PyZMni/vOnTsrt4mJiU4eiQ55c0IocUIocUIocUIocUIocUIocUIon4xB7/lkDPqJOCGUOCGUOCGUOCGUOCGUOCFUu+85yx8HAl3jzQmhxAmhxAmhxAmhxAmhxAmh/gWlotX4VjU5XgAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.imshow(X_train[0], cmap=\"binary\")\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The labels are the class IDs (represented as uint8), from 0 to 9. Conveniently, the class IDs correspond to the digits represented in the images, so we don't need a `class_names` array:"
]
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([7, 3, 4, ..., 5, 6, 8], dtype=uint8)"
]
},
"execution_count": 111,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y_train"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The validation set contains 5,000 images, and the test set contains 10,000 images:"
]
},
{
"cell_type": "code",
"execution_count": 112,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(5000, 28, 28)"
]
},
"execution_count": 112,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_valid.shape"
]
},
{
"cell_type": "code",
"execution_count": 113,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(10000, 28, 28)"
]
},
"execution_count": 113,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_test.shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's take a look at a sample of the images in the dataset:"
]
},
{
"cell_type": "code",
"execution_count": 114,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqIAAAEkCAYAAAD0AFOrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABl3klEQVR4nO3debxN1fvA8c+iIkNJplSiEoki9S19FQ1Ko6SkwdAgRfMslG+ERmmSSEShUUWpNBkqoZCvMvySkG+mZAhF+/fHvs/a59xz7nzOWfuc87xfLy/3nnvuueuuu/c+az/rWc8ynuehlFJKKaVUqpVy3QCllFJKKZWddCCqlFJKKaWc0IGoUkoppZRyQgeiSimllFLKCR2IKqWUUkopJ3QgqpRSSimlnNCBqFJKKaWUciLlA1FjzM/GGC/Ov8mpbksYGGN6GGMWGGM25/z7yhhzrut2hYEx5r6cY+MZ121xyRhzgDFmtDFmnTFmhzFmkTGmhet2uWCMOcUY864xZnXOsdHFdZtcMsaUNsb0M8Yszzk2lhtj+htj9nDdNlf0GIlljOkecYzMNcac7LpNrhhjKhpjnjTGrDDGbDfGfGmMOd51u1wIy/XDRUT0eOCAiH/HAh7wmoO2hMEq4B78fjgO+BSYaIw52mmrHDPGnAh0BRa4botLxphKwEzAAOcCRwI3AWsdNsulCsBC4BZgu+O2hME9QA/gZqA+fr/0AHq6bJRjeoxEMMZcCgwBBgBNgC+BD4wxtZw2zJ0RwFlAZ6AR8BEw1RhzoNNWuRGK64dxvbOSMaYXcBdQ0/O8P502JiSMMRuBnp7nDXPdFheMMfsC3+IPRO8HFnqed6PbVrlhjBkAtPA879+u2xI2xpitwI2e541y3RZXjDGTgA2e53WOeGw0sL/neee5a1k46DECxphZwALP87pGPLYUeMPzvKy6YTHG7A1sAdp5nvdOxONzgQ88z+vtrHEOhOX64TRH1BhjgGuAsToItWHyDvh39F+6bo9DL+BfJD913ZAQuBCYZYyZYIxZa4yZZ4y5MefcUWoGcKoxpj6AMaYBcBrwvtNWqVAwxuwFNMWP+kX6CDgp9S1ybg+gNLAj1+Pbgeapb45zobh+uM4jagXUwQ+VZy1jTCPgK6AssBVo63ne925b5YYxpitwONDRdVtC4lCgOzAYGAQ0Bp7O+VpW584qAB4GKgKLjDG78a/pD3me95zbZqmQqII/8Pot1+O/AWekvjlueZ63xRjzFdDbGLMQ+B9wGdAMWOa0cW6E4vrheiDaFZjted48x+1wbTH+AKMS0A4YbYxp6XneQpeNSjVjTD38PKaTPc/7y3V7QqIUMCdiCu07Y0xd/DweHYiqS4FOwOXAf/GvI0OMMcs9z3vRZcNUqOTOwTNxHssWHYGR+OszduOngY3DX6eRbUJx/XA2EDXGVAPa4L+hZrWcQZfcjc3JWcF3G37aQjZphn8HvzBi5rk0cIox5nqgvOd5O101zpE1wKJcj/2An1Su1KPAY57njc/5/HtjzCH4iw10IKrW4w+2auR6vBqxUdKs4Hne/wEtjDHlgX08z1tjjJkALHfcNBdCcf1wmSPaBdgJjC/gedmoFFDGdSMcmIi/irFxxL85+MdIYyAbo6QzgXq5HjsCWOGgLSp8yuEPNCLtRmtEK2yQYy5+GlykVmT3OgQ8z9uWMwjdD38V/TsFfU8GCsX1w0lENGehxbXAeM/ztrhoQ1gYYwYBk4GV+LkalwMt8Uv1ZBXP8zYBmyIfM8ZsAzZmW5pChMHAlznVJSbgl1+5GbjPaascMcZUwM8hBv9iWcsY0xj/GPnFWcPceQ+41xizHH9qrQlwO/Cy01Y5pMdIjCeAMcaYb/BvbK8HagLPO22VI8aYs/CPix/xj5NH8dPjXnLZLkdCcf1wUr7JGHMqfr3MEzzP+yblDQgRY8wo4FT8qZM/8OtmPup53ocu2xUWxpjPyeLyTQA5GxwMwI+M/oKfG/q057r2mgPGmJbAZ3G+NNrzvC4pbUwIGGMqAv2AtvjTrWvwZxAe9Dwv98rgrKDHSCxjTHfgbvza3QuB2zzPm+a2VW4YY9oDA4GDgI3Am0Avz/P+cNowB8Jy/XBeR1QppZRSSmUnzSNSSimllFJO6EBUKaWUUko5oQNRpZRSSinlhA5ElVJKKaWUEzoQVUoppZRSThRURzTdl9Sbgp9SJNof0bQ/oml/xNI+iab9EU37I5r2RzTtj2gZ2R8aEVVKKaWUUk7oQFQppZRSSjmhA1GllFJKKeWEk73mlVLF988//3DHHXcA8MwzzwDw1VdfAXDcccc5a5dSSilVVBoRVUoppZRSTmhEVKk0sXbtWgD69OnDCy+8EPW15cuXA9kXEe3atSsAY8eOZebMmQAce+yxLpukQujBBx9k/PjxAEyaNAmAQw891GWTUmrRokUAPPnkkwAMHz6cbt26AfD888+7apYKgbVr1zJ//nwA3nnnHQCmTZvGwoULAbjqqqsAOOywwwC44447KFOmTNRrbNy4kcqVKxe7DRoRVUoppZRSTmhENARWrFgB+HepAA899BDG+OW2PM8vG3bkkUcC0L9/fy666CIHrVSurFmzBoBHHnkEICoaevLJJwNwwgknpL5hIXDIIYcAsGPHDpYuXQpoRBRgxowZDBs2DPCjxbnJcSPXkk6dOpUoohFWGzZsAPxr66pVqwD49ttvgeyJiI4ePZo+ffoA2D4wxvD+++/Hff7YsWNp06YNABUrVkxNI1XKjRgxAoABAwbYMYjwPM+OQUaNGhX1tb333pvbbrst6rHLLruMDz/8sNht0YGoI+vWrQNg4MCBvPLKKwCsX78e8C8SchCIxYsXA35Y/JRTTgGgSpUqqWpu0vz1118AnH766YD/BioqVaoEwIIFCzj44INT3rYw2LVrFw899BAAzz77rH28R48eADzxxBMA7LXXXqlvXAjIQBT8N1yASy+91FVznNm1axcAffv2Bfxj5Y8//gCIuZYATJ8+HQjOt3nz5sW84WQCOSZkAJYN/v77bwA7MLjuuuvsY/kZOnQoADfffDN16tQBoF+/fkBmnVP/93//Z1MUJJ3nhx9+sCkKnTt3dtW0lJBB54ABA6I+B3+QCVChQgV73ZBxyT///APAnXfeyb777gvA1VdfDcCvv/5aojbp1LxSSimllHIiZRHRl156CfDvzvfff3/AvwsBaNasmZ0qynT9+/cHsFMlxhg7/S53ILVq1aJq1apR3yd3JT///LONiEoCejqSSOg111wDREdCL7zwQgDuvfdeAGrWrJnva/32228AVK9ePdHNdK5nz55RkVCAbt262bJNKpCtUWGAXr16AfDoo48C0VNruZ1yyil88cUXUY999NFHbNmyBcis6djPP//cdRNSTmZJevbsmedz6tevzy233BL1mLzH7N69m2XLlgFw/fXX26+na1RUosETJkwA/IinXCvkvJkzZ07WRETlGiGR0L322otLLrkEwE65N2nSxD7/tddeA2DQoEEAzJ8/nx07dkS9ZkHv0QXRiKhSSimllHKiyBHRV199FYDvvvsOgJEjRxbq+zZt2hT80D38HytRsbJly1KuXDkAjj76aCAYheeODKY7KY8g0YrIqEWDBg0A/y4+d/6n5HS1aNHC5oums8cffxyIXUjRo0cPHnvsMcA/Lgpyxx132Gj7/fffD8Ctt96awJa68cADDwDYvgC48cYbgSDioeDtt9+2H1922WUOW5J6khfaq1evmGOifPny3H777QC0bdsW8GdaAPbZZx+b2yX56VWqVLHX5UwgMyySA5gNJPInpXjikVz7F154gebNmxf4mpJn3K1bN+bMmQMEEbWwk/GFzD7KYs+jjjqKwYMHA9CqVSvAzyFeuXIlELzXSr5kppXEGzduXNTnzZs35+WXX87z+e3btwegWrVqQLCeI5IsbiuuQl955KI2ZMgQIEhcLQ45QMSOHTtsqFemUmQaYNy4cRkx5SppCD/++CMQvClUrVrVDjrlzaR3797cd999Uc+T1AWZxodg9fR1112X7OYn1MKFC20SvJDpwCeffLJ
"text/plain": [
"<Figure size 864x345.6 with 40 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"n_rows = 4\n",
"n_cols = 10\n",
"plt.figure(figsize=(n_cols * 1.2, n_rows * 1.2))\n",
"for row in range(n_rows):\n",
" for col in range(n_cols):\n",
" index = n_cols * row + col\n",
" plt.subplot(n_rows, n_cols, index + 1)\n",
" plt.imshow(X_train[index], cmap=\"binary\", interpolation=\"nearest\")\n",
" plt.axis('off')\n",
" plt.title(y_train[index])\n",
"plt.subplots_adjust(wspace=0.2, hspace=0.5)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's build a simple dense network and find the optimal learning rate. We will need a callback to grow the learning rate at each iteration. It will also record the learning rate and the loss at each iteration:"
]
},
{
"cell_type": "code",
"execution_count": 115,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"K = tf.keras.backend\n",
"\n",
2021-10-17 04:04:08 +02:00
"class ExponentialLearningRate(tf.keras.callbacks.Callback):\n",
" def __init__(self, factor):\n",
" self.factor = factor\n",
" self.rates = []\n",
" self.losses = []\n",
" def on_batch_end(self, batch, logs):\n",
" self.rates.append(K.get_value(self.model.optimizer.learning_rate))\n",
" self.losses.append(logs[\"loss\"])\n",
" K.set_value(self.model.optimizer.learning_rate, self.model.optimizer.learning_rate * self.factor)"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"tf.keras.backend.clear_session()\n",
"np.random.seed(42)\n",
"tf.random.set_seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(300, activation=\"relu\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will start with a small learning rate of 1e-3, and grow it by 0.5% at each iteration:"
]
},
{
"cell_type": "code",
"execution_count": 118,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"expon_lr = ExponentialLearningRate(factor=1.005)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's train the model for just 1 epoch:"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1719/1719 [==============================] - 3s 2ms/step - loss: nan - accuracy: 0.5843 - val_loss: nan - val_accuracy: 0.0958\n"
]
}
],
"source": [
"history = model.fit(X_train, y_train, epochs=1,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[expon_lr])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now plot the loss as a functionof the learning rate:"
]
},
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'Loss')"
]
},
"execution_count": 120,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAEOCAYAAACNY7BQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAl1ElEQVR4nO3dd3hUZd7/8fd3UkgCKUAgJKH3XkMvxg4oggXL2teVde3rurqu+zzr+lu3Wh67YtdVURQFFcHVNdKkVxFxg7RQRJAWirT798cMJGcMkGByJpl8Xtc1FzPnnDn5zk3gM/e5z7mPOecQERE5LBDpAkREpHJRMIiIiIeCQUREPBQMIiLioWAQEREPBYOIiHjERrqAnyomKdV1bNOCuBhlXEXatWsXNWvWjHQZUU/t7I/wds7fVEhMwGiWXn3aft68eZudc/VKWlflgyE2tT5vfJhH2wYpkS4lquXl5ZGbmxvpMqKe2tkf4e084vHppCTG8fLPe0WuKJ+Z2eqjrfPta7aZNTKzT81smZktNbNbStgm18y2m9nC0ON/S7PvPfsOln/BIlJt6DJfLz97DAeA3zjn5ptZMjDPzP7tnPsybLupzrmzy7LjPfsVDCLy01ikC6hEfOsxOOc2OOfmh57vBJYB2eWx770KBhH5KTQ1kEdExhjMrCnQDZhVwuq+ZrYIWA/c7pxbWsL7RwGjAOIbtOT2MfN4MDepAiuWwsJC8vLyIl1G1FM7+yO8nXfs3IPba2r7EN+DwcxqAW8DtzrndoStng80cc4VmtlQ4F2gVfg+nHOjgdEANTJbue/3Oq6atIv5/3M6dWrGV+wHqKY0KOoPtbM/wtv5oS+mUbtmPLm51Wfw+Vh8PcfTzOIIhsKrzrlx4eudczucc4Wh5xOBODNLP9Y+M1ISjjzv/v/+Xb4Fi0i1oANJXn6elWTAc8Ay59yDR9mmQWg7zKxXqL4tx9pvvVo16N2szpHX0/M3s3XXPqZ8/R2HDumvW0RKR4PPRfw8lNQfuBxYYmYLQ8t+DzQGcM49BVwA/MrMDgB7gIvdcW4YYQZv/LIv67bt4ZT787j02aJhi4Gt0nnsku6kJsWV/6cRkaihsWcv34LBOTeN44Syc+4x4LET2X92WiIf33YSuffncfCQo3PDVKb+dzNd7v2Iv5/fiSGdMkmuEUuoQyIi4qH/G4pU+Sufi2tUJ4n8+4ZwyEFMwJixYjM3vDqfO99ewp1vL6FFvZrcclprzu6USSCgXwIRkZJEVTBAMPVjQv/n92uRzpy7T2P2qu/5+MtNTFi0jptfX8Bz01ZyRZ8m9GpWh0Z1dJqrSHXnNPzsEXXBEC42JkC/Fun0a5HO3We1Y/zCdTzw0df8ZuwiAPo2r8uNp7SkX4u66kqKVGP6118k6oOhuJiAcV73hozoms1XG3fy6fJNvPL5ai59dhYt69diQMt0ft6/GY3rqhchUp1o8NmrWgXDYYGA0T4rhfZZKVwzoBlvzl3LxCUbeG3WGv41czWDWtfj/O4NOa19fWrExkS6XBHxgQ4YFKmWwVBcQlwMV/RtyhV9m/Ltjr08P20lExat54bX5pOWFMfwLlmMzGlEx+zUSJcqIuKLah8MxWWkJHDX0HbcMbgt0/M38+bctbw+Zy0vfb6adpkpDOuSycCW9WiflUKMzmoSiRo6lOSlYChBTMAY1Loeg1rXY/vu/YxftI635xXwj0nL+QfLqZ0Ux1mdMzm3WzbdG9fWoLVIVNC/48MUDMeRmhR35FDTpp17+XzFFj5etom35hXwr5lryEipwfndGzK4YwM6ZqXq+giRKkgdBi8FQxnUT05geNdshnfNZufe/Uz6YiMTl2zg6Snf8ETeCtKS4ujfMp0hHRtwUut6JCdoKg6RqkId/yIKhhOUnBDHyJxGjMxpxPehSfum5W8mb/kmPli8gfiYALlt6jEypxG5beoRF+PrRLYiIidMwVAO6tSMZ0S3bEZ0y+bgIce81VuZ9MVGJixax0dffkt6rRqc2y2LEd2yaZ+ZojEJkUrmOHN1VjsKhnIWEzB6NatDr2Z1uGtoW/KWf8fYuWt5Yfoqnpm6klb1azEypyEjumVTPznh+DsUEV/o61oRBUMFiosJcHr7DE5vn8GWwh+YtHQjb80r4C8Tv+Lvk5bTv2U653fP5oz2DUiM14V0IlI5KBh8UrdWDS7t3YRLezchf9NOxs1fx/iF67llzEJq1YhlSMcGnNe9Ib2b1dGZTSIRoCO8RRQMEdCyfjJ3DG7L7We0Yfaq7xk3v4CJSzYydl4B2WmJjOiWxfndG9K8Xq1Ilyoi1ZCCIYICAaNP87r0aV6XP53TkY++3Mi4+et4Mm8Fj3+6gkGt63F1/6ac1KqeehEiFUhjz14KhkoiMT7myDUSm3bsZcyctbwyczVXvzCH5uk1ubJfU87v0ZBaNfRXJlIRTMPPR+jk+kqofkoCN5/aiul3nsLDF3clOTGOP05YSt+/fMK9733J6i27Il2iSFTRjXq89PWzEouPDRzpRSxYs5UXpq/i5c9X8cKMlZzatj6/PKkFPZvWiXSZIlFBg89FFAxVRLfGtenWuDZ3n9WOV2eu5tVZaxj51Of0b1mXW09rrYAQkXKjQ0lVTEZKAred0YZpd57CH85qx/KNOxn51Odc+uxM5qz6PtLliVRJGnz2UjBUUYnxMfxiYHOm3uENiMuencXCtdsiXZ5IlaNDSUUUDFVceEB8uWEHIx6fzi9fmcvX3+6MdHkiVYI6DF4KhihxOCCm3HEyvz6tNdPztzD4/6Zw17jFfLfzh0iXJ1Lp6XTVIgqGKFOrRiy3nNaKqXeczJX9mjJ2bgEn35/Hk3kr2Lv/YKTLE5EqQMEQpWrXjOePwzow+deD6NO8Dn+f9BWnP/QZE5ds0BTDImH0b8JLwRDlWtSrxbNX9uRf1/QmKS6W61+dz0VPz2RJwfZIlyZSuehI0hEKhmpiQKt0Prh5APed25EV3xVyzuPTuH3sIr7ftS/SpYlEnPoLXgqGaiQ2JsClvZvw6W9zGTWwOeMXruOMh6bwybJvI12aSMSpw1BEwVANpSTEcdfQdoy/YQDpteK55qW5/HbsIrbtVu9BRBQM1Vr7rBTG39if63NbMG7BOk578DMmLFqvgTipfvQr76FgqOZqxMZwx+C2vHfjALLTErn59QVc+/I8NmzfE+nSRHxluvT5CN+CwcwamdmnZrbMzJaa2S0lbGNm9oiZ5ZvZYjPr7ld91V37rBTGXd+fu4e2Y1r+d5zx4BTGL1wX6bJEfKEOg5efPYYDwG+cc+2APsANZtY+bJshQKvQYxTwpI/1VXsxAePaQc2ZfOsg2jRI5pYxC/n1GwvZuXd/pEsTqXDqLxTxLRiccxucc/NDz3cCy4DssM2GAy+7oJlAmpll+lWjBDWpW5Mxo/pw62mtGL9wHUMfmUr+Vl01LVJdROR+DGbWFOgGzApblQ2sLfa6ILRsQ9j7RxHsUZCRkUFeXl5FlVqtdY2Fu3ol8PTivfxl1iG+3voRg5vFEdCx2ApTWFio32cfhLfz7t272bRpr9o+xPdgMLNawNvArc65HeGrS3jLjw7/OedGA6MBcnJyXG5ubnmXKSG5wEVD9vPzJz/hza/3syVQmwcv7EpqUlykS4tKeXl56Pe54oW3c+KcT8nISCM3t1vkiqpEfD0rycziCIbCq865cSVsUgA0Kva6IbDej9rk6FIS4rihaw3uGdaeKf/9jmGPTePL9eGZLlJ1afDZy8+zkgx4DljmnHvwKJtNAK4InZ3UB9junNtwlG3FR2bGVf2bMWZUX344cJDznpzOe4uU2RI9dIC0iJ89hv7A5cApZrYw9BhqZteZ2XWhbSYC3wD5wDPA9T7WJ6XQo0lt3r9pIJ2yU7np9QX87cOvOHhI37dEoolvYwzOuWkcJ5Rd8JLbG/ypSE5UveQavPqLPvzpvaU89dkKvtq4g4cv7kZqosYdpGrSxf5euvJZTkh8bID7zu3Efed2ZHr+ZkY8Pp38TbqVqFRduvK5iIJBfpJLezfhtWv7sHPvfkY8PoOPv9RMrVL1OA0/eygY5Cf
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(expon_lr.rates, expon_lr.losses)\n",
"plt.gca().set_xscale('log')\n",
"plt.hlines(min(expon_lr.losses), min(expon_lr.rates), max(expon_lr.rates))\n",
"plt.axis([min(expon_lr.rates), max(expon_lr.rates), 0, expon_lr.losses[0]])\n",
"plt.grid()\n",
"plt.xlabel(\"Learning rate\")\n",
"plt.ylabel(\"Loss\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The loss starts shooting back up violently when the learning rate goes over 6e-1, so let's try using half of that, at 3e-1:"
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"tf.keras.backend.clear_session()\n",
"np.random.seed(42)\n",
"tf.random.set_seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 122,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(300, activation=\"relu\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(learning_rate=3e-1)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "code",
"execution_count": 124,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"PosixPath('my_mnist_logs/run_001')"
]
},
"execution_count": 124,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"run_index = 1 # increment this at every run\n",
"run_logdir = Path() / \"my_mnist_logs\" / \"run_{:03d}\".format(run_index)\n",
"run_logdir"
]
},
{
"cell_type": "code",
"execution_count": 125,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2363 - accuracy: 0.9264 - val_loss: 0.0972 - val_accuracy: 0.9720\n",
"Epoch 2/100\n",
"1719/1719 [==============================] - 2s 997us/step - loss: 0.0948 - accuracy: 0.9702 - val_loss: 0.1035 - val_accuracy: 0.9706\n",
"Epoch 3/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0667 - accuracy: 0.9792 - val_loss: 0.0783 - val_accuracy: 0.9770\n",
"Epoch 4/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0463 - accuracy: 0.9848 - val_loss: 0.0827 - val_accuracy: 0.9766\n",
"Epoch 5/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0359 - accuracy: 0.9881 - val_loss: 0.0698 - val_accuracy: 0.9826\n",
"Epoch 6/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0297 - accuracy: 0.9908 - val_loss: 0.1048 - val_accuracy: 0.9758\n",
"Epoch 7/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0245 - accuracy: 0.9917 - val_loss: 0.0932 - val_accuracy: 0.9794\n",
"Epoch 8/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0239 - accuracy: 0.9922 - val_loss: 0.0816 - val_accuracy: 0.9798\n",
"Epoch 9/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0154 - accuracy: 0.9952 - val_loss: 0.0775 - val_accuracy: 0.9838\n",
"Epoch 10/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0126 - accuracy: 0.9960 - val_loss: 0.0805 - val_accuracy: 0.9812\n",
"Epoch 11/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0111 - accuracy: 0.9964 - val_loss: 0.0962 - val_accuracy: 0.9804\n",
"Epoch 12/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0118 - accuracy: 0.9963 - val_loss: 0.1044 - val_accuracy: 0.9774\n",
"Epoch 13/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0114 - accuracy: 0.9961 - val_loss: 0.1055 - val_accuracy: 0.9802\n",
"Epoch 14/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0150 - accuracy: 0.9948 - val_loss: 0.0993 - val_accuracy: 0.9826\n",
"Epoch 15/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0054 - accuracy: 0.9981 - val_loss: 0.0955 - val_accuracy: 0.9822\n",
"Epoch 16/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0046 - accuracy: 0.9984 - val_loss: 0.0982 - val_accuracy: 0.9822\n",
"Epoch 17/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0055 - accuracy: 0.9983 - val_loss: 0.0908 - val_accuracy: 0.9844\n",
"Epoch 18/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0070 - accuracy: 0.9978 - val_loss: 0.0883 - val_accuracy: 0.9840\n",
"Epoch 19/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0025 - accuracy: 0.9992 - val_loss: 0.0978 - val_accuracy: 0.9838\n",
"Epoch 20/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0058 - accuracy: 0.9983 - val_loss: 0.1011 - val_accuracy: 0.9830\n",
"Epoch 21/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.0039 - accuracy: 0.9989 - val_loss: 0.0991 - val_accuracy: 0.9840\n",
"Epoch 22/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 9.2480e-04 - accuracy: 0.9998 - val_loss: 0.0963 - val_accuracy: 0.9840\n",
"Epoch 23/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 1.2642e-04 - accuracy: 1.0000 - val_loss: 0.0970 - val_accuracy: 0.9846\n",
"Epoch 24/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 6.9068e-05 - accuracy: 1.0000 - val_loss: 0.0970 - val_accuracy: 0.9854\n",
"Epoch 25/100\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 5.1481e-05 - accuracy: 1.0000 - val_loss: 0.0977 - val_accuracy: 0.9850\n"
]
}
],
"source": [
2021-10-17 04:04:08 +02:00
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=20)\n",
2022-09-12 01:46:42 +02:00
"checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\"my_mnist_model\", save_best_only=True)\n",
2021-10-17 04:04:08 +02:00
"tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)\n",
"\n",
"history = model.fit(X_train, y_train, epochs=100,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[checkpoint_cb, early_stopping_cb, tensorboard_cb])"
]
},
{
"cell_type": "code",
"execution_count": 126,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"313/313 [==============================] - 0s 908us/step - loss: 0.0708 - accuracy: 0.9799\n"
]
},
{
"data": {
"text/plain": [
"[0.07079131156206131, 0.9799000024795532]"
]
},
"execution_count": 126,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
2022-09-12 01:46:42 +02:00
"model = tf.keras.models.load_model(\"my_mnist_model\") # rollback to best model\n",
"model.evaluate(X_test, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We got over 98% accuracy. Finally, let's look at the learning curves using TensorBoard:"
]
},
{
"cell_type": "code",
"execution_count": 127,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe id=\"tensorboard-frame-9f95f24bb0151492\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
" </iframe>\n",
" <script>\n",
" (function() {\n",
" const frame = document.getElementById(\"tensorboard-frame-9f95f24bb0151492\");\n",
" const url = new URL(\"/\", window.location);\n",
" const port = 6008;\n",
" if (port) {\n",
" url.port = port;\n",
" }\n",
" frame.src = url;\n",
" })();\n",
" </script>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%tensorboard --logdir=./my_mnist_logs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
},
"nav_menu": {
"height": "264px",
"width": "369px"
},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
2020-04-06 09:13:12 +02:00
"nbformat_minor": 4
}