2016-09-27 23:31:21 +02:00
{
"cells": [
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"**Chapter 11 – Training Deep Neural Networks**"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"_This notebook contains all the sample code and solutions to the exercises in chapter 11._"
2016-09-27 23:31:21 +02:00
]
},
2019-11-06 04:38:13 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
2022-02-19 06:19:26 +01:00
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/11_training_deep_neural_networks.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
2019-11-06 04:38:13 +01:00
" </td>\n",
2021-05-25 05:01:17 +02:00
" <td>\n",
2022-02-19 06:19:26 +01:00
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/11_training_deep_neural_networks.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
2021-05-25 05:01:17 +02:00
" </td>\n",
2019-11-06 04:38:13 +01:00
"</table>"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
2022-02-19 06:19:26 +01:00
"metadata": {
"tags": []
},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 11:03:20 +01:00
"This project requires Python 3.7 or above:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 1,
2018-03-24 22:50:29 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2019-02-17 13:31:28 +01:00
"import sys\n",
2016-09-27 23:31:21 +02:00
"\n",
2022-02-19 11:03:20 +01:00
"assert sys.version_info >= (3, 7)"
2016-09-27 23:31:21 +02:00
]
},
2017-06-05 18:48:03 +02:00
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-28 23:41:27 +01:00
"And TensorFlow ≥ 2.8:"
2017-06-05 18:48:03 +02:00
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 2,
2018-03-24 22:50:29 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-09-22 09:14:01 +02:00
"from packaging import version\n",
2022-02-19 06:19:26 +01:00
"import tensorflow as tf\n",
"\n",
2022-09-22 09:14:01 +02:00
"assert version.parse(tf.__version__) >= version.parse(\"2.8.0\")"
2016-09-27 23:31:21 +02:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2018-03-24 22:50:29 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"As we did in previous chapters, let's define the default font sizes to make the figures prettier:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 3,
2018-03-24 22:50:29 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"import matplotlib.pyplot as plt\n",
"\n",
"plt.rc('font', size=14)\n",
"plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n",
"plt.rc('xtick', labelsize=10)\n",
"plt.rc('ytick', labelsize=10)"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"And let's create the `images/deep` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
2017-06-05 18:48:03 +02:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 4,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 06:19:26 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"from pathlib import Path\n",
"\n",
"IMAGES_PATH = Path() / \"images\" / \"deep\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)"
2016-09-27 23:31:21 +02:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# Vanishing/Exploding Gradients Problem"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 5,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABNlUlEQVR4nO3de3yO9f/A8ddn56PzDHNYclw5HxIxpxWRYyo5ppIi9ZUvRSchfSUmyuEnEXKMnJVkkiijEWE5m83ZZrOT3ffn98c1y2xj7N6ue9v7+Xhcj+2+Pp/7+rx3ue2967o+B6W1RgghhMhrDmYHIIQQonCSBCSEEMIUkoCEEEKYQhKQEEIIU0gCEkIIYQpJQEIIIUxhkwSklJqrlLqglDqQRXkvpdT+1O03pVQdW7QrhBAi/7LVFdA8oN0dyk8AgVrr2sBYYLaN2hVCCJFPOdniIFrrX5RS/nco/+2Wl7uA8rZoVwghRP5lxjOgF4GNJrQrhBDCjtjkCii7lFKtMBLQY3eoMxAYCODu7t6gQoUKeRRd9litVhwcpO/G3ch5urszZ86gtaZixYpmh2L38vLzdMN6g9MJpwGo6F4RZwfnPGnXFuzx/114ePglrbVPZmV5loCUUrWBOUB7rfXlrOpprWeT+oyoYcOGOjQ0NI8izJ6QkBBatmxpdhh2T87T3bVs2ZLo6GjCwsLMDsXu5dXn6XzceR6Z8wjFkouxrf82Hir9UK63aUv2+P9OKXUqq7I8SZVKqYrASqCP1jo8L9oUQoh7VdKjJB2qduCH3j/ku+STH9nkCkgptRhoCZRSSkUAHwDOAFrrmcD7QEngS6UUQIrWuqEt2hZCiJy6mnCVxJREynqX5YsOX5gdTqFhq15wPe9S/hLwki3aEkIIW4pNiuXJb58kNimWsEFhODnk6aPxQk3OtBCi0Eq4kUDnJZ3ZfXY3K55ZIcknj8nZFkIUSsmWZHos70HIyRAWdF1AlxpdzA6p0JEEJIQolN7f+j7r/1nPzA4z6VW7l9nhFEr5PgFdu3aNCxcucOPGjTxpr2jRohw6dChP2srPCuN5cnZ2pnTp0hQpUsTsUEQ2DG86nId8HqJPnT5mh1Jo5esEdO3aNc6fP4+fnx/u7u6k9rDLVbGxsXh7e+d6O/ldYTtPWmsSEhI4e/YsgCQhO6W1Zv6++fR8uCelPEpJ8jGZfQ2ZvUcXLlzAz88PDw+PPEk+QmRFKYWHhwd+fn5cuHDB7HBEFt7f+j4vrH6Bb/Z9Y3YognyegG7cuIG7u7vZYQiRxt3dPc9uB4t7M3HHRMZtH8dL9V7ipfoyKsQe5OsEBMiVj7Ar8nm0T1/u/pKRP42k58M9mdlxpvw72Yl8n4CEEOJOriRc4b2t7/FUtaeY32U+jg6OZockUuXrTghCCHE3JdxLsGPADvyL+ePsmH9mti4M5ApICFEgbTq6if/9+j8AapSqgZuTm8kRidtJAhL5TkREBK+//jqPPvpoWg/IkydPmh2WsCO/nPqFrku7suTgEhJuJJgdjsiCJCCR7xw9epRly5ZRvHhxmjdvbnY4ws7sPrubjt92xL+YPz/2/hF3Z+kpa68kAYl8p0WLFpw/f54NGzbQo0cPs8MRduTAhQO0W9SOUh6l+KnPT/h4ZroQp7ATkoBEvmNvSw4L+xF2LgxPZ09+6vsTfkX8zA5H3IX0ghNC5HsWqwVHB0d61+5N1xpd8XTxNDskkQ3yp6QQIl87F3eOBrMbsOnoJgBJPvmIJKB87vXXX+epp57Kdv0pU6ZQu3ZtrFZrLkYlRN64knCFxxc8ztErRyniKhPA5jeSgPKxY8eOMWvWLD744INsv2fQoEFcuHCB+fPn52JkQuS+2KRY2i9qz5HLR1j93GqaVmhqdkjiHkkCyseCg4OpU6cODRs2zPZ73N3d6du3L5MmTcrFyITIXYkpiTy1+Cn2RO5heY/ltKncxuyQxH2QBGSHrl+/zsiRI6lSpQouLi4opdJtn332GUlJSSxcuJDnn38+3Xt37NiRof7N7dVXXwXgueee4++//+a3334z48cTIsecHZypWaomC7ouoFP1TmaHI+6T9IKzM1prunXrxo4dOxg9ejQNGzZk586djBkzBn9/f3r27MmTTz7Jrl27iI6OzjAQs0aNGuzcuTPdvk8++YSNGzfyzDPPAFC3bl2KFCnCpk2baNo089sWWmssFstd41VK4eiY95M7rlixAoA9e/YAsHHjRnx8fPDx8SEwMDDP4xF5I8WawqX4S5TxKsOMjjPMDkfklNY6xxswF7gAHMiiXAGfA0eB/UD97By3QYMG+k7+/vvvO5bnhmvXruXq8b/44gutlNI//vhjuv1du3bVpUqV0larVWut9SeffKKVUjopKemOx3vvvfe0m5ub3rBhQ7r9jz32mA4KCsryfVu3btXAXbfAwMBM35/b5+le48lL2f1cBgYG6jp16uRuMAXE1q1btcVq0f1W9dMVp1TU0QnRZodkl7Zu3Wp2CBkAoTqL3/G2ugKaB0wHslpmsD1QNXV7BJiR+lXc5uuvvyYoKIigoKB0+2vUqMGaNWvS1jGJjIykSJEiuLi4ZHmst99+m88//5y1a9fStm3bdGU+Pj6Eh4dn+d4GDRqwe/fuu8Zr1rLbxudaFBZaa4ZuHMr8ffMZ03IMRd2Kmh2SsAGbJCCt9S9KKf87VOkMfJOaDXcppYoppcpqraNs0f6tcn+dqez9wr2f34/nz58nNDSUKVOmZCiLioqiTJkyaa8TExNxdXXN8lhvvfUWs2bNYsOGDbRs2TJDubu7OwkJWU/S6OXlRd26de8asyzsJfLCnBNz+PbMtwx/dDjvtXjP7HCEjeTVMyA/4MwtryNS990xAR05ciTDL89nnnmG1157jfj4eM6dO5fJtCzVbRBuzsXFxXH27NkM+ytUqICHhwfXrl0jKir9j79//34AypYtS3R0NOfPnwfAYrGwbt06Hn/8cZKTk3FxccHDw4OrV69y5MiRdMeoXLkyw4YNY968ecyePZuyZcumq1OlShUcHR05d+4cXl5eGd5fvbpx/latWkX37t3v+nM2atSIBQsWAODo6EiVKlUAuHjxIpGRkenqOjs7U7lyZQBOnz6dIQG6urri7+8PwMmTJ0lKSkpX7u7uTsWKFQE4fvx4hqWvPT09KV++PGBMWHr7Myxvb2/KlSsHQHh4eIarqKJFi6Yl+dvPC0Dx4sUpXbo0FouFo0ePZigvWbIkpUqVwmKxZJr0X331VZ599lnOnDlDnz59CAsLIyUlJa3uW2+9xVNPPcWRI0d45ZVXMrz/3XffpW3btoSFhfHmm29mKP/4449p2rQpv/32G6NGjcpQHhwcTN26dfnpp58YN25chvJZs2ZRvXp11q5dy2effZahfMGCBVSoUIGlS5cyY0bG5y8rVqygVKlSzJs3j3nz5mUo37BhAx4eHnz55ZcsW7YsQ3lISAgAkyZNYt26dWn7o8pGEV4jnFcavMLEoImMGzeOLVu2pHtvyZIl+e677wB45513MjwHLV++PAsXLgTgzTffJCwsLF15tWrVmD17NgADBw7McHegbt26BAcHA9C7d28iIiLSlT/66KNMmDABgO7du3P58uV05W3atOG994zE2b59+wyf/Y4dOzJ8+HCATD87t/7ee/LJJzOU9+/fn/79+xMTE5Otz97tcvuzl5W8SkCZ/Zmc6TWCUmogMBCMX1jR0dHpysPDwwkJCSExMREXFxdSUlLSlZ85E0HRokWxWCyZJoBixYpRpEgRbty4kSEBAJQoUQIvLy+SkpLSEsCtihcvjre3N4mJiVy4cCFDuY+PD+7u7sTHJ2SIDYwebhaLhfj4+AzlHh4eAOzbt48WLVqklc+ePZuYmBi6d+9OXFwczs7O+Pv7c+PGDSIiItJ+aWqteeWVV1i5ciWLFi3iwQcfzNBGXFwcDg4OnDp1ioCAgAzlsbGxgJGIlixZkq7MwcGB0qVLAxATE0NCQgKenp5px9BaExsbS5EihWNA4IEDB9K9TkxMJDY2lqSkpAyfW4CDBw8SEhLChQsXiI6OJiUlBa11Wt2
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – this cell generates and saves Figure 11– 1\n",
"\n",
"import numpy as np\n",
"\n",
"def sigmoid(z):\n",
" return 1 / (1 + np.exp(-z))\n",
"\n",
"z = np.linspace(-5, 5, 200)\n",
"\n",
2016-09-27 23:31:21 +02:00
"plt.plot([-5, 5], [0, 0], 'k-')\n",
2022-02-19 06:19:26 +01:00
"plt.plot([-5, 5], [1, 1], 'k--')\n",
"plt.plot([0, 0], [-0.2, 1.2], 'k-')\n",
"plt.plot([-5, 5], [-3/4, 7/4], 'g--')\n",
"plt.plot(z, sigmoid(z), \"b-\", linewidth=2,\n",
" label=r\"$\\sigma(z) = \\dfrac{1}{1+e^{-z}}$\")\n",
2016-09-27 23:31:21 +02:00
"props = dict(facecolor='black', shrink=0.1)\n",
2022-02-19 06:19:26 +01:00
"plt.annotate('Saturating', xytext=(3.5, 0.7), xy=(5, 1), arrowprops=props,\n",
" fontsize=14, ha=\"center\")\n",
"plt.annotate('Saturating', xytext=(-3.5, 0.3), xy=(-5, 0), arrowprops=props,\n",
" fontsize=14, ha=\"center\")\n",
"plt.annotate('Linear', xytext=(2, 0.2), xy=(0, 0.5), arrowprops=props,\n",
" fontsize=14, ha=\"center\")\n",
"plt.grid(True)\n",
"plt.axis([-5, 5, -0.2, 1.2])\n",
"plt.xlabel(\"$z$\")\n",
"plt.legend(loc=\"upper left\", fontsize=16)\n",
2016-09-27 23:31:21 +02:00
"\n",
2022-02-19 06:19:26 +01:00
"save_fig(\"sigmoid_saturation_plot\")\n",
2016-09-27 23:31:21 +02:00
"plt.show()"
]
},
2019-06-09 14:08:53 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"## Xavier and He Initialization"
2019-06-09 14:08:53 +02:00
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 6,
2018-03-24 22:50:29 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"dense = tf.keras.layers.Dense(50, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\")"
2016-09-27 23:31:21 +02:00
]
},
2017-04-30 10:21:27 +02:00
{
2017-06-05 18:48:03 +02:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 7,
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
2017-04-30 10:21:27 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"he_avg_init = tf.keras.initializers.VarianceScaling(scale=2., mode=\"fan_avg\",\n",
" distribution=\"uniform\")\n",
"dense = tf.keras.layers.Dense(50, activation=\"sigmoid\",\n",
" kernel_initializer=he_avg_init)"
2017-06-05 18:48:03 +02:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"## Nonsaturating Activation Functions"
2019-06-09 14:08:53 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"### Leaky ReLU"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 8,
2018-03-24 22:50:29 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAADkCAYAAADeiPCXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAg9ElEQVR4nO3deXhU9fk28PuZJEASEkBCggEFDLuILBGoCA2kUMrqqwgIBamWaHHtJQj059JXARdc3rpc1pZaqlB/gtVWLVqhJULQBhITQJIQliIFQ1LAGLLQJDPP+wdkZJjJZCaznDMz9+e65oI5y/c888zJ3DkzJ2dEVUFERBRsFqMLICKiyMQAIiIiQzCAiIjIEAwgIiIyBAOIiIgMwQAiIiJDRBux0aSkJO3Zs6cRm3arpqYG8fHxRpcRMtgvzx04cABWqxUDBw40upSQEan7V00NUFoK2GxAp05Ar16AiCfrmbNf+fn5p1S1i6t5hgRQz549kZeXZ8Sm3crOzkZGRobRZYQM9stzGRkZqKysNOV+b1aRuH/985/AxInnw2f2bGD9eiDaw1dps/ZLRL5qbh7fgiMiMoGm8Dl71vvwCVUMICIig0Vi+AAMICIiQ0Vq+AAMICIiw0Ry+AAMICIiQ0R6+AAGnQXXkqqqKlRUVKChoSGo2+3QoQOKi4uDus1Qxn557rHHHoOqhlW/YmJikJycjMTERKNLCTkMn/NM95CrqqpQXl6Obt26ITY2FuLJCfB+cvbsWSQkJARte6GO/fKcxWJBY2MjBgwYYHQpfqGqqKurw4kTJwCAIeQFhs93TPcWXEVFBbp164a4uLighg8ReU5EEBcXh27duqGiosLockIGw8eR6QKooaEBsbGxRpdBRB6IjY0N+lvloYrh48x0AQSARz5EIYI/q55h+LhmygAiIgoXDJ/mMYCIiAKE4eMeA4iIKAAYPi1jAIWYZcuWYcKECUaX4VfffPMNUlJScPjwYY+WnzlzJp5//vkAVxV83vYh0MK1z8HA8PEMA8jPxo8fjwULFgRs/MLCQgwZMsSnMcaPHw8RgYggJiYGffr0wW9/+1uvx5k6darbx9q5c2esWbPGafojjzyCHj162O+vXr0akydPRlpamkfbfeyxx7By5Up8++23XtdsZt72IdDCtc+BxvDxHAPIzwoKCjB8+PCAjb9nzx4MHTrUpzEKCgqwevVqlJWV4dChQ5gzZw7uvPNOFBQUeDXO3r17m32sR48exZkzZ5Cenu40Ly8vz75ebW0t1q5dizvuuMPj7V5zzTW46qqrsH79eq/qNbPW9CHQwrHPgcbw8Q4DyI8OHz6MysrKZl+UT5w4gQULFqBz587o2LEjbr75ZpSXlzsss3LlSgwePBjt27dHly5dsHDhQtTV1QEATp48ifLycvsRUE1NDebMmYNhw4bh6NGjAIDu3bs7vW2yb98+tGvXDkVFRfYaJ02ahK5du6JHjx648847oarYv3+/x/W29Fjz8/MhIhg2bJjLeU3BtHnzZlgsFowePdphmWeeecZ+lHbx7dFHHwUATJ8+HW+99ZbLbfvi+PHjEBG8/fbbGD9+POLi4nDttdeipKQEeXl5GDt2LOLi4jBixAgcO3bMvp675w0APv74YwwdOhRfffXdd3Pdf//9SEtLQ3l5ebN9AM73KzMzE7Gxsejduze2b9+OjRs32pdtqVfN8WS9QPU5HDF8WkFVg34bPny4NqeoqKjZeYFWVVXl0/pvv/22WiwWPXv2rNO8I0eOaHJysi5fvlyLioq0oKBAx44dqzfeeKPDco899pjm5OTo0aNHdcuWLZqamqqrV69WVdXNmzdrbGysNjY2aklJiQ4cOFDnzZuntbW19vVnzpypc+bMcRhz/Pjxevfdd9trTExM1MbGRlVVLSsr0zlz5qjFYtH9+/d7XK+7x6qqumLFCu3bt6/T9KNHjyoA/fjjj1VV9b777tMJEyY4LVdVVaVlZWX224MPPqhdu3bVgwcPqqrqRx99pDExMQ6PvcmqVas0Pj7e7W379u0u6/7ggw8UgH7/+9/X7Oxs3bt3r/br109Hjhyp48aN0x07dmhhYaH26tVL77//fvt67p43VdXi4mIdOHCg/vSnP1VV1TVr1miXLl20tLTUbR927dqlsbGx+sQTT2hpaanOmzdPMzIydPDgwfr3v//do141x5P13PW5SaB+Zrdt2xaQcQPh889VExJUAdXZs1UbGoJfg1n7BSBPm8mCkAggwJibtx566CHt37+/y3kTJ07UFStWOEzbsmWLJiQkuB1z0aJFumDBAlVVXb16tY4YMULfeecdveyyy/SFF15wWv65557TtLQ0+/333ntPO3XqpKdOnbLXaLFYND4+XmNjYxWAtmnTxmmslup96KGHXAbMxevPnTvXafo777yjAOz1zJgxw/74mvPUU09pamqqlpSU2Kft2bNHAeihQ4eclj99+rQePHjQ7a25F9SVK1dqhw4d9OTJk/Zp99xzjyYlJdlrVlVduHChzpo1q9maL37eVFVLSkr0tdde0+joaH3yySe1ffv2umvXLvv85vowZswYh+389a9/VYvFohkZGS6366pXnmhuPXd9bhLpAWSG8FE1b78YQEEKoMzMTJ03b57T9K+++koBaGxsrMNv4e3atdPLLrvMvtyxY8f03nvv1UGDBmmnTp00Pj5eo6Oj9Re/+IWqqs6aNUs7deqkiYmJmp2d7bKGnTt3KgA9ffq0njt3TtPS0hzCJTMzU7OysvTgwYOan5+vP/zhD+1HR97Um5mZ6fYFOCkpSZ9//nmn6StWrNAePXrY70+cOFGzsrKaHWf16tWampqqBw4ccJheWlqqAHTfvn3NrtsaM2fOdHoOp0+f7lTj+PHjddmyZara8vOmej6AvvzyS/3e976nUVFRunnzZofxXPWhrKxMATi8sGzZskUB6I4dO5xqb65XLXG3nid9juQAMkv4qJq3X+4CKCTeoVQNznZ8vbpzQUEBHn74YafphYWFSExMRH5+vtO8Nm3aAABOnz6N6667DmPHjsWzzz6L7t27IyoqCtddd539M5/CwkLcdNNN+OMf/4jTp0+7rGH48OFo06YN8vLyUFBQgOjoaNx9990ONc6fPx+9e/cGALz22mvo1asX7rzzTlxzzTUe11tQUIAlS5a4rOH48eM4deoUBg0a5DRvy5YtDp9zJCUl4ZtvvnE5zqpVq/DrX/8an376qb3eJmfOnAEAdOnSxWm91atXY/Xq1S7HbPLRRx9hzJgxTtP37NmD++67z2FaQUEBHn/8caflsrKyPHremuTm5mLPnj1QVaSkpDjMc9WHpq9uuO666+zTDhw4gH79+uGGG25wWNZdr9xpaT13fY50/MzHD5pLJk9vANoB2AVgD4D9AP5vS+uE42dAR44cUQAuj0w2b96sUVFRzX5eoqr6hz/8QTt06KA2m80+bd26dQpAS0tLtaamRi0Wi+7atUs3bNig8fHxmp+f73KskSNH6j333KMJCQn6wQcfONV46XrDhg3TpUuXelxv0ziX/hZ/6fy//OUvDtNzc3NVRBymr1mzRq+++mqnMR5//HG94oormn3rZ+3atZqamupyXmvfgquurlaLxaI5OTkOYwHQwsJC+7Rjx44pAD1w4ECLz1uTP//5z5qQkKC/+93v9KabbtKJEyc6bNtVH959910VEXutVVVVevnll+vQoUO96lVzPFnPXZ+bROIRkJmOfJqYtV8I5FtwAARA+wv/jwGQC2CUu3XCMYA2bdqkAHTnzp26b98++62oqEjPnDmjSUlJeuONN+oXX3yhhw4d0k8++UQXL16sVqtVVVU//PBDjYqK0nfffVcPHjyoL774onbt2lUTEhLUZrPpZ599plFRUfYXo4cfflhTU1P1+PHjTrU88MADKiJOL3KbNm1Si8Xi9OK7bNky7d27t/1+S/U2PdYtW7Y4PVZVVZvNpgMGDNBBgwbpJ598onv27NHXX39dU1NTddq0aQ7b3rt3r1osFofPV1auXKmdO3fWnTt3OnxIXldXZ1/mtttu09tvv701T1WzPvvsM6c
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – this cell generates and saves Figure 11– 2\n",
2019-06-09 14:08:53 +02:00
"\n",
2022-02-19 06:19:26 +01:00
"def leaky_relu(z, alpha):\n",
" return np.maximum(alpha * z, z)\n",
"\n",
"z = np.linspace(-5, 5, 200)\n",
"plt.plot(z, leaky_relu(z, 0.1), \"b-\", linewidth=2, label=r\"$LeakyReLU(z) = max(\\alpha z, z)$\")\n",
"plt.plot([-5, 5], [0, 0], 'k-')\n",
"plt.plot([0, 0], [-1, 3.7], 'k-')\n",
"plt.grid(True)\n",
"props = dict(facecolor='black', shrink=0.1)\n",
"plt.annotate('Leak', xytext=(-3.5, 0.5), xy=(-5, -0.3), arrowprops=props,\n",
" fontsize=14, ha=\"center\")\n",
"plt.xlabel(\"$z$\")\n",
"plt.axis([-5, 5, -1, 3.7])\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.legend()\n",
"\n",
"save_fig(\"leaky_relu_plot\")\n",
"plt.show()"
2018-05-08 20:21:23 +02:00
]
},
2017-04-30 10:21:27 +02:00
{
2017-06-05 18:48:03 +02:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 9,
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-11-03 13:43:56 +01:00
"outputs": [],
2017-04-30 10:21:27 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"leaky_relu = tf.keras.layers.LeakyReLU(alpha=0.2) # defaults to alpha=0.3\n",
"dense = tf.keras.layers.Dense(50, activation=leaky_relu,\n",
" kernel_initializer=\"he_normal\")"
2017-04-30 10:21:27 +02:00
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 10,
2018-05-08 20:21:23 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-12-16 11:22:41.636848: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
],
2018-05-08 20:21:23 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"model = tf.keras.models.Sequential([\n",
" # [...] # more layers\n",
" tf.keras.layers.Dense(50, kernel_initializer=\"he_normal\"), # no activation\n",
" tf.keras.layers.LeakyReLU(alpha=0.2), # activation as a separate layer\n",
" # [...] # more layers\n",
"])"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
"### ELU"
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"Implementing ELU in TensorFlow is trivial, just specify the activation function when building each layer, and use He initialization:"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 11,
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"dense = tf.keras.layers.Dense(50, activation=\"elu\",\n",
" kernel_initializer=\"he_normal\")"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "markdown",
2022-02-19 06:19:26 +01:00
"metadata": {
"tags": []
},
2017-06-21 15:35:47 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"### SELU"
2017-06-21 15:35:47 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"By default, the SELU hyperparameters (`scale` and `alpha`) are tuned in such a way that the mean output of each neuron remains close to 0, and the standard deviation remains close to 1 (assuming the inputs are standardized with mean 0 and standard deviation 1 too, and other constraints are respected, as explained in the book). Using this activation function, even a 1,000 layer deep neural network preserves roughly mean 0 and standard deviation 1 across all layers, avoiding the exploding/vanishing gradients problem:"
2017-06-21 15:35:47 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 12,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAD+CAYAAAB8xdFqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA5cElEQVR4nO3deXgUVfbw8e/NDmQjBEKIbIogIMgmIo4QDJuIuKCijAqjA4KM44bj6wY6DKg4MuMAIowO4k9Fx3EQVEBhIICKQoCIrBJWCSAQCCQkIUuf949KmjTphIR0Ut3J+TxPPemuulV96nanT1fVrXuNiKCUUkpVNz+7A1BKKVU7aQJSSillC01ASimlbKEJSCmllC00ASmllLKFJiCllFK2CLDjRaOjo6VFixZ2vHSZzpw5Q7169ewOw2dofZXfzp07KSgooF27dnaH4jO88vPlcICfX4lZKSmQkQEBAdCmDYSEVH9oXllfwIYNG46LSEN3y2xJQC1atCApKcmOly5TYmIi8fHxdofhM7S+yi8+Pp709HSv/Nx7K6/7fE2YAAsXwpIlEBMDQHY2DBkCmzZZs1asALt+Y3hdfRUyxuwvbZmeglNKqQuZOhUmTYKtW2HDBuBc8lm+3P7k46s0ASmlVFnefBOefhqMgXffhUGDNPl4SKUTkDEmxBizzhjzozFmqzHmJU8EppRStps3D8aNsx7PmgX33qvJx4M8cQ3oLHCDiGQaYwKBb4wxS0Tkew9sWyml7PHJJ/DAA9bjv/4VHnpIk4+HVToBidWbaWbh08DCSXs4VUr5trVrrSZuL74ITz6pyacKGE/0hm2M8Qc2AK2AmSLytJsyo4HRADExMV0/+uijUrcXFBRESEgIxphKx1YRIlLtr+nLtL7K7+TJk4gIUVFRdofiM6rq8yUi5OTkkJube6GCRK1bx4nu3Tmb689zz13Jhg1R1K+fy7RpybRokeXx2CojMzOT0NBQu8MooU+fPhtEpJu7ZR5JQM6NGRMJLAAeEZEtpZXr1q2blNYc9cCBAxhjiImJITAwsFq/4DIyMggLC6u21/N1Wl/lt3PnTvLz82nfvr3dofiMqvh8iQh5eXn8+uuviAjNmjVzLbBxIzRtCg3P3bbiK0c+XtwMu9QE5NFWcCKSDiQCAy92G2fOnCEuLo6goCD9da2U8ihjDEFBQcTFxXHmzBnXhcnJkJAAvXvDsWOA7yQfX+WJVnANC498MMbUAfoCOyoVlJ+2DldKVZ0S3zHbt0P//pCeDm3bQv36mnyqgSdawcUC8wqvA/kB/xaRLzywXaWUqnp79kDfvtZRz8CB8OGHZOcFaPKpBp5oBbcZ6OyBWJRSqnodPGiddjt0yDr19umnZDuCNflUEz3XpZSqnRwO68hn3z7o3h0+/5xsU1eTTzXSBFQDnTx5kpiYGHbv3l2u8nfccQfTpk2r4qjKVtGYi/Tu3ZtOnTrRqVMnAgICWLt2bRVFWFJ5623kyJE89NBDzucOh4OHHnqIBg0aYIwhMTGxCqP0DiNHjmTw4MF2h+HKzw8GD4aOHWHJErIDwjT5VDcRqfapa9euUppt27aVuqyqnT59ulLrjxgxQrBuwnWZrrnmGufym266qdT1e/fuLePGjSsxf+7cuVKvXr1yxzF+/HgZOXJkuctv3rxZ6tevL+np6eVeR6Ty9VVcRWM+30svvSQPP/ywx+IREVm1apXcfPPN0qRJEwFk7ty5LsvLW2/p6emybt062bJli4iIfP755xIYGCjffvutHD58WM6ePevRuCti5syZ0qJFCwkODpYuXbrI6tWrq+R1LvTZd8eTny93tm3bJuJwiGRkSFaWSN++IiASEyOydWuVvnSVWLlypd0huAUkSSm5QI+APKxv374cPnzYZVq8eHG1vX5WVhZvv/02Dz74YLnX6dChA5deeinvv/9+FUZWuouJubjp06eza9cuZsyY4dG4MjMzufLKK3njjTeoU6dOieXlrbeIiAjCw8Odz1NSUoiNjaVnz540btyYoKAgj8Z9vtOnT5Oenl5i/scff8yjjz7Ks88+y6ZNm+jZsyc33ngjBw4cqNJ4bFVQAPv3Q16e9dwYsv1D9cjHJpqAPCw4OJjGjRu7TNV59/vixYvx8/Pjuuuuc5k/depUjDElpgkTJgAwZMgQ5s+fXyUxbdiwgYSEBOrUqUOrVq1YvXo1//73v50xlhYzQGpqKvfffz8NGjQgMjKSoUOH8uuvvzqXv/fee6xYsYK5c+d6/L6xQYMGMWXKFO64445Sbw0oT70VPwU3cuRIHn/8cecN12UNzHih96wsBQUFfPXVVwwfPpzGjRvz448/ligzbdo0Ro4cyahRo2jbti3Tp08nNjaWWbNmXXD7xYkIU6dO5bLLLqNOnTp06NChzKS8evVqevToQWhoKBEREVxzzTVs2bKlxPY6duxYru1BOevK4bBavB07Zv1F7/OxmyagGmbNmjV07dq1xJfx2LFjXY7KnnzySRo3bsz9998PQPfu3Vm3bh3Z2dkltjllyhRCQ0NLTLGxsc7Ha9ascRvP+vXruf766+nTpw+bN2+mR48eTJw4kcmTJzNp0qQyY967dy9dunQhLi6Ob775hsTERI4fP86YMWMAWLBgAR988AEfffQRAQG2jK1YZr2588YbbzBhwgQuueQSDh8+zPr160ste6H3zJ2tW7fypz/9iWbNmjFs2DDq1avH0qVL6dWrl0u53NxcNmzYQP/+/V3m9+/fn++++65c+1Lk+eef55133mHmzJls27aNZ555hoceeogvv/yyRNn8/HxuueUWfvOb3/Djjz/yww8/8Oijj+Lv719ie6+//voFt1fkgnUlAnv3wqlT1rClzZohosnHbvb819ZgS5cuLdEf07hx43j11Ver5fX3799PbGxsiflhYWHObk1effVV5s+fT2JiIq1atQKgSZMm5OXlcejQIS677DKXdceMGcNdd91VYpvF+56Ki4tzG8+TTz7JzTffzPPPPw/A8OHDufnmm+nVqxc33HBDmTGPGTOGBx98kClTpjjnvfDCC9x+++0APPDAAzRs2JBrrrkGgIkTJ3LbbbeVUTueV1a9uRMREUFYWBj+/v40bty4zLIXes+KpKWl8cEHH/Dee++xefNmBg4cyN///neGDBlCcHCw220fP36cgoICYgpH9iwSExPD8uXLL7gfRc6cOcO0adP4+uuvuf766wFo2bIl69atY+bMmdx0000u5YtOB958883O+rriiivcbq9Tp06EhYWVub1y1ZWI1dLt5Enw94fWrXEE1+HoUU0+dvOJBFR9PfK49jt1Md3k9erVizlz5rjMi4yMrERMFZOdnV3iS6W4l19+mRkzZrBy5Upat27tnF90jcPdL/moqCi3pxEv1FfXkSNHWLNmDStXrnTOCwoKwuFwOI9+Sov5wIEDfP3116xZs4Z//OMfzvkFBQXUrVsXsFrOXcjzzz/P5MmTyyyzcuXKi+5Dq6x685TS3rMi06dP56WXXqJnz57s2rWL5s2bl3vb5x91SgU7AN22bRs5OTkMHDjQZb28vDy3pxejoqIYOXIkAwYMICEhgYSEBO68806aNm16Uds7X4m6EoEDByAtzWr1dvnlOELqkpICOTmafOzmEwnIl9StW7fEL9TyCg8P59SpUyXmp6enExERUa5tREdHl/rFPHnyZN566y1WrVpVIsYTJ04A0LBYJ4xFpkyZ4nIU4s6SJUucv4CLbN++HYCrr77aOW/nzp20adOG3/zmN2XGnJycTHh4OBsKhz8uriIX7R977DHuvffeMsuU6JCyAsqqN08o6z0rMnr0aAIDA3nvvfdo3749t912G/fddx8JCQkup7aKi46Oxt/fnyNHjrjMP3r0aJk/YM7ncDgA+Pzzz0vUY2BgoNt15s6dy2OPPcbSpUtZtGgRzz33HJ999hkDBgxw2V5UVJTL2YTStlfEbV2dOmVd8zEGWrXCUTeUlBQ4fdrKR5p87OUTCciDHXaXye7endu0acPixYtL/ArduHEjbdq0cSm7adMmHnvsMVJTUxk/fjwrVqzg73//O50
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-21 15:35:47 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – this cell generates and saves Figure 11– 3\n",
"\n",
2019-02-17 13:31:28 +01:00
"from scipy.special import erfc\n",
"\n",
"# alpha and scale to self normalize with mean 0 and standard deviation 1\n",
"# (see equation 14 in the paper):\n",
2022-02-19 06:19:26 +01:00
"alpha_0_1 = -np.sqrt(2 / np.pi) / (erfc(1 / np.sqrt(2)) * np.exp(1 / 2) - 1)\n",
"scale_0_1 = (\n",
" (1 - erfc(1 / np.sqrt(2)) * np.sqrt(np.e))\n",
" * np.sqrt(2 * np.pi)\n",
" * (\n",
" 2 * erfc(np.sqrt(2)) * np.e ** 2\n",
" + np.pi * erfc(1 / np.sqrt(2)) ** 2 * np.e\n",
" - 2 * (2 + np.pi) * erfc(1 / np.sqrt(2)) * np.sqrt(np.e)\n",
" + np.pi\n",
" + 2\n",
" ) ** (-1 / 2)\n",
")\n",
"\n",
"def elu(z, alpha=1):\n",
" return np.where(z < 0, alpha * (np.exp(z) - 1), z)\n",
"\n",
2019-02-17 13:31:28 +01:00
"def selu(z, scale=scale_0_1, alpha=alpha_0_1):\n",
2022-02-19 06:19:26 +01:00
" return scale * elu(z, alpha)\n",
"\n",
"z = np.linspace(-5, 5, 200)\n",
"plt.plot(z, elu(z), \"b-\", linewidth=2, label=r\"ELU$_\\alpha(z) = \\alpha (e^z - 1)$ if $z < 0$, else $z$\")\n",
"plt.plot(z, selu(z), \"r--\", linewidth=2, label=r\"SELU$(z) = 1.05 \\, $ELU$_{1.67}(z)$\")\n",
2017-06-21 15:35:47 +02:00
"plt.plot([-5, 5], [0, 0], 'k-')\n",
2022-02-19 06:19:26 +01:00
"plt.plot([-5, 5], [-1, -1], 'k:', linewidth=2)\n",
"plt.plot([-5, 5], [-1.758, -1.758], 'k:', linewidth=2)\n",
2017-06-21 15:35:47 +02:00
"plt.plot([0, 0], [-2.2, 3.2], 'k-')\n",
"plt.grid(True)\n",
"plt.axis([-5, 5, -2.2, 3.2])\n",
2022-02-19 06:19:26 +01:00
"plt.xlabel(\"$z$\")\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.legend()\n",
2017-06-21 15:35:47 +02:00
"\n",
2022-02-19 06:19:26 +01:00
"save_fig(\"elu_selu_plot\")\n",
2017-06-21 15:35:47 +02:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"Using SELU is straightforward:"
2017-06-21 15:35:47 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 13,
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-11-03 13:43:56 +01:00
"outputs": [],
2017-06-21 15:35:47 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"dense = tf.keras.layers.Dense(50, activation=\"selu\",\n",
" kernel_initializer=\"lecun_normal\")"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"**Extra material – an example of a self-regularized network using SELU**\n",
"\n",
"Let's create a neural net for Fashion MNIST with 100 hidden layers, using the SELU activation function:"
2017-06-21 15:35:47 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 14,
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-21 15:35:47 +02:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"tf.random.set_seed(42)\n",
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[28, 28]))\n",
"for layer in range(100):\n",
" model.add(tf.keras.layers.Dense(100, activation=\"selu\",\n",
" kernel_initializer=\"lecun_normal\"))\n",
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))"
2017-06-21 15:35:47 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 15,
2017-06-21 15:35:47 +02:00
"metadata": {},
2019-02-17 13:31:28 +01:00
"outputs": [],
2017-06-21 15:35:47 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
" metrics=[\"accuracy\"])"
2017-06-21 15:35:47 +02:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-21 15:35:47 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"Now let's train it. Do not forget to scale the inputs to mean 0 and standard deviation 1:"
2019-02-17 13:31:28 +01:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 16,
2019-02-17 13:31:28 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"fashion_mnist = tf.keras.datasets.fashion_mnist.load_data()\n",
"(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist\n",
"X_train, y_train = X_train_full[:-5000], y_train_full[:-5000]\n",
"X_valid, y_valid = X_train_full[-5000:], y_train_full[-5000:]\n",
"X_train, X_valid, X_test = X_train / 255, X_valid / 255, X_test / 255"
2017-06-21 15:35:47 +02:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 17,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 06:19:26 +01:00
"outputs": [],
2017-06-21 15:35:47 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"class_names = [\"T-shirt/top\", \"Trouser\", \"Pullover\", \"Dress\", \"Coat\",\n",
" \"Sandal\", \"Shirt\", \"Sneaker\", \"Bag\", \"Ankle boot\"]"
2017-06-21 15:35:47 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 18,
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-11-03 13:43:56 +01:00
"outputs": [],
2017-06-21 15:35:47 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"pixel_means = X_train.mean(axis=0, keepdims=True)\n",
"pixel_stds = X_train.std(axis=0, keepdims=True)\n",
"X_train_scaled = (X_train - pixel_means) / pixel_stds\n",
"X_valid_scaled = (X_valid - pixel_means) / pixel_stds\n",
"X_test_scaled = (X_test - pixel_means) / pixel_stds"
2017-06-21 15:35:47 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 19,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-12-16 11:22:44.499697: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"1719/1719 [==============================] - 13s 7ms/step - loss: 1.3735 - accuracy: 0.4548 - val_loss: 0.9599 - val_accuracy: 0.6444\n",
"Epoch 2/5\n",
"1719/1719 [==============================] - 12s 7ms/step - loss: 0.7783 - accuracy: 0.7073 - val_loss: 0.6529 - val_accuracy: 0.7664\n",
"Epoch 3/5\n",
"1719/1719 [==============================] - 12s 7ms/step - loss: 0.6462 - accuracy: 0.7611 - val_loss: 0.6048 - val_accuracy: 0.7748\n",
"Epoch 4/5\n",
"1719/1719 [==============================] - 11s 6ms/step - loss: 0.5821 - accuracy: 0.7863 - val_loss: 0.5737 - val_accuracy: 0.7944\n",
"Epoch 5/5\n",
"1719/1719 [==============================] - 12s 7ms/step - loss: 0.5401 - accuracy: 0.8041 - val_loss: 0.5333 - val_accuracy: 0.8046\n"
]
}
],
2017-06-21 15:35:47 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"history = model.fit(X_train_scaled, y_train, epochs=5,\n",
" validation_data=(X_valid_scaled, y_valid))"
2017-06-21 15:35:47 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"The network managed to learn, despite how deep it is. Now look at what happens if we try to use the ReLU activation function instead:"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 20,
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
"source": [
2019-02-17 13:31:28 +01:00
"tf.random.set_seed(42)"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 21,
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[28, 28]))\n",
2022-02-19 06:19:26 +01:00
"for layer in range(100):\n",
" model.add(tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"))\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 22,
2017-06-21 15:35:47 +02:00
"metadata": {},
2019-02-17 13:31:28 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2019-06-10 04:48:00 +02:00
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
2022-02-19 06:19:26 +01:00
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
2019-02-17 13:31:28 +01:00
" metrics=[\"accuracy\"])"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 23,
2018-03-24 22:50:29 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"1719/1719 [==============================] - 12s 6ms/step - loss: 1.6932 - accuracy: 0.3071 - val_loss: 1.2058 - val_accuracy: 0.5106\n",
"Epoch 2/5\n",
"1719/1719 [==============================] - 11s 6ms/step - loss: 1.1132 - accuracy: 0.5297 - val_loss: 0.9682 - val_accuracy: 0.5718\n",
"Epoch 3/5\n",
"1719/1719 [==============================] - 10s 6ms/step - loss: 0.9480 - accuracy: 0.6117 - val_loss: 1.0552 - val_accuracy: 0.5102\n",
"Epoch 4/5\n",
"1719/1719 [==============================] - 10s 6ms/step - loss: 0.9763 - accuracy: 0.6003 - val_loss: 0.7764 - val_accuracy: 0.7070\n",
"Epoch 5/5\n",
"1719/1719 [==============================] - 11s 6ms/step - loss: 0.7892 - accuracy: 0.6875 - val_loss: 0.7485 - val_accuracy: 0.7054\n"
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"history = model.fit(X_train_scaled, y_train, epochs=5,\n",
" validation_data=(X_valid_scaled, y_valid))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not great at all, we suffered from the vanishing/exploding gradients problem."
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"### GELU, Swish and Mish"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 24,
2018-03-24 22:50:29 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAADsCAYAAAAy23L6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABbDElEQVR4nO2dd3hURdfAf5PeSCMkEECKIUiXYoiAEHpVFLChIIIiCIIKSrGAH0XBV1+VgIKgqPBSFUSqoCRUKUoLHamBQEhICOltvj9udrObbJINKZsyv+e5z+6de+7Mmbt377kzc+aMkFKiUCgUCkVpY2VpBRQKhUJROVEGSKFQKBQWQRkghUKhUFgEZYAUCoVCYRGUAVIoFAqFRVAGSKFQKBQWocgGSAhRWwixUwhxWghxUggx3oSMEEJ8JYS4IIQ4LoRoVdRyFQqFQlG+sSmGPNKBCVLKf4QQVYC/hRDbpZSnDGR6Aw2ytrbA11mfCoVCoaikFLkFJKWMkFL+k/X9HnAaqJlDrD/wo9T4C3AXQtQoatkKhUKhKL8URwtIjxCiLtASOJDjUE3gmsF+eFZahIk8RgIjARwcHFo/8MADxalimSczMxMrq8o1NKfqXPG5du0aUkoq2/8ZKt9vDXDu3LkoKWW1guSKzQAJIVyAn4E3pZRxOQ+bOMVkDCAp5SJgEUDDhg3l2bNni0vFckFISAhBQUGWVqNUUXWu+AQFBREbG8vRo0ctrUqpU9l+awAhxBVz5IrFLAshbNGMz3Ip5S8mRMKB2gb7tYAbxVG2QqFQKMoOUVHmyxaHF5wAlgCnpZSf5yG2ARia5Q0XCNyVUubqflMoFApF+eXqVWjZ0nz54uiCaw8MAU4IIY5mpU0FHgCQUn4DbAb6ABeARODlYihXoVAoFGWEyEjo3h3Cw80/p8gGSEq5B9NjPIYyEhhT1LIUCoVCUba4GHORLadDWTL2Zc6dgxYt4Ngx884tVi84hUKhUFQeIu5F0O3H7lyKvQjOt/Dzm8y2bVC9unnni7K8IF1+XnCZmZmEh4eTkJBQylqVLMnJyTg4OFhajVJF1bnic/PmTaSU1KhhPP3P1tYWb29vXF1dLaRZyVORveCG/PISy078qO2kOxD6zEk6NquPEOJvKWWbgs4vty2gqKgohBA0bNiwQvnY37t3jypVqlhajVJF1bniY2VlRXp6Oo0aNdKnSSlJSkri+vXrABXaCFVEMjMh/ddgENegzi6CO66hY7P6hcqj3D65Y2Nj8fHxqVDGR6GoTAghcHJyombNmkRGRlpaHUUhkBLeegtW/lgFp3WbmR/4O2O69yt0PuW2BZSRkYGtra2l1VAoFEXE0dGRtLQ0S6uhKAQzZsBXX4GdHfz6swPdunW5r3zKrQEC7Q1KoVCUb9T/uHwgpWTKH1NIOdGPL6Z1wMoKVqyAbt3uP0/Vf6VQKBSKAvko9CPm7J3DF9E9oMFmFi2CAQOKlqcyQAqFQqHIl+jEaL7c+422Y5tEq5dWMmJE0fNVBkihUCgU+XL2aFWS5u+BmLrUz+jF/smLiyVfZYAUCoVCkSdhYdC3L6RE+PF84j6OvbcWO2u73IJ//glnzhQqb2WAFADExMTg4+PDv//+a5b8oEGD+PzzvGLPlj0qev0UiuJGSsmVK9CzJ8TGwpNPwo8LauBi75xb+K+/4PHHoUMHuHzZ7DKUAbIAt27d4q233qJBgwY4ODjg7e1Nu3btmDdvHvHx8QAMGzYMIUSuLTAwUJ/PsGHD6Ncvb9/7oKAgxo4dmyt96dKluLi4GKXNnj2bPn368OCDD5pVh2nTpjFz5kzu3r1rlnxJcvjwYYQQXM7nxi/P9VMoSpujN48SsLA9nfuHc+MGdOqkebzZmPKbPnkS+vSBxETo1w8KseigMkClzOXLl2nVqhVbt25lxowZ/PPPP/z5559MnDiRP/74g82bN+tlu3XrRkREhNFmeLy4SExMZPHixYwoxKhis2bNqF+/PsuWLSt2fczl1KlTDB06lEGDBgHa9Ro+fDg5wzeV1/opFJbgwp0L9PypF4dv7edS5/Y81P4cv/4KJiNHXb4MPXpATAw88QQsXgyFCA5QYQyQEJbZCsvo0aOxsrLi8OHDPPfcczRu3JimTZsyYMAA1q9fz9NPP62Xtbe3p3r16kabp6dnMV41jc2bN2NlZUX79u2N0ufOnWuyFfbhhx8C8MQTT7BixYpi16egcgF++eUXWrRoQUJCAm+++SYAU6ZMITo6mmbNmrF+/foyWz+Foixz9PpJbsffAcDK6S7BC5NwczMheOuWtv6Crom0cmUeTaS8qTAGqDxw584dtm3bxpgxY3B2NtGPimUm5e3evZvWrVvnKnv06NFGra8JEyZQvXp1hg4dCkBAQAAHDx4kKSkpV56zZ8/GxcUl32337t0m9Smo3NTUVF577TV69uzJzz//TIcOHQDo2rUrv/76K926dWPkyJH62fUlUT+FoiKSkQGrZ/RHLv8NkeTF//ptpGuTFrkF09K0brcLF7QV6H79FRwdC11euY6EYEgZDuqt5/z580gpadiwoVF6rVq1iI2NBeDZZ59lyZIlAGzdujXXWM2YMWOYM2dOsep15cqVXFGKAapUqaIPmDlnzhxWrFhBSEgIfn5+APj6+pKWlsaNGzdyja2MGjWKZ555Jt9ya9asaTK9oHJPnDhBVFQUQ4YMMXn+kCFD2LJlCydOnKBVq1YlUj+FoqIhJYwbB2vWgKtrT7YMukS7Ni6mhW1t4fXX4T//ga1bMd1EKpgKY4DKM7t37yYjI4ORI0eSnJysT+/YsSOLFi0yknV3dy/28pOSkvDx8cnz+Mcff0xwcDA7d+7E399fn+6Y9cZjqoXg6elZ5O7CvMrVLSGSV2tRF6BWJ1cS9VMoKgqpGalYCStmzbBhwQKwt4cNG8jb+OgYMQJefFE74T5RXXCliJ+fH0IIzuTwla9Xrx5+fn44OTkZpTs5OeHn52e0eXl5mV2eq6urSS+u2NhY3AzeWLy8vIiJiTGZx6xZs1iwYAGhoaFGD2fQuhQBqlWrluu8onTBFVRu8+bNqVq1KsuXLzd57rJly/Dy8qJZs2YlVj+FoiKQkZnB0HVDaTVnINNnJmFlpQ3ldOpkQjgzEyZMgOPHs9OKYHxAtYBKlapVq9KjRw+Cg4N54403cnWvFTcNGzZk8+bNSCmNWgv//POPUTdgy5YtWbp0aa7zZ8yYwbfffktISIjJLqiwsDB8fX1Nti6K0gVXULl2dnYsWLCAwYMH8+yzz/LYY48BEBoayttvv83WrVtZuXIldnZ2JVY/haK8I6XkjS1vsOrkKi3hxd4Et9/Ck0+aGMuRUjM+X3wBq1ZpYz/FsKCiMkClzIIFC2jfvj2tW7dm+vTptGjRAhsbG/7++2+OHTtG586d9bIpKSncvHnT6Hxra2ujN/K4uDiOHj1qJOPu7k7dunUZPXq03ti9+uqrODg4sHnzZlasWMGvv/6ql+/ZsyeTJk0iOjqaqlWrAlrL4Msvv2TDhg04Ozvr9XB3d9ev5Ll792569eplsp732wVnTrkAzzzzDI0aNeKTTz5h7ty5gDZ3p2PHjhw7dozGjRuXaP0UiorAnYjsRQAfrd+MUSPyMCqzZ2vGx9YWvvuuWIwPoFnBsrr5+/vLvDh16lSex8o6ERERcty4cfLBBx+UdnZ20tnZWbZp00bOnj1bXr9+XUop5UsvvSSBXFvNmjX1+eQlM3DgQL3MwYMHZY8ePaS3t7d0dXWVAQEBct26dbl0CgwMlMHBwVJKKTMzM6Wrq6vJvHfs2CGllDIpKUm6urrK/fv3F/l6xMXFmV2uKQ4dOiQBeenSpTxlLFk/U+jqXFk4c+aMDAsLy/N4ef4/F8TOnTstrYJJDh6U0tlZStrPkQ9NHSzTMzJMC379tZQgpRBSrlplVt7AYWnGM97iRia/raIaoPyw1INpy5Yt0t/fX6anp5slHxwcLLt3714sZZdGnS1ZP1MoA2RMRf0/S1k2DdCZM1J6eWkW4MUXpUxPzzQtuHKlZnhAym++KTjjjAwpk5PNNkDF4oQghPhOCBE
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – this cell generates and saves Figure 11– 4\n",
"\n",
"def swish(z, beta=1):\n",
" return z * sigmoid(beta * z)\n",
"\n",
"def approx_gelu(z):\n",
" return swish(z, beta=1.702)\n",
"\n",
"def softplus(z):\n",
" return np.log(1 + np.exp(z))\n",
"\n",
"def mish(z):\n",
" return z * np.tanh(softplus(z))\n",
"\n",
"z = np.linspace(-4, 2, 200)\n",
"\n",
"beta = 0.6\n",
"plt.plot(z, approx_gelu(z), \"b-\", linewidth=2,\n",
" label=r\"GELU$(z) = z\\,\\Phi(z)$\")\n",
"plt.plot(z, swish(z), \"r--\", linewidth=2,\n",
" label=r\"Swish$(z) = z\\,\\sigma(z)$\")\n",
"plt.plot(z, swish(z, beta), \"r:\", linewidth=2,\n",
" label=fr\"Swish$_{{\\beta={beta}}}(z)=z\\,\\sigma({beta}\\,z)$\")\n",
"plt.plot(z, mish(z), \"g:\", linewidth=3,\n",
" label=fr\"Mish$(z) = z\\,\\tanh($softplus$(z))$\")\n",
"plt.plot([-4, 2], [0, 0], 'k-')\n",
"plt.plot([0, 0], [-2.2, 3.2], 'k-')\n",
"plt.grid(True)\n",
"plt.axis([-4, 2, -1, 2])\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.xlabel(\"$z$\")\n",
"plt.legend(loc=\"upper left\")\n",
"\n",
"save_fig(\"gelu_swish_mish_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Batch Normalization"
2016-09-27 23:31:21 +02:00
]
},
2017-04-30 10:21:27 +02:00
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 25,
2017-06-21 15:35:47 +02:00
"metadata": {},
2019-02-17 13:31:28 +01:00
"outputs": [],
2017-04-30 10:21:27 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code - clear the name counters and set the random seed\n",
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)"
2017-04-30 10:21:27 +02:00
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 26,
2018-03-24 22:50:29 +01:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Dense(300, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 27,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"flatten (Flatten) (None, 784) 0 \n",
"_________________________________________________________________\n",
"batch_normalization (BatchNo (None, 784) 3136 \n",
"_________________________________________________________________\n",
"dense (Dense) (None, 300) 235500 \n",
"_________________________________________________________________\n",
"batch_normalization_1 (Batch (None, 300) 1200 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 100) 30100 \n",
"_________________________________________________________________\n",
"batch_normalization_2 (Batch (None, 100) 400 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 10) 1010 \n",
"=================================================================\n",
"Total params: 271,346\n",
"Trainable params: 268,978\n",
"Non-trainable params: 2,368\n",
"_________________________________________________________________\n"
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"model.summary()"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 28,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[('batch_normalization/gamma:0', True),\n",
" ('batch_normalization/beta:0', True),\n",
" ('batch_normalization/moving_mean:0', False),\n",
" ('batch_normalization/moving_variance:0', False)]"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 28,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"[(var.name, var.trainable) for var in model.layers[1].variables]"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 29,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.5559 - accuracy: 0.8094 - val_loss: 0.4016 - val_accuracy: 0.8558\n",
"Epoch 2/2\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.4083 - accuracy: 0.8561 - val_loss: 0.3676 - val_accuracy: 0.8650\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fa5d11505b0>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 29,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – just show that the model works! 😊\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"sgd\",\n",
" metrics=\"accuracy\")\n",
"model.fit(X_train, y_train, epochs=2, validation_data=(X_valid, y_valid))"
2016-09-27 23:31:21 +02:00
]
},
2017-04-30 10:21:27 +02:00
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-04-30 10:21:27 +02:00
"source": [
2019-03-25 05:03:44 +01:00
"Sometimes applying BN before the activation function works better (there's a debate on this topic). Moreover, the layer before a `BatchNormalization` layer does not need to have bias terms, since the `BatchNormalization` layer some as well, it would be a waste of parameters, so you can set `use_bias=False` when creating those layers:"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 30,
2022-02-19 06:19:26 +01:00
"metadata": {},
"outputs": [],
"source": [
"# extra code - clear the name counters and set the random seed\n",
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 31,
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-11-03 13:43:56 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
2022-02-19 06:19:26 +01:00
" tf.keras.layers.Dense(300, kernel_initializer=\"he_normal\", use_bias=False),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Activation(\"relu\"),\n",
2022-02-19 06:19:26 +01:00
" tf.keras.layers.Dense(100, kernel_initializer=\"he_normal\", use_bias=False),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Activation(\"relu\"),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
2019-02-17 13:31:28 +01:00
"])"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 32,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.6063 - accuracy: 0.7993 - val_loss: 0.4296 - val_accuracy: 0.8418\n",
"Epoch 2/2\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4275 - accuracy: 0.8500 - val_loss: 0.3752 - val_accuracy: 0.8646\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fa5fdd309d0>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 32,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – just show that the model works! 😊\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"sgd\",\n",
" metrics=\"accuracy\")\n",
"model.fit(X_train, y_train, epochs=2, validation_data=(X_valid, y_valid))"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
"## Gradient Clipping"
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"All `tf.keras.optimizers` accept `clipnorm` or `clipvalue` arguments:"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 33,
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"optimizer = tf.keras.optimizers.SGD(clipvalue=1.0)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer)"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 34,
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"optimizer = tf.keras.optimizers.SGD(clipnorm=1.0)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer)"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"## Reusing Pretrained Layers"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "markdown",
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"### Reusing a Keras model"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"Let's split the fashion MNIST training set in two:\n",
2022-02-19 06:19:26 +01:00
"* `X_train_A`: all images of all items except for T-shirts/tops and pullovers (classes 0 and 2).\n",
"* `X_train_B`: a much smaller training set of just the first 200 images of T-shirts/tops and pullovers.\n",
2019-02-17 13:31:28 +01:00
"\n",
"The validation set and the test set are also split this way, but without restricting the number of images.\n",
"\n",
2022-02-19 06:19:26 +01:00
"We will train a model on set A (classification task with 8 classes), and try to reuse it to tackle set B (binary classification). We hope to transfer a little bit of knowledge from task A to task B, since classes in set A (trousers, dresses, coats, sandals, shirts, sneakers, bags, and ankle boots) are somewhat similar to classes in set B (T-shirts/tops and pullovers). However, since we are using `Dense` layers, only patterns that occur at the same location can be reused (in contrast, convolutional layers will transfer much better, since learned patterns can be detected anywhere on the image, as we will see in the chapter 14)."
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 35,
2018-03-24 22:50:29 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"1376/1376 [==============================] - 1s 908us/step - loss: 1.1385 - accuracy: 0.6260 - val_loss: 0.7101 - val_accuracy: 0.7603\n",
"Epoch 2/20\n",
"1376/1376 [==============================] - 1s 869us/step - loss: 0.6221 - accuracy: 0.7911 - val_loss: 0.5293 - val_accuracy: 0.8315\n",
"Epoch 3/20\n",
"1376/1376 [==============================] - 1s 852us/step - loss: 0.5016 - accuracy: 0.8394 - val_loss: 0.4515 - val_accuracy: 0.8581\n",
"Epoch 4/20\n",
"1376/1376 [==============================] - 1s 852us/step - loss: 0.4381 - accuracy: 0.8583 - val_loss: 0.4055 - val_accuracy: 0.8669\n",
"Epoch 5/20\n",
"1376/1376 [==============================] - 1s 844us/step - loss: 0.3979 - accuracy: 0.8692 - val_loss: 0.3748 - val_accuracy: 0.8706\n",
"Epoch 6/20\n",
"1376/1376 [==============================] - 1s 882us/step - loss: 0.3693 - accuracy: 0.8782 - val_loss: 0.3538 - val_accuracy: 0.8787\n",
"Epoch 7/20\n",
"1376/1376 [==============================] - 1s 863us/step - loss: 0.3487 - accuracy: 0.8825 - val_loss: 0.3376 - val_accuracy: 0.8834\n",
"Epoch 8/20\n",
"1376/1376 [==============================] - 2s 1ms/step - loss: 0.3324 - accuracy: 0.8879 - val_loss: 0.3315 - val_accuracy: 0.8847\n",
"Epoch 9/20\n",
"1376/1376 [==============================] - 1s 1ms/step - loss: 0.3198 - accuracy: 0.8920 - val_loss: 0.3174 - val_accuracy: 0.8879\n",
"Epoch 10/20\n",
"1376/1376 [==============================] - 2s 1ms/step - loss: 0.3088 - accuracy: 0.8947 - val_loss: 0.3118 - val_accuracy: 0.8904\n",
"Epoch 11/20\n",
"1376/1376 [==============================] - 1s 1ms/step - loss: 0.2994 - accuracy: 0.8979 - val_loss: 0.3039 - val_accuracy: 0.8925\n",
"Epoch 12/20\n",
"1376/1376 [==============================] - 1s 837us/step - loss: 0.2918 - accuracy: 0.8999 - val_loss: 0.2998 - val_accuracy: 0.8952\n",
"Epoch 13/20\n",
"1376/1376 [==============================] - 1s 840us/step - loss: 0.2852 - accuracy: 0.9016 - val_loss: 0.2932 - val_accuracy: 0.8980\n",
"Epoch 14/20\n",
"1376/1376 [==============================] - 1s 799us/step - loss: 0.2788 - accuracy: 0.9034 - val_loss: 0.2865 - val_accuracy: 0.8990\n",
"Epoch 15/20\n",
"1376/1376 [==============================] - 1s 922us/step - loss: 0.2736 - accuracy: 0.9052 - val_loss: 0.2824 - val_accuracy: 0.9015\n",
"Epoch 16/20\n",
"1376/1376 [==============================] - 1s 835us/step - loss: 0.2686 - accuracy: 0.9068 - val_loss: 0.2796 - val_accuracy: 0.9015\n",
"Epoch 17/20\n",
"1376/1376 [==============================] - 1s 863us/step - loss: 0.2641 - accuracy: 0.9085 - val_loss: 0.2748 - val_accuracy: 0.9015\n",
"Epoch 18/20\n",
"1376/1376 [==============================] - 1s 913us/step - loss: 0.2596 - accuracy: 0.9101 - val_loss: 0.2729 - val_accuracy: 0.9037\n",
"Epoch 19/20\n",
"1376/1376 [==============================] - 1s 909us/step - loss: 0.2558 - accuracy: 0.9119 - val_loss: 0.2715 - val_accuracy: 0.9040\n",
"Epoch 20/20\n",
"1376/1376 [==============================] - 1s 859us/step - loss: 0.2520 - accuracy: 0.9125 - val_loss: 0.2728 - val_accuracy: 0.9027\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-12-15 16:22:23.274500: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_model_A/assets\n"
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – split Fashion MNIST into tasks A and B, then train and save\n",
"# model A to \"my_model_A\".\n",
"\n",
"pos_class_id = class_names.index(\"Pullover\")\n",
"neg_class_id = class_names.index(\"T-shirt/top\")\n",
"\n",
2019-02-17 13:31:28 +01:00
"def split_dataset(X, y):\n",
2022-02-19 06:19:26 +01:00
" y_for_B = (y == pos_class_id) | (y == neg_class_id)\n",
" y_A = y[~y_for_B]\n",
" y_B = (y[y_for_B] == pos_class_id).astype(np.float32)\n",
" old_class_ids = list(set(range(10)) - set([neg_class_id, pos_class_id]))\n",
" for old_class_id, new_class_id in zip(old_class_ids, range(8)):\n",
" y_A[y_A == old_class_id] = new_class_id # reorder class ids for A\n",
" return ((X[~y_for_B], y_A), (X[y_for_B], y_B))\n",
2019-02-17 13:31:28 +01:00
"\n",
"(X_train_A, y_train_A), (X_train_B, y_train_B) = split_dataset(X_train, y_train)\n",
"(X_valid_A, y_valid_A), (X_valid_B, y_valid_B) = split_dataset(X_valid, y_valid)\n",
"(X_test_A, y_test_A), (X_test_B, y_test_B) = split_dataset(X_test, y_test)\n",
"X_train_B = X_train_B[:200]\n",
2022-02-19 06:19:26 +01:00
"y_train_B = y_train_B[:200]\n",
"\n",
"tf.random.set_seed(42)\n",
"\n",
"model_A = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(8, activation=\"softmax\")\n",
"])\n",
"\n",
"model_A.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
" metrics=[\"accuracy\"])\n",
"history = model_A.fit(X_train_A, y_train_A, epochs=20,\n",
" validation_data=(X_valid_A, y_valid_A))\n",
"model_A.save(\"my_model_A\")"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 36,
2018-03-24 22:50:29 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"7/7 [==============================] - 0s 20ms/step - loss: 0.7167 - accuracy: 0.5450 - val_loss: 0.7052 - val_accuracy: 0.5272\n",
"Epoch 2/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.6805 - accuracy: 0.5800 - val_loss: 0.6758 - val_accuracy: 0.6004\n",
"Epoch 3/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.6532 - accuracy: 0.6650 - val_loss: 0.6530 - val_accuracy: 0.6746\n",
"Epoch 4/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.6289 - accuracy: 0.7150 - val_loss: 0.6317 - val_accuracy: 0.7517\n",
"Epoch 5/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.6079 - accuracy: 0.7800 - val_loss: 0.6105 - val_accuracy: 0.8091\n",
"Epoch 6/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.5866 - accuracy: 0.8400 - val_loss: 0.5913 - val_accuracy: 0.8447\n",
"Epoch 7/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.5670 - accuracy: 0.8850 - val_loss: 0.5728 - val_accuracy: 0.8833\n",
"Epoch 8/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.5499 - accuracy: 0.8900 - val_loss: 0.5571 - val_accuracy: 0.8971\n",
"Epoch 9/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.5331 - accuracy: 0.9150 - val_loss: 0.5427 - val_accuracy: 0.9050\n",
"Epoch 10/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.5180 - accuracy: 0.9250 - val_loss: 0.5290 - val_accuracy: 0.9080\n",
"Epoch 11/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.5038 - accuracy: 0.9350 - val_loss: 0.5160 - val_accuracy: 0.9189\n",
"Epoch 12/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4903 - accuracy: 0.9350 - val_loss: 0.5032 - val_accuracy: 0.9228\n",
"Epoch 13/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.4770 - accuracy: 0.9400 - val_loss: 0.4925 - val_accuracy: 0.9228\n",
"Epoch 14/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4656 - accuracy: 0.9450 - val_loss: 0.4817 - val_accuracy: 0.9258\n",
"Epoch 15/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4546 - accuracy: 0.9550 - val_loss: 0.4708 - val_accuracy: 0.9298\n",
"Epoch 16/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4435 - accuracy: 0.9550 - val_loss: 0.4608 - val_accuracy: 0.9318\n",
"Epoch 17/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4330 - accuracy: 0.9600 - val_loss: 0.4510 - val_accuracy: 0.9337\n",
"Epoch 18/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4226 - accuracy: 0.9600 - val_loss: 0.4406 - val_accuracy: 0.9367\n",
"Epoch 19/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4119 - accuracy: 0.9600 - val_loss: 0.4311 - val_accuracy: 0.9377\n",
"Epoch 20/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.4025 - accuracy: 0.9600 - val_loss: 0.4225 - val_accuracy: 0.9367\n",
"63/63 [==============================] - 0s 728us/step - loss: 0.4317 - accuracy: 0.9185\n"
]
},
{
"data": {
"text/plain": [
"[0.43168652057647705, 0.9185000061988831]"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 36,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – train and evaluate model B, without reusing model A\n",
"\n",
"tf.random.set_seed(42)\n",
"model_B = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(1, activation=\"sigmoid\")\n",
"])\n",
"\n",
"model_B.compile(loss=\"binary_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
" metrics=[\"accuracy\"])\n",
"history = model_B.fit(X_train_B, y_train_B, epochs=20,\n",
" validation_data=(X_valid_B, y_valid_B))\n",
"model_B.evaluate(X_test_B, y_test_B)"
2017-06-05 18:48:03 +02:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"Model B reaches 91.85% accuracy on the test set. Now let's try reusing the pretrained model A."
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 37,
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-11-03 13:43:56 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"model_A = tf.keras.models.load_model(\"my_model_A\")\n",
"model_B_on_A = tf.keras.Sequential(model_A.layers[:-1])\n",
"model_B_on_A.add(tf.keras.layers.Dense(1, activation=\"sigmoid\"))"
2017-06-05 18:48:03 +02:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"Note that `model_B_on_A` and `model_A` actually share layers now, so when we train one, it will update both models. If we want to avoid that, we need to build `model_B_on_A` on top of a *clone* of `model_A`:"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 38,
2017-06-21 15:35:47 +02:00
"metadata": {},
2019-02-17 13:31:28 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"tf.random.set_seed(42) # extra code – ensure reproducibility"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 39,
2017-06-21 15:35:47 +02:00
"metadata": {},
2019-02-17 13:31:28 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"model_A_clone = tf.keras.models.clone_model(model_A)\n",
"model_A_clone.set_weights(model_A.get_weights())"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 40,
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – creating model_B_on_A just like in the previous cell\n",
"model_B_on_A = tf.keras.Sequential(model_A_clone.layers[:-1])\n",
"model_B_on_A.add(tf.keras.layers.Dense(1, activation=\"sigmoid\"))"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 41,
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"for layer in model_B_on_A.layers[:-1]:\n",
" layer.trainable = False\n",
"\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)\n",
"model_B_on_A.compile(loss=\"binary_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 42,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/4\n",
"7/7 [==============================] - 0s 23ms/step - loss: 1.7893 - accuracy: 0.5550 - val_loss: 1.3324 - val_accuracy: 0.5084\n",
"Epoch 2/4\n",
"7/7 [==============================] - 0s 7ms/step - loss: 1.1235 - accuracy: 0.5350 - val_loss: 0.9199 - val_accuracy: 0.4807\n",
"Epoch 3/4\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.8836 - accuracy: 0.5000 - val_loss: 0.8266 - val_accuracy: 0.4837\n",
"Epoch 4/4\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.8202 - accuracy: 0.5250 - val_loss: 0.7795 - val_accuracy: 0.4985\n",
"Epoch 1/16\n",
"7/7 [==============================] - 0s 21ms/step - loss: 0.7348 - accuracy: 0.6050 - val_loss: 0.6372 - val_accuracy: 0.6914\n",
"Epoch 2/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.6055 - accuracy: 0.7600 - val_loss: 0.5283 - val_accuracy: 0.8229\n",
"Epoch 3/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.4992 - accuracy: 0.8400 - val_loss: 0.4742 - val_accuracy: 0.8180\n",
"Epoch 4/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4297 - accuracy: 0.8700 - val_loss: 0.4212 - val_accuracy: 0.8773\n",
"Epoch 5/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.3825 - accuracy: 0.9050 - val_loss: 0.3797 - val_accuracy: 0.9031\n",
"Epoch 6/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.3438 - accuracy: 0.9250 - val_loss: 0.3534 - val_accuracy: 0.9149\n",
"Epoch 7/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.3148 - accuracy: 0.9500 - val_loss: 0.3384 - val_accuracy: 0.9001\n",
"Epoch 8/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.3012 - accuracy: 0.9450 - val_loss: 0.3179 - val_accuracy: 0.9209\n",
"Epoch 9/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2767 - accuracy: 0.9650 - val_loss: 0.3043 - val_accuracy: 0.9298\n",
"Epoch 10/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2623 - accuracy: 0.9550 - val_loss: 0.2929 - val_accuracy: 0.9308\n",
"Epoch 11/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2512 - accuracy: 0.9600 - val_loss: 0.2830 - val_accuracy: 0.9327\n",
"Epoch 12/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2397 - accuracy: 0.9600 - val_loss: 0.2744 - val_accuracy: 0.9318\n",
"Epoch 13/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2295 - accuracy: 0.9600 - val_loss: 0.2675 - val_accuracy: 0.9327\n",
"Epoch 14/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2225 - accuracy: 0.9600 - val_loss: 0.2598 - val_accuracy: 0.9347\n",
"Epoch 15/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2147 - accuracy: 0.9600 - val_loss: 0.2542 - val_accuracy: 0.9357\n",
"Epoch 16/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.2077 - accuracy: 0.9600 - val_loss: 0.2492 - val_accuracy: 0.9377\n"
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"history = model_B_on_A.fit(X_train_B, y_train_B, epochs=4,\n",
" validation_data=(X_valid_B, y_valid_B))\n",
"\n",
"for layer in model_B_on_A.layers[:-1]:\n",
" layer.trainable = True\n",
"\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)\n",
"model_B_on_A.compile(loss=\"binary_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model_B_on_A.fit(X_train_B, y_train_B, epochs=16,\n",
" validation_data=(X_valid_B, y_valid_B))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So, what's the final verdict?"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 43,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"63/63 [==============================] - 0s 667us/step - loss: 0.2546 - accuracy: 0.9385\n"
]
},
{
"data": {
"text/plain": [
"[0.2546142041683197, 0.9384999871253967]"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 43,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"model_B_on_A.evaluate(X_test_B, y_test_B)"
2017-06-05 18:48:03 +02:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"Great! We got a bit of transfer: the model's accuracy went up 2 percentage points, from 91.85% to 93.85%. This means the error rate dropped by almost 25%:"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 44,
2018-03-24 22:50:29 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.24539877300613477"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 44,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"1 - (100 - 93.85) / (100 - 91.85)"
2017-06-05 18:48:03 +02:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2019-02-17 13:31:28 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# Faster Optimizers"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 45,
2017-06-21 15:35:47 +02:00
"metadata": {},
2019-02-17 13:31:28 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – a little function to test an optimizer on Fashion MNIST\n",
"\n",
"def build_model(seed=42):\n",
" tf.random.set_seed(seed)\n",
" return tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
" ])\n",
"\n",
"def build_and_train_model(optimizer):\n",
" model = build_model()\n",
" model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
" return model.fit(X_train, y_train, epochs=10,\n",
" validation_data=(X_valid, y_valid))"
2017-06-05 18:48:03 +02:00
]
},
2021-10-07 06:41:46 +02:00
{
2022-02-19 06:19:26 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 46,
2021-10-07 06:41:46 +02:00
"metadata": {},
2022-02-19 06:19:26 +01:00
"outputs": [],
2021-10-07 06:41:46 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9)"
2021-10-07 06:41:46 +02:00
]
},
2017-06-05 18:48:03 +02:00
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 47,
2018-03-24 22:50:29 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.6877 - accuracy: 0.7677 - val_loss: 0.4960 - val_accuracy: 0.8172\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 948us/step - loss: 0.4619 - accuracy: 0.8378 - val_loss: 0.4421 - val_accuracy: 0.8404\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 1s 868us/step - loss: 0.4179 - accuracy: 0.8525 - val_loss: 0.4188 - val_accuracy: 0.8538\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 1s 866us/step - loss: 0.3902 - accuracy: 0.8621 - val_loss: 0.3814 - val_accuracy: 0.8604\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 1s 869us/step - loss: 0.3686 - accuracy: 0.8691 - val_loss: 0.3665 - val_accuracy: 0.8656\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 925us/step - loss: 0.3553 - accuracy: 0.8732 - val_loss: 0.3643 - val_accuracy: 0.8720\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 908us/step - loss: 0.3385 - accuracy: 0.8778 - val_loss: 0.3611 - val_accuracy: 0.8684\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 926us/step - loss: 0.3297 - accuracy: 0.8796 - val_loss: 0.3490 - val_accuracy: 0.8726\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 893us/step - loss: 0.3200 - accuracy: 0.8850 - val_loss: 0.3625 - val_accuracy: 0.8666\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 886us/step - loss: 0.3097 - accuracy: 0.8881 - val_loss: 0.3656 - val_accuracy: 0.8672\n"
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"history_sgd = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Momentum optimization"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 48,
2017-06-21 15:35:47 +02:00
"metadata": {},
2019-02-17 13:31:28 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9)"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 49,
2018-03-24 22:50:29 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 941us/step - loss: 0.6877 - accuracy: 0.7677 - val_loss: 0.4960 - val_accuracy: 0.8172\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 878us/step - loss: 0.4619 - accuracy: 0.8378 - val_loss: 0.4421 - val_accuracy: 0.8404\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 898us/step - loss: 0.4179 - accuracy: 0.8525 - val_loss: 0.4188 - val_accuracy: 0.8538\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 934us/step - loss: 0.3902 - accuracy: 0.8621 - val_loss: 0.3814 - val_accuracy: 0.8604\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 910us/step - loss: 0.3686 - accuracy: 0.8691 - val_loss: 0.3665 - val_accuracy: 0.8656\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 913us/step - loss: 0.3553 - accuracy: 0.8732 - val_loss: 0.3643 - val_accuracy: 0.8720\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 893us/step - loss: 0.3385 - accuracy: 0.8778 - val_loss: 0.3611 - val_accuracy: 0.8684\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 968us/step - loss: 0.3297 - accuracy: 0.8796 - val_loss: 0.3490 - val_accuracy: 0.8726\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 913us/step - loss: 0.3200 - accuracy: 0.8850 - val_loss: 0.3625 - val_accuracy: 0.8666\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 1s 858us/step - loss: 0.3097 - accuracy: 0.8881 - val_loss: 0.3656 - val_accuracy: 0.8672\n"
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"history_momentum = build_and_train_model(optimizer) # extra code"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"## Nesterov Accelerated Gradient"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 50,
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9,\n",
" nesterov=True)"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 51,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 907us/step - loss: 0.6777 - accuracy: 0.7711 - val_loss: 0.4796 - val_accuracy: 0.8260\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 898us/step - loss: 0.4570 - accuracy: 0.8398 - val_loss: 0.4358 - val_accuracy: 0.8396\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 1s 872us/step - loss: 0.4140 - accuracy: 0.8537 - val_loss: 0.4013 - val_accuracy: 0.8566\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 902us/step - loss: 0.3882 - accuracy: 0.8629 - val_loss: 0.3802 - val_accuracy: 0.8616\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 913us/step - loss: 0.3666 - accuracy: 0.8703 - val_loss: 0.3689 - val_accuracy: 0.8638\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 882us/step - loss: 0.3531 - accuracy: 0.8732 - val_loss: 0.3681 - val_accuracy: 0.8688\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 958us/step - loss: 0.3375 - accuracy: 0.8784 - val_loss: 0.3658 - val_accuracy: 0.8670\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 895us/step - loss: 0.3278 - accuracy: 0.8815 - val_loss: 0.3598 - val_accuracy: 0.8682\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 878us/step - loss: 0.3183 - accuracy: 0.8855 - val_loss: 0.3472 - val_accuracy: 0.8720\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 921us/step - loss: 0.3081 - accuracy: 0.8891 - val_loss: 0.3624 - val_accuracy: 0.8708\n"
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"history_nesterov = build_and_train_model(optimizer) # extra code"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"## AdaGrad"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 52,
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-11-03 13:43:56 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.001)"
2017-06-05 18:48:03 +02:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 53,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 1.0003 - accuracy: 0.6822 - val_loss: 0.6876 - val_accuracy: 0.7744\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 912us/step - loss: 0.6389 - accuracy: 0.7904 - val_loss: 0.5837 - val_accuracy: 0.8048\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 930us/step - loss: 0.5682 - accuracy: 0.8105 - val_loss: 0.5379 - val_accuracy: 0.8154\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 878us/step - loss: 0.5316 - accuracy: 0.8215 - val_loss: 0.5135 - val_accuracy: 0.8244\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 1s 855us/step - loss: 0.5076 - accuracy: 0.8295 - val_loss: 0.4937 - val_accuracy: 0.8288\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 1s 868us/step - loss: 0.4905 - accuracy: 0.8338 - val_loss: 0.4821 - val_accuracy: 0.8312\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 940us/step - loss: 0.4776 - accuracy: 0.8371 - val_loss: 0.4705 - val_accuracy: 0.8348\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 966us/step - loss: 0.4674 - accuracy: 0.8409 - val_loss: 0.4611 - val_accuracy: 0.8362\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 892us/step - loss: 0.4587 - accuracy: 0.8435 - val_loss: 0.4548 - val_accuracy: 0.8406\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 873us/step - loss: 0.4511 - accuracy: 0.8458 - val_loss: 0.4469 - val_accuracy: 0.8424\n"
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"history_adagrad = build_and_train_model(optimizer) # extra code"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"## RMSProp"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 54,
2022-02-19 06:19:26 +01:00
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001, rho=0.9)"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 55,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.5138 - accuracy: 0.8135 - val_loss: 0.4413 - val_accuracy: 0.8338\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 942us/step - loss: 0.3932 - accuracy: 0.8590 - val_loss: 0.4518 - val_accuracy: 0.8370\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 948us/step - loss: 0.3711 - accuracy: 0.8692 - val_loss: 0.3914 - val_accuracy: 0.8686\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 949us/step - loss: 0.3643 - accuracy: 0.8735 - val_loss: 0.4176 - val_accuracy: 0.8644\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 970us/step - loss: 0.3578 - accuracy: 0.8769 - val_loss: 0.3874 - val_accuracy: 0.8696\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3561 - accuracy: 0.8775 - val_loss: 0.4650 - val_accuracy: 0.8590\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3528 - accuracy: 0.8783 - val_loss: 0.4122 - val_accuracy: 0.8774\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 989us/step - loss: 0.3491 - accuracy: 0.8811 - val_loss: 0.5151 - val_accuracy: 0.8586\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3479 - accuracy: 0.8829 - val_loss: 0.4457 - val_accuracy: 0.8856\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 1000us/step - loss: 0.3437 - accuracy: 0.8830 - val_loss: 0.4781 - val_accuracy: 0.8636\n"
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"history_rmsprop = build_and_train_model(optimizer) # extra code"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"## Adam Optimization"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 56,
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9,\n",
" beta_2=0.999)"
2017-06-05 18:48:03 +02:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 57,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4949 - accuracy: 0.8220 - val_loss: 0.4110 - val_accuracy: 0.8428\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3727 - accuracy: 0.8637 - val_loss: 0.4153 - val_accuracy: 0.8370\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3372 - accuracy: 0.8756 - val_loss: 0.3600 - val_accuracy: 0.8708\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3126 - accuracy: 0.8833 - val_loss: 0.3498 - val_accuracy: 0.8760\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2965 - accuracy: 0.8901 - val_loss: 0.3264 - val_accuracy: 0.8794\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2821 - accuracy: 0.8947 - val_loss: 0.3295 - val_accuracy: 0.8782\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2672 - accuracy: 0.8993 - val_loss: 0.3473 - val_accuracy: 0.8790\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2587 - accuracy: 0.9020 - val_loss: 0.3230 - val_accuracy: 0.8818\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2500 - accuracy: 0.9057 - val_loss: 0.3676 - val_accuracy: 0.8744\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2428 - accuracy: 0.9073 - val_loss: 0.3879 - val_accuracy: 0.8696\n"
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"history_adam = build_and_train_model(optimizer) # extra code"
2017-06-05 18:48:03 +02:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"**Adamax Optimization**"
2017-06-05 18:48:03 +02:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 58,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 06:19:26 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"optimizer = tf.keras.optimizers.Adamax(learning_rate=0.001, beta_1=0.9,\n",
" beta_2=0.999)"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 59,
2018-03-24 22:50:29 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.5327 - accuracy: 0.8151 - val_loss: 0.4402 - val_accuracy: 0.8340\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 935us/step - loss: 0.3950 - accuracy: 0.8591 - val_loss: 0.3907 - val_accuracy: 0.8512\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 933us/step - loss: 0.3563 - accuracy: 0.8715 - val_loss: 0.3730 - val_accuracy: 0.8676\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 942us/step - loss: 0.3335 - accuracy: 0.8797 - val_loss: 0.3453 - val_accuracy: 0.8738\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 993us/step - loss: 0.3129 - accuracy: 0.8853 - val_loss: 0.3270 - val_accuracy: 0.8792\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 926us/step - loss: 0.2986 - accuracy: 0.8913 - val_loss: 0.3396 - val_accuracy: 0.8772\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 939us/step - loss: 0.2854 - accuracy: 0.8949 - val_loss: 0.3390 - val_accuracy: 0.8770\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 949us/step - loss: 0.2757 - accuracy: 0.8984 - val_loss: 0.3147 - val_accuracy: 0.8854\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 952us/step - loss: 0.2662 - accuracy: 0.9020 - val_loss: 0.3341 - val_accuracy: 0.8760\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 957us/step - loss: 0.2542 - accuracy: 0.9063 - val_loss: 0.3282 - val_accuracy: 0.8780\n"
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"history_adamax = build_and_train_model(optimizer) # extra code"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2022-02-19 06:19:26 +01:00
"metadata": {
"tags": []
},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"**Nadam Optimization**"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 60,
2022-02-19 06:19:26 +01:00
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Nadam(learning_rate=0.001, beta_1=0.9,\n",
" beta_2=0.999)"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 61,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.4826 - accuracy: 0.8284 - val_loss: 0.4092 - val_accuracy: 0.8456\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3610 - accuracy: 0.8667 - val_loss: 0.3893 - val_accuracy: 0.8592\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3270 - accuracy: 0.8784 - val_loss: 0.3653 - val_accuracy: 0.8712\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3049 - accuracy: 0.8874 - val_loss: 0.3444 - val_accuracy: 0.8726\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2897 - accuracy: 0.8905 - val_loss: 0.3174 - val_accuracy: 0.8810\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2753 - accuracy: 0.8981 - val_loss: 0.3389 - val_accuracy: 0.8830\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2652 - accuracy: 0.9000 - val_loss: 0.3725 - val_accuracy: 0.8734\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2563 - accuracy: 0.9034 - val_loss: 0.3229 - val_accuracy: 0.8828\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2463 - accuracy: 0.9079 - val_loss: 0.3353 - val_accuracy: 0.8818\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2402 - accuracy: 0.9091 - val_loss: 0.3813 - val_accuracy: 0.8740\n"
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"history_nadam = build_and_train_model(optimizer) # extra code"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"**AdamW Optimization**"
2017-06-05 18:48:03 +02:00
]
},
2022-02-21 02:39:49 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2023-11-14 06:20:45 +01:00
"Note: Since TF 1.12, `AdamW` is no longer experimental. It is available at `tf.keras.optimizers.AdamW` instead of `tf.keras.optimizers.experimental.AdamW`."
2022-02-21 02:39:49 +01:00
]
},
2017-06-05 18:48:03 +02:00
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 62,
2022-02-19 06:19:26 +01:00
"metadata": {
"tags": []
},
2019-02-17 13:31:28 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2023-11-14 06:20:45 +01:00
"optimizer = tf.keras.optimizers.AdamW(weight_decay=1e-5, learning_rate=0.001,\n",
" beta_1=0.9, beta_2=0.999)"
2017-06-05 18:48:03 +02:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "code",
2023-11-14 06:20:45 +01:00
"execution_count": 63,
2018-03-24 22:50:29 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.4945 - accuracy: 0.8220 - val_loss: 0.4203 - val_accuracy: 0.8424\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3735 - accuracy: 0.8629 - val_loss: 0.4014 - val_accuracy: 0.8474\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3391 - accuracy: 0.8753 - val_loss: 0.3347 - val_accuracy: 0.8760\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3155 - accuracy: 0.8827 - val_loss: 0.3441 - val_accuracy: 0.8720\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2989 - accuracy: 0.8892 - val_loss: 0.3218 - val_accuracy: 0.8786\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2862 - accuracy: 0.8931 - val_loss: 0.3423 - val_accuracy: 0.8814\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2738 - accuracy: 0.8970 - val_loss: 0.3593 - val_accuracy: 0.8764\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2648 - accuracy: 0.8993 - val_loss: 0.3263 - val_accuracy: 0.8856\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2583 - accuracy: 0.9035 - val_loss: 0.3642 - val_accuracy: 0.8680\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2483 - accuracy: 0.9054 - val_loss: 0.3696 - val_accuracy: 0.8702\n"
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"history_adamw = build_and_train_model(optimizer) # extra code"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-14 06:20:45 +01:00
"execution_count": 64,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAtgAAAHoCAYAAABzQZg1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAD0SElEQVR4nOzdd5gUVdbA4d+t6jzdk4c0wBBmQJAgoIiCSjChILqgIihmXOPnrqirmLMsxtXVlaC4LiYUs5hRMYEkRVByhsl5Ond9f3RPzzQzYKMwifM+zzx0Vd2qvl2Oevpy6hxlGAZCCCGEEEKIA0Nr7AkIIYQQQgjRkkiALYQQQgghxAEkAbYQQgghhBAHkATYQgghhBBCHEASYAshhBBCCHEASYAthBBCCCHEAdSgAbZS6lSl1G9KqfVKqX/Uc/xGpdSKyM8qpVRQKZXakHMUQgghhBDiz1ANVQdbKaUDa4GTgO3AEuA8wzBW72X8aOBvhmEMb5AJCiGEEEIIcQA05Ar2QGC9YRgbDcPwAa8AY/Yx/jzg5QaZmRBCCCGEEAdIQwbYmcC2WtvbI/vqUEo5gFOBNxpgXkIIIYQQQhwwpgZ8L1XPvr3lp4wGvjEMo6jeCyk1GZgMYLPZBnTs2PHAzLAZ0Coq0QsLATBsVgKtW8d9bigUQtPkudbfI/cpPnKf4if3Kj5yn+In9yo+cp/iI/cpfmvXri0wDCPj98Y1ZIC9HehQa7s9sHMvY8ezj/QQwzCeA54D6N69u/Hbb78dqDk2ef68PNYff0J4w2ym23ffoTsT4jp34cKFDB069OBNroWQ+xQfuU/xk3sVH7lP8ZN7FR+5T/GR+xQ/pdSWeMY15NeVJUCOUqqzUspCOIh+Z89BSqkk4ATg7QacW7NhbtUKa48e4Q2/n6rvv2vcCQkhhBBCiBgNFmAbhhEArgE+AtYArxmG8YtS6q9Kqb/WGnoW8LFhGJUNNbfmxnnccdHXFV8vasSZCCGEEEKIPTVkigiGYXwAfLDHvmf32H4BeKHhZtX8OI8bQuFzzwFQ8fVXGIaBUvWluAshhBBCiIYmGe3NkP2II9CcTgACO3fh27ChkWckhBBCCCGqNf8Au4Ea5TQlymwm4dhjo9uSJiKEEEII0XQ0aIrIwWCv2MKviz/hsIEn1Xu8rKyMvLw8/H5/A8/s4ApNuoDgmDMA2Gm1krdmze+ek5SUxJo4xrUkCQkJtG/fXsoPCSGEEKLBNPsA20QA95ePQz0BdllZGbm5uWRmZmK321tUnnLI78dbXZ5QKWzduqF0fZ/nlJeX43K5GmB2TUMoFGLHjh0UFBTQqlWrxp6OEEIIIQ4RLWJZr2/FN2xbt7LO/ry8PDIzM3E4HC0quAbQzGY0qy28YRiEKqXoyp40TaN169aUlpY29lSEEEIIcQhpEQG2pgx2Lni0zn6/34/dbm+EGTUMzeWMvg5VVDTiTJous9lMIBBo7GkIIYQQ4hDSIgJsgL4F71OYu73O/pa2cl1bdSURgGB5BcYh+MDn72nJ//yFEEII0TQ1+wDbgxUAm/Kz9r3HG3cyDUxzOFCRh/cMvw/D52vkGQkhhBBCiOYfYJuTo68P2/YK7sryxptMA1OahpZQK02k/ND57EIIIYQQTVWzD7CLTB6WmcIVIlIo56f3n2nkGR0Y+fn5XHXVVXTq1Amr1Urr1q0ZMWIEn3zySXTMxo0b+ettU+l+8skk9+9Px8MPZ9iwYcyZMwdfrdVspRRKKRITE3E4HHTp0oUJEyawaJHUzxZCCCGEONCafZk+d8jNC1m96L/hcwAy18wmGPg7uql5f7SxY8dSVVXFrFmzyM7OJi8vjy+//JLCwkIAfvzxR0aMGEGPww7jkVtuoXvnzlR5vWz0eJgZOWfw4MHR682YMYOhQ4diNpvZuHEjc+bM4fjjj+fhhx/mxhtvbKyPKYQQQgjR4jTvKDTiW7ayRXOSFaqgvbGL5Z/9j36nXNjY0/rDSkpK+Prrr/nkk08YMWIEAFlZWRx11FEAGIbBhRdeSE5ODt9+9x2+DRswvF4AjsrKYsLEiXUeeExOTqZ169a4XC6ysrIYNmwY7dq145ZbbuGss84iOzu7YT+kEEIIIUQL1exTRAC8IR/PZPaLbtuXNO80EafTidPp5J133sHj8dQ5vmLFClavXs2UKVPQNA3dWTsPO1yuL57qGTfccAOhUIi33nrrgM1dCCGEEOJQ1yJWsAG+shVSgplk/BwWWMOvP3wMiR3qjOv0j/cbYXZhmx86Pa5xJpOJF154gcsvv5znnnuOfv36MXjwYM4++2yOPvpo1q5dC0D37t0B0FwuSjdvJnvECFAKlOLWW2/l1ltv3ef7pKWl0apVKzZu3PjnPpgQQgghhIhq9ivYJhX+jlAeqOCp1jWr2O4vH2+kGR0YY8eOZefOnbz77ruMHDmSb7/9lkGDBvHAAw/UGas5HLhcLr6fN4/vX3+ddm3bxjzkuC+GYUitaCGEEEKIA6jZB9iJemL09WfJXvyR130rvyXgb951oW02GyeddBJ33HEH3377LZdeeil33XUXnTp1AuDXX38FwuX6zC4XXTt2pGvHjljifMCzoKCA/Px8unTpcrA+ghBCCCHEIafZp4gkaAmk2lIp8hRR4Cvm2ZTeXFv8M5oyCLrL6oyPN02jKerZsyeBQIDDDjuMHj16MG3aNM455xx0XUdzOglG6mAboVBc13vkkUfQNI0xY8YczGkLIYQQQhxSmv0KtkJxfo/zo9sftLJSHV5aQu5muYpdWFjI8OHDeemll/jpp5/YtGkTr7/+OtOmTWPEiBEkJSXxwgsvsGHDBo455hjefvtt1u/axa8bN/L8G2+wY9cuNC32H21JSQm5ubls3bqVL774gosuuoiHH36Yhx56SCqICCGEEEIcQM1+BRvgnO7nMPPnmVQFqtjuy+MVZxcmVGxEYeApycWZUfdhx6bM6XQyaNAgnnjiCdavX4/X6yUzM5MJEyZw2223ATBw4ECWLVvGgw8+yLXXXsvu3buxW6306taNO6+9liuuujrmmpdffjkAVquVtm3bMmjQIBYuXMjxxx/f4J9PCCGEEKIlaxEBdpI1ibO7nc2c1XMAmNc2gwnrwpUxbP4iQsF2aLremFPcL1arlQceeKDeBxpry87OZtasWdFt/65dBCKNaEy1Vu6ra2KXl5fjcrkOwoyFEEIIIUS1Zp8iUu2Cnhdg0sLfF9YFdvGptTUAJkJUleY15tQajFarHnYwUg9bCCGEEEI0rBYTYLdOaM2oLqOi2y9kdo6+tngK63Q2bIm0hIRwHWzA8HkJxVmqTwghhBBCHDgtJsAGuPjwi1GEA8yVxk58KvzxLPipKitszKk1CKVp4SA7IhSpKiKEEEIIIRpOiwqwuyR3YViHYdHtct0cfa1X5h8Sq9h6rRzrUIWkiQghhBBCNLQWFWADXNL7kuhrHyG8kRVtGx48lXXrYrc0MXnYlZVx18QWQgghhBAHRosLsPtm9GVA6wEAGBjkme3RY0ZFbmNNq8FoVivKYglvhEKEqqoad0JCCCGEEIeYFhdgA1zSq2YVu1wLEoi8tgcr8XpafsCp11rFDkk1ESGEEEKIBtUiA+zjMo8jJyUHCNeAzjPZgHCBDX/p7sacWoPQauVhByvkQUchhBBCiIbUIgNspRQXH35xdLtUJ9o+3REoa5bt0/dHTLk+r5TrE0IIIYRoSC0ywAY4tfOp6Fq4e2PICJGvWwHQVLh9ektWp1yfVBMRQgghhGgwLTbANmtmnOaaXOQSk0Z1kb5w+/Rg40wsDhdddBFKKS677LI6x2666SaUUowaNaqeM2vE5GE3YoCtlGLevHmN9v5CCCGEEA2txQbYAHaTPbqKHTCCFGnhutjNoX16hw4dePXVV6msrIzuCwQC/Pe//6Vjx46/e762Rz1sKdcnhBBCCNEwWnSArSmNVFtqdLvQZIquYjf19ul9+vQhJyeH1157Lbrv/fffx2azMXTo0Oi+UCjEvffeS4cOHbBarfTu3Zu3334bZbGgzGa27NiB/fDDefnFFxk5ciR2u51+/frx008/sWr
"text/plain": [
"<Figure size 864x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAtgAAAHoCAYAAABzQZg1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAADxsklEQVR4nOzdd3hUVf7H8fednslMei+kQICEXkQURIqKBdsiFixgw13XtmsXe5dVd9Xdn+4KKqxrRbGLlaig0ntCD6GkkN6n398fkwwZkkCAkPp9PQ8PmXvPvXPmEjKfnPnecxRVVRFCCCGEEEK0DU1Hd0AIIYQQQojuRAK2EEIIIYQQbUgCthBCCCGEEG1IArYQQgghhBBtSAK2EEIIIYQQbUgCthBCCCGEEG2oXQO2oihnK4qyVVGUHYqi3NfM/rsVRVlX/2eToihuRVHC2rOPQgghhBBCHA+lvebBVhRFC2wDzgT2ASuBK1RVzWqh/fnAX1RVndguHRRCCCGEEKINtOcI9ihgh6qqu1RVdQDvARcepv0VwLvt0jMhhBBCCCHaSHsG7Hhgb6PH++q3NaEoihk4G/ioHfolhBBCCCFEm9G143MpzWxrqT7lfGCZqqqlzZ5IUWYBswBMJtOIXr16tU0PuzmPx4NGI/e1Holcp9aR69R6cq1aR65T68m1ah25Tq0j16n1tm3bVqyqauSR2rVnwN4HJDZ6nADktdD2cg5THqKq6n+A/wD069dP3bp1a1v1sVvLzMxk/PjxHd2NTk+uU+vIdWo9uVatI9ep9eRatY5cp9aR69R6iqLktqZde/66shJIUxQlRVEUA94Q/dmhjRRFCQZOBz5tx74JIYQQQgjRJtptBFtVVZeiKLcA3wBa4A1VVTcrivLH+v2v1Te9GPhWVdWa9uqbEEIIIYQQbaU9S0RQVfUr4KtDtr12yOO3gLfar1dCCCGEEEK0HaloF0IIIYQQog11/YDdTgvlCCGEEEII0RrtWiJyIugKCnCVlaELDW12f2VlJQcOHMDpdLZzzzqf4OBgsrOzO7ob7SowMJCEhASZfkgIIYQQ7abLB2zF4WTPzGvp9dabTUJ2ZWUlhYWFxMfHExAQgKI0NxV3z1FVVYXVau3obrQbj8fD/v37KS4uJioqqqO7I4QQQogeolsM69m3bmXPddfjLi/3237gwAHi4+Mxm809Plz3RBqNhujoaCoqKjq6K0IIIYToQbp+wK7PzfbsbG/IbhSmnE4nAQEBHdQx0Rno9XpcLldHd0MIIYQQPUiXD9jusHCoH522ZWWx5/obcFdW+vbLyHXPJv/+QgghhGhvXT5geyyBxD75hO+xbdMm9txwI+6qqg7slRBCCCGE6Km6fMAGCJk6lZjHH/M9tm3YwN4bbkT1eDqwV0IIIYQQoifqFgEbIPTSS4l59BHf47r163GXlqK63R3Yq2NXVFTEzTffTHJyMkajkejoaCZNmsR3333na7Nr1y5uuOEGkpKSMBqNxMXFMWHCBObPn4/D4fC1UxQFRVEICgrCbDaTmprK9OnTWbp0aUe8NCGEEEKIbq3LT9PXWOjll6O63RQ+8SQAqsOBIzcXQ1ISilbbwb07OlOnTqW2tpZ58+bRp08fDhw4wE8//URJSQkAq1atYtKkSaSnp/PKK6/Qv39/amtryc7O5vXXX6dPnz6MGTPGd77XX3+d8ePHo9fr2bVrF/Pnz2fcuHE899xz3H333R31MoUQQgghup1uFbABwq68EjwqhU89BYCnthZH7h4MSb26TMguLy/nl19+4bvvvmPSpEkAJCUlcdJJJwGgqiozZswgLS2NX3/91W8RlaFDh3LFFVegHrLCZUhICNHR0VitVpKSkpgwYQJxcXHcf//9XHzxxfTp06f9XqAQQgghRDfWbUpEGgu7+iqi77/P99hTW4Njz54uU5NtsViwWCx89tln2Gy2JvvXrVtHVlYWd911V4srFLZm9ow777wTj8fDJ598crxdFkIIIYQQ9brdCHaDsBkzyFu50vfYU+MN2f1e29xhfdr97HmtaqfT6Xjrrbe48cYb+c9//sOwYcMYM2YM06ZN4+STT2bbtm0A9OvXz3dMRUUF8fHxvscPPPAADzzwwGGfJzw8nKioKHbt2nUMr0YIIYQQQjSnW45gN9BaLOiio32PPdXVHdibozN16lTy8vL4/PPPOeecc/j1118ZPXo0Tz/9dLPtrVYr69atY926dcTFxfnd5Hg4qqrKXNFCCCGEEG2oWwdsAH1kJLqo6CM37IRMJhNnnnkmDz/8ML/++ivXX389jz76KMnJyQBs2bLF11aj0dCnTx/69OmDwWBo1fmLi4spKioiNTX1RHRfCCGEEKJH6rYlIo3poyIBFdeBA2RflQSAxmrFkJiI0kINc2eUkZGBy+Wif//+pKenM2fOHC699FK0x3jz5gsvvIBGo+HCCy9s454KIYQQQvRcPSJgA+ijokAFV9EBADxVVTj37kXfCUN2SUkJ06ZN47rrrmPw4MFYrVZWrVrFnDlzmDRpEsHBwbz11lucccYZnHLKKcyePZv09HTcbjfLli1j3759TUJ3eXk5hYWFlJWVsXPnTubPn8+CBQuYM2eOzCAihBBCCNGGekzABtA1jGQXFQF4l1Pfuw99YkKnCtkWi4XRo0fz0ksvsWPHDux2O/Hx8UyfPp0HH3wQgFGjRrFmzRqeeeYZbr31VgoKCggICGDw4ME89dRT3HDDDX7nvPHGGwEwGo3ExsYyevRoMjMzGTduXLu/PiGEEEKI7qxHBWxFUdA1jGQXN4TsSti3D31C5wnZRqORp59+usUbGhv06dOHefPmHfF8DXNiV1VVYbVa26SPQgghhBCieZ0jUbYjRVHQRUehi4jwbXNXVuLct6/LzJMthBBCCCE6rx4XsKEhZEejCz8kZO/f32QFRCGEEEIIIY5GjwzYUB+yY6LRhYf7trkrKrwj2RKyhRBCCCHEMeqxARsaQnZMMyFbRrKFEEIIIcSx6dEBGxqF7LAw3zZ3RbmUiwghhBBCiGPS4wM21Ifs2Fi0jUN2eTnO/XkSsoUQQgghxFGRgF1PURT0sbFoQ0N929zlZTjzJGQLIYQQQojWk4DdiKIo6OPi/EN2WRnOvHwJ2UIIIYQQolUkYB/CF7JDQnzb3GWlOPMlZAshhBBCiCOTgN0MRVHQx8f7h+zSUlwSsoUQQgghxBFIwG6BL2QHB/u2uUpLcRUUnPCQPXPmTBRF4YYbbmiy75577kFRFKZMmXJC+9BWFEVh4cKFHd0NIYQQQoh2IwH7MBRFQZ+Q4B+yS0raJWQnJiby/vvvU1NTc/C5XS7++9//0qtXrxP63EIIIYQQ4thJwD4CX8gOOiRkFxae0JA9ePBg0tLS+OCDD3zbvvzyS0wmE+PHj/dt83g8PPHEEyQmJmI0Ghk0aBCffvqpb//u3btRFIX33nuPc845h4CAAIYNG8aGDRvYtGkTp556KoGBgYwdO5acnBy/Pnz++eeMGDECk8lESkoKs2fPxuFw+PYnJyfz5JNPctNNNxEUFERCQgJ/+9vf/PYDTJs2DUVRfI8fffRRBg4c6Pdcb731FhaLxfe4oc38+fNJTk7GYrFw7bXX4nA4+L//+z8SExMJDw/nr3/9Kx6P55ivsxBCCCFEW5OA3QrekB2PNijIt81VXHzCQ/b111/PG2+84Xv8xhtvcO2116Ioim/bSy+9xN/+9jeee+45Nm7cyMUXX8wf/vAH1q1b53euRx55hDvuuIO1a9cSEhLC9OnTufXWW3nqqadYsWIFNpuN2267zdf+m2++4corr+SWW25h8+bNvPHGGyxcuJAHHnjA77x///vfGTRoEGvWrOHee+/lnnvu4bfffgNg5cqVALz++uvk5+f7HrfW7t27+fTTT/niiy/46KOP+PDDD7nwwgtZuXIl3377LXPnzuWVV15h0aJFR3VeIYQQQogTSdfRHWh3jwYfuU0zFMBw3M9dcVTNp0+fzl133cX27duxWq0sXryYV155hYcfftjX5vnnn+euu+5i+vTpADz++OP8/PPPPP/887z99tu+dn/
"text/plain": [
"<Figure size 864x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – visualize the learning curves of all the optimizers\n",
"\n",
"for loss in (\"loss\", \"val_loss\"):\n",
" plt.figure(figsize=(12, 8))\n",
" opt_names = \"SGD Momentum Nesterov AdaGrad RMSProp Adam Adamax Nadam AdamW\"\n",
" for history, opt_name in zip((history_sgd, history_momentum, history_nesterov,\n",
" history_adagrad, history_rmsprop, history_adam,\n",
" history_adamax, history_nadam, history_adamw),\n",
" opt_names.split()):\n",
" plt.plot(history.history[loss], label=f\"{opt_name}\", linewidth=3)\n",
"\n",
" plt.grid()\n",
" plt.xlabel(\"Epochs\")\n",
" plt.ylabel({\"loss\": \"Training loss\", \"val_loss\": \"Validation loss\"}[loss])\n",
" plt.legend(loc=\"upper left\")\n",
" plt.axis([0, 9, 0.1, 0.7])\n",
" plt.show()"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"## Learning Rate Scheduling"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"### Power Scheduling"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2023-11-15 06:29:39 +01:00
"```python\n",
"learning_rate = initial_learning_rate / (1 + step / decay_steps)**power\n",
"```\n",
"\n",
"Keras uses `power = 1`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note**: The `decay` argument in optimizers is deprecated. The old optimizers which implement the `decay` argument are still available in `tf.keras.optimizers.legacy`, but you should use the schedulers in `tf.keras.optimizers.schedules` instead."
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-14 06:20:45 +01:00
"execution_count": 65,
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-11-03 13:43:56 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2023-11-15 06:29:39 +01:00
"# DEPRECATED:\n",
"optimizer = tf.keras.optimizers.legacy.SGD(learning_rate=0.01, decay=1e-4)"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2023-11-14 06:20:45 +01:00
"execution_count": 66,
2017-06-21 15:35:47 +02:00
"metadata": {},
2023-11-15 06:29:39 +01:00
"outputs": [],
"source": [
"# RECOMMENDED:\n",
"lr_schedule = tf.keras.optimizers.schedules.InverseTimeDecay(\n",
" initial_learning_rate=0.01,\n",
" decay_steps=10_000,\n",
" decay_rate=1.0,\n",
" staircase=False\n",
")\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `InverseTimeDecay` scheduler uses `learning_rate = initial_learning_rate / (1 + decay_rate * step / decay_steps)`. If you set `staircase=True`, then it replaces `step / decay_step` with `floor(step / decay_step)`."
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.7004 - accuracy: 0.7588 - val_loss: 0.4991 - val_accuracy: 0.8206\n",
2022-02-19 10:24:54 +01:00
"Epoch 2/10\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4781 - accuracy: 0.8316 - val_loss: 0.4477 - val_accuracy: 0.8372\n",
2022-02-19 10:24:54 +01:00
"Epoch 3/10\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4293 - accuracy: 0.8487 - val_loss: 0.4177 - val_accuracy: 0.8498\n",
2022-02-19 10:24:54 +01:00
"Epoch 4/10\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4053 - accuracy: 0.8563 - val_loss: 0.3987 - val_accuracy: 0.8602\n",
2022-02-19 10:24:54 +01:00
"Epoch 5/10\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3864 - accuracy: 0.8633 - val_loss: 0.3859 - val_accuracy: 0.8612\n",
2022-02-19 10:24:54 +01:00
"Epoch 6/10\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3720 - accuracy: 0.8675 - val_loss: 0.3942 - val_accuracy: 0.8584\n",
2022-02-19 10:24:54 +01:00
"Epoch 7/10\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3616 - accuracy: 0.8709 - val_loss: 0.3706 - val_accuracy: 0.8670\n",
2022-02-19 10:24:54 +01:00
"Epoch 8/10\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3529 - accuracy: 0.8741 - val_loss: 0.3758 - val_accuracy: 0.8638\n",
2022-02-19 10:24:54 +01:00
"Epoch 9/10\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3452 - accuracy: 0.8765 - val_loss: 0.3587 - val_accuracy: 0.8680\n",
2022-02-19 10:24:54 +01:00
"Epoch 10/10\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.3379 - accuracy: 0.8793 - val_loss: 0.3569 - val_accuracy: 0.8714\n"
2022-02-19 10:24:54 +01:00
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"history_power_scheduling = build_and_train_model(optimizer) # extra code"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 68,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
2023-11-15 06:29:39 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAk4AAAHNCAYAAADolfQeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAACEuklEQVR4nO3dd1gU1/4G8HfYXZYO0os0WyxYwV6w91iiCZqE6FVzf8ZYiTFqTCwpxhi92I1Go97kGk00RiMaURE1YgXsXRBFEOkdFnZ+fyAbV4orLuwC7+d5iHDmzMx35mTl6zlnzgiiKIogIiIiohcy0HUARERERNUFEyciIiIiDTFxIiIiItIQEyciIiIiDTFxIiIiItIQEyciIiIiDTFxIiIiItIQEyciIiIiDTFxIiIiItIQEyciogqIjo6GIAgYO3asTs7fvXt3CILwSsco6xrGjh0LQRAQHR39SscnqomYOBHVUMW/FJ/9MjQ0hKurK95++21cunRJ1yFWiezsbHz99ddo06YNzMzMYGRkhLp166Jr166YM2cO7t69q+sQiagakeo6ACKqXPXr18e7774LAMjMzMTp06exfft27N69G0ePHkWnTp10HGHlycjIQJcuXXDp0iU0aNAA7777LqysrPDgwQNcvXoV33zzDerXr4/69evrOlS9snjxYsyePRsuLi66DoVI7zBxIqrhGjRogAULFqiVzZs3D1999RU+/fRThISE6CawKhAYGIhLly5h/Pjx2LhxY4mhraioKOTl5ekoOv3l5OQEJycnXYdBpJc4VEdUC02ZMgUAcO7cOVVZQUEB/vOf/6Bly5YwNjaGpaUlevTogf3796vtGxkZCUEQMH36dLXyX3/9FYIgwNTUFPn5+WrbHB0d0aRJE7UyURSxefNmdO7cGRYWFjAxMYGPjw82b95cIt4FCxZAEAQcO3YMW7duhbe3N0xMTNC9e/dyrzMsLAwAMHny5FLnA3l6eqJx48YlyhMSEjBz5ky89tprMDIygrW1NTp06IBly5aVep579+5h5MiRqFOnDkxNTdG7d29cvHix1LoJCQmYMWMGGjRoALlcDltbW4wYMQJXrlwptf7Jkyfh6+sLU1NT2NjYwM/PDw8ePCi1bnlzk569hy9S2nGOHTsGQRCwYMEChIeHo1+/fjA3N4elpSWGDx9e5nyo3bt3w8fHB8bGxnBwcMD777+PlJQUeHh4wMPD44WxEOkbJk5EtdDzSYQoivDz80NAQAByc3Px4YcfquZBDR48GCtXrlTVbdmyJaytrUv0VBX/Qs7OzsaZM2dU5devX8fjx4/Ro0cPtfO9++67GD9+PBITE/H2229jwoQJyMrKwvjx4zFz5sxS4166dCk++OADNGzYEFOnTkWXLl3KvU5ra2sAwJ07d158U566ffs22rRpg2XLlsHe3h7Tpk3D22+/DSMjI3z11Vcl6kdHR6N9+/Z48uQJxo0bhz59+uDIkSPo0aMHHj9+rFb37t278Pb2xooVK9CgQQNMmTIFAwcOxMGDB9GhQwe1+wYAR44cQc+ePXHmzBmMHDkS//73vxEVFYXOnTsjJSVF42vSpvPnz6Nr166QSqX4v//7P/j4+GDPnj3o3bs3cnNz1epu3rwZI0aMwN27d/Hee+9hzJgxCAsLQ58+faBQKHQSP9ErE4moRoqKihIBiP369Sux7dNPPxUBiN27dxdFURS3bdsmAhB9fX3FvLw8Vb0HDx6I9vb2okwmE+/du6cqHz58uCgIgvjkyRNVWZMmTcTu3buLEolEXLhwoap8zZo1IgBx586dqrINGzaIAMTx48eLCoVCVZ6Xlye+/vrrIgDx/PnzqvL58+eLAERTU1Px0qVLGt+DPXv2iABECwsL8ZNPPhGPHDkiJicnl7tPu3btRADihg0bSmx78OCB6vvi+wtA/Oabb9TqzZs3TwQgLl68WK28U6dOolQqFQ8dOqRWfvPmTdHc3Fxs3ry5qqywsFCsV6+eKAiCeOLECVW5UqkU3377bdW5nzVmzBgRgBgVFVUi9uJ7GBISUuIaxowZ88LjhISEqM75yy+/qNX39/cXAYjbt29XlaWkpIhmZmaiubm5ePfuXVW5QqEQe/fuLQIQ3d3dS8RJpO+YOBHVUMW/FOvXry/Onz9fnD9/vvjRRx+JnTt3FgGIRkZG4qlTp0RRFMWePXuKAMQzZ86UOM7ixYtFAOIXX3yhKluxYoUIQPz1119FURTF+Ph4EYD4n//8R2zXrp3o6+urqjty5EgRgPj48WNVWYsWLURTU1MxJyenxPkuXbokAhA/+ugjVVnxL/0ZM2a89H349ttvRTMzM9Uv/eJ78uGHH4q3bt1Sq3v27FkRgNitW7cXHrf4/np6eoqFhYWlbnvjjTdUZeHh4apksTQBAQEiAPHy5cuiKIpiaGioCEB8/fXXS9SNjo4WJRKJThKn0u5N8baAgABV2ZYtW8pss7CwMCZOVG1xcjhRDXf37l0sXLgQACCTyeDg4IC3334bs2fPRvPmzQEAERERMDY2Rrt27UrsXzyPKDIyUlVWPOwWEhKCkSNHqobtevTogfj4eAQGBiI3NxdyuRyhoaFo1qwZ7O3tARQN5V2+fBnOzs745ptvSpyveAjnxo0bJbaVFt+LfPzxx5g4cSIOHjyIU6dO4fz58zhz5gzWrFmDTZs2YceOHRgyZAgA4OzZswCAvn37anz8li1bwsBAfdZD3bp1AQCpqamqstOnTwMA4uPjS0zWB/653hs3bsDLy0s1R6pr164l6rq7u8PV1VUn6yy1adOmRFlp11scf2lPbbZr1w5SKX/9UPXE/3OJarh+/frh4MGD5dZJT0+Hq6trqdscHR0BAGlpaaoyLy8v2NnZqRKmkJAQ2NjYoEWLFoiPj8eSJUtw6tQp2NnZ4cmTJ/Dz81Ptm5KSAlEUERsbq0roSpOVlVWizMHBodzrKIu5uTnefPNNvPnmm6prmTt3LtauXYvx48cjNjYWhoaGql/8L/MYvqWlZYmy4qSgsLBQVZacnAwA2L9/f4kJ988qvu7i+12ccD7PwcFBJ4mTptebnp4OALCzsytR38DAALa2tpUUIVHl4uRwIoKFhUWJiczFisstLCxUZYIgwNfXF9evX0d8fDyOHTsGX19fCIKALl26QCaTISQkRDVh/NmJ4cXH8fb2hlg0XaDUr9KWSXjVlbKLWVpaYvXq1XB3d0diYiIuX74MALCysgIAxMbGauU8zyq+7lWrVpV73WPGjFHFCBQ9hVea0tqruOeroKCgxLZnE9+qUHy9T548KbFNqVQiMTGxSuMh0hYmTkSE1q1bIycnRzVU9azQ0FAAQKtWrdTKi4fwfv75Z9y6dQs9e/YEAJiamqJdu3Y4evQoQkJCVElWMXNzczRp0gTXr19XG9qpaoIgwMTERK2seCjw0KFDWj9f+/btAfyzRMKLtGzZEgBw4sSJEtvu379f6pIEderUAVB64hcREaFxrNpQHP+pU6dKbDt79mypyR1RdcDEiYhUvRxz5sxRe0w8NjYWy5cvh1QqxTvvvKO2T3Ev0pIlS9R+Lv7+3LlzCAkJQfPmzWFjY6O279SpU5GdnY3333+/1CG5qKgorQxDff/992prVT1r9+7duHHjBqysrODl5QUAaNu2Ldq1a4fjx49j48aNJfZ5lZ6odu3aoX379ti+fTt27NhRYrtSqVQlqQDQpUsXeHp64s8//8TJkydV5aIoYu7cuWrDYsV8fHwAAFu2bFEr/+2339SOXRWGDh0KMzMz/PDDD4iKilKVFxQU4LPPPqvSWIi0iXOciAj+/v7YvXs3/vjjD7Ro0QKDBw9GVlYWdu7ciaSkJCxbtgz16tVT26dp06ZwcHDA48eP4eDggKZNm6q29ejRA19++SVSU1NVSdmz/u///g+nT5/G1q1b8ffff6N3795wdnbG48ePcePGDZw5cwb/+9//XnmBxAMHDmDixIlo0KABOnfuDGdnZ2RmZiIyMhInTpyAgYEB1q5dC7lcrtrnp59+Qvfu3fHvf/8b//3vf9GxY0fk5ubi6tWriIiIQFJSUoXj2b5
2022-02-19 10:24:54 +01:00
"text/plain": [
2023-11-15 06:29:39 +01:00
"<Figure size 640x480 with 1 Axes>"
2022-02-19 10:24:54 +01:00
]
},
2023-11-15 06:29:39 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"output_type": "display_data"
}
],
2017-06-05 18:48:03 +02:00
"source": [
2023-11-15 06:29:39 +01:00
"# extra code – this cell plots power scheduling with staircase=True or False\n",
2022-02-19 06:19:26 +01:00
"\n",
2023-11-15 06:29:39 +01:00
"initial_learning_rate = 0.01\n",
"decay_rate = 1.0\n",
"decay_steps = 10_000\n",
2020-09-25 14:25:17 +02:00
"\n",
2023-11-15 06:29:39 +01:00
"steps = np.arange(100_000)\n",
"lrs = initial_learning_rate / (1 + decay_rate * steps / decay_steps)\n",
"lrs2 = initial_learning_rate / (1 + decay_rate * np.floor(steps / decay_steps))\n",
2019-02-17 13:31:28 +01:00
"\n",
2023-11-15 06:29:39 +01:00
"plt.plot(steps, lrs, \"-\", label=\"staircase=False\")\n",
"plt.plot(steps, lrs2, \"-\", label=\"staircase=True\")\n",
"plt.axis([0, steps.max(), 0, 0.0105])\n",
"plt.xlabel(\"Step\")\n",
2019-02-17 13:31:28 +01:00
"plt.ylabel(\"Learning Rate\")\n",
"plt.title(\"Power Scheduling\", fontsize=14)\n",
2023-11-15 06:29:39 +01:00
"plt.legend()\n",
2019-02-17 13:31:28 +01:00
"plt.grid(True)\n",
"plt.show()"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"### Exponential Scheduling"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "markdown",
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2023-11-15 06:29:39 +01:00
"```python\n",
"learning_rate = initial_learning_rate * decay_rate ** (step / decay_steps)\n",
"```"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
"lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(\n",
" initial_learning_rate=0.01,\n",
" decay_steps=20_000,\n",
" decay_rate=0.1,\n",
" staircase=False\n",
")\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.6916 - accuracy: 0.7632 - val_loss: 0.5030 - val_accuracy: 0.8254\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4832 - accuracy: 0.8311 - val_loss: 0.4601 - val_accuracy: 0.8358\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4372 - accuracy: 0.8449 - val_loss: 0.4256 - val_accuracy: 0.8524\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.4131 - accuracy: 0.8546 - val_loss: 0.4037 - val_accuracy: 0.8568\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3952 - accuracy: 0.8596 - val_loss: 0.3950 - val_accuracy: 0.8598\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3825 - accuracy: 0.8640 - val_loss: 0.4010 - val_accuracy: 0.8584\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3739 - accuracy: 0.8667 - val_loss: 0.3851 - val_accuracy: 0.8650\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.3664 - accuracy: 0.8696 - val_loss: 0.3811 - val_accuracy: 0.8616\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.3606 - accuracy: 0.8720 - val_loss: 0.3749 - val_accuracy: 0.8662\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.3555 - accuracy: 0.8743 - val_loss: 0.3706 - val_accuracy: 0.8662\n"
]
}
],
"source": [
"history_exponential_scheduling = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAk4AAAHNCAYAAADolfQeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAAB99klEQVR4nO3dd1wU194G8GcLLB2kF2n23iB2xQpGTTTRRE1CNJZ7DdGgpFhiYklRY/TFEjUmRjRFzVVjiWjEAliwIsZuVBRFEBGlCizsvH8gGze74LIs7ALP9/Ph3nD27MxvZ0Qe55w5IxIEQQARERERPZfY0AUQERER1RQMTkRERERaYnAiIiIi0hKDExEREZGWGJyIiIiItMTgRERERKQlBiciIiIiLTE4EREREWmJwYmIiIhISwxORFRr+fj4wMfHp1LbmDNnDkQiEaKjo/VSU2Xo4/PoSl/HQdNniIiIgEgkQkRERKW2TVQdGJyIqsmtW7cgEonK/WrXrp2hy6xRxowZA5FIhFu3bhm6FCVBEPDzzz+jT58+cHBwgKmpKVxcXNC+fXuEhIQgJibG0CUSUSVIDV0AUV3TsGFDvPXWWxpfc3V1reZqarcDBw5U+z7Hjh2LiIgI1KtXD4MHD4a7uzvS09Nx7do1rF27FllZWQgICKj2uozZK6+8gs6dO8PNzc3QpRA9F4MTUTVr1KgR5syZY+gy6oSGDRtW6/4OHz6MiIgItGvXDjExMbCxsVF5/fHjx7h06VK11lQT2NrawtbW1tBlEGmFQ3VERuqLL76ASCTC5MmT1V4rnW8ydepUtbbo6Gh8//33aNmyJczMzODl5YUZM2YgPz9f437++OMP9O7dG7a2tjA3N0e7du0QHh6O4uJilX6lQ41jxozBzZs3MXz4cNSrVw+Wlpbo168fzp07p3H7aWlpmDp1Kho1agSZTAZHR0cMGzYMFy5cUOtbOv8lNzcXYWFh8PDwgEwmQ5s2bbBlyxa1vuvXrwcA+Pr6Koc7e/Xqpba9Z927dw+zZ89G586d4ezsDJlMBh8fH4SEhCAtLU3jZ9BWXFwcAGD06NFqoQkA7Ozs0LVrV7X2wsJCLF26FB07doS1tTWsrKzQokULhIWF4dGjR2r9tTk+z257yZIl6NChAywtLWFtbY0ePXpg586dGvvfuXMHo0aNgr29PaysrBAQEIDY2FiNfcubmxQdHQ2RSKTVPxLK2k7p+Xzw4AHGjh0LZ2dnmJubo3PnzmXOtfrrr78wcOBAWFtbw9bWFgMHDsSFCxeMcliXaiZecSIyUjNnzkRUVBRWrFiBwMBAvPTSSwCAo0eP4osvvkCbNm2wYMECtfctXrwY0dHRGDFiBAYPHozIyEgsWLAAZ8+exZ49eyASiZR9ly5diilTpsDe3h5vvPEGLC0tsWvXLkydOhWHDx/Gli1bVPoDJQGqU6dOaNGiBcaOHYsbN25gx44d6N27Ny5fvgwXFxdl3xs3bqBXr15ITk5GYGAghg4dirS0NGzduhV//vknDhw4gE6dOqlsXy6XIzAwEBkZGXj11VeRl5eHTZs24fXXX8fevXsRGBgIAJgyZQoiIiJw7tw5hIaGws7ODgCeO3k6NjYWixcvRt++fdGpUyeYmJjg7NmzWLVqFf7880/Ex8frfPXD3t4eAHD9+nWt35Ofn4+goCDExsaicePGeOeddyCTyfD3339j9erVePvtt1GvXj1lf22PDwAUFBRgwIABiI6ORvv27TFu3DjI5XLs3r0bQ4YMwfLlyzFp0iRl/5SUFHTp0gXJyckICgpChw4dcPnyZfTv3x+9e/fW6ZhU1uPHj9GtWzfY2NjgzTffRFpaGjZv3oygoCCcOXMGrVq1UvY9d+4cevTogby8PLz66qto1KgRzpw5g+7du6Nt27YGqZ9qIYGIqkViYqIAQGjYsKEwe/ZsjV979uxReU9SUpJQr149wdHRUbh3757w+PFjwcfHRzA3NxcuXryo0nf27NkCAMHMzEy4cOGCsl0ulwv9+/cXAAgbNmxQtt+4cUOQSqWCs7OzkJSUpGwvKCgQAgICBADCTz/9pFY/AGHBggUq+541a5YAQJg/f75Ke9euXQWpVCrs27dPpf3q1auCtbW10Lp1a5V2b29vAYAwZMgQoaCgQNm+f/9+AYAQFBSk0n/06NECACExMVHteJduz9vbW6Xt/v37QnZ2tlrf9evXCwCEL774QqW99LgeOnRI4z6elZSUJFhbWwtisVh4++23hd9//13l2Gry0UcfCQCE4OBgoaioSOW1x48fq9Ra0eMzc+ZMAYAwZ84cQaFQKNuzsrIEf39/wdTUVEhOTla2lx7Pfx+D7777Tnnunz0O69atEwAI69atU/tchw4dEgAIs2fPVmnXdE7K2k7pPkNCQoTi4mJl+w8//CAAEP773/+q9O/evbsAQPjf//6n0l56Dsv7s0KkLQYnomrybPAo6ys0NFTtfVu2bBEACP369RNGjhwpABBWrlyp1q/0l8OECRPUXjt16pQAQOjbt6+ybd68eQIAYeHChWr94+Li1PqX1u/r66vyS+zZ11599VVlW3x8vABAGDdunMbjERYWJgAQzp8/r2wrDQY3b95U6+/t7S3Y29urtOkSnMqiUCgEGxsboVevXirtFQlOgiAIe/fuFTw9PVXOq5OTk/D6668LBw4cUOlbVFQk2NjYCLa2tkJGRsZzt12R41NcXCzUq1dPaNSokUpoKrVz504BgLB8+XJBEEoCs5mZmeDs7Cw8efJEpW9xcbHQpEkTgwQnS0tLtaArl8sFqVQqdOjQQdl269YtAYDQvn17tVpyc3MFe3t7BifSCw7VEVWzoKAg7N27V+v+w4YNw/jx4/HDDz8AAIYMGYJ33323zP49evRQa/P394e5uTkSEhKUbWfPngUAlTlBpTp37qzWv1Tbtm0hFqtOj6xfvz6AkmGVUsePHwcApKamapzncuXKFeX/PzvcYmdnB19fX7X+9evXV84hqqxt27bhu+++Q3x8PB49eqQyn+vevXuV2nZQUBBu3ryJ6OhoxMbG4syZMzhy5Ah+++03/Pbbb5gxYwa++uorACWfPSsrC/369VMZjiuPtsfn6tWrePToEdzd3TF37ly1/g8ePFDWUNo/Pz8fffr0gZmZmUpfsViMrl274tq1a9odBD1q3LgxrKysVNqkUilcXFxU/ryVzrHTNIfMwsICbdu2xaFDh6q0VqobGJyIaoBXX31VGZzee++9cvs6OzuX2Z6cnKz8PisrCwBU5iSV17+Upvk/UmnJXyXPBpCMjAwAwO7du7F79+4y683NzX3u9kv3oVAoytyOthYvXowPP/wQTk5OCAwMRP369WFubg4ACA8PR0FBQaX3IZVK0a9fP/Tr1w8AUFRUhIiICLz77ruYP38+hg8fjg4dOih/8Xt4eGi9bW2PT+nxv3jxIi5evFjm9kqPf2ZmJoCy//yU9eekqpX3eZ/981b659nJyUljf0PVT7UPgxORkcvIyMB//vMfWFlZQS6XY9KkSYiPj4elpaXG/mXdGZaWlqbyS6j0rq/79+/D29tbY39Nd4Zpq/S9/56AbEhFRUX4/PPP4e7ujoSEBJVfsoIg4Ouvv66S/UqlUowfPx6HDx/Ghg0bcOjQIXTo0EE5oV1TQK2s0uM/bNiwMu+4e1bpn42y/vzcv39fra30ymNRUZHaa6VBrLqUft7SK2n/pql+Il1wOQIiIzdhwgTcvXsXK1aswIIFC3Dt2jWEhoaW2f/w4cNqbadPn8aTJ09UViZv3749AGi8rfvkyZNq/Suq9G45fQ2vaSKRSABAbemEsqSnpyMzMxOdO3dWuzJReoyq0r/DbtOmTWFjY4NTp05pXHagMpo3bw4bGxucPn0acrn8uf2bNm0KMzMznD59Wm3pCoVCgWPHjqm9p3R4UVPwKx0Kri6ld81pqjMvL6/M5TKIKorBiciIff/999i2bRtGjBiB0aNHIzQ0FEFBQVi7dm2ZVxF++uknlaGZoqIizJw5E0DJ+kKl3njjDUilUixZskRlXo9cLsf06dMBlDzSRFcdO3ZEp06
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# extra code – this cell plots exponential scheduling\n",
"\n",
"initial_learning_rate = 0.01\n",
"decay_rate = 0.1\n",
"decay_steps = 20_000\n",
"\n",
"steps = np.arange(100_000)\n",
"lrs = initial_learning_rate * decay_rate ** (steps / decay_steps)\n",
"lrs2 = initial_learning_rate * decay_rate ** np.floor(steps / decay_steps)\n",
"\n",
"plt.plot(steps, lrs, \"-\", label=\"staircase=False\")\n",
"plt.plot(steps, lrs2, \"-\", label=\"staircase=True\")\n",
"plt.axis([0, steps.max(), 0, 0.0105])\n",
"plt.xlabel(\"Step\")\n",
"plt.ylabel(\"Learning Rate\")\n",
"plt.title(\"Exponential Scheduling\", fontsize=14)\n",
"plt.legend()\n",
"plt.grid(True)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Keras also provides a `LearningRateScheduler` callback class that lets you define your own scheduling function. Let's see how you could use it to implement exponential decay. Note that in this case the learning rate only changes at each epoch, not at each step:"
]
},
{
"cell_type": "code",
"execution_count": 72,
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
"source": [
2019-02-17 13:31:28 +01:00
"def exponential_decay_fn(epoch):\n",
2022-02-19 06:19:26 +01:00
" return 0.01 * 0.1 ** (epoch / 20)"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 73,
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-11-03 13:43:56 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"def exponential_decay(lr0, s):\n",
" def exponential_decay_fn(epoch):\n",
2022-02-19 06:19:26 +01:00
" return lr0 * 0.1 ** (epoch / s)\n",
2019-02-17 13:31:28 +01:00
" return exponential_decay_fn\n",
2017-06-05 18:48:03 +02:00
"\n",
2019-02-17 13:31:28 +01:00
"exponential_decay_fn = exponential_decay(lr0=0.01, s=20)"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 74,
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – build and compile a model for Fashion MNIST\n",
"\n",
"tf.random.set_seed(42)\n",
"model = build_model()\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 75,
"metadata": {
"scrolled": true
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.6905 - accuracy: 0.7643 - val_loss: 0.4814 - val_accuracy: 0.8330 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 2/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4672 - accuracy: 0.8357 - val_loss: 0.4488 - val_accuracy: 0.8374 - lr: 0.0089\n",
2022-02-19 10:24:54 +01:00
"Epoch 3/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.4212 - accuracy: 0.8503 - val_loss: 0.4118 - val_accuracy: 0.8532 - lr: 0.0079\n",
2022-02-19 10:24:54 +01:00
"Epoch 4/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3975 - accuracy: 0.8593 - val_loss: 0.3884 - val_accuracy: 0.8636 - lr: 0.0071\n",
2022-02-19 10:24:54 +01:00
"Epoch 5/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3781 - accuracy: 0.8657 - val_loss: 0.3772 - val_accuracy: 0.8642 - lr: 0.0063\n",
2022-02-19 10:24:54 +01:00
"Epoch 6/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3634 - accuracy: 0.8710 - val_loss: 0.3779 - val_accuracy: 0.8662 - lr: 0.0056\n",
2022-02-19 10:24:54 +01:00
"Epoch 7/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3530 - accuracy: 0.8744 - val_loss: 0.3674 - val_accuracy: 0.8652 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 8/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3437 - accuracy: 0.8771 - val_loss: 0.3616 - val_accuracy: 0.8686 - lr: 0.0045\n",
2022-02-19 10:24:54 +01:00
"Epoch 9/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3359 - accuracy: 0.8801 - val_loss: 0.3509 - val_accuracy: 0.8728 - lr: 0.0040\n",
2022-02-19 10:24:54 +01:00
"Epoch 10/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3290 - accuracy: 0.8826 - val_loss: 0.3504 - val_accuracy: 0.8720 - lr: 0.0035\n",
2022-02-19 10:24:54 +01:00
"Epoch 11/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3236 - accuracy: 0.8844 - val_loss: 0.3458 - val_accuracy: 0.8736 - lr: 0.0032\n",
2022-02-19 10:24:54 +01:00
"Epoch 12/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3186 - accuracy: 0.8869 - val_loss: 0.3459 - val_accuracy: 0.8752 - lr: 0.0028\n",
2022-02-19 10:24:54 +01:00
"Epoch 13/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3147 - accuracy: 0.8878 - val_loss: 0.3359 - val_accuracy: 0.8770 - lr: 0.0025\n",
2022-02-19 10:24:54 +01:00
"Epoch 14/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.3109 - accuracy: 0.8890 - val_loss: 0.3404 - val_accuracy: 0.8762 - lr: 0.0022\n",
2022-02-19 10:24:54 +01:00
"Epoch 15/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3076 - accuracy: 0.8902 - val_loss: 0.3398 - val_accuracy: 0.8790 - lr: 0.0020\n",
2022-02-19 10:24:54 +01:00
"Epoch 16/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3043 - accuracy: 0.8915 - val_loss: 0.3331 - val_accuracy: 0.8784 - lr: 0.0018\n",
2022-02-19 10:24:54 +01:00
"Epoch 17/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3020 - accuracy: 0.8924 - val_loss: 0.3363 - val_accuracy: 0.8774 - lr: 0.0016\n",
2022-02-19 10:24:54 +01:00
"Epoch 18/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2998 - accuracy: 0.8927 - val_loss: 0.3356 - val_accuracy: 0.8778 - lr: 0.0014\n",
2022-02-19 10:24:54 +01:00
"Epoch 19/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2979 - accuracy: 0.8935 - val_loss: 0.3309 - val_accuracy: 0.8796 - lr: 0.0013\n",
2022-02-19 10:24:54 +01:00
"Epoch 20/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2961 - accuracy: 0.8940 - val_loss: 0.3308 - val_accuracy: 0.8782 - lr: 0.0011\n",
2022-02-19 10:24:54 +01:00
"Epoch 21/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2944 - accuracy: 0.8951 - val_loss: 0.3286 - val_accuracy: 0.8802 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 22/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2930 - accuracy: 0.8953 - val_loss: 0.3313 - val_accuracy: 0.8804 - lr: 8.9125e-04\n",
2022-02-19 10:24:54 +01:00
"Epoch 23/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2916 - accuracy: 0.8957 - val_loss: 0.3285 - val_accuracy: 0.8796 - lr: 7.9433e-04\n",
2022-02-19 10:24:54 +01:00
"Epoch 24/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2904 - accuracy: 0.8961 - val_loss: 0.3313 - val_accuracy: 0.8786 - lr: 7.0795e-04\n",
2022-02-19 10:24:54 +01:00
"Epoch 25/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2896 - accuracy: 0.8962 - val_loss: 0.3296 - val_accuracy: 0.8812 - lr: 6.3096e-04\n"
2022-02-19 10:24:54 +01:00
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2023-11-15 09:23:37 +01:00
"n_epochs = 20\n",
"\n",
2021-10-17 04:04:08 +02:00
"lr_scheduler = tf.keras.callbacks.LearningRateScheduler(exponential_decay_fn)\n",
2022-02-19 06:19:26 +01:00
"history = model.fit(X_train, y_train, epochs=n_epochs,\n",
" validation_data=(X_valid, y_valid),\n",
2019-02-17 13:31:28 +01:00
" callbacks=[lr_scheduler])"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2023-11-15 06:29:39 +01:00
"Alternatively, the schedule function can take the current learning rate as a second argument:"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 76,
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-11-03 13:43:56 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"def exponential_decay_fn(epoch, lr):\n",
2022-02-19 06:19:26 +01:00
" return lr * 0.1 ** (1 / 20)"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2023-11-15 06:29:39 +01:00
"**Extra material**: if you want to use a custom scheduling function that updates the learning rate at each iteration rather than at each epoch, you can write your own callback class like this:"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 77,
2018-03-24 22:50:29 +01:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"K = tf.keras.backend\n",
2017-06-05 18:48:03 +02:00
"\n",
2021-10-17 04:04:08 +02:00
"class ExponentialDecay(tf.keras.callbacks.Callback):\n",
2022-02-19 06:19:26 +01:00
" def __init__(self, n_steps=40_000):\n",
2019-02-17 13:31:28 +01:00
" super().__init__()\n",
2022-02-19 06:19:26 +01:00
" self.n_steps = n_steps\n",
2017-06-05 18:48:03 +02:00
"\n",
2019-02-17 13:31:28 +01:00
" def on_batch_begin(self, batch, logs=None):\n",
" # Note: the `batch` argument is reset at each epoch\n",
2021-08-31 10:54:35 +02:00
" lr = K.get_value(self.model.optimizer.learning_rate)\n",
2022-02-19 06:19:26 +01:00
" new_learning_rate = lr * 0.1 ** (1 / self.n_steps)\n",
" K.set_value(self.model.optimizer.learning_rate, new_learning_rate)\n",
2017-06-05 18:48:03 +02:00
"\n",
2019-02-17 13:31:28 +01:00
" def on_epoch_end(self, epoch, logs=None):\n",
" logs = logs or {}\n",
2022-02-19 06:19:26 +01:00
" logs['lr'] = K.get_value(self.model.optimizer.learning_rate)"
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 78,
2022-02-19 06:19:26 +01:00
"metadata": {},
"outputs": [],
"source": [
2019-02-17 13:31:28 +01:00
"lr0 = 0.01\n",
2022-02-19 06:19:26 +01:00
"model = build_model()\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=lr0)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 79,
2018-03-24 22:50:29 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.6947 - accuracy: 0.7635 - val_loss: 0.5014 - val_accuracy: 0.8224 - lr: 0.0091\n",
2022-02-19 10:24:54 +01:00
"Epoch 2/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4718 - accuracy: 0.8349 - val_loss: 0.4530 - val_accuracy: 0.8382 - lr: 0.0083\n",
2022-02-19 10:24:54 +01:00
"Epoch 3/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4255 - accuracy: 0.8500 - val_loss: 0.4216 - val_accuracy: 0.8526 - lr: 0.0076\n",
2022-02-19 10:24:54 +01:00
"Epoch 4/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4025 - accuracy: 0.8587 - val_loss: 0.3954 - val_accuracy: 0.8618 - lr: 0.0069\n",
2022-02-19 10:24:54 +01:00
"Epoch 5/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3840 - accuracy: 0.8643 - val_loss: 0.3847 - val_accuracy: 0.8612 - lr: 0.0063\n",
2022-02-19 10:24:54 +01:00
"Epoch 6/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3696 - accuracy: 0.8689 - val_loss: 0.3908 - val_accuracy: 0.8558 - lr: 0.0058\n",
2022-02-19 10:24:54 +01:00
"Epoch 7/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.3590 - accuracy: 0.8722 - val_loss: 0.3744 - val_accuracy: 0.8670 - lr: 0.0052\n",
2022-02-19 10:24:54 +01:00
"Epoch 8/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3498 - accuracy: 0.8749 - val_loss: 0.3754 - val_accuracy: 0.8640 - lr: 0.0048\n",
2022-02-19 10:24:54 +01:00
"Epoch 9/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3415 - accuracy: 0.8783 - val_loss: 0.3592 - val_accuracy: 0.8700 - lr: 0.0044\n",
2022-02-19 10:24:54 +01:00
"Epoch 10/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3340 - accuracy: 0.8803 - val_loss: 0.3575 - val_accuracy: 0.8724 - lr: 0.0040\n",
2022-02-19 10:24:54 +01:00
"Epoch 11/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3281 - accuracy: 0.8833 - val_loss: 0.3573 - val_accuracy: 0.8718 - lr: 0.0036\n",
2022-02-19 10:24:54 +01:00
"Epoch 12/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3228 - accuracy: 0.8847 - val_loss: 0.3579 - val_accuracy: 0.8688 - lr: 0.0033\n",
2022-02-19 10:24:54 +01:00
"Epoch 13/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3182 - accuracy: 0.8865 - val_loss: 0.3421 - val_accuracy: 0.8756 - lr: 0.0030\n",
2022-02-19 10:24:54 +01:00
"Epoch 14/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3138 - accuracy: 0.8882 - val_loss: 0.3468 - val_accuracy: 0.8766 - lr: 0.0028\n",
2022-02-19 10:24:54 +01:00
"Epoch 15/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3101 - accuracy: 0.8889 - val_loss: 0.3471 - val_accuracy: 0.8766 - lr: 0.0025\n",
2022-02-19 10:24:54 +01:00
"Epoch 16/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3064 - accuracy: 0.8898 - val_loss: 0.3386 - val_accuracy: 0.8752 - lr: 0.0023\n",
2022-02-19 10:24:54 +01:00
"Epoch 17/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3035 - accuracy: 0.8903 - val_loss: 0.3417 - val_accuracy: 0.8758 - lr: 0.0021\n",
2022-02-19 10:24:54 +01:00
"Epoch 18/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3005 - accuracy: 0.8919 - val_loss: 0.3398 - val_accuracy: 0.8768 - lr: 0.0019\n",
2022-02-19 10:24:54 +01:00
"Epoch 19/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2983 - accuracy: 0.8929 - val_loss: 0.3357 - val_accuracy: 0.8766 - lr: 0.0017\n",
2022-02-19 10:24:54 +01:00
"Epoch 20/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2959 - accuracy: 0.8939 - val_loss: 0.3370 - val_accuracy: 0.8752 - lr: 0.0016\n",
2022-02-19 10:24:54 +01:00
"Epoch 21/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2940 - accuracy: 0.8938 - val_loss: 0.3346 - val_accuracy: 0.8782 - lr: 0.0014\n",
2022-02-19 10:24:54 +01:00
"Epoch 22/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2917 - accuracy: 0.8949 - val_loss: 0.3361 - val_accuracy: 0.8766 - lr: 0.0013\n",
2022-02-19 10:24:54 +01:00
"Epoch 23/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2902 - accuracy: 0.8955 - val_loss: 0.3349 - val_accuracy: 0.8796 - lr: 0.0012\n",
2022-02-19 10:24:54 +01:00
"Epoch 24/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2884 - accuracy: 0.8959 - val_loss: 0.3364 - val_accuracy: 0.8796 - lr: 0.0011\n",
2022-02-19 10:24:54 +01:00
"Epoch 25/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2871 - accuracy: 0.8969 - val_loss: 0.3352 - val_accuracy: 0.8802 - lr: 1.0000e-03\n"
2022-02-19 10:24:54 +01:00
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2023-11-15 09:23:37 +01:00
"import math\n",
"\n",
2022-02-19 06:19:26 +01:00
"batch_size = 32\n",
2023-11-15 09:23:37 +01:00
"n_steps = n_epochs * math.ceil(len(X_train) / batch_size)\n",
2022-02-19 06:19:26 +01:00
"exp_decay = ExponentialDecay(n_steps)\n",
"history = model.fit(X_train, y_train, epochs=n_epochs,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[exp_decay])"
2017-06-05 18:48:03 +02:00
]
},
2023-11-15 06:29:39 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Piecewise Constant Scheduling"
]
},
2017-06-05 18:48:03 +02:00
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 80,
"metadata": {},
"outputs": [],
"source": [
"lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(\n",
" boundaries=[50_000, 80_000],\n",
" values=[0.01, 0.005, 0.001]\n",
")\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.6942 - accuracy: 0.7617 - val_loss: 0.4892 - val_accuracy: 0.8318\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4751 - accuracy: 0.8340 - val_loss: 0.4603 - val_accuracy: 0.8346\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4280 - accuracy: 0.8500 - val_loss: 0.4245 - val_accuracy: 0.8542\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4035 - accuracy: 0.8581 - val_loss: 0.3867 - val_accuracy: 0.8626\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3828 - accuracy: 0.8650 - val_loss: 0.3827 - val_accuracy: 0.8634\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3665 - accuracy: 0.8700 - val_loss: 0.3880 - val_accuracy: 0.8608\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.3539 - accuracy: 0.8730 - val_loss: 0.3669 - val_accuracy: 0.8688\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3423 - accuracy: 0.8773 - val_loss: 0.3583 - val_accuracy: 0.8708\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3322 - accuracy: 0.8807 - val_loss: 0.3447 - val_accuracy: 0.8758\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3218 - accuracy: 0.8832 - val_loss: 0.3488 - val_accuracy: 0.8716\n"
]
}
],
"source": [
"history_piecewise_scheduling = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "code",
"execution_count": 82,
2019-02-17 13:31:28 +01:00
"metadata": {
"scrolled": true
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
2023-11-15 06:29:39 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAk4AAAHNCAYAAADolfQeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAABTFUlEQVR4nO3deVxU1f8/8NfIMoAKKiiIsrokRLmAKcqikbgvpUlqpLn0JVJR+pjikssnJXP5kCmSRi4tSqWlFZqoiBtuLGauH01FCURcQCVhgPP7wx/34zQDXkZkBng9Hw8fOWfec++ZOYO8uvfccxVCCAEiIiIieqJ6+u4AERERUU3B4EREREQkE4MTERERkUwMTkREREQyMTgRERERycTgRERERCQTgxMRERGRTAxORERERDIxOBERERHJxOBEBMDZ2RnOzs767sYz06NHDygUCn13g2qZefPmQaFQYN++fdW+7ytXrkChUGDMmDFPtZ3y3oNCoUCPHj2eattUOzE4Ua1U9o/q439MTU3h4OCAkSNH4vfff9d3F+u0CxcuYNKkSXj++edhaWkJpVIJR0dHDBs2DFu2bEFpaam+u6ihqn5R62r9+vVQKBRYv359pV/7xx9/YPTo0XB2doZSqYSVlRVat26N1157DZ9++il45y0i+Yz13QGiZ6lVq1Z48803AQD379/HkSNHsGnTJmzduhV79+5Ft27dAAB79uzRZzefuY0bN6KgoEDf3QAALFu2DNOnT0dpaSl8fHzQq1cvWFhY4Nq1a9i9eze2bNmCsWPHIjY2Vt9drRUSEhIwYMAAFBcXIyAgAK+++ioA4M8//8ShQ4fw448/4r333oOxMX8dPO7s2bOwsLDQdzfIAPEnhWq11q1bY968eWpts2fPxsKFCzFr1iwkJiYCeBSwajNHR0d9dwEAsGbNGvzrX/+Cs7MztmzZgk6dOqk9X1xcjA0bNuDAgQN66mHt8+6776KkpAS7d+9Gz5491Z4TQmDXrl0wMjLSU+8MV7t27fTdBTJQPFVHdc6kSZMAAMePH5faypvjJITAl19+ie7du8PS0hIWFhbw8vLCl19+qXXbQghs2LABfn5+aNSoESwsLNCmTRuEhIQgIyNDrfbevXuYO3cunn/+eZibm6NRo0bo06cPDh48qFY3ZcoUKBQKpKenq7X3798fCoUC48ePV2vfsWMHFAoFFi9eLLVpm+NUWlqKL774Ai+99BKaNGkCCwsLODs7Y8iQIdi/f7/Ge9u/fz8GDhwIGxsbKJVKtGnTBrNnz5Z9JCsvLw/Tpk2Dqakpfv31V43QBADGxsYYN24cPv/8c7X2goICzJs3D+3atYOZmRmaNGmC/v374/DhwxrbeHzOynfffYdOnTrB3NwczZs3x+TJk/H3339rvGbLli3w9/dHs2bNYGZmBgcHB/Tp0wc//fQTgEenyVxcXAAAGzZsUDsFXDY35q+//sLcuXPRtWtXNGvWDEqlEs7OzggNDUVOTo7GPseMGQOFQoErV64gOjoabm5uMDMzg5OTE+bPn692unLMmDF4++23AQBvv/222v4rkpOTg0uXLsHDw0MjNAGP5vH07t1b63YOHDiAV199Fba2tlAqlXBwcMBrr72m8f0sI/ezBir3XSopKcHixYvRunVrmJmZoXXr1oiMjCz3dG5Fc5MqM5dR23YqM2ZlCgoK8MEHH8DBwQFmZmbw8PDA2rVrsW/fPigUCo3/sSPDxyNOVOfInSQthMCbb76Jb7/9Fm3btsXIkSNhamqKhIQEjBs3DmfOnMHSpUvV6keMGIG4uDi0aNECI0aMgKWlJa5cuYK4uDj06dNHOvJz+/Zt+Pn54fTp0/D19UXv3r2Rl5eHbdu2oWfPnvj+++8xZMgQAEDPnj3x6aefIjExER06dADw6JdJ2S+wsqNmZcp+kWv7Rfm4iIgIfPLJJ2jVqhVGjhyJhg0bIjMzEwcOHMDevXvh5+cn1cbExCA0NBSNGzfGwIED0bRpUxw/fhwLFy5EYmIiEhMTYWpqWuH+vv/+e+Tn52PkyJFwd3evsFapVEp/LywsREBAAI4cOYJOnTphypQpyMnJQVxcHHbt2oW4uDi89tprGttYtWoVduzYgcGDB6NHjx7YuXMnPvvsM9y6dQvffPONVLd69WqEhoaiefPmePXVV2FtbY2srCwcO3YMP/30E4YMGYIOHTogLCwMn376Kdq3by+NDQDpF/H+/fuxbNkyBAQEoEuXLjAxMUFaWhpWr16N3377DampqbCystLo57Rp07Bv3z4MGDAAgYGB+OmnnzBv3jwUFRVh4cKFAIAhQ4bg7t272LZtGwYPHix9D57EysoKRkZGyMrKwoMHD1C/fn1Zr1u1ahUmTZoEc3NzvPrqq3B0dERmZiYOHjyIH374AT4+Pjp91kDlv0vvvPMOvvzyS7i4uOC9997Dw4cPsXz5cq2hubrIGTPg0c/pgAEDkJiYiPbt22PkyJG4ffs23n//fU48r8kEUS10+fJlAUD07t1b47lZs2YJAKJHjx5Sm5OTk3ByclKrW7NmjQAgxo0bJ1QqldReWFgoBg4cKACIEydOSO2rVq0SAERAQIAoKChQ21ZBQYG4deuW9HjkyJECgPjyyy/V6rKzs4WDg4No2rSp+Pvvv4UQQty5c0fUq1dPDBw4UKo7evSotC8A4urVq9JznTt3Fg0bNhTFxcVSm7+/v/jnj3uTJk1EixYtxIMHD9TaS0tL1fp6+vRpYWxsLDp27KjWLoQQkZGRAoBYunSpeJIxY8YIAOKLL754Yu3jFixYIACIUaNGidLSUqn95MmTQqlUisaNG4v8/Hypfe7cuQKAsLKyEufOnZPaCwoKRNu2bYVCoRCZmZlSe6dOnYSpqanIycnR2Hdubq7097Lv1OjRo7X288aNG+LevXsa7Rs2bBAAxEcffaTWPnr0aAFAuLi4iL/++ktqv3nzpmjUqJFo2LChKCwslNrXrVsnAIh169Zp3X95hgwZIgCIDh06iOjoaJGeni6KiorKrf/999+FkZGRsLe3F5cvX1Z7rrS0VO2zq+xnXdnvUmJiogAg2rdvL+7fvy+1X79+XdjY2GgdDwDC399f63vT9nNe9h4SExOfuJ3KjtkXX3whAIhBgwaJkpISqf3s2bPCzMxMABBz587V2lcyXDxVR7XaxYsXMW/ePMybNw//+te/4OPjg4ULF8LMzAyLFi2q8LUrV65E/fr1sXLlSrWJs6amptL/VW7atElqX7VqFYyMjLB69WqYm5urbcvc3BxNmjQBAOTm5iIuLg4BAQHS6Zcytra2mDZtGm7evIndu3cDABo1aoT27dtj//79KCkpAfDoKNPjh/n37t0LAMjPz0dqaip8fX1lzVsxNTXVmBSsUCikvgLA559/juLiYqxYsUKtHQA++OADNG3aVO1zKE92djYAoGXLlk+sfdz69ethYmKCjz/+WO1o4YsvvogxY8bgzp072LZtm8brwsLC8Nxzz0mPzc3NMWLECAghkJKSolZrYmICExMTjW1YW1vL7mezZs3QoEEDjfbg4GBYWlpK4/lPc+bMQfPmzaXHNjY2GDx4MO7du4fz58/L3n951q5di/79+yM9PR2hoaHo0KEDGjRogO7du2PFihUap9NiYmJQUlKCjz76SOO0lkKhgL29vcY+5H7Wlf0ubdy4EQDw4Ycfqh0ta9GiBcLCwir/YVQRuWP29ddfAwD+/e9/o169//26bdeuHUaPHl19HaYqxVN1VKtdunQJ8+fPB/Dol6OtrS1GjhyJGTNm4IUXXij3dQUFBTh16hTs7e3x8ccfazyvUqkAAOfOnQMAPHjwAGfOnEHr1q3Rpk2bCvt0/PhxlJSU4OHDh1rnN/z3v/+Vtj1gwAAAj067paWlITU1FZ07d5YO/fv4+MDOzg6JiYkYM2aMFK6edJoOAIYPH46YmBh4eHggKCgI/v7+8Pb21jidc+TIEQDAzp07tf7yNzExkT6Hqpafn48///wTbm5uWgNXjx498PnnnyM9PV26erKMtjlUZdu4e/eu1DZ8+HDMmDEDHh4eeOONN9CjRw/
2022-02-19 10:24:54 +01:00
"text/plain": [
2023-11-15 06:29:39 +01:00
"<Figure size 640x480 with 1 Axes>"
2022-02-19 10:24:54 +01:00
]
},
2023-11-15 06:29:39 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"output_type": "display_data"
}
],
2017-06-05 18:48:03 +02:00
"source": [
2023-11-15 06:29:39 +01:00
"# extra code – this cell plots piecewise constant scheduling\n",
"\n",
"boundaries = [50_000, 80_000]\n",
"values = [0.01, 0.005, 0.001]\n",
"\n",
"steps = np.arange(100_000)\n",
"\n",
"lrs = np.full(len(steps), values[0])\n",
"for boundary, value in zip(boundaries, values[1:]):\n",
" lrs[boundary:] = value\n",
2022-02-19 06:19:26 +01:00
"\n",
2023-11-15 06:29:39 +01:00
"plt.plot(steps, lrs, \"-\")\n",
"plt.axis([0, steps.max(), 0, 0.0105])\n",
"plt.xlabel(\"Step\")\n",
2019-02-17 13:31:28 +01:00
"plt.ylabel(\"Learning Rate\")\n",
2023-11-15 06:29:39 +01:00
"plt.title(\"Piecewise Constant Scheduling\", fontsize=14)\n",
2019-02-17 13:31:28 +01:00
"plt.grid(True)\n",
"plt.show()"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2023-11-15 06:29:39 +01:00
"Just like we did with exponential scheduling, we could also implement piecewise constant scheduling manually:"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 83,
2017-06-21 15:35:47 +02:00
"metadata": {},
2019-02-17 13:31:28 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"def piecewise_constant_fn(epoch):\n",
" if epoch < 5:\n",
" return 0.01\n",
" elif epoch < 15:\n",
" return 0.005\n",
" else:\n",
" return 0.001"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 84,
2018-05-08 20:21:23 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – this cell demonstrates a more general way to define\n",
"# piecewise constant scheduling.\n",
"\n",
2019-02-17 13:31:28 +01:00
"def piecewise_constant(boundaries, values):\n",
" boundaries = np.array([0] + boundaries)\n",
" values = np.array(values)\n",
" def piecewise_constant_fn(epoch):\n",
2022-02-19 06:19:26 +01:00
" return values[(boundaries > epoch).argmax() - 1]\n",
2019-02-17 13:31:28 +01:00
" return piecewise_constant_fn\n",
"\n",
"piecewise_constant_fn = piecewise_constant([5, 15], [0.01, 0.005, 0.001])"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 85,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 5s 2ms/step - loss: 0.5433 - accuracy: 0.8087 - val_loss: 0.4586 - val_accuracy: 0.8288 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 2/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4487 - accuracy: 0.8439 - val_loss: 0.4608 - val_accuracy: 0.8350 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 3/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4263 - accuracy: 0.8502 - val_loss: 0.4234 - val_accuracy: 0.8568 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 4/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4241 - accuracy: 0.8537 - val_loss: 0.4359 - val_accuracy: 0.8490 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 5/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4080 - accuracy: 0.8584 - val_loss: 0.4165 - val_accuracy: 0.8560 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 6/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3544 - accuracy: 0.8738 - val_loss: 0.3830 - val_accuracy: 0.8662 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 7/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3464 - accuracy: 0.8761 - val_loss: 0.4026 - val_accuracy: 0.8652 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 8/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3426 - accuracy: 0.8772 - val_loss: 0.4212 - val_accuracy: 0.8544 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 9/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3417 - accuracy: 0.8793 - val_loss: 0.4116 - val_accuracy: 0.8612 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 10/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3339 - accuracy: 0.8804 - val_loss: 0.4090 - val_accuracy: 0.8618 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 11/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.3309 - accuracy: 0.8819 - val_loss: 0.4033 - val_accuracy: 0.8746 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 12/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 5s 3ms/step - loss: 0.3270 - accuracy: 0.8826 - val_loss: 0.4518 - val_accuracy: 0.8630 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 13/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.3270 - accuracy: 0.8837 - val_loss: 0.3714 - val_accuracy: 0.8674 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 14/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3247 - accuracy: 0.8844 - val_loss: 0.4026 - val_accuracy: 0.8652 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 15/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.3204 - accuracy: 0.8852 - val_loss: 0.3993 - val_accuracy: 0.8724 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 16/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2859 - accuracy: 0.8963 - val_loss: 0.3930 - val_accuracy: 0.8736 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 17/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2781 - accuracy: 0.8978 - val_loss: 0.4021 - val_accuracy: 0.8714 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 18/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2743 - accuracy: 0.8984 - val_loss: 0.3955 - val_accuracy: 0.8754 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 19/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2704 - accuracy: 0.8999 - val_loss: 0.4015 - val_accuracy: 0.8756 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 20/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2683 - accuracy: 0.9015 - val_loss: 0.4161 - val_accuracy: 0.8756 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 21/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2655 - accuracy: 0.9020 - val_loss: 0.4207 - val_accuracy: 0.8740 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 22/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.2646 - accuracy: 0.9020 - val_loss: 0.4497 - val_accuracy: 0.8746 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 23/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2626 - accuracy: 0.9032 - val_loss: 0.4429 - val_accuracy: 0.8762 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 24/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.2608 - accuracy: 0.9038 - val_loss: 0.4566 - val_accuracy: 0.8748 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 25/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.2587 - accuracy: 0.9038 - val_loss: 0.4726 - val_accuracy: 0.8770 - lr: 0.0010\n"
2022-02-19 10:24:54 +01:00
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – use a tf.keras.callbacks.LearningRateScheduler like earlier\n",
2019-02-17 13:31:28 +01:00
"\n",
"n_epochs = 25\n",
2022-02-19 06:19:26 +01:00
"\n",
"lr_scheduler = tf.keras.callbacks.LearningRateScheduler(piecewise_constant_fn)\n",
"\n",
"model = build_model()\n",
"optimizer = tf.keras.optimizers.Nadam(learning_rate=lr0)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, y_train, epochs=n_epochs,\n",
" validation_data=(X_valid, y_valid),\n",
2019-02-17 13:31:28 +01:00
" callbacks=[lr_scheduler])"
2017-06-05 18:48:03 +02:00
]
},
2023-11-15 06:29:39 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We've looked at `InverseTimeDecay`, `ExponentialDecay`, and `PiecewiseConstantDecay`. A few more schedulers are available in `tf.keras.optimizers.schedules`, here is the full list:"
]
},
2017-06-05 18:48:03 +02:00
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 86,
2018-05-08 20:21:23 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
2023-11-15 06:29:39 +01:00
"name": "stdout",
"output_type": "stream",
"text": [
"• CosineDecay – A LearningRateSchedule that uses a cosine decay with optional warmup.\n",
"• CosineDecayRestarts – A LearningRateSchedule that uses a cosine decay schedule with restarts.\n",
"• ExponentialDecay – A LearningRateSchedule that uses an exponential decay schedule.\n",
"• InverseTimeDecay – A LearningRateSchedule that uses an inverse time decay schedule.\n",
"• LearningRateSchedule – The learning rate schedule base class.\n",
"• PiecewiseConstantDecay – A LearningRateSchedule that uses a piecewise constant decay schedule.\n",
"• PolynomialDecay – A LearningRateSchedule that uses a polynomial decay schedule.\n"
]
2022-02-19 10:24:54 +01:00
}
],
2017-06-05 18:48:03 +02:00
"source": [
2023-11-15 06:29:39 +01:00
"for name in sorted(dir(tf.keras.optimizers.schedules)):\n",
" if name[0] == name[0].lower(): # must start with capital letter\n",
" continue\n",
" scheduler_class = getattr(tf.keras.optimizers.schedules, name)\n",
" print(f\"• {name} – {scheduler_class.__doc__.splitlines()[0]}\")"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"### Performance Scheduling"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 87,
2018-05-08 20:21:23 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – build and compile the model\n",
"\n",
"model = build_model()\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=lr0)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 88,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.6807 - accuracy: 0.7679 - val_loss: 0.4814 - val_accuracy: 0.8310 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 2/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4659 - accuracy: 0.8343 - val_loss: 0.4615 - val_accuracy: 0.8306 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 3/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.4201 - accuracy: 0.8505 - val_loss: 0.4199 - val_accuracy: 0.8490 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 4/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3957 - accuracy: 0.8590 - val_loss: 0.3845 - val_accuracy: 0.8614 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 5/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3754 - accuracy: 0.8658 - val_loss: 0.3742 - val_accuracy: 0.8614 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 6/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3588 - accuracy: 0.8709 - val_loss: 0.3853 - val_accuracy: 0.8628 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 7/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3469 - accuracy: 0.8740 - val_loss: 0.3627 - val_accuracy: 0.8690 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 8/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.3346 - accuracy: 0.8785 - val_loss: 0.3574 - val_accuracy: 0.8680 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 9/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.3244 - accuracy: 0.8828 - val_loss: 0.3410 - val_accuracy: 0.8748 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 10/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3149 - accuracy: 0.8850 - val_loss: 0.3410 - val_accuracy: 0.8720 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 11/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3074 - accuracy: 0.8879 - val_loss: 0.3629 - val_accuracy: 0.8678 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 12/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2990 - accuracy: 0.8920 - val_loss: 0.3379 - val_accuracy: 0.8746 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 13/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.2929 - accuracy: 0.8938 - val_loss: 0.3223 - val_accuracy: 0.8808 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 14/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2867 - accuracy: 0.8947 - val_loss: 0.3405 - val_accuracy: 0.8754 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 15/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2807 - accuracy: 0.8972 - val_loss: 0.3480 - val_accuracy: 0.8730 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 16/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.2743 - accuracy: 0.8998 - val_loss: 0.3350 - val_accuracy: 0.8766 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 17/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2694 - accuracy: 0.9019 - val_loss: 0.3421 - val_accuracy: 0.8764 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 18/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2631 - accuracy: 0.9032 - val_loss: 0.3360 - val_accuracy: 0.8772 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 19/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2445 - accuracy: 0.9110 - val_loss: 0.3162 - val_accuracy: 0.8874 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 20/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2410 - accuracy: 0.9131 - val_loss: 0.3221 - val_accuracy: 0.8812 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 21/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2380 - accuracy: 0.9137 - val_loss: 0.3166 - val_accuracy: 0.8828 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 22/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.2351 - accuracy: 0.9148 - val_loss: 0.3146 - val_accuracy: 0.8854 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 23/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2330 - accuracy: 0.9160 - val_loss: 0.3191 - val_accuracy: 0.8836 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 24/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.2300 - accuracy: 0.9161 - val_loss: 0.3175 - val_accuracy: 0.8878 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 25/25\n",
2023-11-15 06:29:39 +01:00
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2276 - accuracy: 0.9174 - val_loss: 0.3205 - val_accuracy: 0.8868 - lr: 0.0050\n"
2022-02-19 10:24:54 +01:00
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2021-10-17 04:04:08 +02:00
"lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5)\n",
2022-02-19 06:19:26 +01:00
"history = model.fit(X_train, y_train, epochs=n_epochs,\n",
" validation_data=(X_valid, y_valid),\n",
2019-02-17 13:31:28 +01:00
" callbacks=[lr_scheduler])"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 89,
2018-05-08 20:21:23 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcgAAAEbCAYAAABeCxRrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABM20lEQVR4nO2dd5hUVdKH34JhSCoyYkBRQMS4q4iIcRVUFFlXxIS45lVEBcHVNawJRXfVNaEoiIrKigHFwKcYkTGsCVQyIogJGMGMCAoz1PdH3ZE7Pd3Tt2d6unum632e+/QN55xb99LMr885dapEVXEcx3EcpyINsm2A4ziO4+QiLpCO4ziOEwcXSMdxHMeJgwuk4ziO48TBBdJxHMdx4uAC6TiO4zhxcIF06jQislJETsu2HfUNEflcRC7Kth2Ok01cIJ1aR0QeFBENtlIR+VJERopIy2zblg5EpFvwbK0SXB8aev51IrJURMaJyNaZtjWw57SQPSoiJSIyXkTa17DNlem003GyjQukkyleBVoD7YAzgb8Ad2fToAwzH3v+NkBf4I/A+CzasyqwZ0vgRKATMFFEGmbRJsfJKVwgnUzxm6p+raqLVfVl4HHg0HABETldROaKyK8i8omIXCAiDULXtxOR4uD6fBE5IqZ+u6BH1CXmvIrIsaHjLYMe3HciskpEpotI99D1v4jIB8F9PhOR60WksIbPXxo8/1JVfRO4F9hbRDaqqpKIHC0is0TkNxH5SkQuFxEJXf9cRK4QkXtEZIWILBaRf0SwRwN7SlR1CnAN8AdguwR2/F1EZorILyKyRETuE5GNg2vdgAeA5qFe6dDgWqGI3BjY9YuITBWRw0LtNhSR+4P3vFpEFojIxTH/7g+KyHMx9gwVkdkRntNxqk1Btg1w8g8R2RboCawNnTsLuBYYBHyA/bG+NygzIviD+TTwA7AP0AwYDjRO8d7NgdeB5UAfYAmwW+j6YcA4YDDwBrANMCq4T1rm5ERkC+BooCzYEpXbA3gCuC6waU/gHmAFcGeo6AXA1cB/gMOBO0TkLVV9JwWzVgefjRJcXwcMARYBbYP73wmcDLwdXPsX0CEoXz7c+kBw7kRgMdAL+D8R2VNVZ2A/0pcAxwPfAF2B0cB3wP0p2O846UdVffOtVjfgQaAU+6O5GtBguyBU5kvg5Jh6Q4C5wf6hmJhsE7q+f9DOacFxu+C4S0w7Chwb7J8F/Ay0SmDrG8CVMeeOCmyXBHW6BfdI1ObQwPaV2NBm+fMPT/LexgGvxWlrcej4c+DRmDILgCuqaPc0YGXouA3wDvAVUBhq96Iq2ugJ/AY0iNdmcK4DJqzbxJx/Bri7irZvAF6N+f48F+c9zM72d9u3+r15D9LJFG8A/YGmmEh1AO4AEJFNga2Be0RkZKhOAVA+nLgTsERVvwxdfw/7A5wKuwMzVfXbBNf3ALqKyCWhcw0Cu7cASlK8XzmfYr2nxkBv4Bjgn0nq7AQ8H3PuLeBqEdlIVVcE52bGlFkKbJak7eaBU41gvfEPgaNVdU28wiJyEHBZYFMLoCFQiL2TpQnu0Tlof25oVBjsHbwWansANi/dFnvPjYAvktjvOLWOC6STKVap6sJg/3wRmQJcifUEyuebBmDDdfGQBOfDlItleI4udsgwWTsNsPm4J+Jc+yaCDYlYE3r+OSLSEbgL63klQrCeZjzC59fGuZbMv2AV5pizDlimqr8kNEKkLSbU9wJXYcOfnYFHMZFMRIPAlj3j2Lg6aLsvcDs2fP02Nnx8Hjb8Xc46Kv+7JRoKdpy04QLpZItrgBdEZLSqLhWRJUAHVR2boPxcYCsR2VpVvwrOdaWiEJQLWOvQuU4x7XwInCQirRL0Ij8EdgyJWW0xDJgvIneq6gcJyszFhpHD7I8Nsf5cw/trCs/YBRPCC1S1DCDWQQpYg/Uqw3yECdsWao5A8dgfeE9VR5SfEJEOMWW+ofK/Y+yx46Qd92J1soKqFgNzgCuCU0OBiwPP1R1E5A8icoqIXBZcfxX4GBgrIp1EZB/gNmxus7zN1cC7wCUisouI7AvcHHPrRzAHnWdE5E8i0l5Ejgx5sV4LnCgi1wY27Cgix4rITREe6w+BbeEt7v8xVV0ETMSEMhG3AAcGHpvbi8hfgQuBKLakkwXY34ohwfvqh80Ph/kcaCIiPUSklYg0U9VPsHnUB4N3uK2IdBGRi0Tk6KDeJ0BnETlcRDqKyJXAgTFtvwbsLiJniHkyXwzsV0vP6jjryfYkqG/1fyOOk0Vw/kTM0aNtcNwP68H9inmrvgWcECq/PeaB+hv2R/tIzPHltFCZnYD/YUOIs4A/EXLSCcq0wZaZ/BiU+wjoFrp+KPBmcG0FMA0YWMXzdWO9403stgEJHEqAfYMy+1bR9tHBc6zBnGguJ+QsRBxnGqAYGFFFm6cR41ATp0yFdoHzMW/T1cBkzOtUgXahMiOBb4PzQ4NzjYLnXxQ8w9fYD4M9guuFmLfqD8G/x/3YMO7nMfYMxeZ/f8LWz/4r3jv1zbd0bqKaaIrDcRzHcfIXH2J1HMdxnDi4QDqO4zhOHFwgHcdxHCcOLpCO4ziOE4e8XAfZoEEDbdq0abbNyDnWrVtHgwb+mymMv5P4+HuJT31/L6tWrVJVrb8PGENeCmRhYSG//JIwcEjeUlxcTLdu3bJtRk7h7yQ+/l7iU9/fi4isTl6q/pA3vwQcx3EcJxVcIB3HcRwnDi6QjuM4jhMHF0jHcRzHiYMLpOM4juPEIaMCKUJPEeaLsFCES+NcFxHuCK7PFKFz6NoYEZaLMDumTpEIr4iwIPhsmcyOtr+tZc+tv2bcuGh2jxsH7dpBgwb2WV/rHXTQgXXCzkzU83dSdb1cfy+OkxYyFRUdtCHop6DbghaCzgDdOaZML9AXQAV0b9D3QtcOAO0MOjumzk2glwb7l4LemMyWPUBHcK42a6b68MNaJQ8/rNqsmSqs37xe/a5XF2z0ernJlClTsm1CrQL8olnOsJHJLWPZPETYBxiqymHB8WUm0Pw7VOYeoFiVR4Pj+UA3VUqC43bAc6r8IVTn9zIitA7q71CVLV1E9A2asi2L+LXFFpx/fuKyd9wBP/1U+XyLFni9elqvLtiYz/XatoXPP09cL5vkwTrIVaraPNt2ZIpMCuSxQE9VzgyOTwb2UmVgqMxzwA2qvBUcTwYuUWVacNyOygL5oyobh45/UK08zCpCf6A/wB7IHm9RyP2cyUBGIJLYbns98Qqo16un9eqCjflcT0R57bXXE1fMIitXrmSDDTbIthm1Rvfu3fNKIDPWVQU9DvS+0PHJoHfGlHkedP/Q8WTQPULH7eIMsf4Yc/xDMlv2CMZrfqGpdmlTEjuKUIG2bbXCEE/51rZtldW8Xh2uVxds9Hq5Sd4PsUJPhfkKCxUuraLcngplWp7IHHZQmB7aVigMCa4NVVgSutarShvSuGXkJvZedR/Ql0LHl4FeFlPmHtB+oeP5oK1Dx/EE8vcyoK1B5yezpVwgV1OoHx98bpVfiLoy7+L10levLtjo9XKTvBZIaKjwqcK2CoUKMxR2TlDuNYVJvwtk5etfK7QNjocqXJTwvrWpWxm7EVoAugi0PeuddHaJKfNnKjrpvB9zPZ5A/oeKTjo3JbNlj/D/uE6dkn4pHn7YfrWK2GfU/6B1r966OmJn7dfzd5KsXubeS8uW9l+1TZvcFkfVvBfIfRReCh1fpnBZnHJDFM5TeDCBQB6q8L/Qcf0XSHu32gv0E8yb9fLg3ADQAcG+gN4VXJ8F2iVU91HQEtC1oItB/xac3wQbil0QfBYls2PzggJ79NdfT+GrUf+p7/+5q4O/k/hk8r2MG2f/XefPz9gtq019/760gt8UpoW2/rpeyI5VuC90fLLCCA3//YWtFF4PeomJBHKMwsDQ8VCFzxVmBtdaVqpTS1tGs3moMgmYFHNuVGhfgfMS1O2X4Px3wMGp2LGiYUMoLIRHHoEDDkilquM4GaaoyD6//z67djjwLZSi2iXB5fieWBW5HbgE1bK4HloihcCRYKscAkYCw4K2hgG3AGekZHg1yctIOusAeveGJ56ANWu
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – this cell plots performance scheduling\n",
"\n",
2019-02-17 13:31:28 +01:00
"plt.plot(history.epoch, history.history[\"lr\"], \"bo-\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Learning Rate\", color='b')\n",
"plt.tick_params('y', colors='b')\n",
"plt.gca().set_xlim(0, n_epochs - 1)\n",
"plt.grid(True)\n",
"\n",
"ax2 = plt.gca().twinx()\n",
"ax2.plot(history.epoch, history.history[\"val_loss\"], \"r^-\")\n",
"ax2.set_ylabel('Validation Loss', color='r')\n",
"ax2.tick_params('y', colors='r')\n",
"\n",
"plt.title(\"Reduce LR on Plateau\", fontsize=14)\n",
"plt.show()"
2017-06-05 18:48:03 +02:00
]
},
2019-05-05 06:42:08 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1Cycle scheduling"
]
},
2022-02-19 06:19:26 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `ExponentialLearningRate` custom callback updates the learning rate during training, at the end of each batch. It multiplies it by a constant `factor`. It also saves the learning rate and loss at each batch. Since `logs[\"loss\"]` is actually the mean loss since the start of the epoch, and we want to save the batch loss instead, we must compute the mean times the number of batches since the beginning of the epoch to get the total loss so far, then we subtract the total loss at the previous batch to get the current batch's loss."
]
},
2019-05-05 06:42:08 +02:00
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 90,
2019-05-05 06:42:08 +02:00
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"K = tf.keras.backend\n",
2019-05-05 06:42:08 +02:00
"\n",
2021-10-17 04:04:08 +02:00
"class ExponentialLearningRate(tf.keras.callbacks.Callback):\n",
2019-05-05 06:42:08 +02:00
" def __init__(self, factor):\n",
" self.factor = factor\n",
" self.rates = []\n",
" self.losses = []\n",
"\n",
2022-02-19 06:19:26 +01:00
" def on_epoch_begin(self, epoch, logs=None):\n",
" self.sum_of_epoch_losses = 0\n",
"\n",
" def on_batch_end(self, batch, logs=None):\n",
" mean_epoch_loss = logs[\"loss\"] # the epoch's mean loss so far \n",
" new_sum_of_epoch_losses = mean_epoch_loss * (batch + 1)\n",
" batch_loss = new_sum_of_epoch_losses - self.sum_of_epoch_losses\n",
" self.sum_of_epoch_losses = new_sum_of_epoch_losses\n",
" self.rates.append(K.get_value(self.model.optimizer.learning_rate))\n",
" self.losses.append(batch_loss)\n",
" K.set_value(self.model.optimizer.learning_rate,\n",
" self.model.optimizer.learning_rate * self.factor)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `find_learning_rate()` function trains the model using the `ExponentialLearningRate` callback, and it returns the learning rates and corresponding batch losses. At the end, it restores the model and its optimizer to their initial state."
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 91,
2022-02-19 06:19:26 +01:00
"metadata": {},
"outputs": [],
"source": [
"def find_learning_rate(model, X, y, epochs=1, batch_size=32, min_rate=1e-4,\n",
" max_rate=1):\n",
2019-05-05 06:42:08 +02:00
" init_weights = model.get_weights()\n",
2020-09-25 14:25:17 +02:00
" iterations = math.ceil(len(X) / batch_size) * epochs\n",
2022-02-19 06:19:26 +01:00
" factor = (max_rate / min_rate) ** (1 / iterations)\n",
2021-08-31 10:54:35 +02:00
" init_lr = K.get_value(model.optimizer.learning_rate)\n",
" K.set_value(model.optimizer.learning_rate, min_rate)\n",
2019-05-05 06:42:08 +02:00
" exp_lr = ExponentialLearningRate(factor)\n",
" history = model.fit(X, y, epochs=epochs, batch_size=batch_size,\n",
" callbacks=[exp_lr])\n",
2021-08-31 10:54:35 +02:00
" K.set_value(model.optimizer.learning_rate, init_lr)\n",
2019-05-05 06:42:08 +02:00
" model.set_weights(init_weights)\n",
2022-02-19 06:19:26 +01:00
" return exp_lr.rates, exp_lr.losses"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `plot_lr_vs_loss()` function plots the learning rates vs the losses. The optimal learning rate to use as the maximum learning rate in 1cycle is near the bottom of the curve."
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 92,
2022-02-19 06:19:26 +01:00
"metadata": {},
"outputs": [],
"source": [
2019-05-05 06:42:08 +02:00
"def plot_lr_vs_loss(rates, losses):\n",
2022-02-19 06:19:26 +01:00
" plt.plot(rates, losses, \"b\")\n",
2019-05-05 06:42:08 +02:00
" plt.gca().set_xscale('log')\n",
2022-02-19 06:19:26 +01:00
" max_loss = losses[0] + min(losses)\n",
" plt.hlines(min(losses), min(rates), max(rates), color=\"k\")\n",
" plt.axis([min(rates), max(rates), 0, max_loss])\n",
2019-05-05 06:42:08 +02:00
" plt.xlabel(\"Learning rate\")\n",
2022-02-19 06:19:26 +01:00
" plt.ylabel(\"Loss\")\n",
" plt.grid()"
2019-05-05 06:42:08 +02:00
]
},
2021-03-18 22:50:13 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"Let's build a simple Fashion MNIST model and compile it:"
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 93,
2022-02-19 06:19:26 +01:00
"metadata": {},
"outputs": [],
"source": [
"model = build_model()\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's find the optimal max learning rate for 1cycle:"
2021-03-18 22:50:13 +01:00
]
},
2019-05-05 06:42:08 +02:00
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 94,
2019-05-05 06:42:08 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"430/430 [==============================] - 1s 1ms/step - loss: 1.7725 - accuracy: 0.4122\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAEOCAYAAACNY7BQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABAZUlEQVR4nO2dd5gc1bG331qlVRaKIKEECmQQQQiTBDY52aRLtMGAjAEb30sONsaYZAwPl2CCsT6wScYYc4UQwWANiCBAAiQQEkpIKKEcUFhJqz3fHzWH7untSavZ2Zndep9nnk6ne870zvZvqupUHXHOYRiGYRieiobugGEYhlFamDAYhmEYKZgwGIZhGCmYMBiGYRgpmDAYhmEYKZgwGIZhGCk0b+gObC2dOnVyAwYMaOhuNBrWrVtH27ZtG7objQa7n4Wl3O7nzJmwejW0bw+DBgX7P/kEunSBpUuhVy/Ydtvi923ixInLnHPd4o6VvTD06NGDCRMmNHQ3Gg2JRILhw4c3dDcaDXY/C0u53c/jj4eXX4b994d//zvY37YtnHEG3H8/XHIJXH998fsmInPTHTNXkmEYRj3h84ejecTOQfPkz/KamuL2KRdMGAzDMOoJ/9CPPvxragJh2LKluH3KBRMGwzCMeiKTxdCsma6bMBiGYTQhvKUQFYaaGhBRcTBXkmEYRhPCC0L04e8cVFToq0lbDCJSKSIfisgkEZkiIjfHtBERuU9EZorIZBHZu1j9MwzDKDS5WAylKAzFHK66ETjcObdWRFoA74jIK8658aE2xwADk6/9gYeSS8MwjLIjU4zBXEmAU9YmN1skX9HJIE4C/ppsOx7oJCLbFauPhmEYhSRuVJIXiVJ2JRU1wU1EmgETgQHAg865DyJNegHzQtvzk/sWRa4zAhgB0K1bNxKJRH11ucmxdu1au58FxO5nYSm3+7ly5V5AJ1avXk0i8QngRWI4c+d+hXPbM3fuYhKJmQ3Yy9oUVRicc1uAvUSkE/AvEdnNOfd5qInEnRZznUeBRwEGDx7syikTstQpt8zSUsfuZ2Ept/vZoYMu27fv+F2/vYWwww79adUKevbcnuHDt2+YDqahQUYlOedWAQng6Mih+UDv0Pb2wMLi9MowDKOwxI1K8usipetKKuaopG5JSwERaQ38AJgWaTYK+HFydNIwYLVzbhGGYRhlSNyopHCMoVSDz8V0JW0HPJGMM1QAzznnRovIxQDOuYeBMcCxwExgPXB+EftnGIZRUOJGJZWDxVA0YXDOTQaGxOx/OLTugEuL1SfDMIz6JNuopFLNY7DMZ8MwjHoim8VQqq4kEwbDMIx6IlOMoZRdSSYMhmEY9UTcqCRzJRmGYTRhzJVkGIZhpJBtuKq5kgzDMJoY2RLczJVkGIbRxCjXBDcTBsMwjHqiXBPcTBgMwzDqiUwJbhZ8NgzDaILEWQwWfDYMw2jCxMUYLPhsGIbRhMklwc1cSYZhGE2IbBaDuZIMwzCaGNliDOZKMgzDaGLEjUqykhiGYRhNGBuVZBiGYaRgo5IMwzCMFDKNSjJXUpnw6quwcmVD98IwjMaCVVctccaNg8ceS398zRo45hg4/vji9ckwjMaNzcdQQsyeDUcfDUuWBPsOPRQuuggWLIg/Z9UqXb73Xr13zzCMJkK2BDezGIrIY4/Ba6/BHXcE+/wf4y9/CfatWQObN+u6FwbDMIxCYcHnEmLhQl2OHq3Ldev0jwAwcWLQrmNHOOUUXTdhMAyj0OSS4GaupCLxySe6nDEDVq+GadOCP4YPLvvtl17S5erVwfnffqvLhQvhhhtKU9ENwyh9siW4mSupSGzcCF98AYMH6/bs2TB1qq7vvDOsWKHr69ennhe2GGbP1uUhh8BttwXnG4Zh5EMmi8FcSYCI9BaRsSIyVUSmiMjlMW2Gi8hqEfk0+fpNtuuuXt0iZfubb6C6Go44QrdnzYJFi3R9zz0DYfBWgScsDAsWqHDMmqXbURGJY9Qo6NsXNmzI3tYwjKaBTe2ZnWrgCufczsAw4FIR2SWm3Tjn3F7J1++yXXTJksqUG+sf8HvvrctZs2D5cmjRAnr3VpHo0iU11hA+D2DZMnVDedauzf7h3nsPvv4avvoqdf/rr6cfCWUYRuMmblSSuZJCOOcWOec+Tq5/C0wFem39dWHp0mDbP+D79IGuXdUttGKFikHnznpsxQq49dbgnC1bUmMMS5fC4sXBdtS6iGPePF1+/XWwb/16OOooHTqbjbffhsmTs7czDKN8yMViaNLCEEZE+gFDgA9iDh8gIpNE5BUR2TWX6911F9xyi/6y9w/4Tp1gxx3hyy9VCDp3DoQB4IPQO69YoYKy3XZqWRRKGD77TJdh6yPMDTfAm2/q+qGHqqvLMIzGQ7kmuDUv9huKSDvgn8CvnHNrIoc/Bvo659aKyLHAi8DAmGuMAEbo1j7cfbeuvf76UpYsqQTaM23aeLbffnteemk7Bg5cS/PmjkWLFgCqNeE/xujRHzFjRj9atmxDhw7NmTx5BatXrwd2BODcc+Gppxax556refrpPjz22ARatkz9a86YsT/QmnHj5tKr19dMntyRpUsrgUF07FhFIjE+pX1NDdx223Buuw2OPPIbYFsAEokEX3zRgfXrm7Hvvpnrc3zxRQcmTerImWfOy9guH9auXUsikSjY9Zo6dj8LS7ndz+rqg4FmVFdvIZEYB8CXX7YH9mHKlM9YtGgbNm7sQSLxboP2sxbOuaK9gBbAa8D/5Nh+DtA1c5t9nOpx6mvpUudeeinYPukk5958s3Y7cO6uu3Q5bJhze+zh3IknOnfllfFtwbl33nHuxhudmzfPOeecq652rnlzPXb66c61a6fru++uy/79tV1NjXMjRzq3YIH2L+7aNTXB+pYtLiPNmmm7b7/N3C4fxo4dW7iLGXY/C0y53c/WrfV/tLIy2PfBB7pv9GjnLr/cuY4dG6ZvwASX5rlazFFJAvwFmOqcuydNm22T7RCRoaira3ku1+/RI3W7Y0d1zzRrpttRV1KYq67S5eLF0K1b4Erq0ydIjAtz773w+9/DJZcE51VX6/qrrwbBau9KWrRIH/XTpsFPfwq77ZbqqgoTDoKffHL6dvPnB+v+fQzDKC2siF52DgTOBQ4PDUc9VkQuFpGLk21OBT4XkUnAfcAZSWXLyjnnpG63aAHt28OuyShFly4qFpn4/e9ThWHbbaGyUo8NGwavvAK9esHzz+u+xYs1fuDjC127apmNMJ07Q1UV3H8//Oc/um/lSnjqqfg+TJkSrP/f/8G//lW7zYwZOsLKf6E+/TTz5zIMo2HINiqpyQefnXPvOOfEObeHC4ajjnHOPeycezjZ5gHn3K7OuT2dc8Occ1lL2nXvXsWf/6yJaHfdVfv4zjvrsnNnfdC3agU33RQc79JFl+edB2edpUHrmTN1mGnXrkEtpbPP1tFFBxwQnPvhhzBoEBx7rG6fdFLt9//JT3R5+eXw5JNq2bRrBw89pPt32AGOPDJo74PR++2ny+Uhe+nrr+Hjj/kupuJ5773UXySGYZQGmSyGUg4+l33mc6dOm7nwQmjZEg47rPbxvn11uWkTtG6tv97/+7+D4wMGpLbbYYfgWGVl4CLq2lWXp59e+z180typpwb7fv1rXR5zTLBv/HgYPhyOOy61mutrrwVJeK+/rsvHHlMLx7uSxo3TPu6zDzzySOr7P/mkWiRxjB2rSX+eKVM0O9wwjPrHpvYsAXr2rL1v33112a5dsK9Dh2C9f39d9uuny8su03yCf/4ztTqrtyxOOCH+vSsrU4Xpd7/TP/hBB8HA0Liq7t1TxcJft0cP2H57FYpmzVSwevRQYbjrLrj22vj3veYaGDIE/v731P1PPgnf+x4cfrgKEei1dtsNfvGL+GsZhlE4tmwJHvrmSmpAunevve/UU9VP/8tfBvvCAWUvCN5iaN0adt9dA7/hB7q3GCorNZv66ad1+/JkYY9mzdRNdeut8OKLuq+iQq8Xjid07qxi4WnePOiTdyn94AfQpo1+niVL4Oqr088T0bOnTi40fryKg5+MaORIeP99XZ80SZfLlun
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2019-05-05 06:42:08 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"batch_size = 128\n",
"rates, losses = find_learning_rate(model, X_train, y_train, epochs=1,\n",
" batch_size=batch_size)\n",
"plot_lr_vs_loss(rates, losses)"
2019-05-05 06:42:08 +02:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2019-05-05 06:42:08 +02:00
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"Looks like the max learning rate to use for 1cycle is around 10<sup>– 1</sup>."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `OneCycleScheduler` custom callback updates the learning rate at the beginning of each batch. It applies the logic described in the book: increase the learning rate linearly during about half of training, then reduce it linearly back to the initial learning rate, and lastly reduce it down to close to zero linearly for the very last part of training."
2019-05-05 06:42:08 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 95,
2019-05-05 06:42:08 +02:00
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"class OneCycleScheduler(tf.keras.callbacks.Callback):\n",
2022-02-19 06:19:26 +01:00
" def __init__(self, iterations, max_lr=1e-3, start_lr=None,\n",
" last_iterations=None, last_lr=None):\n",
2019-05-05 06:42:08 +02:00
" self.iterations = iterations\n",
2022-02-19 06:19:26 +01:00
" self.max_lr = max_lr\n",
" self.start_lr = start_lr or max_lr / 10\n",
2019-05-05 06:42:08 +02:00
" self.last_iterations = last_iterations or iterations // 10 + 1\n",
" self.half_iteration = (iterations - self.last_iterations) // 2\n",
2022-02-19 06:19:26 +01:00
" self.last_lr = last_lr or self.start_lr / 1000\n",
2019-05-05 06:42:08 +02:00
" self.iteration = 0\n",
2022-02-19 06:19:26 +01:00
"\n",
" def _interpolate(self, iter1, iter2, lr1, lr2):\n",
" return (lr2 - lr1) * (self.iteration - iter1) / (iter2 - iter1) + lr1\n",
"\n",
2019-05-05 06:42:08 +02:00
" def on_batch_begin(self, batch, logs):\n",
" if self.iteration < self.half_iteration:\n",
2022-02-19 06:19:26 +01:00
" lr = self._interpolate(0, self.half_iteration, self.start_lr,\n",
" self.max_lr)\n",
2019-05-05 06:42:08 +02:00
" elif self.iteration < 2 * self.half_iteration:\n",
2022-02-19 06:19:26 +01:00
" lr = self._interpolate(self.half_iteration, 2 * self.half_iteration,\n",
" self.max_lr, self.start_lr)\n",
2019-05-05 06:42:08 +02:00
" else:\n",
2022-02-19 06:19:26 +01:00
" lr = self._interpolate(2 * self.half_iteration, self.iterations,\n",
" self.start_lr, self.last_lr)\n",
2019-05-05 06:42:08 +02:00
" self.iteration += 1\n",
2022-02-19 06:19:26 +01:00
" K.set_value(self.model.optimizer.learning_rate, lr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's build and compile a simple Fashion MNIST model, then train it using the `OneCycleScheduler` callback:"
2019-05-05 06:42:08 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 96,
2019-05-05 06:42:08 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
"430/430 [==============================] - 1s 2ms/step - loss: 0.9502 - accuracy: 0.6913 - val_loss: 0.6003 - val_accuracy: 0.7874\n",
"Epoch 2/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.5695 - accuracy: 0.8025 - val_loss: 0.4918 - val_accuracy: 0.8248\n",
"Epoch 3/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.4954 - accuracy: 0.8252 - val_loss: 0.4762 - val_accuracy: 0.8264\n",
"Epoch 4/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.4515 - accuracy: 0.8402 - val_loss: 0.4261 - val_accuracy: 0.8478\n",
"Epoch 5/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.4225 - accuracy: 0.8492 - val_loss: 0.4066 - val_accuracy: 0.8486\n",
"Epoch 6/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3958 - accuracy: 0.8571 - val_loss: 0.4787 - val_accuracy: 0.8224\n",
"Epoch 7/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3787 - accuracy: 0.8626 - val_loss: 0.3917 - val_accuracy: 0.8566\n",
"Epoch 8/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3630 - accuracy: 0.8683 - val_loss: 0.4719 - val_accuracy: 0.8296\n",
"Epoch 9/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3512 - accuracy: 0.8724 - val_loss: 0.3673 - val_accuracy: 0.8652\n",
"Epoch 10/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3360 - accuracy: 0.8766 - val_loss: 0.4957 - val_accuracy: 0.8466\n",
"Epoch 11/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3287 - accuracy: 0.8786 - val_loss: 0.4187 - val_accuracy: 0.8370\n",
"Epoch 12/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3173 - accuracy: 0.8815 - val_loss: 0.3425 - val_accuracy: 0.8728\n",
"Epoch 13/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2961 - accuracy: 0.8910 - val_loss: 0.3217 - val_accuracy: 0.8792\n",
"Epoch 14/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2818 - accuracy: 0.8958 - val_loss: 0.3734 - val_accuracy: 0.8692\n",
"Epoch 15/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2675 - accuracy: 0.9003 - val_loss: 0.3261 - val_accuracy: 0.8844\n",
"Epoch 16/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2558 - accuracy: 0.9055 - val_loss: 0.3205 - val_accuracy: 0.8820\n",
"Epoch 17/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2464 - accuracy: 0.9091 - val_loss: 0.3089 - val_accuracy: 0.8894\n",
"Epoch 18/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2368 - accuracy: 0.9115 - val_loss: 0.3130 - val_accuracy: 0.8870\n",
"Epoch 19/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2292 - accuracy: 0.9145 - val_loss: 0.3078 - val_accuracy: 0.8854\n",
"Epoch 20/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2205 - accuracy: 0.9186 - val_loss: 0.3092 - val_accuracy: 0.8886\n",
"Epoch 21/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2138 - accuracy: 0.9209 - val_loss: 0.3022 - val_accuracy: 0.8914\n",
"Epoch 22/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2073 - accuracy: 0.9232 - val_loss: 0.3054 - val_accuracy: 0.8914\n",
"Epoch 23/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2020 - accuracy: 0.9261 - val_loss: 0.3026 - val_accuracy: 0.8896\n",
"Epoch 24/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.1989 - accuracy: 0.9273 - val_loss: 0.3020 - val_accuracy: 0.8922\n",
"Epoch 25/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.1967 - accuracy: 0.9276 - val_loss: 0.3016 - val_accuracy: 0.8920\n"
]
}
],
2019-05-05 06:42:08 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"model = build_model()\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(),\n",
" metrics=[\"accuracy\"])\n",
2019-05-05 06:42:08 +02:00
"n_epochs = 25\n",
2022-02-19 06:19:26 +01:00
"onecycle = OneCycleScheduler(math.ceil(len(X_train) / batch_size) * n_epochs,\n",
" max_lr=0.1)\n",
"history = model.fit(X_train, y_train, epochs=n_epochs, batch_size=batch_size,\n",
" validation_data=(X_valid, y_valid),\n",
2019-05-05 06:42:08 +02:00
" callbacks=[onecycle])"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "markdown",
2018-05-08 20:21:23 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"# Avoiding Overfitting Through Regularization"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "markdown",
2018-05-08 20:21:23 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"## $\\ell_1$ and $\\ell_2$ regularization"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 97,
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-11-03 13:43:56 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"layer = tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\",\n",
" kernel_regularizer=tf.keras.regularizers.l2(0.01))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or use `l1(0.1)` for ℓ <sub>1</sub> regularization with a factor of 0.1, or `l1_l2(0.1, 0.01)` for both ℓ <sub>1</sub> and ℓ <sub>2</sub> regularization, with factors 0.1 and 0.01 respectively."
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 98,
2017-06-21 15:35:47 +02:00
"metadata": {},
2019-02-17 13:31:28 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"tf.random.set_seed(42) # extra code – for reproducibility"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 99,
2017-06-21 15:35:47 +02:00
"metadata": {},
2019-02-17 13:31:28 +01:00
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"from functools import partial\n",
"\n",
2021-10-17 04:04:08 +02:00
"RegularizedDense = partial(tf.keras.layers.Dense,\n",
2022-02-19 06:19:26 +01:00
" activation=\"relu\",\n",
2019-02-17 13:31:28 +01:00
" kernel_initializer=\"he_normal\",\n",
2021-10-17 04:04:08 +02:00
" kernel_regularizer=tf.keras.regularizers.l2(0.01))\n",
2019-02-17 13:31:28 +01:00
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
2022-02-19 06:19:26 +01:00
" RegularizedDense(100),\n",
2019-02-17 13:31:28 +01:00
" RegularizedDense(100),\n",
" RegularizedDense(10, activation=\"softmax\")\n",
2022-02-19 06:19:26 +01:00
"])"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 100,
2018-05-08 20:21:23 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"1719/1719 [==============================] - 2s 878us/step - loss: 3.1224 - accuracy: 0.7748 - val_loss: 1.8602 - val_accuracy: 0.8264\n",
"Epoch 2/2\n",
"1719/1719 [==============================] - 1s 814us/step - loss: 1.4263 - accuracy: 0.8159 - val_loss: 1.1269 - val_accuracy: 0.8182\n"
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – compile and train the model\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.02)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, y_train, epochs=2,\n",
" validation_data=(X_valid, y_valid))"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"## Dropout"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 101,
2018-05-08 20:21:23 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
2019-02-28 12:48:06 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"tf.random.set_seed(42) # extra code – for reproducibility"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 102,
2019-02-28 12:48:06 +01:00
"metadata": {},
"outputs": [],
2017-06-05 18:48:03 +02:00
"source": [
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
2022-02-19 06:19:26 +01:00
" tf.keras.layers.Dropout(rate=0.2),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dropout(rate=0.2),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dropout(rate=0.2),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
2022-02-19 06:19:26 +01:00
"])"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 103,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.6703 - accuracy: 0.7536 - val_loss: 0.4498 - val_accuracy: 0.8342\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 996us/step - loss: 0.5103 - accuracy: 0.8136 - val_loss: 0.4401 - val_accuracy: 0.8296\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 998us/step - loss: 0.4712 - accuracy: 0.8263 - val_loss: 0.3806 - val_accuracy: 0.8554\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 977us/step - loss: 0.4488 - accuracy: 0.8337 - val_loss: 0.3711 - val_accuracy: 0.8608\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4342 - accuracy: 0.8409 - val_loss: 0.3672 - val_accuracy: 0.8606\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 983us/step - loss: 0.4245 - accuracy: 0.8427 - val_loss: 0.3706 - val_accuracy: 0.8600\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 995us/step - loss: 0.4131 - accuracy: 0.8467 - val_loss: 0.3582 - val_accuracy: 0.8650\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 959us/step - loss: 0.4074 - accuracy: 0.8484 - val_loss: 0.3478 - val_accuracy: 0.8708\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 997us/step - loss: 0.4024 - accuracy: 0.8533 - val_loss: 0.3556 - val_accuracy: 0.8690\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 998us/step - loss: 0.3903 - accuracy: 0.8552 - val_loss: 0.3453 - val_accuracy: 0.8732\n"
]
}
],
2019-02-28 12:48:06 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – compile and train the model\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, y_train, epochs=10,\n",
" validation_data=(X_valid, y_valid))"
2019-02-28 12:48:06 +01:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2019-02-28 12:48:06 +01:00
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"The training accuracy looks like it's lower than the validation accuracy, but that's just because dropout is only active during training. If we evaluate the model on the training set after training (i.e., with dropout turned off), we get the \"real\" training accuracy, which is very slightly higher than the validation accuracy and the test accuracy:"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 104,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1719/1719 [==============================] - 1s 578us/step - loss: 0.3082 - accuracy: 0.8849\n"
]
},
{
"data": {
"text/plain": [
"[0.30816400051116943, 0.8849090933799744]"
]
},
2023-11-15 06:29:39 +01:00
"execution_count": 104,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"model.evaluate(X_train, y_train)"
2019-02-28 12:48:06 +01:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 105,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"313/313 [==============================] - 0s 588us/step - loss: 0.3629 - accuracy: 0.8700\n"
]
},
{
"data": {
"text/plain": [
"[0.3628920316696167, 0.8700000047683716]"
]
},
2023-11-15 06:29:39 +01:00
"execution_count": 105,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"model.evaluate(X_test, y_test)"
2019-02-28 12:48:06 +01:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2019-02-28 12:48:06 +01:00
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"**Note**: make sure to use `AlphaDropout` instead of `Dropout` if you want to build a self-normalizing neural net using SELU."
2019-02-28 12:48:06 +01:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2019-02-28 12:48:06 +01:00
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"## MC Dropout"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 106,
2019-02-28 12:48:06 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"tf.random.set_seed(42) # extra code – for reproducibility"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 107,
2019-02-28 12:48:06 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"y_probas = np.stack([model(X_test, training=True)\n",
" for sample in range(100)])\n",
"y_proba = y_probas.mean(axis=0)"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 108,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0. , 0. , 0. , 0. , 0. , 0.024, 0. , 0.132, 0. ,\n",
" 0.844]], dtype=float32)"
]
},
2023-11-15 06:29:39 +01:00
"execution_count": 108,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"model.predict(X_test[:1]).round(3)"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 109,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0. , 0. , 0. , 0. , 0. , 0.067, 0. , 0.209, 0.001,\n",
" 0.723], dtype=float32)"
]
},
2023-11-15 06:29:39 +01:00
"execution_count": 109,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"y_proba[0].round(3)"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 110,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0. , 0. , 0. , 0.001, 0. , 0.096, 0. , 0.162, 0.001,\n",
" 0.183], dtype=float32)"
]
},
2023-11-15 06:29:39 +01:00
"execution_count": 110,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"y_std = y_probas.std(axis=0)\n",
"y_std[0].round(3)"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 111,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.8717"
]
},
2023-11-15 06:29:39 +01:00
"execution_count": 111,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"y_pred = y_proba.argmax(axis=1)\n",
"accuracy = (y_pred == y_test).sum() / len(y_test)\n",
2019-02-28 12:48:06 +01:00
"accuracy"
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 112,
2019-02-28 12:48:06 +01:00
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"class MCDropout(tf.keras.layers.Dropout):\n",
2022-02-19 06:19:26 +01:00
" def call(self, inputs, training=None):\n",
2019-02-28 12:48:06 +01:00
" return super().call(inputs, training=True)"
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 113,
2019-02-28 12:48:06 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – shows how to convert Dropout to MCDropout in a Sequential model\n",
"Dropout = tf.keras.layers.Dropout\n",
2021-10-17 04:04:08 +02:00
"mc_model = tf.keras.Sequential([\n",
2022-02-19 06:19:26 +01:00
" MCDropout(layer.rate) if isinstance(layer, Dropout) else layer\n",
2019-02-28 12:48:06 +01:00
" for layer in model.layers\n",
2022-02-19 06:19:26 +01:00
"])\n",
"mc_model.set_weights(model.get_weights())"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 114,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_25\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"flatten_22 (Flatten) (None, 784) 0 \n",
"_________________________________________________________________\n",
"mc_dropout (MCDropout) (None, 784) 0 \n",
"_________________________________________________________________\n",
"dense_89 (Dense) (None, 100) 78500 \n",
"_________________________________________________________________\n",
"mc_dropout_1 (MCDropout) (None, 100) 0 \n",
"_________________________________________________________________\n",
"dense_90 (Dense) (None, 100) 10100 \n",
"_________________________________________________________________\n",
"mc_dropout_2 (MCDropout) (None, 100) 0 \n",
"_________________________________________________________________\n",
"dense_91 (Dense) (None, 10) 1010 \n",
"=================================================================\n",
"Total params: 89,610\n",
"Trainable params: 89,610\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
2019-02-28 12:48:06 +01:00
"source": [
"mc_model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can use the model with MC Dropout:"
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 115,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0. , 0. , 0. , 0. , 0. , 0.07, 0. , 0.17, 0. , 0.76]],\n",
" dtype=float32)"
]
},
2023-11-15 06:29:39 +01:00
"execution_count": 115,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – shows that the model works without retraining\n",
"tf.random.set_seed(42)\n",
"np.mean([mc_model.predict(X_test[:1])\n",
" for sample in range(100)], axis=0).round(2)"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"## Max norm"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 116,
2018-05-08 20:21:23 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"dense = tf.keras.layers.Dense(\n",
" 100, activation=\"relu\", kernel_initializer=\"he_normal\",\n",
" kernel_constraint=tf.keras.constraints.max_norm(1.))"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 117,
2019-02-17 13:31:28 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.5500 - accuracy: 0.8015 - val_loss: 0.4510 - val_accuracy: 0.8242\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 960us/step - loss: 0.4089 - accuracy: 0.8499 - val_loss: 0.3956 - val_accuracy: 0.8504\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 974us/step - loss: 0.3777 - accuracy: 0.8604 - val_loss: 0.3693 - val_accuracy: 0.8680\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 943us/step - loss: 0.3581 - accuracy: 0.8690 - val_loss: 0.3517 - val_accuracy: 0.8716\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 949us/step - loss: 0.3416 - accuracy: 0.8729 - val_loss: 0.3433 - val_accuracy: 0.8682\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 951us/step - loss: 0.3368 - accuracy: 0.8756 - val_loss: 0.4045 - val_accuracy: 0.8582\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 935us/step - loss: 0.3293 - accuracy: 0.8767 - val_loss: 0.4168 - val_accuracy: 0.8476\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 951us/step - loss: 0.3258 - accuracy: 0.8779 - val_loss: 0.3570 - val_accuracy: 0.8674\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 970us/step - loss: 0.3269 - accuracy: 0.8787 - val_loss: 0.3702 - val_accuracy: 0.8578\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 948us/step - loss: 0.3169 - accuracy: 0.8809 - val_loss: 0.3907 - val_accuracy: 0.8578\n"
]
}
],
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – shows how to apply max norm to every hidden layer in a model\n",
"\n",
2021-10-17 04:04:08 +02:00
"MaxNormDense = partial(tf.keras.layers.Dense,\n",
2022-02-19 06:19:26 +01:00
" activation=\"relu\", kernel_initializer=\"he_normal\",\n",
2021-10-17 04:04:08 +02:00
" kernel_constraint=tf.keras.constraints.max_norm(1.))\n",
2017-06-05 18:48:03 +02:00
"\n",
2022-02-19 06:19:26 +01:00
"tf.random.set_seed(42)\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
2022-02-19 06:19:26 +01:00
" MaxNormDense(100),\n",
2019-02-17 13:31:28 +01:00
" MaxNormDense(100),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
2019-02-17 13:31:28 +01:00
"])\n",
2022-02-19 06:19:26 +01:00
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, y_train, epochs=10,\n",
" validation_data=(X_valid, y_valid))"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2020-04-06 09:13:12 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2019-02-28 12:48:06 +01:00
"# Exercises"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "markdown",
2018-05-08 20:21:23 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2019-02-17 13:31:28 +01:00
"## 1. to 7."
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"1. Glorot initialization and He initialization were designed to make the output standard deviation as close as possible to the input standard deviation, at least at the beginning of training. This reduces the vanishing/exploding gradients problem.\n",
"2. No, all weights should be sampled independently; they should not all have the same initial value. One important goal of sampling weights randomly is to break symmetry: if all the weights have the same initial value, even if that value is not zero, then symmetry is not broken (i.e., all neurons in a given layer are equivalent), and backpropagation will be unable to break it. Concretely, this means that all the neurons in any given layer will always have the same weights. It's like having just one neuron per layer, and much slower. It is virtually impossible for such a configuration to converge to a good solution.\n",
"3. It is perfectly fine to initialize the bias terms to zero. Some people like to initialize them just like weights, and that's OK too; it does not make much difference.\n",
"4. ReLU is usually a good default for the hidden layers, as it is fast and yields good results. Its ability to output precisely zero can also be useful in some cases (e.g., see Chapter 17). Moreover, it can sometimes benefit from optimized implementations as well as from hardware acceleration. The leaky ReLU variants of ReLU can improve the model's quality without hindering its speed too much compared to ReLU. For large neural nets and more complex problems, GLU, Swish and Mish can give you a slightly higher quality model, but they have a computational cost. The hyperbolic tangent (tanh) can be useful in the output layer if you need to output a number in a fixed range (by default between – 1 and 1), but nowadays it is not used much in hidden layers, except in recurrent nets. The sigmoid activation function is also useful in the output layer when you need to estimate a probability (e.g., for binary classification), but it is rarely used in hidden layers (there are exceptions—for example, for the coding layer of variational autoencoders; see Chapter 17). The softplus activation function is useful in the output layer when you need to ensure that the output will always be positive. The softmax activation function is useful in the output layer to estimate probabilities for mutually exclusive classes, but it is rarely (if ever) used in hidden layers.\n",
"5. If you set the `momentum` hyperparameter too close to 1 (e.g., 0.99999) when using an `SGD` optimizer, then the algorithm will likely pick up a lot of speed, hopefully moving roughly toward the global minimum, but its momentum will carry it right past the minimum. Then it will slow down and come back, accelerate again, overshoot again, and so on. It may oscillate this way many times before converging, so overall it will take much longer to converge than with a smaller `momentum` value.\n",
"6. One way to produce a sparse model (i.e., with most weights equal to zero) is to train the model normally, then zero out tiny weights. For more sparsity, you can apply ℓ <sub>1</sub> regularization during training, which pushes the optimizer toward sparsity. A third option is to use the TensorFlow Model Optimization Toolkit.\n",
"7. Yes, dropout does slow down training, in general roughly by a factor of two. However, it has no impact on inference speed since it is only turned on during training. MC Dropout is exactly like dropout during training, but it is still active during inference, so each inference is slowed down slightly. More importantly, when using MC Dropout you generally want to run inference 10 times or more to get better predictions. This means that making predictions is slowed down by a factor of 10 or more."
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "markdown",
2018-05-08 20:21:23 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2020-03-10 21:55:45 +01:00
"## 8. Deep Learning on CIFAR10"
2017-06-05 18:48:03 +02:00
]
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "markdown",
2018-05-08 20:21:23 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2020-03-10 21:55:45 +01:00
"### a.\n",
2022-02-19 06:19:26 +01:00
"*Exercise: Build a DNN with 20 hidden layers of 100 neurons each (that's too many, but it's the point of this exercise). Use He initialization and the Swish activation function.*"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 118,
2019-02-17 13:31:28 +01:00
"metadata": {},
2017-11-03 13:43:56 +01:00
"outputs": [],
2020-03-10 21:55:45 +01:00
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
2020-03-10 21:55:45 +01:00
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100,\n",
2022-02-19 06:19:26 +01:00
" activation=\"swish\",\n",
" kernel_initializer=\"he_normal\"))"
2020-03-10 21:55:45 +01:00
]
2017-06-05 18:48:03 +02:00
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2020-03-10 21:55:45 +01:00
"### b.\n",
2021-10-17 04:04:08 +02:00
"*Exercise: Using Nadam optimization and early stopping, train the network on the CIFAR10 dataset. You can load it with `tf.keras.datasets.cifar10.load_data()`. The dataset is composed of 60,000 32 × 32– pixel color images (50,000 for training, 10,000 for testing) with 10 classes, so you'll need a softmax output layer with 10 neurons. Remember to search for the right learning rate each time you change the model's architecture or hyperparameters.*"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-05 18:48:03 +02:00
"source": [
2020-03-10 21:55:45 +01:00
"Let's add the output layer to the model:"
2017-06-05 18:48:03 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 119,
2018-05-08 20:21:23 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))"
2016-09-27 23:31:21 +02:00
]
},
2017-04-30 10:21:27 +02:00
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-04-30 10:21:27 +02:00
"source": [
2020-03-10 21:55:45 +01:00
"Let's use a Nadam optimizer with a learning rate of 5e-5. I tried learning rates 1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3 and 1e-2, and I compared their learning curves for 10 epochs each (using the TensorBoard callback, below). The learning rates 3e-5 and 1e-4 were pretty good, so I tried 5e-5, which turned out to be slightly better."
2017-04-30 10:21:27 +02:00
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 120,
2018-05-08 20:21:23 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"optimizer = tf.keras.optimizers.Nadam(learning_rate=5e-5)\n",
2020-03-10 21:55:45 +01:00
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
2016-09-27 23:31:21 +02:00
]
},
{
2017-06-05 18:48:03 +02:00
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2020-03-10 21:55:45 +01:00
"Let's load the CIFAR10 dataset. We also want to use early stopping, so we need a validation set. Let's use the first 5,000 images of the original training set as the validation set:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 121,
2018-05-08 20:21:23 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
2020-03-10 21:55:45 +01:00
"source": [
2022-02-20 05:36:52 +01:00
"cifar10 = tf.keras.datasets.cifar10.load_data()\n",
"(X_train_full, y_train_full), (X_test, y_test) = cifar10\n",
2020-03-10 21:55:45 +01:00
"\n",
"X_train = X_train_full[5000:]\n",
"y_train = y_train_full[5000:]\n",
"X_valid = X_train_full[:5000]\n",
"y_valid = y_train_full[:5000]"
]
2019-02-17 13:31:28 +01:00
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2020-03-10 21:55:45 +01:00
"Now we can create the callbacks we need and train the model:"
2016-09-27 23:31:21 +02:00
]
},
{
2020-03-10 21:55:45 +01:00
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 122,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-20 05:36:52 +01:00
"outputs": [],
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=20,\n",
" restore_best_weights=True)\n",
"model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\"my_cifar10_model\",\n",
" save_best_only=True)\n",
2020-03-10 21:55:45 +01:00
"run_index = 1 # increment every time you train the model\n",
2022-02-19 06:19:26 +01:00
"run_logdir = Path() / \"my_cifar10_logs\" / f\"run_{run_index:03d}\"\n",
2021-10-17 04:04:08 +02:00
"tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)\n",
2020-03-10 21:55:45 +01:00
"callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 123,
2018-05-08 20:21:23 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
2022-02-20 05:36:52 +01:00
"data": {
"text/html": [
"\n",
" <iframe id=\"tensorboard-frame-d05c16b556c70d97\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
" </iframe>\n",
" <script>\n",
" (function() {\n",
" const frame = document.getElementById(\"tensorboard-frame-d05c16b556c70d97\");\n",
" const url = new URL(\"/\", window.location);\n",
" const port = 6006;\n",
" if (port) {\n",
" url.port = port;\n",
" }\n",
" frame.src = url;\n",
" })();\n",
" </script>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
2022-02-19 10:24:54 +01:00
}
],
2020-03-10 21:55:45 +01:00
"source": [
2022-02-20 05:36:52 +01:00
"%load_ext tensorboard\n",
2022-02-19 06:19:26 +01:00
"%tensorboard --logdir=./my_cifar10_logs"
2020-03-10 21:55:45 +01:00
]
2016-09-27 23:31:21 +02:00
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 124,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-20 05:36:52 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"1404/1407 [============================>.] - ETA: 0s - loss: 4.0493 - accuracy: 0.1598INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 17s 10ms/step - loss: 4.0462 - accuracy: 0.1597 - val_loss: 2.1441 - val_accuracy: 0.2036\n",
"Epoch 2/100\n",
"1407/1407 [==============================] - ETA: 0s - loss: 2.0667 - accuracy: 0.2320INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 2.0667 - accuracy: 0.2320 - val_loss: 2.0134 - val_accuracy: 0.2472\n",
"Epoch 3/100\n",
"1407/1407 [==============================] - ETA: 0s - loss: 1.9472 - accuracy: 0.2819INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.9472 - accuracy: 0.2819 - val_loss: 1.9427 - val_accuracy: 0.2796\n",
"Epoch 4/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.8636 - accuracy: 0.3182INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.8637 - accuracy: 0.3182 - val_loss: 1.8934 - val_accuracy: 0.3222\n",
"Epoch 5/100\n",
"1402/1407 [============================>.] - ETA: 0s - loss: 1.7975 - accuracy: 0.3464INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.7974 - accuracy: 0.3465 - val_loss: 1.8389 - val_accuracy: 0.3284\n",
"Epoch 6/100\n",
"1407/1407 [==============================] - 9s 7ms/step - loss: 1.7446 - accuracy: 0.3664 - val_loss: 2.0006 - val_accuracy: 0.3030\n",
"Epoch 7/100\n",
"1407/1407 [==============================] - ETA: 0s - loss: 1.6974 - accuracy: 0.3852INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.6974 - accuracy: 0.3852 - val_loss: 1.7075 - val_accuracy: 0.3738\n",
"Epoch 8/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.6605 - accuracy: 0.3984INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.6604 - accuracy: 0.3984 - val_loss: 1.6788 - val_accuracy: 0.3836\n",
"Epoch 9/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.6322 - accuracy: 0.4114INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.6321 - accuracy: 0.4114 - val_loss: 1.6477 - val_accuracy: 0.4014\n",
"Epoch 10/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.6065 - accuracy: 0.4205 - val_loss: 1.6623 - val_accuracy: 0.3980\n",
"Epoch 11/100\n",
"1401/1407 [============================>.] - ETA: 0s - loss: 1.5843 - accuracy: 0.4287INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.5845 - accuracy: 0.4285 - val_loss: 1.6032 - val_accuracy: 0.4198\n",
"Epoch 12/100\n",
"1407/1407 [==============================] - 9s 6ms/step - loss: 1.5634 - accuracy: 0.4367 - val_loss: 1.6063 - val_accuracy: 0.4258\n",
"Epoch 13/100\n",
"1401/1407 [============================>.] - ETA: 0s - loss: 1.5443 - accuracy: 0.4420INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"<<47 more lines>>\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.3247 - accuracy: 0.5256 - val_loss: 1.5130 - val_accuracy: 0.4616\n",
"Epoch 33/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.3164 - accuracy: 0.5286 - val_loss: 1.5284 - val_accuracy: 0.4686\n",
"Epoch 34/100\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 1.3091 - accuracy: 0.5303 - val_loss: 1.5208 - val_accuracy: 0.4682\n",
"Epoch 35/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.3026 - accuracy: 0.5319 - val_loss: 1.5479 - val_accuracy: 0.4604\n",
"Epoch 36/100\n",
"1407/1407 [==============================] - 11s 8ms/step - loss: 1.2930 - accuracy: 0.5378 - val_loss: 1.5443 - val_accuracy: 0.4580\n",
"Epoch 37/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.2833 - accuracy: 0.5406 - val_loss: 1.5165 - val_accuracy: 0.4710\n",
"Epoch 38/100\n",
"1407/1407 [==============================] - 11s 8ms/step - loss: 1.2763 - accuracy: 0.5433 - val_loss: 1.5345 - val_accuracy: 0.4672\n",
"Epoch 39/100\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 1.2687 - accuracy: 0.5437 - val_loss: 1.5162 - val_accuracy: 0.4712\n",
"Epoch 40/100\n",
"1407/1407 [==============================] - 11s 7ms/step - loss: 1.2623 - accuracy: 0.5490 - val_loss: 1.5717 - val_accuracy: 0.4566\n",
"Epoch 41/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.2580 - accuracy: 0.5467 - val_loss: 1.5296 - val_accuracy: 0.4738\n",
"Epoch 42/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.2469 - accuracy: 0.5532 - val_loss: 1.5179 - val_accuracy: 0.4690\n",
"Epoch 43/100\n",
"1407/1407 [==============================] - 11s 8ms/step - loss: 1.2404 - accuracy: 0.5542 - val_loss: 1.5542 - val_accuracy: 0.4566\n",
"Epoch 44/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.2292 - accuracy: 0.5605 - val_loss: 1.5536 - val_accuracy: 0.4608\n",
"Epoch 45/100\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 1.2276 - accuracy: 0.5606 - val_loss: 1.5522 - val_accuracy: 0.4624\n",
"Epoch 46/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.2200 - accuracy: 0.5637 - val_loss: 1.5339 - val_accuracy: 0.4794\n",
"Epoch 47/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.2080 - accuracy: 0.5677 - val_loss: 1.5451 - val_accuracy: 0.4688\n",
"Epoch 48/100\n",
"1407/1407 [==============================] - 15s 10ms/step - loss: 1.2050 - accuracy: 0.5675 - val_loss: 1.5209 - val_accuracy: 0.4770\n",
"Epoch 49/100\n",
"1407/1407 [==============================] - 10s 7ms/step - loss: 1.1947 - accuracy: 0.5718 - val_loss: 1.5435 - val_accuracy: 0.4736\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fb9f02fc070>"
]
},
2023-11-15 06:29:39 +01:00
"execution_count": 124,
2022-02-20 05:36:52 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-03-10 21:55:45 +01:00
"source": [
"model.fit(X_train, y_train, epochs=100,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=callbacks)"
]
2016-09-27 23:31:21 +02:00
},
{
2019-02-17 13:31:28 +01:00
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 125,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-20 05:36:52 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"157/157 [==============================] - 0s 2ms/step - loss: 1.5062 - accuracy: 0.4676\n"
]
},
{
"data": {
"text/plain": [
"[1.5061508417129517, 0.4675999879837036]"
]
},
2023-11-15 06:29:39 +01:00
"execution_count": 125,
2022-02-20 05:36:52 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
2020-03-10 21:55:45 +01:00
"model.evaluate(X_valid, y_valid)"
2016-09-27 23:31:21 +02:00
]
},
2017-04-30 10:21:27 +02:00
{
2019-02-17 13:31:28 +01:00
"cell_type": "markdown",
2018-05-08 20:21:23 +02:00
"metadata": {},
2017-04-30 10:21:27 +02:00
"source": [
2022-02-20 05:36:52 +01:00
"The model with the lowest validation loss gets about 46.8% accuracy on the validation set. It took 29 epochs to reach the lowest validation loss, with roughly 10 seconds per epoch on my laptop (without a GPU). Let's see if we can improve the model using Batch Normalization."
2017-04-30 10:21:27 +02:00
]
},
2016-09-27 23:31:21 +02:00
{
2017-06-05 18:48:03 +02:00
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2020-03-10 21:55:45 +01:00
"### c.\n",
"*Exercise: Now try adding Batch Normalization and compare the learning curves: Is it converging faster than before? Does it produce a better model? How does it affect training speed?*"
2016-09-27 23:31:21 +02:00
]
},
{
2020-03-10 21:55:45 +01:00
"cell_type": "markdown",
2018-05-08 20:21:23 +02:00
"metadata": {},
2020-03-10 21:55:45 +01:00
"source": [
"The code below is very similar to the code above, with a few changes:\n",
"\n",
2022-02-19 06:19:26 +01:00
"* I added a BN layer after every Dense layer (before the activation function), except for the output layer.\n",
2020-03-10 21:55:45 +01:00
"* I changed the learning rate to 5e-4. I experimented with 1e-5, 3e-5, 5e-5, 1e-4, 3e-4, 5e-4, 1e-3 and 3e-3, and I chose the one with the best validation performance after 20 epochs.\n",
2022-02-19 06:19:26 +01:00
"* I renamed the run directories to run_bn_* and the model file name to `my_cifar10_bn_model`."
2020-03-10 21:55:45 +01:00
]
2016-09-27 23:31:21 +02:00
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 126,
2018-05-08 20:21:23 +02:00
"metadata": {},
2022-02-20 05:36:52 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 2.0377 - accuracy: 0.2523INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 32s 18ms/step - loss: 2.0374 - accuracy: 0.2525 - val_loss: 1.8766 - val_accuracy: 0.3154\n",
"Epoch 2/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.7874 - accuracy: 0.3542 - val_loss: 1.8784 - val_accuracy: 0.3268\n",
"Epoch 3/100\n",
"1407/1407 [==============================] - 20s 15ms/step - loss: 1.6806 - accuracy: 0.3969 - val_loss: 1.9764 - val_accuracy: 0.3252\n",
"Epoch 4/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 1.6111 - accuracy: 0.4229INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 24s 17ms/step - loss: 1.6112 - accuracy: 0.4228 - val_loss: 1.7087 - val_accuracy: 0.3750\n",
"Epoch 5/100\n",
"1402/1407 [============================>.] - ETA: 0s - loss: 1.5520 - accuracy: 0.4478INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 21s 15ms/step - loss: 1.5521 - accuracy: 0.4476 - val_loss: 1.6272 - val_accuracy: 0.4176\n",
"Epoch 6/100\n",
"1406/1407 [============================>.] - ETA: 0s - loss: 1.5030 - accuracy: 0.4659INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 23s 16ms/step - loss: 1.5030 - accuracy: 0.4660 - val_loss: 1.5401 - val_accuracy: 0.4452\n",
"Epoch 7/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.4559 - accuracy: 0.4812 - val_loss: 1.6990 - val_accuracy: 0.3952\n",
"Epoch 8/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 1.4169 - accuracy: 0.4987INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 21s 15ms/step - loss: 1.4168 - accuracy: 0.4987 - val_loss: 1.5078 - val_accuracy: 0.4652\n",
"Epoch 9/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.3863 - accuracy: 0.5123 - val_loss: 1.5513 - val_accuracy: 0.4470\n",
"Epoch 10/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.3514 - accuracy: 0.5216 - val_loss: 1.5208 - val_accuracy: 0.4562\n",
"Epoch 11/100\n",
"1407/1407 [==============================] - 16s 12ms/step - loss: 1.3220 - accuracy: 0.5314 - val_loss: 1.7301 - val_accuracy: 0.4206\n",
"Epoch 12/100\n",
"1404/1407 [============================>.] - ETA: 0s - loss: 1.2933 - accuracy: 0.5410INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 25s 18ms/step - loss: 1.2931 - accuracy: 0.5410 - val_loss: 1.4909 - val_accuracy: 0.4734\n",
"Epoch 13/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.2702 - accuracy: 0.5490 - val_loss: 1.5256 - val_accuracy: 0.4636\n",
"Epoch 14/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.2424 - accuracy: 0.5591 - val_loss: 1.5569 - val_accuracy: 0.4624\n",
"Epoch 15/100\n",
"<<12 more lines>>\n",
"Epoch 21/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.1174 - accuracy: 0.6066 - val_loss: 1.5241 - val_accuracy: 0.4828\n",
"Epoch 22/100\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.0978 - accuracy: 0.6128 - val_loss: 1.5313 - val_accuracy: 0.4772\n",
"Epoch 23/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.0844 - accuracy: 0.6198 - val_loss: 1.4993 - val_accuracy: 0.4924\n",
"Epoch 24/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.0677 - accuracy: 0.6244 - val_loss: 1.4622 - val_accuracy: 0.5078\n",
"Epoch 25/100\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.0571 - accuracy: 0.6297 - val_loss: 1.4917 - val_accuracy: 0.4990\n",
"Epoch 26/100\n",
"1407/1407 [==============================] - 19s 14ms/step - loss: 1.0395 - accuracy: 0.6327 - val_loss: 1.4888 - val_accuracy: 0.4896\n",
"Epoch 27/100\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.0298 - accuracy: 0.6370 - val_loss: 1.5358 - val_accuracy: 0.5024\n",
"Epoch 28/100\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.0150 - accuracy: 0.6444 - val_loss: 1.5219 - val_accuracy: 0.5030\n",
"Epoch 29/100\n",
"1407/1407 [==============================] - 16s 12ms/step - loss: 1.0100 - accuracy: 0.6456 - val_loss: 1.4933 - val_accuracy: 0.5098\n",
"Epoch 30/100\n",
"1407/1407 [==============================] - 20s 14ms/step - loss: 0.9956 - accuracy: 0.6492 - val_loss: 1.4756 - val_accuracy: 0.5012\n",
"Epoch 31/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 0.9787 - accuracy: 0.6576 - val_loss: 1.5181 - val_accuracy: 0.4936\n",
"Epoch 32/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 0.9710 - accuracy: 0.6565 - val_loss: 1.7510 - val_accuracy: 0.4568\n",
"Epoch 33/100\n",
"1407/1407 [==============================] - 20s 14ms/step - loss: 0.9613 - accuracy: 0.6628 - val_loss: 1.5576 - val_accuracy: 0.4910\n",
"Epoch 34/100\n",
"1407/1407 [==============================] - 19s 14ms/step - loss: 0.9530 - accuracy: 0.6651 - val_loss: 1.5087 - val_accuracy: 0.5046\n",
"Epoch 35/100\n",
"1407/1407 [==============================] - 19s 13ms/step - loss: 0.9388 - accuracy: 0.6701 - val_loss: 1.5534 - val_accuracy: 0.4950\n",
"Epoch 36/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 0.9331 - accuracy: 0.6743 - val_loss: 1.5033 - val_accuracy: 0.5046\n",
"Epoch 37/100\n",
"1407/1407 [==============================] - 19s 14ms/step - loss: 0.9144 - accuracy: 0.6808 - val_loss: 1.5679 - val_accuracy: 0.5028\n",
"157/157 [==============================] - 0s 2ms/step - loss: 1.4236 - accuracy: 0.5074\n"
]
},
{
"data": {
"text/plain": [
"[1.4236289262771606, 0.5073999762535095]"
]
},
2023-11-15 06:29:39 +01:00
"execution_count": 126,
2022-02-20 05:36:52 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-03-10 21:55:45 +01:00
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
2020-03-10 21:55:45 +01:00
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100, kernel_initializer=\"he_normal\"))\n",
" model.add(tf.keras.layers.BatchNormalization())\n",
2022-02-19 06:19:26 +01:00
" model.add(tf.keras.layers.Activation(\"swish\"))\n",
2022-02-20 05:36:52 +01:00
"\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
2020-03-10 21:55:45 +01:00
"\n",
2021-10-17 04:04:08 +02:00
"optimizer = tf.keras.optimizers.Nadam(learning_rate=5e-4)\n",
2020-03-10 21:55:45 +01:00
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"\n",
2022-02-19 06:19:26 +01:00
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=20,\n",
" restore_best_weights=True)\n",
"model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\"my_cifar10_bn_model\",\n",
" save_best_only=True)\n",
2020-03-10 21:55:45 +01:00
"run_index = 1 # increment every time you train the model\n",
2022-02-19 06:19:26 +01:00
"run_logdir = Path() / \"my_cifar10_logs\" / f\"run_bn_{run_index:03d}\"\n",
2021-10-17 04:04:08 +02:00
"tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)\n",
2020-03-10 21:55:45 +01:00
"callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]\n",
"\n",
"model.fit(X_train, y_train, epochs=100,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=callbacks)\n",
"\n",
"model.evaluate(X_valid, y_valid)"
]
2019-02-17 13:31:28 +01:00
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"* *Is the model converging faster than before?* Much faster! The previous model took 29 epochs to reach the lowest validation loss, while the new model achieved that same loss in just 12 epochs and continued to make progress until the 17th epoch. The BN layers stabilized training and allowed us to use a much larger learning rate, so convergence was faster.\n",
"* *Does BN produce a better model?* Yes! The final model is also much better, with 50.7% validation accuracy instead of 46.7%. It's still not a very good model, but at least it's much better than before (a Convolutional Neural Network would do much better, but that's a different topic, see chapter 14).\n",
"* *How does BN affect training speed?* Although the model converged much faster, each epoch took about 15s instead of 10s, because of the extra computations required by the BN layers. But overall the training time (wall time) to reach the best model was shortened by about 10%."
2016-09-27 23:31:21 +02:00
]
},
{
2017-06-05 18:48:03 +02:00
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2020-03-10 21:55:45 +01:00
"### d.\n",
"*Exercise: Try replacing Batch Normalization with SELU, and make the necessary adjustements to ensure the network self-normalizes (i.e., standardize the input features, use LeCun normal initialization, make sure the DNN contains only a sequence of dense layers, etc.).*"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 127,
2020-03-10 21:55:45 +01:00
"metadata": {
"scrolled": true
},
2022-02-20 05:36:52 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 1.9386 - accuracy: 0.3045INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 20s 13ms/step - loss: 1.9385 - accuracy: 0.3046 - val_loss: 1.8175 - val_accuracy: 0.3510\n",
"Epoch 2/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.7241 - accuracy: 0.3869INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.7241 - accuracy: 0.3869 - val_loss: 1.7677 - val_accuracy: 0.3614\n",
"Epoch 3/100\n",
"1407/1407 [==============================] - ETA: 0s - loss: 1.6272 - accuracy: 0.4263INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.6272 - accuracy: 0.4263 - val_loss: 1.6878 - val_accuracy: 0.4054\n",
"Epoch 4/100\n",
"1406/1407 [============================>.] - ETA: 0s - loss: 1.5644 - accuracy: 0.4492INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.5643 - accuracy: 0.4492 - val_loss: 1.6589 - val_accuracy: 0.4304\n",
"Epoch 5/100\n",
"1404/1407 [============================>.] - ETA: 0s - loss: 1.5080 - accuracy: 0.4712INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.5080 - accuracy: 0.4712 - val_loss: 1.5651 - val_accuracy: 0.4538\n",
"Epoch 6/100\n",
"1404/1407 [============================>.] - ETA: 0s - loss: 1.4611 - accuracy: 0.4873INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.4613 - accuracy: 0.4872 - val_loss: 1.5305 - val_accuracy: 0.4678\n",
"Epoch 7/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.4174 - accuracy: 0.5077 - val_loss: 1.5346 - val_accuracy: 0.4558\n",
"Epoch 8/100\n",
"1406/1407 [============================>.] - ETA: 0s - loss: 1.3781 - accuracy: 0.5175INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.3781 - accuracy: 0.5175 - val_loss: 1.4773 - val_accuracy: 0.4882\n",
"Epoch 9/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.3413 - accuracy: 0.5345 - val_loss: 1.5021 - val_accuracy: 0.4764\n",
"Epoch 10/100\n",
"1407/1407 [==============================] - 15s 10ms/step - loss: 1.3182 - accuracy: 0.5422 - val_loss: 1.5709 - val_accuracy: 0.4762\n",
"Epoch 11/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.2832 - accuracy: 0.5571 - val_loss: 1.5345 - val_accuracy: 0.4868\n",
"Epoch 12/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.2557 - accuracy: 0.5667 - val_loss: 1.5024 - val_accuracy: 0.4900\n",
"Epoch 13/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.2373 - accuracy: 0.5710 - val_loss: 1.5114 - val_accuracy: 0.5028\n",
"Epoch 14/100\n",
"1404/1407 [============================>.] - ETA: 0s - loss: 1.2071 - accuracy: 0.5846INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.2073 - accuracy: 0.5847 - val_loss: 1.4608 - val_accuracy: 0.5026\n",
"Epoch 15/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.1843 - accuracy: 0.5940 - val_loss: 1.4962 - val_accuracy: 0.5038\n",
"Epoch 16/100\n",
"1407/1407 [==============================] - 16s 12ms/step - loss: 1.1617 - accuracy: 0.6026 - val_loss: 1.5255 - val_accuracy: 0.5062\n",
"Epoch 17/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.1452 - accuracy: 0.6084 - val_loss: 1.5057 - val_accuracy: 0.5036\n",
"Epoch 18/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.1297 - accuracy: 0.6145 - val_loss: 1.5097 - val_accuracy: 0.5010\n",
"Epoch 19/100\n",
"1407/1407 [==============================] - 16s 12ms/step - loss: 1.1004 - accuracy: 0.6245 - val_loss: 1.5218 - val_accuracy: 0.5014\n",
"Epoch 20/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0971 - accuracy: 0.6304 - val_loss: 1.5253 - val_accuracy: 0.5090\n",
"Epoch 21/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.0670 - accuracy: 0.6345 - val_loss: 1.5006 - val_accuracy: 0.5034\n",
"Epoch 22/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.0544 - accuracy: 0.6407 - val_loss: 1.5244 - val_accuracy: 0.5010\n",
"Epoch 23/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0338 - accuracy: 0.6502 - val_loss: 1.5355 - val_accuracy: 0.5096\n",
"Epoch 24/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.0281 - accuracy: 0.6514 - val_loss: 1.5257 - val_accuracy: 0.5164\n",
"Epoch 25/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.4097 - accuracy: 0.6478 - val_loss: 1.8203 - val_accuracy: 0.3514\n",
"Epoch 26/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.3733 - accuracy: 0.5157 - val_loss: 1.5600 - val_accuracy: 0.4664\n",
"Epoch 27/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.2032 - accuracy: 0.5814 - val_loss: 1.5367 - val_accuracy: 0.4944\n",
"Epoch 28/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.1291 - accuracy: 0.6121 - val_loss: 1.5333 - val_accuracy: 0.4852\n",
"Epoch 29/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0734 - accuracy: 0.6317 - val_loss: 1.5475 - val_accuracy: 0.5032\n",
"Epoch 30/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0294 - accuracy: 0.6469 - val_loss: 1.5400 - val_accuracy: 0.5052\n",
"Epoch 31/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0081 - accuracy: 0.6605 - val_loss: 1.5617 - val_accuracy: 0.4856\n",
"Epoch 32/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0109 - accuracy: 0.6603 - val_loss: 1.5727 - val_accuracy: 0.5124\n",
"Epoch 33/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 0.9646 - accuracy: 0.6762 - val_loss: 1.5333 - val_accuracy: 0.5174\n",
"Epoch 34/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 0.9597 - accuracy: 0.6789 - val_loss: 1.5601 - val_accuracy: 0.5016\n",
"157/157 [==============================] - 0s 1ms/step - loss: 1.4608 - accuracy: 0.5026\n"
]
},
{
"data": {
"text/plain": [
"[1.4607702493667603, 0.5026000142097473]"
]
},
2023-11-15 06:29:39 +01:00
"execution_count": 127,
2022-02-20 05:36:52 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-03-10 21:55:45 +01:00
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
2020-03-10 21:55:45 +01:00
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100,\n",
2022-02-19 06:19:26 +01:00
" kernel_initializer=\"lecun_normal\",\n",
" activation=\"selu\"))\n",
2022-02-20 05:36:52 +01:00
"\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
2020-03-10 21:55:45 +01:00
"\n",
2021-10-17 04:04:08 +02:00
"optimizer = tf.keras.optimizers.Nadam(learning_rate=7e-4)\n",
2020-03-10 21:55:45 +01:00
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"\n",
2022-02-19 06:19:26 +01:00
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(\n",
" patience=20, restore_best_weights=True)\n",
"model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\n",
" \"my_cifar10_selu_model\", save_best_only=True)\n",
2020-03-10 21:55:45 +01:00
"run_index = 1 # increment every time you train the model\n",
2022-02-19 06:19:26 +01:00
"run_logdir = Path() / \"my_cifar10_logs\" / f\"run_selu_{run_index:03d}\"\n",
2021-10-17 04:04:08 +02:00
"tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)\n",
2020-03-10 21:55:45 +01:00
"callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]\n",
"\n",
"X_means = X_train.mean(axis=0)\n",
"X_stds = X_train.std(axis=0)\n",
"X_train_scaled = (X_train - X_means) / X_stds\n",
"X_valid_scaled = (X_valid - X_means) / X_stds\n",
"X_test_scaled = (X_test - X_means) / X_stds\n",
"\n",
"model.fit(X_train_scaled, y_train, epochs=100,\n",
" validation_data=(X_valid_scaled, y_valid),\n",
" callbacks=callbacks)\n",
"\n",
"model.evaluate(X_valid_scaled, y_valid)"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"This model reached the first model's validation loss in just 8 epochs. After 14 epochs, it reached its lowest validation loss, with about 50.3% accuracy, which is better than the original model (46.7%), but not quite as good as the model using batch normalization (50.7%). Each epoch took only 9 seconds. So it's the fastest model to train so far."
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
2020-03-10 21:55:45 +01:00
"### e.\n",
"*Exercise: Try regularizing the model with alpha dropout. Then, without retraining your model, see if you can achieve better accuracy using MC Dropout.*"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 128,
2018-05-08 20:21:23 +02:00
"metadata": {},
2022-02-20 05:36:52 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.8953 - accuracy: 0.3240INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 18s 11ms/step - loss: 1.8950 - accuracy: 0.3239 - val_loss: 1.7556 - val_accuracy: 0.3812\n",
"Epoch 2/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 1.6618 - accuracy: 0.4129INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.6618 - accuracy: 0.4130 - val_loss: 1.6563 - val_accuracy: 0.4114\n",
"Epoch 3/100\n",
"1402/1407 [============================>.] - ETA: 0s - loss: 1.5772 - accuracy: 0.4431INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.5770 - accuracy: 0.4432 - val_loss: 1.6507 - val_accuracy: 0.4232\n",
"Epoch 4/100\n",
"1406/1407 [============================>.] - ETA: 0s - loss: 1.5081 - accuracy: 0.4673INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 15s 10ms/step - loss: 1.5081 - accuracy: 0.4672 - val_loss: 1.5892 - val_accuracy: 0.4566\n",
"Epoch 5/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 1.4560 - accuracy: 0.4902INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.4561 - accuracy: 0.4902 - val_loss: 1.5382 - val_accuracy: 0.4696\n",
"Epoch 6/100\n",
"1401/1407 [============================>.] - ETA: 0s - loss: 1.4095 - accuracy: 0.5050INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.4094 - accuracy: 0.5050 - val_loss: 1.5236 - val_accuracy: 0.4818\n",
"Epoch 7/100\n",
"1401/1407 [============================>.] - ETA: 0s - loss: 1.3634 - accuracy: 0.5234INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.3636 - accuracy: 0.5232 - val_loss: 1.5139 - val_accuracy: 0.4840\n",
"Epoch 8/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.3297 - accuracy: 0.5377INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.3296 - accuracy: 0.5378 - val_loss: 1.4780 - val_accuracy: 0.4982\n",
"Epoch 9/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.2907 - accuracy: 0.5485 - val_loss: 1.5151 - val_accuracy: 0.4854\n",
"Epoch 10/100\n",
"1407/1407 [==============================] - 13s 10ms/step - loss: 1.2559 - accuracy: 0.5646 - val_loss: 1.4980 - val_accuracy: 0.4976\n",
"Epoch 11/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.2221 - accuracy: 0.5767 - val_loss: 1.5199 - val_accuracy: 0.4990\n",
"Epoch 12/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.1960 - accuracy: 0.5870 - val_loss: 1.5167 - val_accuracy: 0.5030\n",
"Epoch 13/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.1684 - accuracy: 0.5955 - val_loss: 1.5815 - val_accuracy: 0.5014\n",
"Epoch 14/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.1463 - accuracy: 0.6025 - val_loss: 1.5427 - val_accuracy: 0.5112\n",
"Epoch 15/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.1125 - accuracy: 0.6169 - val_loss: 1.5868 - val_accuracy: 0.5212\n",
"Epoch 16/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.0854 - accuracy: 0.6243 - val_loss: 1.6234 - val_accuracy: 0.5090\n",
"Epoch 17/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0668 - accuracy: 0.6328 - val_loss: 1.6162 - val_accuracy: 0.5072\n",
"Epoch 18/100\n",
"1407/1407 [==============================] - 15s 10ms/step - loss: 1.0440 - accuracy: 0.6442 - val_loss: 1.5748 - val_accuracy: 0.5162\n",
"Epoch 19/100\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 1.0272 - accuracy: 0.6477 - val_loss: 1.6518 - val_accuracy: 0.5200\n",
"Epoch 20/100\n",
"1407/1407 [==============================] - 13s 10ms/step - loss: 1.0007 - accuracy: 0.6594 - val_loss: 1.6224 - val_accuracy: 0.5186\n",
"Epoch 21/100\n",
"1407/1407 [==============================] - 15s 10ms/step - loss: 0.9824 - accuracy: 0.6639 - val_loss: 1.6972 - val_accuracy: 0.5136\n",
"Epoch 22/100\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 0.9660 - accuracy: 0.6714 - val_loss: 1.7210 - val_accuracy: 0.5278\n",
"Epoch 23/100\n",
"1407/1407 [==============================] - 13s 10ms/step - loss: 0.9472 - accuracy: 0.6780 - val_loss: 1.6436 - val_accuracy: 0.5006\n",
"Epoch 24/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 0.9314 - accuracy: 0.6819 - val_loss: 1.7059 - val_accuracy: 0.5160\n",
"Epoch 25/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 0.9172 - accuracy: 0.6888 - val_loss: 1.6926 - val_accuracy: 0.5200\n",
"Epoch 26/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 0.8990 - accuracy: 0.6947 - val_loss: 1.7705 - val_accuracy: 0.5148\n",
"Epoch 27/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 0.8758 - accuracy: 0.7028 - val_loss: 1.7023 - val_accuracy: 0.5198\n",
"Epoch 28/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 0.8622 - accuracy: 0.7090 - val_loss: 1.7567 - val_accuracy: 0.5184\n",
"157/157 [==============================] - 0s 1ms/step - loss: 1.4780 - accuracy: 0.4982\n"
]
},
{
"data": {
"text/plain": [
"[1.4779616594314575, 0.498199999332428]"
]
},
2023-11-15 06:29:39 +01:00
"execution_count": 128,
2022-02-20 05:36:52 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-03-10 21:55:45 +01:00
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
2020-03-10 21:55:45 +01:00
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100,\n",
2022-02-19 06:19:26 +01:00
" kernel_initializer=\"lecun_normal\",\n",
" activation=\"selu\"))\n",
2020-03-10 21:55:45 +01:00
"\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.AlphaDropout(rate=0.1))\n",
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
2020-03-10 21:55:45 +01:00
"\n",
2021-10-17 04:04:08 +02:00
"optimizer = tf.keras.optimizers.Nadam(learning_rate=5e-4)\n",
2020-03-10 21:55:45 +01:00
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"\n",
2022-02-19 06:19:26 +01:00
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(\n",
" patience=20, restore_best_weights=True)\n",
"model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\n",
" \"my_cifar10_alpha_dropout_model\", save_best_only=True)\n",
2020-03-10 21:55:45 +01:00
"run_index = 1 # increment every time you train the model\n",
2022-02-19 06:19:26 +01:00
"run_logdir = Path() / \"my_cifar10_logs\" / f\"run_alpha_dropout_{run_index:03d}\"\n",
2021-10-17 04:04:08 +02:00
"tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)\n",
2020-03-10 21:55:45 +01:00
"callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]\n",
"\n",
"X_means = X_train.mean(axis=0)\n",
"X_stds = X_train.std(axis=0)\n",
"X_train_scaled = (X_train - X_means) / X_stds\n",
"X_valid_scaled = (X_valid - X_means) / X_stds\n",
"X_test_scaled = (X_test - X_means) / X_stds\n",
"\n",
"model.fit(X_train_scaled, y_train, epochs=100,\n",
" validation_data=(X_valid_scaled, y_valid),\n",
" callbacks=callbacks)\n",
"\n",
"model.evaluate(X_valid_scaled, y_valid)"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"The model reaches 48.1% accuracy on the validation set. That's worse than without dropout (50.3%). With an extensive hyperparameter search, it might be possible to do better (I tried dropout rates of 5%, 10%, 20% and 40%, and learning rates 1e-4, 3e-4, 5e-4, and 1e-3), but probably not much better in this case."
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
2020-03-10 21:55:45 +01:00
"Let's use MC Dropout now. We will need the `MCAlphaDropout` class we used earlier, so let's just copy it here for convenience:"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 129,
2018-05-08 20:21:23 +02:00
"metadata": {},
2017-06-14 09:09:23 +02:00
"outputs": [],
2020-03-10 21:55:45 +01:00
"source": [
2021-10-17 04:04:08 +02:00
"class MCAlphaDropout(tf.keras.layers.AlphaDropout):\n",
2020-03-10 21:55:45 +01:00
" def call(self, inputs):\n",
" return super().call(inputs, training=True)"
]
2017-06-14 09:09:23 +02:00
},
{
2020-03-10 21:55:45 +01:00
"cell_type": "markdown",
2018-05-08 20:21:23 +02:00
"metadata": {},
2020-03-10 21:55:45 +01:00
"source": [
"Now let's create a new model, identical to the one we just trained (with the same weights), but with `MCAlphaDropout` dropout layers instead of `AlphaDropout` layers:"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 130,
2018-05-08 20:21:23 +02:00
"metadata": {},
2017-06-14 09:09:23 +02:00
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"mc_model = tf.keras.Sequential([\n",
2022-02-19 06:19:26 +01:00
" (\n",
" MCAlphaDropout(layer.rate)\n",
" if isinstance(layer, tf.keras.layers.AlphaDropout)\n",
" else layer\n",
" )\n",
2020-03-10 21:55:45 +01:00
" for layer in model.layers\n",
"])"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
2020-03-10 21:55:45 +01:00
"Then let's add a couple utility functions. The first will run the model many times (10 by default) and it will return the mean predicted class probabilities. The second will use these mean probabilities to predict the most likely class for each instance:"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 131,
2018-05-08 20:21:23 +02:00
"metadata": {},
2017-06-14 09:09:23 +02:00
"outputs": [],
2020-03-10 21:55:45 +01:00
"source": [
"def mc_dropout_predict_probas(mc_model, X, n_samples=10):\n",
" Y_probas = [mc_model.predict(X) for sample in range(n_samples)]\n",
" return np.mean(Y_probas, axis=0)\n",
"\n",
"def mc_dropout_predict_classes(mc_model, X, n_samples=10):\n",
" Y_probas = mc_dropout_predict_probas(mc_model, X, n_samples)\n",
2022-02-19 06:19:26 +01:00
" return Y_probas.argmax(axis=1)"
2020-03-10 21:55:45 +01:00
]
2017-06-14 09:09:23 +02:00
},
{
2020-03-10 21:55:45 +01:00
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2020-03-10 21:55:45 +01:00
"source": [
"Now let's make predictions for all the instances in the validation set, and compute the accuracy:"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 132,
2018-05-08 20:21:23 +02:00
"metadata": {},
2022-02-20 05:36:52 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.4984"
]
},
2023-11-15 06:29:39 +01:00
"execution_count": 132,
2022-02-20 05:36:52 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-14 09:09:23 +02:00
"source": [
2020-03-10 21:55:45 +01:00
"tf.random.set_seed(42)\n",
"\n",
"y_pred = mc_dropout_predict_classes(mc_model, X_valid_scaled)\n",
2022-02-19 06:19:26 +01:00
"accuracy = (y_pred == y_valid[:, 0]).mean()\n",
2020-03-10 21:55:45 +01:00
"accuracy"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
2022-02-20 05:36:52 +01:00
"We get back to roughly the accuracy of the model without dropout in this case (about 50.3% accuracy).\n",
2020-03-10 21:55:45 +01:00
"\n",
"So the best model we got in this exercise is the Batch Normalization model."
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
2020-03-10 21:55:45 +01:00
"### f.\n",
"*Exercise: Retrain your model using 1cycle scheduling and see if it improves training speed and model accuracy.*"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 133,
2018-05-08 20:21:23 +02:00
"metadata": {},
2017-06-14 09:09:23 +02:00
"outputs": [],
"source": [
2020-03-10 21:55:45 +01:00
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
2020-03-10 21:55:45 +01:00
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100,\n",
2022-02-19 06:19:26 +01:00
" kernel_initializer=\"lecun_normal\",\n",
" activation=\"selu\"))\n",
2020-03-10 21:55:45 +01:00
"\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.AlphaDropout(rate=0.1))\n",
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
2020-03-10 21:55:45 +01:00
"\n",
2022-02-19 06:19:26 +01:00
"optimizer = tf.keras.optimizers.SGD()\n",
2020-03-10 21:55:45 +01:00
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 134,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-20 05:36:52 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"352/352 [==============================] - 3s 8ms/step - loss: nan - accuracy: 0.1706\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEOCAYAAACKDawAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAu0klEQVR4nO3deXxU5bkH8N8TguwQE5IQdhDBIIoIolcqm0VxrVtdqN66VCrKtVbrVtuqvS612l4XXFFBiytaXBBQSwkoWhBEWYxBEFBZw04WAiTv/eOZ13NmMjOZCZMzkzm/7+eTz2xn5rxzkjznPc+7iTEGRESU/jKSXQAiIvIGAz4RkU8w4BMR+QQDPhGRTzDgExH5BAM+EZFPZCa7ANFkZWWZXr16JbsYaaO8vBytWrVKdjHSAo9lYiXzeJaUAJWVQHU1kJkJ9O+flGIkzOLFi7caY3LDvZbSAT8/Px+LFi1KdjHSRlFREYYPH57sYqQFHsvESubxHDYMKC4GSkuBrCygsYccEVkX6TWmdIjI14wBMnwSCX3yNYmIwqupAZo00fvpPvEAAz4R+ZoxDPhERL7gTukw4BMRpTGmdIiIfII1fCIin2AOn4jIJ9wpnXTHgE9EvsaUDhGRTzClQ0TkE+ylQ0TkE0zpEBH5BFM6REQ+wZQOEZFPMKVDROQTnB6ZiMgnmNIhIvIJNtoSEfkEc/hERD7BlA4RkU8wpUNE5BOs4RMR+QS7ZRIR+QRTOkREPlFTwxo+EZEv7NsHNG+e7FJ4gwGfiHytqooBn4jIFxjwiYh8wBhg716gRYtkl8QbDPhE5Fv79+stAz4RUZqrqtJbpnSIiNIcAz4RkU/s3au3TOkQEaU51vCJiHzCBvxmzZJbDq8w4BORbzHgNzARaSIiS0Rkutf7JiJyszl8BvyG8xsAxUnYLxFRkNAc/hVXJK8sXsj0cmci0hnAGQDuBXCjl/smIgrlTulUVQGZnkZE73n99R4GcAuANpE2EJGxAMYCQG5uLoqKijwpmB+UlZXxeCYIj2ViJet4LlqUA+AoLFu2CPv2lXm+f695FvBF5EwAW4wxi0VkeKTtjDHPAHgGAPr06WOGD4+4KcWpqKgIPJ6JwWOZWMk6nqWlejtkyCD06+f57j3nZQ5/CICzRWQtgFcBjBSRKR7un4goiG20ZT/8BDPG3G6M6WyM6Q7gYgD/NsZc6tX+iYhCsVsmEZFP+C3gJ6VN2hhTBKAoGfsmIrL8FvBZwyci32IOn4jIJ2wN/5BDklsOrzDgE5FvVVVpsBdJdkm8wYBPRL5VVeWf/D3AgE9EPrZ3LwM+EZEvVFX5p8EWYMAnIh9jSoeIyCcY8ImIfII5fCIin2AOn4jIJ5jSISLyCQZ8IiKfYA6fiMgn9u5lDp+IyBf27AHatUt2KbzDgE9EvrVrF9C2bbJL4R0GfCLypQMHgIoKBnwiorS3Z4/eMuATEaW53bv1lgGfiCjNMeATEfkEAz4RkU8w4BMR+YQN+OyHT0SU5ljDb2SM0R8iongx4DcyZ54JtGyZ7FIQUWO0ezcgArRqleySeCcz2QU4GDNmJLsERNRY7d4NtGkDZDTqam98Uvqr7tuXgexsYM0aYMcOYMqU8NsxrUNE8fLbPDpAigf8qqoM7NgBlJQAkyYBl10GfP+9vnbggLOdHSJNRBSr3bsZ8FNKdbUA0DPx6tX63MaNert+vbPdjh3O/TvvBF57zaMCElHK2bEDuPtuoLo6+nYM+CmmpqZ2wN+0SW/XrXO2cwf8CROA55/3qIBElHLefRe46y5g6dLo223f7q8++ECKB3x3Df/bb/W5zZv1du1aZzsb8Kuq9Jf4zTfelZGIUsuWLXpru12GU1GhJ4Sjj/amTKmiUQT87dudAB+thu9+rarKmzISUWopLdXbaG178+cD+/cDI0d6U6ZU0SgC/ooV+ssBnKBuG28BPSHU1Div1dQ4VwRuX3/NHj1E6S6WGv6cOUBmJvCTn3hTplSR0gHf5vCXLHGesymdrVuBrl31/tVXAxdc4DToAsDKlcGf9e67QGEh8MorDVhgIko6G/Cj1fAXLAAGDABat/amTKnCs4AvIs1FZKGIfCkiK0Tk7rreY2v4P/ygj7t2dWrx27YB3bsDTZro4w8+CO658803zlUBADz5pN5+8snBfhMiSmWxpHRKS4GOHb0pTyrxsoZfBWCkMaY/gGMAjBaRE6K9wQZ8a9Cg4Bp++/aavgGA8nKgqEiHSnfooIO0DjkEePppfc+sWbrdnDnAffcBe/cm8qsRUaqIJaWzYwdw6KHelCeVeBbwjSoLPGwa+ImaUXcH/Pz82jX8nJzgnPw77wB5ecCDDwJffqnPzZ6tDb7GaErnq6+AO+4Apk1L2FcjohQSWsOfMwcYNy54mx07gOxsb8uVCjydS0dEmgBYDKAXgMeNMQvCbDMWwFh9NPDH59u124OKii0oKzsM06d/hK1bh6C8/HsA3QAAmZk12LcvA61bl6FTp0W44YaOePjh3ti0qRSzZ28CcBT69fsexcVdAADPPrsZBQXFDfp9U01ZWRmKioqSXYy0wGOZWIk6npWVGaioGAoAKCnZiKKiEjz55GGYOrULBg78DJs2tcDxx29Defkw7Ny5BkVF6+r4xDRjjPH8B0AWgDkA+kXfbqBp0UInQT7nHGPeflvvz5yptw89ZCdINuaii/T2xBPNj04+2ZgTTjDm2Wf1tZISY+67z5jzzzcmK8uYfftMVI88YsykSdG3SaQZM4yprGy4z58zZ07DfbjP8FgmVqKO55o1Tkz4+c/1uauuch7n5hqzaZM+fvzxhOwy5QBYZCLE1KT00jHG7ARQBGB0Xdva6Y87dwb699f7//633ubkAC+/DDz0EPDii9ow++CDznsLCrTnjs3pdekC3H47MGYMsHNn3Q24f/878MwzcXyxg7B8OXD66cD48d7sjygd2f91wMnh79rlvLZzp3bjBpjDb1AikisiWYH7LQD8FMDX0d7Tps0BnH++3s/O1hx+Vpbm5QFttL3kEuCmm7SB9pprgBNPdN5vA/7mzToNaosW+vyoUbr9u+9G3ndlJfDdd9rf/803gc8/r9/3jpX945w3r2H3Q5TONmzQ23btnBz+zp16u3Wr9tyzvfkY8BtWAYA5IrIUwGcAPjTGTI/6hoLKH4N027baA6d/fyf45uTUscMCYN8+nW0zL895vk0bYPjw6AF/9Wq9MNywAbjiCr0yaEj2j9JdQyGi+BQHmuUGDXICvq3h28bcNWv01o+Ntl720llqjBlgjDnaGNPPGPPn2N6nt/aXY9M6QN0B3/az/fLL4IAP6GpZK1fqL7+8HPjVr4CBA7UXz9y5wFNP6XY1NfqH8/HHevKoj/vv155B0djLTPvHSUTx++orTf927Fg7pbNtm97agO/HGn7Kr3j1pz/p4KoxY/Tx2WcDjz6q99u3j/7eggK93bgRGDw4+DU7pHrBAr0CeP55TRcdc0zwgC2rogJYuDD+odjPPgv8/vd6/957I29nA77dF5duJIpfcTHQt69mBEJr+Ha6ZNbwU1hOjjaeNmumj92THWVlRX+vDfhA7Rp+v35A8+bAZ59p0D/qKG04HTtWXwvnyiu1T2883nxTb0WcQWLhuAM+Z/skil9NjRPw27TRGr4xta+a7TxbdcWPdJTyAT+UiI6avfHGutei7NRJG2cBIDc3+LWmTXUujYULNeXTv79eBk6Y4Myjba8gWrUC7rlHA/Frr8WX2ikp0VtjdL6fG24Iv5293AR0kjciis933+nVcWGhBvwDB7SWHzqqfu1avQKw07L4SaML+ABw6qnA3/5W93YtWwK33KL3beB3GzxYc/MbNgTPiy2ijairVukfRu/emoMfNEhPBnl5wCmnBE/BvH69pm/ctfjKSv3jOuoofTxtmtOl1M0YDfj2KsQu9kJEsbPdrPv3d1aysvNwuW3Z4s90DtBIA3487rwT+L//A37969qvnX66c9/dGAxot6527TTI27x99+7Af/6jl4gffgi89JLWIMaP14aiq6/W9JC1apUG8xEjnOfcE7wBerWQkaGzeHb
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-14 09:09:23 +02:00
"source": [
2020-03-10 21:55:45 +01:00
"batch_size = 128\n",
2022-02-19 06:19:26 +01:00
"rates, losses = find_learning_rate(model, X_train_scaled, y_train, epochs=1,\n",
" batch_size=batch_size)\n",
"plot_lr_vs_loss(rates, losses)"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 135,
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-11-03 13:43:56 +01:00
"outputs": [],
2020-03-10 21:55:45 +01:00
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
2020-03-10 21:55:45 +01:00
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100,\n",
2020-03-10 21:55:45 +01:00
" kernel_initializer=\"lecun_normal\",\n",
" activation=\"selu\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.AlphaDropout(rate=0.1))\n",
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
2020-03-10 21:55:45 +01:00
"\n",
2022-02-19 06:19:26 +01:00
"optimizer = tf.keras.optimizers.SGD(learning_rate=2e-2)\n",
2020-03-10 21:55:45 +01:00
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "code",
2023-11-15 06:29:39 +01:00
"execution_count": 136,
2017-06-21 15:35:47 +02:00
"metadata": {},
2022-02-20 05:36:52 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 2.0559 - accuracy: 0.2839 - val_loss: 1.7917 - val_accuracy: 0.3768\n",
"Epoch 2/15\n",
"352/352 [==============================] - 3s 8ms/step - loss: 1.7596 - accuracy: 0.3797 - val_loss: 1.6566 - val_accuracy: 0.4258\n",
"Epoch 3/15\n",
"352/352 [==============================] - 3s 8ms/step - loss: 1.6199 - accuracy: 0.4247 - val_loss: 1.6395 - val_accuracy: 0.4260\n",
"Epoch 4/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.5451 - accuracy: 0.4524 - val_loss: 1.6202 - val_accuracy: 0.4408\n",
"Epoch 5/15\n",
"352/352 [==============================] - 3s 8ms/step - loss: 1.4952 - accuracy: 0.4691 - val_loss: 1.5981 - val_accuracy: 0.4488\n",
"Epoch 6/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.4541 - accuracy: 0.4842 - val_loss: 1.5720 - val_accuracy: 0.4490\n",
"Epoch 7/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.4171 - accuracy: 0.4967 - val_loss: 1.6035 - val_accuracy: 0.4470\n",
"Epoch 8/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.3497 - accuracy: 0.5194 - val_loss: 1.4918 - val_accuracy: 0.4864\n",
"Epoch 9/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.2788 - accuracy: 0.5459 - val_loss: 1.5597 - val_accuracy: 0.4672\n",
"Epoch 10/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.2070 - accuracy: 0.5707 - val_loss: 1.5845 - val_accuracy: 0.4864\n",
"Epoch 11/15\n",
"352/352 [==============================] - 3s 10ms/step - loss: 1.1433 - accuracy: 0.5926 - val_loss: 1.5293 - val_accuracy: 0.4998\n",
"Epoch 12/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.0745 - accuracy: 0.6182 - val_loss: 1.5118 - val_accuracy: 0.5072\n",
"Epoch 13/15\n",
"352/352 [==============================] - 3s 10ms/step - loss: 1.0030 - accuracy: 0.6413 - val_loss: 1.5388 - val_accuracy: 0.5204\n",
"Epoch 14/15\n",
"352/352 [==============================] - 3s 10ms/step - loss: 0.9388 - accuracy: 0.6654 - val_loss: 1.5547 - val_accuracy: 0.5210\n",
"Epoch 15/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 0.8989 - accuracy: 0.6805 - val_loss: 1.5835 - val_accuracy: 0.5242\n"
]
}
],
2020-03-10 21:55:45 +01:00
"source": [
"n_epochs = 15\n",
2022-02-20 05:36:52 +01:00
"n_iterations = math.ceil(len(X_train_scaled) / batch_size) * n_epochs\n",
"onecycle = OneCycleScheduler(n_iterations, max_lr=0.05)\n",
2020-03-10 21:55:45 +01:00
"history = model.fit(X_train_scaled, y_train, epochs=n_epochs, batch_size=batch_size,\n",
" validation_data=(X_valid_scaled, y_valid),\n",
" callbacks=[onecycle])"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "markdown",
2017-06-21 15:35:47 +02:00
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"One cycle allowed us to train the model in just 15 epochs, each taking only 2 seconds (thanks to the larger batch size). This is several times faster than the fastest model we trained so far. Moreover, we improved the model's performance (from 50.7% to 52.0%)."
]
2016-09-27 23:31:21 +02:00
}
],
"metadata": {
"kernelspec": {
2023-11-15 09:23:37 +01:00
"display_name": "Python 3 (ipykernel)",
2016-09-27 23:31:21 +02:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2023-11-15 09:23:37 +01:00
"version": "3.10.13"
2016-09-27 23:31:21 +02:00
},
"nav_menu": {
"height": "360px",
"width": "416px"
},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
2020-04-06 09:13:12 +02:00
"nbformat_minor": 4
2016-09-27 23:31:21 +02:00
}