handson-ml/11_training_deep_neural_net...

4600 lines
618 KiB
Plaintext
Raw Normal View History

2016-09-27 23:31:21 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"**Chapter 11 Training Deep Neural Networks**"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"_This notebook contains all the sample code and solutions to the exercises in chapter 11._"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/11_training_deep_neural_networks.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/11_training_deep_neural_networks.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
" </td>\n",
"</table>"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 11:03:20 +01:00
"This project requires Python 3.7 or above:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"import sys\n",
2016-09-27 23:31:21 +02:00
"\n",
2022-02-19 11:03:20 +01:00
"assert sys.version_info >= (3, 7)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-28 23:41:27 +01:00
"And TensorFlow ≥ 2.8:"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"from packaging import version\n",
"import tensorflow as tf\n",
"\n",
"assert version.parse(tf.__version__) >= version.parse(\"2.8.0\")"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"As we did in previous chapters, let's define the default font sizes to make the figures prettier:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"plt.rc('font', size=14)\n",
"plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n",
"plt.rc('xtick', labelsize=10)\n",
"plt.rc('ytick', labelsize=10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And let's create the `images/deep` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"IMAGES_PATH = Path() / \"images\" / \"deep\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Vanishing/Exploding Gradients Problem"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABNlUlEQVR4nO3de3yO9f/A8ddn56PzDHNYclw5HxIxpxWRYyo5ppIi9ZUvRSchfSUmyuEnEXKMnJVkkiijEWE5m83ZZrOT3ffn98c1y2xj7N6ue9v7+Xhcj+2+Pp/7+rx3ue2967o+B6W1RgghhMhrDmYHIIQQonCSBCSEEMIUkoCEEEKYQhKQEEIIU0gCEkIIYQpJQEIIIUxhkwSklJqrlLqglDqQRXkvpdT+1O03pVQdW7QrhBAi/7LVFdA8oN0dyk8AgVrr2sBYYLaN2hVCCJFPOdniIFrrX5RS/nco/+2Wl7uA8rZoVwghRP5lxjOgF4GNJrQrhBDCjtjkCii7lFKtMBLQY3eoMxAYCODu7t6gQoUKeRRd9litVhwcpO/G3ch5urszZ86gtaZixYpmh2L38vLzdMN6g9MJpwGo6F4RZwfnPGnXFuzx/114ePglrbVPZmV5loCUUrWBOUB7rfXlrOpprWeT+oyoYcOGOjQ0NI8izJ6QkBBatmxpdhh2T87T3bVs2ZLo6GjCwsLMDsXu5dXn6XzceR6Z8wjFkouxrf82Hir9UK63aUv2+P9OKXUqq7I8SZVKqYrASqCP1jo8L9oUQoh7VdKjJB2qduCH3j/ku+STH9nkCkgptRhoCZRSSkUAHwDOAFrrmcD7QEngS6UUQIrWuqEt2hZCiJy6mnCVxJREynqX5YsOX5gdTqFhq15wPe9S/hLwki3aEkIIW4pNiuXJb58kNimWsEFhODnk6aPxQk3OtBCi0Eq4kUDnJZ3ZfXY3K55ZIcknj8nZFkIUSsmWZHos70HIyRAWdF1AlxpdzA6p0JEEJIQolN7f+j7r/1nPzA4z6VW7l9nhFEr5PgFdu3aNCxcucOPGjTxpr2jRohw6dChP2srPCuN5cnZ2pnTp0hQpUsTsUEQ2DG86nId8HqJPnT5mh1Jo5esEdO3aNc6fP4+fnx/u7u6k9rDLVbGxsXh7e+d6O/ldYTtPWmsSEhI4e/YsgCQhO6W1Zv6++fR8uCelPEpJ8jGZfQ2ZvUcXLlzAz88PDw+PPEk+QmRFKYWHhwd+fn5cuHDB7HBEFt7f+j4vrH6Bb/Z9Y3YognyegG7cuIG7u7vZYQiRxt3dPc9uB4t7M3HHRMZtH8dL9V7ipfoyKsQe5OsEBMiVj7Ar8nm0T1/u/pKRP42k58M9mdlxpvw72Yl8n4CEEOJOriRc4b2t7/FUtaeY32U+jg6OZockUuXrTghCCHE3JdxLsGPADvyL+ePsmH9mti4M5ApICFEgbTq6if/9+j8AapSqgZuTm8kRidtJAhL5TkREBK+//jqPPvpoWg/IkydPmh2WsCO/nPqFrku7suTgEhJuJJgdjsiCJCCR7xw9epRly5ZRvHhxmjdvbnY4ws7sPrubjt92xL+YPz/2/hF3Z+kpa68kAYl8p0WLFpw/f54NGzbQo0cPs8MRduTAhQO0W9SOUh6l+KnPT/h4ZroQp7ATkoBEvmNvSw4L+xF2LgxPZ09+6vsTfkX8zA5H3IX0ghNC5HsWqwVHB0d61+5N1xpd8XTxNDskkQ3yp6QQIl87F3eOBrMbsOnoJgBJPvmIJKB87vXXX+epp57Kdv0pU6ZQu3ZtrFZrLkYlRN64knCFxxc8ztErRyniKhPA5jeSgPKxY8eOMWvWLD744INsv2fQoEFcuHCB+fPn52JkQuS+2KRY2i9qz5HLR1j93GqaVmhqdkjiHkkCyseCg4OpU6cODRs2zPZ73N3d6du3L5MmTcrFyITIXYkpiTy1+Cn2RO5heY/ltKncxuyQxH2QBGSHrl+/zsiRI6lSpQouLi4opdJtn332GUlJSSxcuJDnn38+3Xt37NiRof7N7dVXXwXgueee4++//+a3334z48cTIsecHZypWaomC7ouoFP1TmaHI+6T9IKzM1prunXrxo4dOxg9ejQNGzZk586djBkzBn9/f3r27MmTTz7Jrl27iI6OzjAQs0aNGuzcuTPdvk8++YSNGzfyzDPPAFC3bl2KFCnCpk2baNo089sWWmssFstd41VK4eiY95M7rlixAoA9e/YAsHHjRnx8fPDx8SEwMDDP4xF5I8WawqX4S5TxKsOMjjPMDkfklNY6xxswF7gAHMiiXAGfA0eB/UD97By3QYMG+k7+/vvvO5bnhmvXruXq8b/44gutlNI//vhjuv1du3bVpUqV0larVWut9SeffKKVUjopKemOx3vvvfe0m5ub3rBhQ7r9jz32mA4KCsryfVu3btXAXbfAwMBM35/b5+le48lL2f1cBgYG6jp16uRuMAXE1q1btcVq0f1W9dMVp1TU0QnRZodkl7Zu3Wp2CBkAoTqL3/G2ugKaB0wHslpmsD1QNXV7BJiR+lXc5uuvvyYoKIigoKB0+2vUqMGaNWvS1jGJjIykSJEiuLi4ZHmst99+m88//5y1a9fStm3bdGU+Pj6Eh4dn+d4GDRqwe/fuu8Zr1rLbxudaFBZaa4ZuHMr8ffMZ03IMRd2Kmh2SsAGbJCCt9S9KKf87VOkMfJOaDXcppYoppcpqraNs0f6tcn+dqez9wr2f34/nz58nNDSUKVOmZCiLioqiTJkyaa8TExNxdXXN8lhvvfUWs2bNYsOGDbRs2TJDubu7OwkJWU/S6OXlRd26de8asyzsJfLCnBNz+PbMtwx/dDjvtXjP7HCEjeTVMyA/4MwtryNS990xAR05ciTDL89nnnmG1157jfj4eM6dO5fJtCzVbRBuzsXFxXH27NkM+ytUqICHhwfXrl0jKir9j79//34AypYtS3R0NOfPnwfAYrGwbt06Hn/8cZKTk3FxccHDw4OrV69y5MiRdMeoXLkyw4YNY968ecyePZuyZcumq1OlShUcHR05d+4cXl5eGd5fvbpx/latWkX37t3v+nM2atSIBQsWAODo6EiVKlUAuHjxIpGRkenqOjs7U7lyZQBOnz6dIQG6urri7+8PwMmTJ0lKSkpX7u7uTsWKFQE4fvx4hqWvPT09KV++PGBMWHr7Myxvb2/KlSsHQHh4eIarqKJFi6Yl+dvPC0Dx4sUpXbo0FouFo0ePZigvWbIkpUqVwmKxZJr0X331VZ599lnOnDlDnz59CAsLIyUlJa3uW2+9xVNPPcWRI0d45ZVXMrz/3XffpW3btoSFhfHmm29mKP/4449p2rQpv/32G6NGjcpQHhwcTN26dfnpp58YN25chvJZs2ZRvXp11q5dy2effZahfMGCBVSoUIGlS5cyY0bG5y8rVqygVKlSzJs3j3nz5mUo37BhAx4eHnz55ZcsW7YsQ3lISAgAkyZNYt26dWn7o8pGEV4jnFcavMLEoImMGzeOLVu2pHtvyZIl+e677wB45513MjwHLV++PAsXLgTgzTffJCwsLF15tWrVmD17NgADBw7McHegbt26BAcHA9C7d28iIiLSlT/66KNMmDABgO7du3P58uV05W3atOG994zE2b59+wyf/Y4dOzJ8+HCATD87t/7ee/LJJzOU9+/fn/79+xMTE5Otz97tcvuzl5W8SkCZ/Zmc6TWCUmogMBCMX1jR0dHpysPDwwkJCSExMREXFxdSUlLSlZ85E0HRokWxWCyZJoBixYpRpEgRbty4kSEBAJQoUQIvLy+SkpLSEsCtihcvjre3N4mJiVy4cCFDuY+PD+7u7sTHJ2SIDYwebhaLhfj4+AzlHh4eAOzbt48WLVqklc+ePZuYmBi6d+9OXFwczs7O+Pv7c+PGDSIiItJ+aWqteeWVV1i5ciWLFi3iwQcfzNBGXFwcDg4OnDp1ioCAgAzlsbGxgJGIlixZkq7MwcGB0qVLAxATE0NCQgKenp5px9BaExsbS5EihWNA4IEDB9K9TkxMJDY2lqSkpAyfW4CDBw8SEhLChQsXiI6OJiUlBa11Wt2
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 111\n",
"\n",
"import numpy as np\n",
"\n",
"def sigmoid(z):\n",
" return 1 / (1 + np.exp(-z))\n",
"\n",
"z = np.linspace(-5, 5, 200)\n",
"\n",
2016-09-27 23:31:21 +02:00
"plt.plot([-5, 5], [0, 0], 'k-')\n",
"plt.plot([-5, 5], [1, 1], 'k--')\n",
"plt.plot([0, 0], [-0.2, 1.2], 'k-')\n",
"plt.plot([-5, 5], [-3/4, 7/4], 'g--')\n",
"plt.plot(z, sigmoid(z), \"b-\", linewidth=2,\n",
" label=r\"$\\sigma(z) = \\dfrac{1}{1+e^{-z}}$\")\n",
2016-09-27 23:31:21 +02:00
"props = dict(facecolor='black', shrink=0.1)\n",
"plt.annotate('Saturating', xytext=(3.5, 0.7), xy=(5, 1), arrowprops=props,\n",
" fontsize=14, ha=\"center\")\n",
"plt.annotate('Saturating', xytext=(-3.5, 0.3), xy=(-5, 0), arrowprops=props,\n",
" fontsize=14, ha=\"center\")\n",
"plt.annotate('Linear', xytext=(2, 0.2), xy=(0, 0.5), arrowprops=props,\n",
" fontsize=14, ha=\"center\")\n",
"plt.grid(True)\n",
"plt.axis([-5, 5, -0.2, 1.2])\n",
"plt.xlabel(\"$z$\")\n",
"plt.legend(loc=\"upper left\", fontsize=16)\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"sigmoid_saturation_plot\")\n",
2016-09-27 23:31:21 +02:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Xavier and He Initialization"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"dense = tf.keras.layers.Dense(50, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\")"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"he_avg_init = tf.keras.initializers.VarianceScaling(scale=2., mode=\"fan_avg\",\n",
" distribution=\"uniform\")\n",
"dense = tf.keras.layers.Dense(50, activation=\"sigmoid\",\n",
" kernel_initializer=he_avg_init)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Nonsaturating Activation Functions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Leaky ReLU"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAADkCAYAAADeiPCXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAg9ElEQVR4nO3deXhU9fk28PuZJEASEkBCggEFDLuILBGoCA2kUMrqqwgIBamWaHHtJQj059JXARdc3rpc1pZaqlB/gtVWLVqhJULQBhITQJIQliIFQ1LAGLLQJDPP+wdkZJjJZCaznDMz9+e65oI5y/c888zJ3DkzJ2dEVUFERBRsFqMLICKiyMQAIiIiQzCAiIjIEAwgIiIyBAOIiIgMwQAiIiJDRBux0aSkJO3Zs6cRm3arpqYG8fHxRpcRMtgvzx04cABWqxUDBw40upSQEan7V00NUFoK2GxAp05Ar16AiCfrmbNf+fn5p1S1i6t5hgRQz549kZeXZ8Sm3crOzkZGRobRZYQM9stzGRkZqKysNOV+b1aRuH/985/AxInnw2f2bGD9eiDaw1dps/ZLRL5qbh7fgiMiMoGm8Dl71vvwCVUMICIig0Vi+AAMICIiQ0Vq+AAMICIiw0Ry+AAMICIiQ0R6+AAGnQXXkqqqKlRUVKChoSGo2+3QoQOKi4uDus1Qxn557rHHHoOqhlW/YmJikJycjMTERKNLCTkMn/NM95CrqqpQXl6Obt26ITY2FuLJCfB+cvbsWSQkJARte6GO/fKcxWJBY2MjBgwYYHQpfqGqqKurw4kTJwCAIeQFhs93TPcWXEVFBbp164a4uLighg8ReU5EEBcXh27duqGiosLockIGw8eR6QKooaEBsbGxRpdBRB6IjY0N+lvloYrh48x0AQSARz5EIYI/q55h+LhmygAiIgoXDJ/mMYCIiAKE4eMeA4iIKAAYPi1jAIWYZcuWYcKECUaX4VfffPMNUlJScPjwYY+WnzlzJp5//vkAVxV83vYh0MK1z8HA8PEMA8jPxo8fjwULFgRs/MLCQgwZMsSnMcaPHw8RgYggJiYGffr0wW9/+1uvx5k6darbx9q5c2esWbPGafojjzyCHj162O+vXr0akydPRlpamkfbfeyxx7By5Up8++23XtdsZt72IdDCtc+BxvDxHAPIzwoKCjB8+PCAjb9nzx4MHTrUpzEKCgqwevVqlJWV4dChQ5gzZw7uvPNOFBQUeDXO3r17m32sR48exZkzZ5Cenu40Ly8vz75ebW0t1q5dizvuuMPj7V5zzTW46qqrsH79eq/qNbPW9CHQwrHPgcbw8Q4DyI8OHz6MysrKZl+UT5w4gQULFqBz587o2LEjbr75ZpSXlzsss3LlSgwePBjt27dHly5dsHDhQtTV1QEATp48ifLycvsRUE1NDebMmYNhw4bh6NGjAIDu3bs7vW2yb98+tGvXDkVFRfYaJ02ahK5du6JHjx648847oarYv3+/x/W29Fjz8/MhIhg2bJjLeU3BtHnzZlgsFowePdphmWeeecZ+lHbx7dFHHwUATJ8+HW+99ZbLbfvi+PHjEBG8/fbbGD9+POLi4nDttdeipKQEeXl5GDt2LOLi4jBixAgcO3bMvp675w0APv74YwwdOhRfffXdd3Pdf//9SEtLQ3l5ebN9AM73KzMzE7Gxsejduze2b9+OjRs32pdtqVfN8WS9QPU5HDF8WkFVg34bPny4NqeoqKjZeYFWVVXl0/pvv/22WiwWPXv2rNO8I0eOaHJysi5fvlyLioq0oKBAx44dqzfeeKPDco899pjm5OTo0aNHdcuWLZqamqqrV69WVdXNmzdrbGysNjY2aklJiQ4cOFDnzZuntbW19vVnzpypc+bMcRhz/Pjxevfdd9trTExM1MbGRlVVLSsr0zlz5qjFYtH9+/d7XK+7x6qqumLFCu3bt6/T9KNHjyoA/fjjj1VV9b777tMJEyY4LVdVVaVlZWX224MPPqhdu3bVgwcPqqrqRx99pDExMQ6PvcmqVas0Pj7e7W379u0u6/7ggw8UgH7/+9/X7Oxs3bt3r/br109Hjhyp48aN0x07dmhhYaH26tVL77//fvt67p43VdXi4mIdOHCg/vSnP1VV1TVr1miXLl20tLTUbR927dqlsbGx+sQTT2hpaanOmzdPMzIydPDgwfr3v//do141x5P13PW5SaB+Zrdt2xaQcQPh889VExJUAdXZs1UbGoJfg1n7BSBPm8mCkAggwJibtx566CHt37+/y3kTJ07UFStWOEzbsmWLJiQkuB1z0aJFumDBAlVVXb16tY4YMULfeecdveyyy/SFF15wWv65557TtLQ0+/333ntPO3XqpKdOnbLXaLFYND4+XmNjYxWAtmnTxmmslup96KGHXAbMxevPnTvXafo777yjAOz1zJgxw/74mvPUU09pamqqlpSU2Kft2bNHAeihQ4eclj99+rQePHjQ7a25F9SVK1dqhw4d9OTJk/Zp99xzjyYlJdlrVlVduHChzpo1q9maL37eVFVLSkr0tdde0+joaH3yySe1ffv2umvXLvv85vowZswYh+389a9/VYvFohkZGS6366pXnmhuPXd9bhLpAWSG8FE1b78YQEEKoMzMTJ03b57T9K+++koBaGxsrMNv4e3atdPLLrvMvtyxY8f03nvv1UGDBmmnTp00Pj5eo6Oj9Re/+IWqqs6aNUs7deqkiYmJmp2d7bKGnTt3KgA9ffq0njt3TtPS0hzCJTMzU7OysvTgwYOan5+vP/zhD+1HR97Um5mZ6fYFOCkpSZ9//nmn6StWrNAePXrY70+cOFGzsrKaHWf16tWampqqBw4ccJheWlqqAHTfvn3NrtsaM2fOdHoOp0+f7lTj+PHjddmyZara8vOmej6AvvzyS/3e976nUVFRunnzZofxXPWhrKxMATi8sGzZskUB6I4dO5xqb65XLXG3nid9juQAMkv4qJq3X+4CKCTeoVQNznZ8vbpzQUEBHn74YafphYWFSExMRH5+vtO8Nm3aAABOnz6N6667DmPHjsWzzz6L7t27IyoqCtddd539M5/CwkLcdNNN+OMf/4jTp0+7rGH48OFo06YN8vLyUFBQgOjoaNx9990ONc6fPx+9e/cGALz22mvo1asX7rzzTlxzzTUe11tQUIAlS5a4rOH48eM4deoUBg0a5DRvy5YtDp9zJCUl4ZtvvnE5zqpVq/DrX/8an376qb3eJmfOnAEAdOnSxWm91atXY/Xq1S7HbPLRRx9hzJgxTtP37NmD++67z2FaQUEBHn/8caflsrKyPHremuTm5mLPnj1QVaSkpDjMc9WHpq9uuO666+zTDhw4gH79+uGGG25wWNZdr9xpaT13fY50/MzHD5pLJk9vANoB2AVgD4D9AP5vS+uE42dAR44cUQAuj0w2b96sUVFRzX5eoqr6hz/8QTt06KA2m80+bd26dQpAS0tLtaamRi0Wi+7atUs3bNig8fHxmp+f73KskSNH6j333KMJCQn6wQcfONV46XrDhg3TpUuXelxv0ziX/hZ/6fy//OUvDtNzc3NVRBymr1mzRq+++mqnMR5//HG94oormn3rZ+3atZqamupyXmvfgquurlaLxaI5OTkOYwHQwsJC+7Rjx44pAD1w4ECLz1uTP//5z5qQkKC/+93v9KabbtKJEyc6bNtVH959910VEXutVVVVevnll+vQoUO96lVzPFnPXZ+bROIRkJmOfJqYtV8I5FtwAARA+wv/jwGQC2CUu3XCMYA2bdqkAHTnzp26b98++62oqEjPnDmjSUlJeuONN+oXX3yhhw4d0k8++UQXL16sVqtVVVU//PBDjYqK0nfffVcPHjyoL774onbt2lUTEhLUZrPpZ599plFRUfYXo4cfflhTU1P1+PHjTrU88MADKiJOL3KbNm1Si8Xi9OK7bNky7d27t/1+S/U2PdYtW7Y4PVZVVZvNpgMGDNBBgwbpJ598onv27NHXX39dU1NTddq0aQ7b3rt3r1osFofPV1auXKmdO3fWnTt3OnxIXldXZ1/mtttu09tvv701T1WzPvvsM6c
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 112\n",
"\n",
"def leaky_relu(z, alpha):\n",
" return np.maximum(alpha * z, z)\n",
"\n",
"z = np.linspace(-5, 5, 200)\n",
"plt.plot(z, leaky_relu(z, 0.1), \"b-\", linewidth=2, label=r\"$LeakyReLU(z) = max(\\alpha z, z)$\")\n",
"plt.plot([-5, 5], [0, 0], 'k-')\n",
"plt.plot([0, 0], [-1, 3.7], 'k-')\n",
"plt.grid(True)\n",
"props = dict(facecolor='black', shrink=0.1)\n",
"plt.annotate('Leak', xytext=(-3.5, 0.5), xy=(-5, -0.3), arrowprops=props,\n",
" fontsize=14, ha=\"center\")\n",
"plt.xlabel(\"$z$\")\n",
"plt.axis([-5, 5, -1, 3.7])\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.legend()\n",
"\n",
"save_fig(\"leaky_relu_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"leaky_relu = tf.keras.layers.LeakyReLU(alpha=0.2) # defaults to alpha=0.3\n",
"dense = tf.keras.layers.Dense(50, activation=leaky_relu,\n",
" kernel_initializer=\"he_normal\")"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-12-16 11:22:41.636848: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
],
"source": [
"model = tf.keras.models.Sequential([\n",
" # [...] # more layers\n",
" tf.keras.layers.Dense(50, kernel_initializer=\"he_normal\"), # no activation\n",
" tf.keras.layers.LeakyReLU(alpha=0.2), # activation as a separate layer\n",
" # [...] # more layers\n",
"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ELU"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Implementing ELU in TensorFlow is trivial, just specify the activation function when building each layer, and use He initialization:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"dense = tf.keras.layers.Dense(50, activation=\"elu\",\n",
" kernel_initializer=\"he_normal\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"### SELU"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By default, the SELU hyperparameters (`scale` and `alpha`) are tuned in such a way that the mean output of each neuron remains close to 0, and the standard deviation remains close to 1 (assuming the inputs are standardized with mean 0 and standard deviation 1 too, and other constraints are respected, as explained in the book). Using this activation function, even a 1,000 layer deep neural network preserves roughly mean 0 and standard deviation 1 across all layers, avoiding the exploding/vanishing gradients problem:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAD+CAYAAAB8xdFqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA5cElEQVR4nO3deXgUVfbw8e/NDmQjBEKIbIogIMgmIo4QDJuIuKCijAqjA4KM44bj6wY6DKg4MuMAIowO4k9Fx3EQVEBhIICKQoCIrBJWCSAQCCQkIUuf949KmjTphIR0Ut3J+TxPPemuulV96nanT1fVrXuNiKCUUkpVNz+7A1BKKVU7aQJSSillC01ASimlbKEJSCmllC00ASmllLKFJiCllFK2CLDjRaOjo6VFixZ2vHSZzpw5Q7169ewOw2dofZXfzp07KSgooF27dnaH4jO88vPlcICfX4lZKSmQkQEBAdCmDYSEVH9oXllfwIYNG46LSEN3y2xJQC1atCApKcmOly5TYmIi8fHxdofhM7S+yi8+Pp709HSv/Nx7K6/7fE2YAAsXwpIlEBMDQHY2DBkCmzZZs1asALt+Y3hdfRUyxuwvbZmeglNKqQuZOhUmTYKtW2HDBuBc8lm+3P7k46s0ASmlVFnefBOefhqMgXffhUGDNPl4SKUTkDEmxBizzhjzozFmqzHmJU8EppRStps3D8aNsx7PmgX33qvJx4M8cQ3oLHCDiGQaYwKBb4wxS0Tkew9sWyml7PHJJ/DAA9bjv/4VHnpIk4+HVToBidWbaWbh08DCSXs4VUr5trVrrSZuL74ITz6pyacKGE/0hm2M8Qc2AK2AmSLytJsyo4HRADExMV0/+uijUrcXFBRESEgIxphKx1YRIlLtr+nLtL7K7+TJk4gIUVFRdofiM6rq8yUi5OTkkJube6GCRK1bx4nu3Tmb689zz13Jhg1R1K+fy7RpybRokeXx2CojMzOT0NBQu8MooU+fPhtEpJu7ZR5JQM6NGRMJLAAeEZEtpZXr1q2blNYc9cCBAxhjiImJITAwsFq/4DIyMggLC6u21/N1Wl/lt3PnTvLz82nfvr3dofiMqvh8iQh5eXn8+uuviAjNmjVzLbBxIzRtCg3P3bbiK0c+XtwMu9QE5NFWcCKSDiQCAy92G2fOnCEuLo6goCD9da2U8ihjDEFBQcTFxXHmzBnXhcnJkJAAvXvDsWOA7yQfX+WJVnANC498MMbUAfoCOyoVlJ+2DldKVZ0S3zHbt0P//pCeDm3bQv36mnyqgSdawcUC8wqvA/kB/xaRLzywXaWUqnp79kDfvtZRz8CB8OGHZOcFaPKpBp5oBbcZ6OyBWJRSqnodPGiddjt0yDr19umnZDuCNflUEz3XpZSqnRwO68hn3z7o3h0+/5xsU1eTTzXSBFQDnTx5kpiYGHbv3l2u8nfccQfTpk2r4qjKVtGYi/Tu3ZtOnTrRqVMnAgICWLt2bRVFWFJ5623kyJE89NBDzucOh4OHHnqIBg0aYIwhMTGxCqP0DiNHjmTw4MF2h+HKzw8GD4aOHWHJErIDwjT5VDcRqfapa9euUppt27aVuqyqnT59ulLrjxgxQrBuwnWZrrnmGufym266qdT1e/fuLePGjSsxf+7cuVKvXr1yxzF+/HgZOXJkuctv3rxZ6tevL+np6eVeR6Ty9VVcRWM+30svvSQPP/ywx+IREVm1apXcfPPN0qRJEwFk7ty5LsvLW2/p6emybt062bJli4iIfP755xIYGCjffvutHD58WM6ePevRuCti5syZ0qJFCwkODpYuXbrI6tWrq+R1LvTZd8eTny93tm3bJuJwiGRkSFaWSN++IiASEyOydWuVvnSVWLlypd0huAUkSSm5QI+APKxv374cPnzYZVq8eHG1vX5WVhZvv/02Dz74YLnX6dChA5deeinvv/9+FUZWuouJubjp06eza9cuZsyY4dG4MjMzufLKK3njjTeoU6dOieXlrbeIiAjCw8Odz1NSUoiNjaVnz540btyYoKAgj8Z9vtOnT5Oenl5i/scff8yjjz7Ks88+y6ZNm+jZsyc33ngjBw4cqNJ4bFVQAPv3Q16e9dwYsv1D9cjHJpqAPCw4OJjGjRu7TNV59/vixYvx8/Pjuuuuc5k/depUjDElpgkTJgAwZMgQ5s+fXyUxbdiwgYSEBOrUqUOrVq1YvXo1//73v50xlhYzQGpqKvfffz8NGjQgMjKSoUOH8uuvvzqXv/fee6xYsYK5c+d6/L6xQYMGMWXKFO64445Sbw0oT70VPwU3cuRIHn/8cecN12UNzHih96wsBQUFfPXVVwwfPpzGjRvz448/ligzbdo0Ro4cyahRo2jbti3Tp08nNjaWWbNmXXD7xYkIU6dO5bLLLqNOnTp06NChzKS8evVqevToQWhoKBEREVxzzTVs2bKlxPY6duxYru1BOevK4bBavB07Zv1F7/OxmyagGmbNmjV07dq1xJfx2LFjXY7KnnzySRo3bsz9998PQPfu3Vm3bh3Z2dkltjllyhRCQ0NLTLGxsc7Ha9ascRvP+vXruf766+nTpw+bN2+mR48eTJw4kcmTJzNp0qQyY967dy9dunQhLi6Ob775hsTERI4fP86YMWMAWLBgAR988AEfffQRAQG2jK1YZr2588YbbzBhwgQuueQSDh8+zPr160ste6H3zJ2tW7fypz/9iWbNmjFs2DDq1avH0qVL6dWrl0u53NxcNmzYQP/+/V3m9+/fn++++65c+1Lk+eef55133mHmzJls27aNZ555hoceeogvv/yyRNn8/HxuueUWfvOb3/Djjz/yww8/8Oijj+Lv719ie6+//voFt1fkgnUlAnv3wqlT1rClzZohosnHbvb819ZgS5cuLdEf07hx43j11Ver5fX3799PbGxsiflhYWHObk1effVV5s+fT2JiIq1atQKgSZMm5OXlcejQIS677DKXdceMGcNdd91VYpvF+56Ki4tzG8+TTz7JzTffzPPPPw/A8OHDufnmm+nVqxc33HBDmTGPGTOGBx98kClTpjjnvfDCC9x+++0APPDAAzRs2JBrrrkGgIkTJ3LbbbeVUTueV1a9uRMREUFYWBj+/v40bty4zLIXes+KpKWl8cEHH/Dee++xefNmBg4cyN///neGDBlCcHCw220fP36cgoICYgpH9iwSExPD8uXLL7gfRc6cOcO0adP4+uuvuf766wFo2bIl69atY+bMmdx0000u5YtOB958883O+rriiivcbq9Tp06EhYWVub1y1ZWI1dLt5Enw94fWrXEE1+HoUU0+dvOJBFR9PfK49jt1Md3k9erVizlz5rjMi4yMrERMFZOdnV3iS6W4l19+mRkzZrBy5Upat27tnF90jcPdL/moqCi3pxEv1FfXkSNHWLNmDStXrnTOCwoKwuFwOI9+Sov5wIEDfP3116xZs4Z//OMfzvkFBQXUrVsXsFrOXcjzzz/P5MmTyyyzcuXKi+5Dq6x685TS3rMi06dP56WXXqJnz57s2rWL5s2bl3vb5x91SgU7AN22bRs5OTkMHDjQZb28vDy3pxejoqIYOXIkAwYMICEhgYSEBO68806aNm16Uds7X4m6EoEDByAtzWr1dvnlOELqkpICOTmafOzmEwnIl9StW7fEL9TyCg8P59SpUyXmp6enExERUa5tREdHl/rFPHnyZN566y1WrVpVIsYTJ04A0LBYJ4xFpkyZ4nIU4s6SJUucv4CLbN++HYCrr77aOW/nzp20adOG3/zmN2XGnJycTHh4OBsKhz8uriIX7R977DHuvffeMsuU6JCyAsqqN08o6z0rMnr0aAIDA3nvvfdo3749t912G/fddx8JCQkup7aKi46Oxt/fnyNHjrjMP3r0aJk/YM7ncDgA+Pzzz0vUY2BgoNt15s6dy2OPPcbSpUtZtGgRzz33HJ999hkDBgxw2V5UVJTL2YTStlfEbV2dOmVd8zEGWrXCUTeUlBQ4fdrKR5p87OUTCciDHXaXye7endu0acPixYtL/ArduHEjbdq0cSm7adMmHnvsMVJTUxk/fjwrVqzg73//O50
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 113\n",
"\n",
"from scipy.special import erfc\n",
"\n",
"# alpha and scale to self normalize with mean 0 and standard deviation 1\n",
"# (see equation 14 in the paper):\n",
"alpha_0_1 = -np.sqrt(2 / np.pi) / (erfc(1 / np.sqrt(2)) * np.exp(1 / 2) - 1)\n",
"scale_0_1 = (\n",
" (1 - erfc(1 / np.sqrt(2)) * np.sqrt(np.e))\n",
" * np.sqrt(2 * np.pi)\n",
" * (\n",
" 2 * erfc(np.sqrt(2)) * np.e ** 2\n",
" + np.pi * erfc(1 / np.sqrt(2)) ** 2 * np.e\n",
" - 2 * (2 + np.pi) * erfc(1 / np.sqrt(2)) * np.sqrt(np.e)\n",
" + np.pi\n",
" + 2\n",
" ) ** (-1 / 2)\n",
")\n",
"\n",
"def elu(z, alpha=1):\n",
" return np.where(z < 0, alpha * (np.exp(z) - 1), z)\n",
"\n",
"def selu(z, scale=scale_0_1, alpha=alpha_0_1):\n",
" return scale * elu(z, alpha)\n",
"\n",
"z = np.linspace(-5, 5, 200)\n",
"plt.plot(z, elu(z), \"b-\", linewidth=2, label=r\"ELU$_\\alpha(z) = \\alpha (e^z - 1)$ if $z < 0$, else $z$\")\n",
"plt.plot(z, selu(z), \"r--\", linewidth=2, label=r\"SELU$(z) = 1.05 \\, $ELU$_{1.67}(z)$\")\n",
"plt.plot([-5, 5], [0, 0], 'k-')\n",
"plt.plot([-5, 5], [-1, -1], 'k:', linewidth=2)\n",
"plt.plot([-5, 5], [-1.758, -1.758], 'k:', linewidth=2)\n",
"plt.plot([0, 0], [-2.2, 3.2], 'k-')\n",
"plt.grid(True)\n",
"plt.axis([-5, 5, -2.2, 3.2])\n",
"plt.xlabel(\"$z$\")\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.legend()\n",
"\n",
"save_fig(\"elu_selu_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using SELU is straightforward:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"dense = tf.keras.layers.Dense(50, activation=\"selu\",\n",
" kernel_initializer=\"lecun_normal\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Extra material an example of a self-regularized network using SELU**\n",
"\n",
"Let's create a neural net for Fashion MNIST with 100 hidden layers, using the SELU activation function:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42)\n",
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[28, 28]))\n",
"for layer in range(100):\n",
" model.add(tf.keras.layers.Dense(100, activation=\"selu\",\n",
" kernel_initializer=\"lecun_normal\"))\n",
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's train it. Do not forget to scale the inputs to mean 0 and standard deviation 1:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"fashion_mnist = tf.keras.datasets.fashion_mnist.load_data()\n",
"(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist\n",
"X_train, y_train = X_train_full[:-5000], y_train_full[:-5000]\n",
"X_valid, y_valid = X_train_full[-5000:], y_train_full[-5000:]\n",
"X_train, X_valid, X_test = X_train / 255, X_valid / 255, X_test / 255"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"class_names = [\"T-shirt/top\", \"Trouser\", \"Pullover\", \"Dress\", \"Coat\",\n",
" \"Sandal\", \"Shirt\", \"Sneaker\", \"Bag\", \"Ankle boot\"]"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"pixel_means = X_train.mean(axis=0, keepdims=True)\n",
"pixel_stds = X_train.std(axis=0, keepdims=True)\n",
"X_train_scaled = (X_train - pixel_means) / pixel_stds\n",
"X_valid_scaled = (X_valid - pixel_means) / pixel_stds\n",
"X_test_scaled = (X_test - pixel_means) / pixel_stds"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-12-16 11:22:44.499697: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"1719/1719 [==============================] - 13s 7ms/step - loss: 1.3735 - accuracy: 0.4548 - val_loss: 0.9599 - val_accuracy: 0.6444\n",
"Epoch 2/5\n",
"1719/1719 [==============================] - 12s 7ms/step - loss: 0.7783 - accuracy: 0.7073 - val_loss: 0.6529 - val_accuracy: 0.7664\n",
"Epoch 3/5\n",
"1719/1719 [==============================] - 12s 7ms/step - loss: 0.6462 - accuracy: 0.7611 - val_loss: 0.6048 - val_accuracy: 0.7748\n",
"Epoch 4/5\n",
"1719/1719 [==============================] - 11s 6ms/step - loss: 0.5821 - accuracy: 0.7863 - val_loss: 0.5737 - val_accuracy: 0.7944\n",
"Epoch 5/5\n",
"1719/1719 [==============================] - 12s 7ms/step - loss: 0.5401 - accuracy: 0.8041 - val_loss: 0.5333 - val_accuracy: 0.8046\n"
]
}
],
"source": [
"history = model.fit(X_train_scaled, y_train, epochs=5,\n",
" validation_data=(X_valid_scaled, y_valid))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The network managed to learn, despite how deep it is. Now look at what happens if we try to use the ReLU activation function instead:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[28, 28]))\n",
"for layer in range(100):\n",
" model.add(tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"))\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"1719/1719 [==============================] - 12s 6ms/step - loss: 1.6932 - accuracy: 0.3071 - val_loss: 1.2058 - val_accuracy: 0.5106\n",
"Epoch 2/5\n",
"1719/1719 [==============================] - 11s 6ms/step - loss: 1.1132 - accuracy: 0.5297 - val_loss: 0.9682 - val_accuracy: 0.5718\n",
"Epoch 3/5\n",
"1719/1719 [==============================] - 10s 6ms/step - loss: 0.9480 - accuracy: 0.6117 - val_loss: 1.0552 - val_accuracy: 0.5102\n",
"Epoch 4/5\n",
"1719/1719 [==============================] - 10s 6ms/step - loss: 0.9763 - accuracy: 0.6003 - val_loss: 0.7764 - val_accuracy: 0.7070\n",
"Epoch 5/5\n",
"1719/1719 [==============================] - 11s 6ms/step - loss: 0.7892 - accuracy: 0.6875 - val_loss: 0.7485 - val_accuracy: 0.7054\n"
]
}
],
"source": [
"history = model.fit(X_train_scaled, y_train, epochs=5,\n",
" validation_data=(X_valid_scaled, y_valid))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not great at all, we suffered from the vanishing/exploding gradients problem."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### GELU, Swish and Mish"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAADsCAYAAAAy23L6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABbDElEQVR4nO2dd3hURdfAf5PeSCMkEECKIUiXYoiAEHpVFLChIIIiCIIKSrGAH0XBV1+VgIKgqPBSFUSqoCRUKUoLHamBQEhICOltvj9udrObbJINKZsyv+e5z+6de+7Mmbt377kzc+aMkFKiUCgUCkVpY2VpBRQKhUJROVEGSKFQKBQWQRkghUKhUFgEZYAUCoVCYRGUAVIoFAqFRVAGSKFQKBQWocgGSAhRWwixUwhxWghxUggx3oSMEEJ8JYS4IIQ4LoRoVdRyFQqFQlG+sSmGPNKBCVLKf4QQVYC/hRDbpZSnDGR6Aw2ytrbA11mfCoVCoaikFLkFJKWMkFL+k/X9HnAaqJlDrD/wo9T4C3AXQtQoatkKhUKhKL8URwtIjxCiLtASOJDjUE3gmsF+eFZahIk8RgIjARwcHFo/8MADxalimSczMxMrq8o1NKfqXPG5du0aUkoq2/8ZKt9vDXDu3LkoKWW1guSKzQAJIVyAn4E3pZRxOQ+bOMVkDCAp5SJgEUDDhg3l2bNni0vFckFISAhBQUGWVqNUUXWu+AQFBREbG8vRo0ctrUqpU9l+awAhxBVz5IrFLAshbNGMz3Ip5S8mRMKB2gb7tYAbxVG2QqFQKMoOUVHmyxaHF5wAlgCnpZSf5yG2ARia5Q0XCNyVUubqflMoFApF+eXqVWjZ0nz54uiCaw8MAU4IIY5mpU0FHgCQUn4DbAb6ABeARODlYihXoVAoFGWEyEjo3h3Cw80/p8gGSEq5B9NjPIYyEhhT1LIUCoVCUba4GHORLadDWTL2Zc6dgxYt4Ngx884tVi84hUKhUFQeIu5F0O3H7lyKvQjOt/Dzm8y2bVC9unnni7K8IF1+XnCZmZmEh4eTkJBQylqVLMnJyTg4OFhajVJF1bnic/PmTaSU1KhhPP3P1tYWb29vXF1dLaRZyVORveCG/PISy078qO2kOxD6zEk6NquPEOJvKWWbgs4vty2gqKgohBA0bNiwQvnY37t3jypVqlhajVJF1bniY2VlRXp6Oo0aNdKnSSlJSkri+vXrABXaCFVEMjMh/ddgENegzi6CO66hY7P6hcqj3D65Y2Nj8fHxqVDGR6GoTAghcHJyombNmkRGRlpaHUUhkBLeegtW/lgFp3WbmR/4O2O69yt0PuW2BZSRkYGtra2l1VAoFEXE0dGRtLQ0S6uhKAQzZsBXX4GdHfz6swPdunW5r3zKrQEC7Q1KoVCUb9T/uHwgpWTKH1NIOdGPL6Z1wMoKVqyAbt3uP0/Vf6VQKBSKAvko9CPm7J3DF9E9oMFmFi2CAQOKlqcyQAqFQqHIl+jEaL7c+422Y5tEq5dWMmJE0fNVBkihUCgU+XL2aFWS5u+BmLrUz+jF/smLiyVfZYAUCoVCkSdhYdC3L6RE+PF84j6OvbcWO2u73IJ//glnzhQqb2WAFADExMTg4+PDv//+a5b8oEGD+PzzvGLPlj0qev0UiuJGSsmVK9CzJ8TGwpNPwo8LauBi75xb+K+/4PHHoUMHuHzZ7DKUAbIAt27d4q233qJBgwY4ODjg7e1Nu3btmDdvHvHx8QAMGzYMIUSuLTAwUJ/PsGHD6Ncvb9/7oKAgxo4dmyt96dKluLi4GKXNnj2bPn368OCDD5pVh2nTpjFz5kzu3r1rlnxJcvjwYYQQXM7nxi/P9VMoSpujN48SsLA9nfuHc+MGdOqkebzZmPKbPnkS+vSBxETo1w8KseigMkClzOXLl2nVqhVbt25lxowZ/PPPP/z5559MnDiRP/74g82bN+tlu3XrRkREhNFmeLy4SExMZPHixYwoxKhis2bNqF+/PsuWLSt2fczl1KlTDB06lEGDBgHa9Ro+fDg5wzeV1/opFJbgwp0L9PypF4dv7edS5/Y81P4cv/4KJiNHXb4MPXpATAw88QQsXgyFCA5QYQyQEJbZCsvo0aOxsrLi8OHDPPfcczRu3JimTZsyYMAA1q9fz9NPP62Xtbe3p3r16kabp6dnMV41jc2bN2NlZUX79u2N0ufOnWuyFfbhhx8C8MQTT7BixYpi16egcgF++eUXWrRoQUJCAm+++SYAU6ZMITo6mmbNmrF+/foyWz+Foixz9PpJbsffAcDK6S7BC5NwczMheOuWtv6Crom0cmUeTaS8qTAGqDxw584dtm3bxpgxY3B2NtGPimUm5e3evZvWrVvnKnv06NFGra8JEyZQvXp1hg4dCkBAQAAHDx4kKSkpV56zZ8/GxcUl32337t0m9Smo3NTUVF577TV69uzJzz//TIcOHQDo2rUrv/76K926dWPkyJH62fUlUT+FoiKSkQGrZ/RHLv8NkeTF//ptpGuTFrkF09K0brcLF7QV6H79FRwdC11euY6EYEgZDuqt5/z580gpadiwoVF6rVq1iI2NBeDZZ59lyZIlAGzdujXXWM2YMWOYM2dOsep15cqVXFGKAapUqaIPmDlnzhxWrFhBSEgIfn5+APj6+pKWlsaNGzdyja2MGjWKZ555Jt9ya9asaTK9oHJPnDhBVFQUQ4YMMXn+kCFD2LJlCydOnKBVq1YlUj+FoqIhJYwbB2vWgKtrT7YMukS7Ni6mhW1t4fXX4T//ga1bMd1EKpgKY4DKM7t37yYjI4ORI0eSnJysT+/YsSOLFi0yknV3dy/28pOSkvDx8cnz+Mcff0xwcDA7d+7E399fn+6Y9cZjqoXg6elZ5O7CvMrVLSGSV2tRF6BWJ1cS9VMoKgqpGalYCStmzbBhwQKwt4cNG8jb+OgYMQJefFE74T5RXXCliJ+fH0IIzuTwla9Xrx5+fn44OTkZpTs5OeHn52e0eXl5mV2eq6urSS+u2NhY3AzeWLy8vIiJiTGZx6xZs1iwYAGhoaFGD2fQuhQBqlWrluu8onTBFVRu8+bNqVq1KsuXLzd57rJly/Dy8qJZs2YlVj+FoiKQkZnB0HVDaTVnINNnJmFlpQ3ldOpkQjgzEyZMgOPHs9OKYHxAtYBKlapVq9KjRw+Cg4N54403cnWvFTcNGzZk8+bNSCmNWgv//POPUTdgy5YtWbp0aa7zZ8yYwbfffktISIjJLqiwsDB8fX1Nti6K0gVXULl2dnYsWLCAwYMH8+yzz/LYY48BEBoayttvv83WrVtZuXIldnZ2JVY/haK8I6XkjS1vsOrkKi3hxd4Et9/Ck0+aGMuRUjM+X3wBq1ZpYz/FsKCiMkClzIIFC2jfvj2tW7dm+vTptGjRAhsbG/7++2+OHTtG586d9bIpKSncvHnT6Hxra2ujN/K4uDiOHj1qJOPu7k7dunUZPXq03ti9+uqrODg4sHnzZlasWMGvv/6ql+/ZsyeTJk0iOjqaqlWrAlrL4Msvv2TDhg04Ozvr9XB3d9ev5Ll792569eplsp732wVnTrkAzzzzDI0aNeKTTz5h7ty5gDZ3p2PHjhw7dozGjRuXaP0UiorAnYjsRQAfrd+MUSPyMCqzZ2vGx9YWvvuuWIwPoFnBsrr5+/vLvDh16lSex8o6ERERcty4cfLBBx+UdnZ20tnZWbZp00bOnj1bXr9+XUop5UsvvSSBXFvNmjX1+eQlM3DgQL3MwYMHZY8ePaS3t7d0dXWVAQEBct26dbl0CgwMlMHBwVJKKTMzM6Wrq6vJvHfs2CGllDIpKUm6urrK/fv3F/l6xMXFmV2uKQ4dOiQBeenSpTxlLFk/U+jqXFk4c+aMDAsLy/N4ef4/F8TOnTstrYJJDh6U0tlZStrPkQ9NHSzTMzJMC379tZQgpRBSrlplVt7AYWnGM97iRia/raIaoPyw1INpy5Yt0t/fX6anp5slHxwcLLt3714sZZdGnS1ZP1MoA2RMRf0/S1k2DdCZM1J6eWkW4MUXpUxPzzQtuHKlZnhAym++KTjjjAwpk5PNNkDF4oQghPhOCBE
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 114\n",
"\n",
"def swish(z, beta=1):\n",
" return z * sigmoid(beta * z)\n",
"\n",
"def approx_gelu(z):\n",
" return swish(z, beta=1.702)\n",
"\n",
"def softplus(z):\n",
" return np.log(1 + np.exp(z))\n",
"\n",
"def mish(z):\n",
" return z * np.tanh(softplus(z))\n",
"\n",
"z = np.linspace(-4, 2, 200)\n",
"\n",
"beta = 0.6\n",
"plt.plot(z, approx_gelu(z), \"b-\", linewidth=2,\n",
" label=r\"GELU$(z) = z\\,\\Phi(z)$\")\n",
"plt.plot(z, swish(z), \"r--\", linewidth=2,\n",
" label=r\"Swish$(z) = z\\,\\sigma(z)$\")\n",
"plt.plot(z, swish(z, beta), \"r:\", linewidth=2,\n",
" label=fr\"Swish$_{{\\beta={beta}}}(z)=z\\,\\sigma({beta}\\,z)$\")\n",
"plt.plot(z, mish(z), \"g:\", linewidth=3,\n",
" label=fr\"Mish$(z) = z\\,\\tanh($softplus$(z))$\")\n",
"plt.plot([-4, 2], [0, 0], 'k-')\n",
"plt.plot([0, 0], [-2.2, 3.2], 'k-')\n",
"plt.grid(True)\n",
"plt.axis([-4, 2, -1, 2])\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.xlabel(\"$z$\")\n",
"plt.legend(loc=\"upper left\")\n",
"\n",
"save_fig(\"gelu_swish_mish_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Batch Normalization"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"# extra code - clear the name counters and set the random seed\n",
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Dense(300, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"flatten (Flatten) (None, 784) 0 \n",
"_________________________________________________________________\n",
"batch_normalization (BatchNo (None, 784) 3136 \n",
"_________________________________________________________________\n",
"dense (Dense) (None, 300) 235500 \n",
"_________________________________________________________________\n",
"batch_normalization_1 (Batch (None, 300) 1200 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 100) 30100 \n",
"_________________________________________________________________\n",
"batch_normalization_2 (Batch (None, 100) 400 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 10) 1010 \n",
"=================================================================\n",
"Total params: 271,346\n",
"Trainable params: 268,978\n",
"Non-trainable params: 2,368\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[('batch_normalization/gamma:0', True),\n",
" ('batch_normalization/beta:0', True),\n",
" ('batch_normalization/moving_mean:0', False),\n",
" ('batch_normalization/moving_variance:0', False)]"
]
},
"execution_count": 28,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[(var.name, var.trainable) for var in model.layers[1].variables]"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.5559 - accuracy: 0.8094 - val_loss: 0.4016 - val_accuracy: 0.8558\n",
"Epoch 2/2\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.4083 - accuracy: 0.8561 - val_loss: 0.3676 - val_accuracy: 0.8650\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fa5d11505b0>"
]
},
"execution_count": 29,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code just show that the model works! 😊\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"sgd\",\n",
" metrics=\"accuracy\")\n",
"model.fit(X_train, y_train, epochs=2, validation_data=(X_valid, y_valid))"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-03-25 05:03:44 +01:00
"Sometimes applying BN before the activation function works better (there's a debate on this topic). Moreover, the layer before a `BatchNormalization` layer does not need to have bias terms, since the `BatchNormalization` layer some as well, it would be a waste of parameters, so you can set `use_bias=False` when creating those layers:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"# extra code - clear the name counters and set the random seed\n",
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(300, kernel_initializer=\"he_normal\", use_bias=False),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Activation(\"relu\"),\n",
" tf.keras.layers.Dense(100, kernel_initializer=\"he_normal\", use_bias=False),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Activation(\"relu\"),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.6063 - accuracy: 0.7993 - val_loss: 0.4296 - val_accuracy: 0.8418\n",
"Epoch 2/2\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4275 - accuracy: 0.8500 - val_loss: 0.3752 - val_accuracy: 0.8646\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fa5fdd309d0>"
]
},
"execution_count": 32,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code just show that the model works! 😊\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"sgd\",\n",
" metrics=\"accuracy\")\n",
"model.fit(X_train, y_train, epochs=2, validation_data=(X_valid, y_valid))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gradient Clipping"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"All `tf.keras.optimizers` accept `clipnorm` or `clipvalue` arguments:"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(clipvalue=1.0)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(clipnorm=1.0)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reusing Pretrained Layers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Reusing a Keras model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's split the fashion MNIST training set in two:\n",
"* `X_train_A`: all images of all items except for T-shirts/tops and pullovers (classes 0 and 2).\n",
"* `X_train_B`: a much smaller training set of just the first 200 images of T-shirts/tops and pullovers.\n",
"\n",
"The validation set and the test set are also split this way, but without restricting the number of images.\n",
"\n",
"We will train a model on set A (classification task with 8 classes), and try to reuse it to tackle set B (binary classification). We hope to transfer a little bit of knowledge from task A to task B, since classes in set A (trousers, dresses, coats, sandals, shirts, sneakers, bags, and ankle boots) are somewhat similar to classes in set B (T-shirts/tops and pullovers). However, since we are using `Dense` layers, only patterns that occur at the same location can be reused (in contrast, convolutional layers will transfer much better, since learned patterns can be detected anywhere on the image, as we will see in the chapter 14)."
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"1376/1376 [==============================] - 1s 908us/step - loss: 1.1385 - accuracy: 0.6260 - val_loss: 0.7101 - val_accuracy: 0.7603\n",
"Epoch 2/20\n",
"1376/1376 [==============================] - 1s 869us/step - loss: 0.6221 - accuracy: 0.7911 - val_loss: 0.5293 - val_accuracy: 0.8315\n",
"Epoch 3/20\n",
"1376/1376 [==============================] - 1s 852us/step - loss: 0.5016 - accuracy: 0.8394 - val_loss: 0.4515 - val_accuracy: 0.8581\n",
"Epoch 4/20\n",
"1376/1376 [==============================] - 1s 852us/step - loss: 0.4381 - accuracy: 0.8583 - val_loss: 0.4055 - val_accuracy: 0.8669\n",
"Epoch 5/20\n",
"1376/1376 [==============================] - 1s 844us/step - loss: 0.3979 - accuracy: 0.8692 - val_loss: 0.3748 - val_accuracy: 0.8706\n",
"Epoch 6/20\n",
"1376/1376 [==============================] - 1s 882us/step - loss: 0.3693 - accuracy: 0.8782 - val_loss: 0.3538 - val_accuracy: 0.8787\n",
"Epoch 7/20\n",
"1376/1376 [==============================] - 1s 863us/step - loss: 0.3487 - accuracy: 0.8825 - val_loss: 0.3376 - val_accuracy: 0.8834\n",
"Epoch 8/20\n",
"1376/1376 [==============================] - 2s 1ms/step - loss: 0.3324 - accuracy: 0.8879 - val_loss: 0.3315 - val_accuracy: 0.8847\n",
"Epoch 9/20\n",
"1376/1376 [==============================] - 1s 1ms/step - loss: 0.3198 - accuracy: 0.8920 - val_loss: 0.3174 - val_accuracy: 0.8879\n",
"Epoch 10/20\n",
"1376/1376 [==============================] - 2s 1ms/step - loss: 0.3088 - accuracy: 0.8947 - val_loss: 0.3118 - val_accuracy: 0.8904\n",
"Epoch 11/20\n",
"1376/1376 [==============================] - 1s 1ms/step - loss: 0.2994 - accuracy: 0.8979 - val_loss: 0.3039 - val_accuracy: 0.8925\n",
"Epoch 12/20\n",
"1376/1376 [==============================] - 1s 837us/step - loss: 0.2918 - accuracy: 0.8999 - val_loss: 0.2998 - val_accuracy: 0.8952\n",
"Epoch 13/20\n",
"1376/1376 [==============================] - 1s 840us/step - loss: 0.2852 - accuracy: 0.9016 - val_loss: 0.2932 - val_accuracy: 0.8980\n",
"Epoch 14/20\n",
"1376/1376 [==============================] - 1s 799us/step - loss: 0.2788 - accuracy: 0.9034 - val_loss: 0.2865 - val_accuracy: 0.8990\n",
"Epoch 15/20\n",
"1376/1376 [==============================] - 1s 922us/step - loss: 0.2736 - accuracy: 0.9052 - val_loss: 0.2824 - val_accuracy: 0.9015\n",
"Epoch 16/20\n",
"1376/1376 [==============================] - 1s 835us/step - loss: 0.2686 - accuracy: 0.9068 - val_loss: 0.2796 - val_accuracy: 0.9015\n",
"Epoch 17/20\n",
"1376/1376 [==============================] - 1s 863us/step - loss: 0.2641 - accuracy: 0.9085 - val_loss: 0.2748 - val_accuracy: 0.9015\n",
"Epoch 18/20\n",
"1376/1376 [==============================] - 1s 913us/step - loss: 0.2596 - accuracy: 0.9101 - val_loss: 0.2729 - val_accuracy: 0.9037\n",
"Epoch 19/20\n",
"1376/1376 [==============================] - 1s 909us/step - loss: 0.2558 - accuracy: 0.9119 - val_loss: 0.2715 - val_accuracy: 0.9040\n",
"Epoch 20/20\n",
"1376/1376 [==============================] - 1s 859us/step - loss: 0.2520 - accuracy: 0.9125 - val_loss: 0.2728 - val_accuracy: 0.9027\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-12-15 16:22:23.274500: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_model_A/assets\n"
]
}
],
"source": [
"# extra code split Fashion MNIST into tasks A and B, then train and save\n",
"# model A to \"my_model_A\".\n",
"\n",
"pos_class_id = class_names.index(\"Pullover\")\n",
"neg_class_id = class_names.index(\"T-shirt/top\")\n",
"\n",
"def split_dataset(X, y):\n",
" y_for_B = (y == pos_class_id) | (y == neg_class_id)\n",
" y_A = y[~y_for_B]\n",
" y_B = (y[y_for_B] == pos_class_id).astype(np.float32)\n",
" old_class_ids = list(set(range(10)) - set([neg_class_id, pos_class_id]))\n",
" for old_class_id, new_class_id in zip(old_class_ids, range(8)):\n",
" y_A[y_A == old_class_id] = new_class_id # reorder class ids for A\n",
" return ((X[~y_for_B], y_A), (X[y_for_B], y_B))\n",
"\n",
"(X_train_A, y_train_A), (X_train_B, y_train_B) = split_dataset(X_train, y_train)\n",
"(X_valid_A, y_valid_A), (X_valid_B, y_valid_B) = split_dataset(X_valid, y_valid)\n",
"(X_test_A, y_test_A), (X_test_B, y_test_B) = split_dataset(X_test, y_test)\n",
"X_train_B = X_train_B[:200]\n",
"y_train_B = y_train_B[:200]\n",
"\n",
"tf.random.set_seed(42)\n",
"\n",
"model_A = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(8, activation=\"softmax\")\n",
"])\n",
"\n",
"model_A.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
" metrics=[\"accuracy\"])\n",
"history = model_A.fit(X_train_A, y_train_A, epochs=20,\n",
" validation_data=(X_valid_A, y_valid_A))\n",
"model_A.save(\"my_model_A\")"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"7/7 [==============================] - 0s 20ms/step - loss: 0.7167 - accuracy: 0.5450 - val_loss: 0.7052 - val_accuracy: 0.5272\n",
"Epoch 2/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.6805 - accuracy: 0.5800 - val_loss: 0.6758 - val_accuracy: 0.6004\n",
"Epoch 3/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.6532 - accuracy: 0.6650 - val_loss: 0.6530 - val_accuracy: 0.6746\n",
"Epoch 4/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.6289 - accuracy: 0.7150 - val_loss: 0.6317 - val_accuracy: 0.7517\n",
"Epoch 5/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.6079 - accuracy: 0.7800 - val_loss: 0.6105 - val_accuracy: 0.8091\n",
"Epoch 6/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.5866 - accuracy: 0.8400 - val_loss: 0.5913 - val_accuracy: 0.8447\n",
"Epoch 7/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.5670 - accuracy: 0.8850 - val_loss: 0.5728 - val_accuracy: 0.8833\n",
"Epoch 8/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.5499 - accuracy: 0.8900 - val_loss: 0.5571 - val_accuracy: 0.8971\n",
"Epoch 9/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.5331 - accuracy: 0.9150 - val_loss: 0.5427 - val_accuracy: 0.9050\n",
"Epoch 10/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.5180 - accuracy: 0.9250 - val_loss: 0.5290 - val_accuracy: 0.9080\n",
"Epoch 11/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.5038 - accuracy: 0.9350 - val_loss: 0.5160 - val_accuracy: 0.9189\n",
"Epoch 12/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4903 - accuracy: 0.9350 - val_loss: 0.5032 - val_accuracy: 0.9228\n",
"Epoch 13/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.4770 - accuracy: 0.9400 - val_loss: 0.4925 - val_accuracy: 0.9228\n",
"Epoch 14/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4656 - accuracy: 0.9450 - val_loss: 0.4817 - val_accuracy: 0.9258\n",
"Epoch 15/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4546 - accuracy: 0.9550 - val_loss: 0.4708 - val_accuracy: 0.9298\n",
"Epoch 16/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4435 - accuracy: 0.9550 - val_loss: 0.4608 - val_accuracy: 0.9318\n",
"Epoch 17/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4330 - accuracy: 0.9600 - val_loss: 0.4510 - val_accuracy: 0.9337\n",
"Epoch 18/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4226 - accuracy: 0.9600 - val_loss: 0.4406 - val_accuracy: 0.9367\n",
"Epoch 19/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4119 - accuracy: 0.9600 - val_loss: 0.4311 - val_accuracy: 0.9377\n",
"Epoch 20/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.4025 - accuracy: 0.9600 - val_loss: 0.4225 - val_accuracy: 0.9367\n",
"63/63 [==============================] - 0s 728us/step - loss: 0.4317 - accuracy: 0.9185\n"
]
},
{
"data": {
"text/plain": [
"[0.43168652057647705, 0.9185000061988831]"
]
},
"execution_count": 36,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code train and evaluate model B, without reusing model A\n",
"\n",
"tf.random.set_seed(42)\n",
"model_B = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(1, activation=\"sigmoid\")\n",
"])\n",
"\n",
"model_B.compile(loss=\"binary_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
" metrics=[\"accuracy\"])\n",
"history = model_B.fit(X_train_B, y_train_B, epochs=20,\n",
" validation_data=(X_valid_B, y_valid_B))\n",
"model_B.evaluate(X_test_B, y_test_B)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Model B reaches 91.85% accuracy on the test set. Now let's try reusing the pretrained model A."
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"model_A = tf.keras.models.load_model(\"my_model_A\")\n",
"model_B_on_A = tf.keras.Sequential(model_A.layers[:-1])\n",
"model_B_on_A.add(tf.keras.layers.Dense(1, activation=\"sigmoid\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that `model_B_on_A` and `model_A` actually share layers now, so when we train one, it will update both models. If we want to avoid that, we need to build `model_B_on_A` on top of a *clone* of `model_A`:"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42) # extra code ensure reproducibility"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"model_A_clone = tf.keras.models.clone_model(model_A)\n",
"model_A_clone.set_weights(model_A.get_weights())"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"# extra code creating model_B_on_A just like in the previous cell\n",
"model_B_on_A = tf.keras.Sequential(model_A_clone.layers[:-1])\n",
"model_B_on_A.add(tf.keras.layers.Dense(1, activation=\"sigmoid\"))"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"for layer in model_B_on_A.layers[:-1]:\n",
" layer.trainable = False\n",
"\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)\n",
"model_B_on_A.compile(loss=\"binary_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/4\n",
"7/7 [==============================] - 0s 23ms/step - loss: 1.7893 - accuracy: 0.5550 - val_loss: 1.3324 - val_accuracy: 0.5084\n",
"Epoch 2/4\n",
"7/7 [==============================] - 0s 7ms/step - loss: 1.1235 - accuracy: 0.5350 - val_loss: 0.9199 - val_accuracy: 0.4807\n",
"Epoch 3/4\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.8836 - accuracy: 0.5000 - val_loss: 0.8266 - val_accuracy: 0.4837\n",
"Epoch 4/4\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.8202 - accuracy: 0.5250 - val_loss: 0.7795 - val_accuracy: 0.4985\n",
"Epoch 1/16\n",
"7/7 [==============================] - 0s 21ms/step - loss: 0.7348 - accuracy: 0.6050 - val_loss: 0.6372 - val_accuracy: 0.6914\n",
"Epoch 2/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.6055 - accuracy: 0.7600 - val_loss: 0.5283 - val_accuracy: 0.8229\n",
"Epoch 3/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.4992 - accuracy: 0.8400 - val_loss: 0.4742 - val_accuracy: 0.8180\n",
"Epoch 4/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4297 - accuracy: 0.8700 - val_loss: 0.4212 - val_accuracy: 0.8773\n",
"Epoch 5/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.3825 - accuracy: 0.9050 - val_loss: 0.3797 - val_accuracy: 0.9031\n",
"Epoch 6/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.3438 - accuracy: 0.9250 - val_loss: 0.3534 - val_accuracy: 0.9149\n",
"Epoch 7/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.3148 - accuracy: 0.9500 - val_loss: 0.3384 - val_accuracy: 0.9001\n",
"Epoch 8/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.3012 - accuracy: 0.9450 - val_loss: 0.3179 - val_accuracy: 0.9209\n",
"Epoch 9/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2767 - accuracy: 0.9650 - val_loss: 0.3043 - val_accuracy: 0.9298\n",
"Epoch 10/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2623 - accuracy: 0.9550 - val_loss: 0.2929 - val_accuracy: 0.9308\n",
"Epoch 11/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2512 - accuracy: 0.9600 - val_loss: 0.2830 - val_accuracy: 0.9327\n",
"Epoch 12/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2397 - accuracy: 0.9600 - val_loss: 0.2744 - val_accuracy: 0.9318\n",
"Epoch 13/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2295 - accuracy: 0.9600 - val_loss: 0.2675 - val_accuracy: 0.9327\n",
"Epoch 14/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2225 - accuracy: 0.9600 - val_loss: 0.2598 - val_accuracy: 0.9347\n",
"Epoch 15/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2147 - accuracy: 0.9600 - val_loss: 0.2542 - val_accuracy: 0.9357\n",
"Epoch 16/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.2077 - accuracy: 0.9600 - val_loss: 0.2492 - val_accuracy: 0.9377\n"
]
}
],
"source": [
"history = model_B_on_A.fit(X_train_B, y_train_B, epochs=4,\n",
" validation_data=(X_valid_B, y_valid_B))\n",
"\n",
"for layer in model_B_on_A.layers[:-1]:\n",
" layer.trainable = True\n",
"\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)\n",
"model_B_on_A.compile(loss=\"binary_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model_B_on_A.fit(X_train_B, y_train_B, epochs=16,\n",
" validation_data=(X_valid_B, y_valid_B))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So, what's the final verdict?"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"63/63 [==============================] - 0s 667us/step - loss: 0.2546 - accuracy: 0.9385\n"
]
},
{
"data": {
"text/plain": [
"[0.2546142041683197, 0.9384999871253967]"
]
},
"execution_count": 43,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_B_on_A.evaluate(X_test_B, y_test_B)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Great! We got a bit of transfer: the model's accuracy went up 2 percentage points, from 91.85% to 93.85%. This means the error rate dropped by almost 25%:"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.24539877300613477"
]
},
"execution_count": 44,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"1 - (100 - 93.85) / (100 - 91.85)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Faster Optimizers"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"# extra code a little function to test an optimizer on Fashion MNIST\n",
"\n",
"def build_model(seed=42):\n",
" tf.random.set_seed(seed)\n",
" return tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
" ])\n",
"\n",
"def build_and_train_model(optimizer):\n",
" model = build_model()\n",
" model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
" return model.fit(X_train, y_train, epochs=10,\n",
" validation_data=(X_valid, y_valid))"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.6877 - accuracy: 0.7677 - val_loss: 0.4960 - val_accuracy: 0.8172\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 948us/step - loss: 0.4619 - accuracy: 0.8378 - val_loss: 0.4421 - val_accuracy: 0.8404\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 1s 868us/step - loss: 0.4179 - accuracy: 0.8525 - val_loss: 0.4188 - val_accuracy: 0.8538\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 1s 866us/step - loss: 0.3902 - accuracy: 0.8621 - val_loss: 0.3814 - val_accuracy: 0.8604\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 1s 869us/step - loss: 0.3686 - accuracy: 0.8691 - val_loss: 0.3665 - val_accuracy: 0.8656\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 925us/step - loss: 0.3553 - accuracy: 0.8732 - val_loss: 0.3643 - val_accuracy: 0.8720\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 908us/step - loss: 0.3385 - accuracy: 0.8778 - val_loss: 0.3611 - val_accuracy: 0.8684\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 926us/step - loss: 0.3297 - accuracy: 0.8796 - val_loss: 0.3490 - val_accuracy: 0.8726\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 893us/step - loss: 0.3200 - accuracy: 0.8850 - val_loss: 0.3625 - val_accuracy: 0.8666\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 886us/step - loss: 0.3097 - accuracy: 0.8881 - val_loss: 0.3656 - val_accuracy: 0.8672\n"
]
}
],
"source": [
"history_sgd = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Momentum optimization"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 941us/step - loss: 0.6877 - accuracy: 0.7677 - val_loss: 0.4960 - val_accuracy: 0.8172\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 878us/step - loss: 0.4619 - accuracy: 0.8378 - val_loss: 0.4421 - val_accuracy: 0.8404\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 898us/step - loss: 0.4179 - accuracy: 0.8525 - val_loss: 0.4188 - val_accuracy: 0.8538\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 934us/step - loss: 0.3902 - accuracy: 0.8621 - val_loss: 0.3814 - val_accuracy: 0.8604\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 910us/step - loss: 0.3686 - accuracy: 0.8691 - val_loss: 0.3665 - val_accuracy: 0.8656\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 913us/step - loss: 0.3553 - accuracy: 0.8732 - val_loss: 0.3643 - val_accuracy: 0.8720\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 893us/step - loss: 0.3385 - accuracy: 0.8778 - val_loss: 0.3611 - val_accuracy: 0.8684\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 968us/step - loss: 0.3297 - accuracy: 0.8796 - val_loss: 0.3490 - val_accuracy: 0.8726\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 913us/step - loss: 0.3200 - accuracy: 0.8850 - val_loss: 0.3625 - val_accuracy: 0.8666\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 1s 858us/step - loss: 0.3097 - accuracy: 0.8881 - val_loss: 0.3656 - val_accuracy: 0.8672\n"
]
}
],
"source": [
"history_momentum = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Nesterov Accelerated Gradient"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9,\n",
" nesterov=True)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 907us/step - loss: 0.6777 - accuracy: 0.7711 - val_loss: 0.4796 - val_accuracy: 0.8260\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 898us/step - loss: 0.4570 - accuracy: 0.8398 - val_loss: 0.4358 - val_accuracy: 0.8396\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 1s 872us/step - loss: 0.4140 - accuracy: 0.8537 - val_loss: 0.4013 - val_accuracy: 0.8566\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 902us/step - loss: 0.3882 - accuracy: 0.8629 - val_loss: 0.3802 - val_accuracy: 0.8616\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 913us/step - loss: 0.3666 - accuracy: 0.8703 - val_loss: 0.3689 - val_accuracy: 0.8638\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 882us/step - loss: 0.3531 - accuracy: 0.8732 - val_loss: 0.3681 - val_accuracy: 0.8688\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 958us/step - loss: 0.3375 - accuracy: 0.8784 - val_loss: 0.3658 - val_accuracy: 0.8670\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 895us/step - loss: 0.3278 - accuracy: 0.8815 - val_loss: 0.3598 - val_accuracy: 0.8682\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 878us/step - loss: 0.3183 - accuracy: 0.8855 - val_loss: 0.3472 - val_accuracy: 0.8720\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 921us/step - loss: 0.3081 - accuracy: 0.8891 - val_loss: 0.3624 - val_accuracy: 0.8708\n"
]
}
],
"source": [
"history_nesterov = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## AdaGrad"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.001)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 1.0003 - accuracy: 0.6822 - val_loss: 0.6876 - val_accuracy: 0.7744\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 912us/step - loss: 0.6389 - accuracy: 0.7904 - val_loss: 0.5837 - val_accuracy: 0.8048\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 930us/step - loss: 0.5682 - accuracy: 0.8105 - val_loss: 0.5379 - val_accuracy: 0.8154\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 878us/step - loss: 0.5316 - accuracy: 0.8215 - val_loss: 0.5135 - val_accuracy: 0.8244\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 1s 855us/step - loss: 0.5076 - accuracy: 0.8295 - val_loss: 0.4937 - val_accuracy: 0.8288\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 1s 868us/step - loss: 0.4905 - accuracy: 0.8338 - val_loss: 0.4821 - val_accuracy: 0.8312\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 940us/step - loss: 0.4776 - accuracy: 0.8371 - val_loss: 0.4705 - val_accuracy: 0.8348\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 966us/step - loss: 0.4674 - accuracy: 0.8409 - val_loss: 0.4611 - val_accuracy: 0.8362\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 892us/step - loss: 0.4587 - accuracy: 0.8435 - val_loss: 0.4548 - val_accuracy: 0.8406\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 873us/step - loss: 0.4511 - accuracy: 0.8458 - val_loss: 0.4469 - val_accuracy: 0.8424\n"
]
}
],
"source": [
"history_adagrad = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## RMSProp"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001, rho=0.9)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.5138 - accuracy: 0.8135 - val_loss: 0.4413 - val_accuracy: 0.8338\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 942us/step - loss: 0.3932 - accuracy: 0.8590 - val_loss: 0.4518 - val_accuracy: 0.8370\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 948us/step - loss: 0.3711 - accuracy: 0.8692 - val_loss: 0.3914 - val_accuracy: 0.8686\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 949us/step - loss: 0.3643 - accuracy: 0.8735 - val_loss: 0.4176 - val_accuracy: 0.8644\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 970us/step - loss: 0.3578 - accuracy: 0.8769 - val_loss: 0.3874 - val_accuracy: 0.8696\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3561 - accuracy: 0.8775 - val_loss: 0.4650 - val_accuracy: 0.8590\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3528 - accuracy: 0.8783 - val_loss: 0.4122 - val_accuracy: 0.8774\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 989us/step - loss: 0.3491 - accuracy: 0.8811 - val_loss: 0.5151 - val_accuracy: 0.8586\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3479 - accuracy: 0.8829 - val_loss: 0.4457 - val_accuracy: 0.8856\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 1000us/step - loss: 0.3437 - accuracy: 0.8830 - val_loss: 0.4781 - val_accuracy: 0.8636\n"
]
}
],
"source": [
"history_rmsprop = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Adam Optimization"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9,\n",
" beta_2=0.999)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4949 - accuracy: 0.8220 - val_loss: 0.4110 - val_accuracy: 0.8428\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3727 - accuracy: 0.8637 - val_loss: 0.4153 - val_accuracy: 0.8370\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3372 - accuracy: 0.8756 - val_loss: 0.3600 - val_accuracy: 0.8708\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3126 - accuracy: 0.8833 - val_loss: 0.3498 - val_accuracy: 0.8760\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2965 - accuracy: 0.8901 - val_loss: 0.3264 - val_accuracy: 0.8794\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2821 - accuracy: 0.8947 - val_loss: 0.3295 - val_accuracy: 0.8782\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2672 - accuracy: 0.8993 - val_loss: 0.3473 - val_accuracy: 0.8790\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2587 - accuracy: 0.9020 - val_loss: 0.3230 - val_accuracy: 0.8818\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2500 - accuracy: 0.9057 - val_loss: 0.3676 - val_accuracy: 0.8744\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2428 - accuracy: 0.9073 - val_loss: 0.3879 - val_accuracy: 0.8696\n"
]
}
],
"source": [
"history_adam = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Adamax Optimization**"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adamax(learning_rate=0.001, beta_1=0.9,\n",
" beta_2=0.999)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.5327 - accuracy: 0.8151 - val_loss: 0.4402 - val_accuracy: 0.8340\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 935us/step - loss: 0.3950 - accuracy: 0.8591 - val_loss: 0.3907 - val_accuracy: 0.8512\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 933us/step - loss: 0.3563 - accuracy: 0.8715 - val_loss: 0.3730 - val_accuracy: 0.8676\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 942us/step - loss: 0.3335 - accuracy: 0.8797 - val_loss: 0.3453 - val_accuracy: 0.8738\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 993us/step - loss: 0.3129 - accuracy: 0.8853 - val_loss: 0.3270 - val_accuracy: 0.8792\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 926us/step - loss: 0.2986 - accuracy: 0.8913 - val_loss: 0.3396 - val_accuracy: 0.8772\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 939us/step - loss: 0.2854 - accuracy: 0.8949 - val_loss: 0.3390 - val_accuracy: 0.8770\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 949us/step - loss: 0.2757 - accuracy: 0.8984 - val_loss: 0.3147 - val_accuracy: 0.8854\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 952us/step - loss: 0.2662 - accuracy: 0.9020 - val_loss: 0.3341 - val_accuracy: 0.8760\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 957us/step - loss: 0.2542 - accuracy: 0.9063 - val_loss: 0.3282 - val_accuracy: 0.8780\n"
]
}
],
"source": [
"history_adamax = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"**Nadam Optimization**"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Nadam(learning_rate=0.001, beta_1=0.9,\n",
" beta_2=0.999)"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.4826 - accuracy: 0.8284 - val_loss: 0.4092 - val_accuracy: 0.8456\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3610 - accuracy: 0.8667 - val_loss: 0.3893 - val_accuracy: 0.8592\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3270 - accuracy: 0.8784 - val_loss: 0.3653 - val_accuracy: 0.8712\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3049 - accuracy: 0.8874 - val_loss: 0.3444 - val_accuracy: 0.8726\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2897 - accuracy: 0.8905 - val_loss: 0.3174 - val_accuracy: 0.8810\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2753 - accuracy: 0.8981 - val_loss: 0.3389 - val_accuracy: 0.8830\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2652 - accuracy: 0.9000 - val_loss: 0.3725 - val_accuracy: 0.8734\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2563 - accuracy: 0.9034 - val_loss: 0.3229 - val_accuracy: 0.8828\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2463 - accuracy: 0.9079 - val_loss: 0.3353 - val_accuracy: 0.8818\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2402 - accuracy: 0.9091 - val_loss: 0.3813 - val_accuracy: 0.8740\n"
]
}
],
"source": [
"history_nadam = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**AdamW Optimization**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: Since TF 1.12, `AdamW` is no longer experimental. It is available at `tf.keras.optimizers.AdamW` instead of `tf.keras.optimizers.experimental.AdamW`."
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.AdamW(weight_decay=1e-5, learning_rate=0.001,\n",
" beta_1=0.9, beta_2=0.999)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.4945 - accuracy: 0.8220 - val_loss: 0.4203 - val_accuracy: 0.8424\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3735 - accuracy: 0.8629 - val_loss: 0.4014 - val_accuracy: 0.8474\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3391 - accuracy: 0.8753 - val_loss: 0.3347 - val_accuracy: 0.8760\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3155 - accuracy: 0.8827 - val_loss: 0.3441 - val_accuracy: 0.8720\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2989 - accuracy: 0.8892 - val_loss: 0.3218 - val_accuracy: 0.8786\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2862 - accuracy: 0.8931 - val_loss: 0.3423 - val_accuracy: 0.8814\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2738 - accuracy: 0.8970 - val_loss: 0.3593 - val_accuracy: 0.8764\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2648 - accuracy: 0.8993 - val_loss: 0.3263 - val_accuracy: 0.8856\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2583 - accuracy: 0.9035 - val_loss: 0.3642 - val_accuracy: 0.8680\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2483 - accuracy: 0.9054 - val_loss: 0.3696 - val_accuracy: 0.8702\n"
]
}
],
"source": [
"history_adamw = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAtgAAAHoCAYAAABzQZg1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAD0SElEQVR4nOzdd5gUVdbA4d+t6jzdk4c0wBBmQJAgoIiCSjChILqgIihmXOPnrqirmLMsxtXVlaC4LiYUs5hRMYEkRVByhsl5Ond9f3RPzzQzYKMwifM+zzx0Vd2qvl2Oevpy6hxlGAZCCCGEEEKIA0Nr7AkIIYQQQgjRkkiALYQQQgghxAEkAbYQQgghhBAHkATYQgghhBBCHEASYAshhBBCCHEASYAthBBCCCHEAdSgAbZS6lSl1G9KqfVKqX/Uc/xGpdSKyM8qpVRQKZXakHMUQgghhBDiz1ANVQdbKaUDa4GTgO3AEuA8wzBW72X8aOBvhmEMb5AJCiGEEEIIcQA05Ar2QGC9YRgbDcPwAa8AY/Yx/jzg5QaZmRBCCCGEEAdIQwbYmcC2WtvbI/vqUEo5gFOBNxpgXkIIIYQQQhwwpgZ8L1XPvr3lp4wGvjEMo6jeCyk1GZgMYLPZBnTs2PHAzLAZ0Coq0QsLATBsVgKtW8d9bigUQtPkudbfI/cpPnKf4if3Kj5yn+In9yo+cp/iI/cpfmvXri0wDCPj98Y1ZIC9HehQa7s9sHMvY8ezj/QQwzCeA54D6N69u/Hbb78dqDk2ef68PNYff0J4w2ym23ffoTsT4jp34cKFDB069OBNroWQ+xQfuU/xk3sVH7lP8ZN7FR+5T/GR+xQ/pdSWeMY15NeVJUCOUqqzUspCOIh+Z89BSqkk4ATg7QacW7NhbtUKa48e4Q2/n6rvv2vcCQkhhBBCiBgNFmAbhhEArgE+AtYArxmG8YtS6q9Kqb/WGnoW8LFhGJUNNbfmxnnccdHXFV8vasSZCCGEEEKIPTVkigiGYXwAfLDHvmf32H4BeKHhZtX8OI8bQuFzzwFQ8fVXGIaBUvWluAshhBBCiIYmGe3NkP2II9CcTgACO3fh27ChkWckhBBCCCGqNf8Au4Ea5TQlymwm4dhjo9uSJiKEEEII0XQ0aIrIwWCv2MKviz/hsIEn1Xu8rKyMvLw8/H5/A8/s4ApNuoDgmDMA2Gm1krdmze+ek5SUxJo4xrUkCQkJtG/fXsoPCSGEEKLBNPsA20QA95ePQz0BdllZGbm5uWRmZmK321tUnnLI78dbXZ5QKWzduqF0fZ/nlJeX43K5GmB2TUMoFGLHjh0UFBTQqlWrxp6OEEIIIQ4RLWJZr2/FN2xbt7LO/ry8PDIzM3E4HC0quAbQzGY0qy28YRiEKqXoyp40TaN169aUlpY29lSEEEIIcQhpEQG2pgx2Lni0zn6/34/dbm+EGTUMzeWMvg5VVDTiTJous9lMIBBo7GkIIYQQ4hDSIgJsgL4F71OYu73O/pa2cl1bdSURgGB5BcYh+MDn72nJ//yFEEII0TQ1+wDbgxUAm/Kz9r3HG3cyDUxzOFCRh/cMvw/D52vkGQkhhBBCiOYfYJuTo68P2/YK7sryxptMA1OahpZQK02k/ND57EIIIYQQTVWzD7CLTB6WmcIVIlIo56f3n2nkGR0Y+fn5XHXVVXTq1Amr1Urr1q0ZMWIEn3zySXTMxo0b+ettU+l+8skk9+9Px8MPZ9iwYcyZMwdfrdVspRRKKRITE3E4HHTp0oUJEyawaJHUzxZCCCGEONCafZk+d8jNC1m96L/hcwAy18wmGPg7uql5f7SxY8dSVVXFrFmzyM7OJi8vjy+//JLCwkIAfvzxR0aMGEGPww7jkVtuoXvnzlR5vWz0eJgZOWfw4MHR682YMYOhQ4diNpvZuHEjc+bM4fjjj+fhhx/mxhtvbKyPKYQQQgjR4jTvKDTiW7ayRXOSFaqgvbGL5Z/9j36nXNjY0/rDSkpK+Prrr/nkk08YMWIEAFlZWRx11FEAGIbBhRdeSE5ODt9+9x2+DRswvF4AjsrKYsLEiXUeeExOTqZ169a4XC6ysrIYNmwY7dq145ZbbuGss84iOzu7YT+kEEIIIUQL1exTRAC8IR/PZPaLbtuXNO80EafTidPp5J133sHj8dQ5vmLFClavXs2UKVPQNA3dWTsPO1yuL57qGTfccAOhUIi33nrrgM1dCCGEEOJQ1yJWsAG+shVSgplk/BwWWMOvP3wMiR3qjOv0j/cbYXZhmx86Pa5xJpOJF154gcsvv5znnnuOfv36MXjwYM4++2yOPvpo1q5dC0D37t0B0FwuSjdvJnvECFAKlOLWW2/l1ltv3ef7pKWl0apVKzZu3PjnPpgQQgghhIhq9ivYJhX+jlAeqOCp1jWr2O4vH2+kGR0YY8eOZefOnbz77ruMHDmSb7/9lkGDBvHAAw/UGas5HLhcLr6fN4/vX3+ddm3bxjzkuC+GYUitaCGEEEKIA6jZB9iJemL09WfJXvyR130rvyXgb951oW02GyeddBJ33HEH3377LZdeeil33XUXnTp1AuDXX38FwuX6zC4XXTt2pGvHjljifMCzoKCA/Px8unTpcrA+ghBCCCHEIafZp4gkaAmk2lIp8hRR4Cvm2ZTeXFv8M5oyCLrL6oyPN02jKerZsyeBQIDDDjuMHj16MG3aNM455xx0XUdzOglG6mAboVBc13vkkUfQNI0xY8YczGkLIYQQQhxSmv0KtkJxfo/zo9sftLJSHV5aQu5muYpdWFjI8OHDeemll/jpp5/YtGkTr7/+OtOmTWPEiBEkJSXxwgsvsGHDBo455hjefvtt1u/axa8bN/L8G2+wY9cuNC32H21JSQm5ubls3bqVL774gosuuoiHH36Yhx56SCqICCGEEEIcQM1+BRvgnO7nMPPnmVQFqtjuy+MVZxcmVGxEYeApycWZUfdhx6bM6XQyaNAgnnjiCdavX4/X6yUzM5MJEyZw2223ATBw4ECWLVvGgw8+yLXXXsvu3buxW6306taNO6+9liuuujrmmpdffjkAVquVtm3bMmjQIBYuXMjxxx/f4J9PCCGEEKIlaxEBdpI1ibO7nc2c1XMAmNc2gwnrwpUxbP4iQsF2aLremFPcL1arlQceeKDeBxpry87OZtasWdFt/65dBCKNaEy1Vu6ra2KXl5fjcrkOwoyFEEIIIUS1Zp8iUu2Cnhdg0sLfF9YFdvGptTUAJkJUleY15tQajFarHnYwUg9bCCGEEEI0rBYTYLdOaM2oLqOi2y9kdo6+tngK63Q2bIm0hIRwHWzA8HkJxVmqTwghhBBCHDgtJsAGuPjwi1GEA8yVxk58KvzxLPipKitszKk1CKVp4SA7IhSpKiKEEEIIIRpOiwqwuyR3YViHYdHtct0cfa1X5h8Sq9h6rRzrUIWkiQghhBBCNLQWFWADXNL7kuhrHyG8kRVtGx48lXXrYrc0MXnYlZVx18QWQgghhBAHRosLsPtm9GVA6wEAGBjkme3RY0ZFbmNNq8FoVivKYglvhEKEqqoad0JCCCGEEIeYFhdgA1zSq2YVu1wLEoi8tgcr8XpafsCp11rFDkk1ESGEEEKIBtUiA+zjMo8jJyUHCNeAzjPZgHCBDX/p7sacWoPQauVhByvkQUchhBBCiIbUIgNspRQXH35xdLtUJ9o+3REoa5bt0/dHTLk+r5TrE0IIIYRoSC0ywAY4tfOp6Fq4e2PICJGvWwHQVLh9ektWp1yfVBMRQgghhGgwLTbANmtmnOaaXOQSk0Z1kb5w+/Rg40wsDhdddBFKKS677LI6x2666SaUUowaNaqeM2vE5GE3YoCtlGLevHmN9v5CCCGEEA2txQbYAHaTPbqKHTCCFGnhutjNoX16hw4dePXVV6msrIzuCwQC/Pe//6Vjx46/e762Rz1sKdcnhBBCCNEwWnSArSmNVFtqdLvQZIquYjf19ul9+vQhJyeH1157Lbrv/fffx2azMXTo0Oi+UCjEvffeS4cOHbBarfTu3Zu3334bZbGgzGa27NiB/fDDefnFFxk5ciR2u51+/frx008/sWr
"text/plain": [
"<Figure size 864x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAtgAAAHoCAYAAABzQZg1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAADxsklEQVR4nOzdd3hUVf7H8fednslMei+kQICEXkQURIqKBdsiFixgw13XtmsXe5dVd9Xdn+4KKqxrRbGLlaig0ntCD6GkkN6n398fkwwZkkCAkPp9PQ8PmXvPvXPmEjKfnPnecxRVVRFCCCGEEEK0DU1Hd0AIIYQQQojuRAK2EEIIIYQQbUgCthBCCCGEEG1IArYQQgghhBBtSAK2EEIIIYQQbUgCthBCCCGEEG2oXQO2oihnK4qyVVGUHYqi3NfM/rsVRVlX/2eToihuRVHC2rOPQgghhBBCHA+lvebBVhRFC2wDzgT2ASuBK1RVzWqh/fnAX1RVndguHRRCCCGEEKINtOcI9ihgh6qqu1RVdQDvARcepv0VwLvt0jMhhBBCCCHaSHsG7Hhgb6PH++q3NaEoihk4G/ioHfolhBBCCCFEm9G143MpzWxrqT7lfGCZqqqlzZ5IUWYBswBMJtOIXr16tU0PuzmPx4NGI/e1Holcp9aR69R6cq1aR65T68m1ah25Tq0j16n1tm3bVqyqauSR2rVnwN4HJDZ6nADktdD2cg5THqKq6n+A/wD069dP3bp1a1v1sVvLzMxk/PjxHd2NTk+uU+vIdWo9uVatI9ep9eRatY5cp9aR69R6iqLktqZde/66shJIUxQlRVEUA94Q/dmhjRRFCQZOBz5tx74JIYQQQgjRJtptBFtVVZeiKLcA3wBa4A1VVTcrivLH+v2v1Te9GPhWVdWa9uqbEEIIIYQQbaU9S0RQVfUr4KtDtr12yOO3gLfar1dCCCGEEEK0HaloF0IIIYQQog11/YDdTgvlCCGEEEII0RrtWiJyIugKCnCVlaELDW12f2VlJQcOHMDpdLZzzzqf4OBgsrOzO7ob7SowMJCEhASZfkgIIYQQ7abLB2zF4WTPzGvp9dabTUJ2ZWUlhYWFxMfHExAQgKI0NxV3z1FVVYXVau3obrQbj8fD/v37KS4uJioqqqO7I4QQQogeolsM69m3bmXPddfjLi/3237gwAHi4+Mxm809Plz3RBqNhujoaCoqKjq6K0IIIYToQbp+wK7PzfbsbG/IbhSmnE4nAQEBHdQx0Rno9XpcLldHd0MIIYQQPUiXD9jusHCoH522ZWWx5/obcFdW+vbLyHXPJv/+QgghhGhvXT5geyyBxD75hO+xbdMm9txwI+6qqg7slRBCCCGE6Km6fMAGCJk6lZjHH/M9tm3YwN4bbkT1eDqwV0IIIYQQoifqFgEbIPTSS4l59BHf47r163GXlqK63R3Yq2NXVFTEzTffTHJyMkajkejoaCZNmsR3333na7Nr1y5uuOEGkpKSMBqNxMXFMWHCBObPn4/D4fC1UxQFRVEICgrCbDaTmprK9OnTWbp0aUe8NCGEEEKIbq3LT9PXWOjll6O63RQ+8SQAqsOBIzcXQ1ISilbbwb07OlOnTqW2tpZ58+bRp08fDhw4wE8//URJSQkAq1atYtKkSaSnp/PKK6/Qv39/amtryc7O5vXXX6dPnz6MGTPGd77XX3+d8ePHo9fr2bVrF/Pnz2fcuHE899xz3H333R31MoUQQgghup1uFbABwq68EjwqhU89BYCnthZH7h4MSb26TMguLy/nl19+4bvvvmPSpEkAJCUlcdJJJwGgqiozZswgLS2NX3/91W8RlaFDh3LFFVegHrLCZUhICNHR0VitVpKSkpgwYQJxcXHcf//9XHzxxfTp06f9XqAQQgghRDfWbUpEGgu7+iqi77/P99hTW4Njz54uU5NtsViwWCx89tln2Gy2JvvXrVtHVlYWd911V4srFLZm9ow777wTj8fDJ598crxdFkIIIYQQ9brdCHaDsBkzyFu50vfYU+MN2f1e29xhfdr97HmtaqfT6Xjrrbe48cYb+c9//sOwYcMYM2YM06ZN4+STT2bbtm0A9OvXz3dMRUUF8fHxvscPPPAADzzwwGGfJzw8nKioKHbt2nUMr0YIIYQQQjSnW45gN9BaLOiio32PPdXVHdibozN16lTy8vL4/PPPOeecc/j1118ZPXo0Tz/9dLPtrVYr69atY926dcTFxfnd5Hg4qqrKXNFCCCGEEG2oWwdsAH1kJLqo6CM37IRMJhNnnnkmDz/8ML/++ivXX389jz76KMnJyQBs2bLF11aj0dCnTx/69OmDwWBo1fmLi4spKioiNTX1RHRfCCGEEKJH6rYlIo3poyIBFdeBA2RflQSAxmrFkJiI0kINc2eUkZGBy+Wif//+pKenM2fOHC699FK0x3jz5gsvvIBGo+HCCy9s454KIYQQQvRcPSJgA+ijokAFV9EBADxVVTj37kXfCUN2SUkJ06ZN47rrrmPw4MFYrVZWrVrFnDlzmDRpEsHBwbz11lucccYZnHLKKcyePZv09HTcbjfLli1j3759TUJ3eXk5hYWFlJWVsXPnTubPn8+CBQuYM2eOzCAihBBCCNGGekzABtA1jGQXFQF4l1Pfuw99YkKnCtkWi4XRo0fz0ksvsWPHDux2O/Hx8UyfPp0HH3wQgFGjRrFmzRqeeeYZbr31VgoKCggICGDw4ME89dRT3HDDDX7nvPHGGwEwGo3ExsYyevRoMjMzGTduXLu/PiGEEEKI7qxHBWxFUdA1jGQXN4TsSti3D31C5wnZRqORp59+usUbGhv06dOHefPmHfF8DXNiV1VVYbVa26SPQgghhBCieZ0jUbYjRVHQRUehi4jwbXNXVuLct6/LzJMthBBCCCE6rx4XsKEhZEejCz8kZO/f32QFRCGEEEIIIY5GjwzYUB+yY6LRhYf7trkrKrwj2RKyhRBCCCHEMeqxARsaQnZMMyFbRrKFEEIIIcSx6dEBGxqF7LAw3zZ3RbmUiwghhBBCiGPS4wM21Ifs2Fi0jUN2eTnO/XkSsoUQQgghxFGRgF1PURT0sbFoQ0N929zlZTjzJGQLIYQQQojWk4DdiKIo6OPi/EN2WRnOvHwJ2UIIIYQQolUkYB/CF7JDQnzb3GWlOPMlZAshhBBCiCOTgN0MRVHQx8f7h+zSUlwSsoUQQgghxBFIwG6BL2QHB/u2uUpLcRUUnPCQPXPmTBRF4YYbbmiy75577kFRFKZMmXJC+9BWFEVh4cKFHd0NIYQQQoh2IwH7MBRFQZ+Q4B+yS0raJWQnJiby/vvvU1NTc/C5XS7++9//0qtXrxP63EIIIYQQ4thJwD4CX8gOOiRkFxae0JA9ePBg0tLS+OCDD3zbvvzyS0wmE+PHj/dt83g8PPHEEyQmJmI0Ghk0aBCffvqpb//u3btRFIX33nuPc845h4CAAIYNG8aGDRvYtGkTp556KoGBgYwdO5acnBy/Pnz++eeMGDECk8lESkoKs2fPxuFw+PYnJyfz5JNPctNNNxEUFERCQgJ/+9vf/PYDTJs2DUVRfI8fffRRBg4c6Pdcb731FhaLxfe4oc38+fNJTk7GYrFw7bXX4nA4+L//+z8SExMJDw/nr3/9Kx6P55ivsxBCCCFEW5OA3QrekB2PNijIt81VXHzCQ/b111/PG2+84Xv8xhtvcO2116Ioim/bSy+9xN/+9jeee+45Nm7cyMUXX8wf/vAH1q1b53euRx55hDvuuIO1a9cSEhLC9OnTufXWW3nqqadYsWIFNpuN2267zdf+m2++4corr+SWW25h8+bNvPHGGyxcuJAHHnjA77x///vfGTRoEGvWrOHee+/lnnvu4bfffgNg5cqVALz++uvk5+f7HrfW7t27+fTTT/niiy/46KOP+PDDD7nwwgtZuXIl3377LXPnzuWVV15h0aJFR3VeIYQQQogTSdfRHWh3jwYfuU0zFMBw3M9dcVTNp0+fzl133cX27duxWq0sXryYV155hYcfftjX5vnnn+euu+5i+vTpADz++OP8/PPPPP/887z99tu+dn/
"text/plain": [
"<Figure size 864x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code visualize the learning curves of all the optimizers\n",
"\n",
"for loss in (\"loss\", \"val_loss\"):\n",
" plt.figure(figsize=(12, 8))\n",
" opt_names = \"SGD Momentum Nesterov AdaGrad RMSProp Adam Adamax Nadam AdamW\"\n",
" for history, opt_name in zip((history_sgd, history_momentum, history_nesterov,\n",
" history_adagrad, history_rmsprop, history_adam,\n",
" history_adamax, history_nadam, history_adamw),\n",
" opt_names.split()):\n",
" plt.plot(history.history[loss], label=f\"{opt_name}\", linewidth=3)\n",
"\n",
" plt.grid()\n",
" plt.xlabel(\"Epochs\")\n",
" plt.ylabel({\"loss\": \"Training loss\", \"val_loss\": \"Validation loss\"}[loss])\n",
" plt.legend(loc=\"upper left\")\n",
" plt.axis([0, 9, 0.1, 0.7])\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learning Rate Scheduling"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Power Scheduling"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"learning_rate = initial_learning_rate / (1 + step / decay_steps)**power\n",
"```\n",
"\n",
"Keras uses `power = 1`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note**: The `decay` argument in optimizers is deprecated. The old optimizers which implement the `decay` argument are still available in `tf.keras.optimizers.legacy`, but you should use the schedulers in `tf.keras.optimizers.schedules` instead."
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"# DEPRECATED:\n",
"optimizer = tf.keras.optimizers.legacy.SGD(learning_rate=0.01, decay=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
"# RECOMMENDED:\n",
"lr_schedule = tf.keras.optimizers.schedules.InverseTimeDecay(\n",
" initial_learning_rate=0.01,\n",
" decay_steps=10_000,\n",
" decay_rate=1.0,\n",
" staircase=False\n",
")\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `InverseTimeDecay` scheduler uses `learning_rate = initial_learning_rate / (1 + decay_rate * step / decay_steps)`. If you set `staircase=True`, then it replaces `step / decay_step` with `floor(step / decay_step)`."
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.7004 - accuracy: 0.7588 - val_loss: 0.4991 - val_accuracy: 0.8206\n",
2022-02-19 10:24:54 +01:00
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4781 - accuracy: 0.8316 - val_loss: 0.4477 - val_accuracy: 0.8372\n",
2022-02-19 10:24:54 +01:00
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4293 - accuracy: 0.8487 - val_loss: 0.4177 - val_accuracy: 0.8498\n",
2022-02-19 10:24:54 +01:00
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4053 - accuracy: 0.8563 - val_loss: 0.3987 - val_accuracy: 0.8602\n",
2022-02-19 10:24:54 +01:00
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3864 - accuracy: 0.8633 - val_loss: 0.3859 - val_accuracy: 0.8612\n",
2022-02-19 10:24:54 +01:00
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3720 - accuracy: 0.8675 - val_loss: 0.3942 - val_accuracy: 0.8584\n",
2022-02-19 10:24:54 +01:00
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3616 - accuracy: 0.8709 - val_loss: 0.3706 - val_accuracy: 0.8670\n",
2022-02-19 10:24:54 +01:00
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3529 - accuracy: 0.8741 - val_loss: 0.3758 - val_accuracy: 0.8638\n",
2022-02-19 10:24:54 +01:00
"Epoch 9/10\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3452 - accuracy: 0.8765 - val_loss: 0.3587 - val_accuracy: 0.8680\n",
2022-02-19 10:24:54 +01:00
"Epoch 10/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.3379 - accuracy: 0.8793 - val_loss: 0.3569 - val_accuracy: 0.8714\n"
2022-02-19 10:24:54 +01:00
]
}
],
"source": [
"history_power_scheduling = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAk4AAAHNCAYAAADolfQeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAACEuklEQVR4nO3dd1gU1/4G8HfYXZYO0os0WyxYwV6w91iiCZqE6FVzf8ZYiTFqTCwpxhi92I1Go97kGk00RiMaURE1YgXsXRBFEOkdFnZ+fyAbV4orLuwC7+d5iHDmzMx35mTl6zlnzgiiKIogIiIiohcy0HUARERERNUFEyciIiIiDTFxIiIiItIQEyciIiIiDTFxIiIiItIQEyciIiIiDTFxIiIiItIQEyciIiIiDTFxIiIiItIQEyciogqIjo6GIAgYO3asTs7fvXt3CILwSsco6xrGjh0LQRAQHR39SscnqomYOBHVUMW/FJ/9MjQ0hKurK95++21cunRJ1yFWiezsbHz99ddo06YNzMzMYGRkhLp166Jr166YM2cO7t69q+sQiagakeo6ACKqXPXr18e7774LAMjMzMTp06exfft27N69G0ePHkWnTp10HGHlycjIQJcuXXDp0iU0aNAA7777LqysrPDgwQNcvXoV33zzDerXr4/69evrOlS9snjxYsyePRsuLi66DoVI7zBxIqrhGjRogAULFqiVzZs3D1999RU+/fRThISE6CawKhAYGIhLly5h/Pjx2LhxY4mhraioKOTl5ekoOv3l5OQEJycnXYdBpJc4VEdUC02ZMgUAcO7cOVVZQUEB/vOf/6Bly5YwNjaGpaUlevTogf3796vtGxkZCUEQMH36dLXyX3/9FYIgwNTUFPn5+WrbHB0d0aRJE7UyURSxefNmdO7cGRYWFjAxMYGPjw82b95cIt4FCxZAEAQcO3YMW7duhbe3N0xMTNC9e/dyrzMsLAwAMHny5FLnA3l6eqJx48YlyhMSEjBz5ky89tprMDIygrW1NTp06IBly5aVep579+5h5MiRqFOnDkxNTdG7d29cvHix1LoJCQmYMWMGGjRoALlcDltbW4wYMQJXrlwptf7Jkyfh6+sLU1NT2NjYwM/PDw8ePCi1bnlzk569hy9S2nGOHTsGQRCwYMEChIeHo1+/fjA3N4elpSWGDx9e5nyo3bt3w8fHB8bGxnBwcMD777+PlJQUeHh4wMPD44WxEOkbJk5EtdDzSYQoivDz80NAQAByc3Px4YcfquZBDR48GCtXrlTVbdmyJaytrUv0VBX/Qs7OzsaZM2dU5devX8fjx4/Ro0cPtfO9++67GD9+PBITE/H2229jwoQJyMrKwvjx4zFz5sxS4166dCk++OADNGzYEFOnTkWXLl3KvU5ra2sAwJ07d158U566ffs22rRpg2XLlsHe3h7Tpk3D22+/DSMjI3z11Vcl6kdHR6N9+/Z48uQJxo0bhz59+uDIkSPo0aMHHj9+rFb37t278Pb2xooVK9CgQQNMmTIFAwcOxMGDB9GhQwe1+wYAR44cQc+ePXHmzBmMHDkS//73vxEVFYXOnTsjJSVF42vSpvPnz6Nr166QSqX4v//7P/j4+GDPnj3o3bs3cnNz1epu3rwZI0aMwN27d/Hee+9hzJgxCAsLQ58+faBQKHQSP9ErE4moRoqKihIBiP369Sux7dNPPxUBiN27dxdFURS3bdsmAhB9fX3FvLw8Vb0HDx6I9vb2okwmE+/du6cqHz58uCgIgvjkyRNVWZMmTcTu3buLEolEXLhwoap8zZo1IgBx586dqrINGzaIAMTx48eLCoVCVZ6Xlye+/vrrIgDx/PnzqvL58+eLAERTU1Px0qVLGt+DPXv2iABECwsL8ZNPPhGPHDkiJicnl7tPu3btRADihg0bSmx78OCB6vvi+wtA/Oabb9TqzZs3TwQgLl68WK28U6dOolQqFQ8dOqRWfvPmTdHc3Fxs3ry5qqywsFCsV6+eKAiCeOLECVW5UqkU3377bdW5nzVmzBgRgBgVFVUi9uJ7GBISUuIaxowZ88LjhISEqM75yy+/qNX39/cXAYjbt29XlaWkpIhmZmaiubm5ePfuXVW5QqEQe/fuLQIQ3d3dS8RJpO+YOBHVUMW/FOvXry/Onz9fnD9/vvjRRx+JnTt3FgGIRkZG4qlTp0RRFMWePXuKAMQzZ86UOM7ixYtFAOIXX3yhKluxYoUIQPz1119FURTF+Ph4EYD4n//8R2zXrp3o6+urqjty5EgRgPj48WNVWYsWLURTU1MxJyenxPkuXbokAhA/+ugjVVnxL/0ZM2a89H349ttvRTMzM9Uv/eJ78uGHH4q3bt1Sq3v27FkRgNitW7cXHrf4/np6eoqFhYWlbnvjjTdUZeHh4apksTQBAQEiAPHy5cuiKIpiaGioCEB8/fXXS9SNjo4WJRKJThKn0u5N8baAgABV2ZYtW8pss7CwMCZOVG1xcjhRDXf37l0sXLgQACCTyeDg4IC3334bs2fPRvPmzQEAERERMDY2Rrt27UrsXzyPKDIyUlVWPOwWEhKCkSNHqobtevTogfj4eAQGBiI3NxdyuRyhoaFo1qwZ7O3tARQN5V2+fBnOzs745ptvSpyveAjnxo0bJbaVFt+LfPzxx5g4cSIOHjyIU6dO4fz58zhz5gzWrFmDTZs2YceOHRgyZAgA4OzZswCAvn37anz8li1bwsBAfdZD3bp1AQCpqamqstOnTwMA4uPjS0zWB/653hs3bsDLy0s1R6pr164l6rq7u8PV1VUn6yy1adOmRFlp11scf2lPbbZr1w5SKX/9UPXE/3OJarh+/frh4MGD5dZJT0+Hq6trqdscHR0BAGlpaaoyLy8v2NnZqRKmkJAQ2NjYoEWLFoiPj8eSJUtw6tQp2NnZ4cmTJ/Dz81Ptm5KSAlEUERsbq0roSpOVlVWizMHBodzrKIu5uTnefPNNvPnmm6prmTt3LtauXYvx48cjNjYWhoaGql/8L/MYvqWlZYmy4qSgsLBQVZacnAwA2L9/f4kJ988qvu7i+12ccD7PwcFBJ4mTptebnp4OALCzsytR38DAALa2tpUUIVHl4uRwIoKFhUWJiczFisstLCxUZYIgwNfXF9evX0d8fDyOHTsGX19fCIKALl26QCaTISQkRDVh/NmJ4cXH8fb2hlg0XaDUr9KWSXjVlbKLWVpaYvXq1XB3d0diYiIuX74MALCysgIAxMbGauU8zyq+7lWrVpV73WPGjFHFCBQ9hVea0tqruOeroKCgxLZnE9+qUHy9T548KbFNqVQiMTGxSuMh0hYmTkSE1q1bIycnRzVU9azQ0FAAQKtWrdTKi4fwfv75Z9y6dQs9e/YEAJiamqJdu3Y4evQoQkJCVElWMXNzczRp0gTXr19XG9qpaoIgwMTERK2seCjw0KFDWj9f+/btAfyzRMKLtGzZEgBw4sSJEtvu379f6pIEderUAVB64hcREaFxrNpQHP+pU6dKbDt79mypyR1RdcDEiYhUvRxz5sxRe0w8NjYWy5cvh1QqxTvvvKO2T3Ev0pIlS9R+Lv7+3LlzCAkJQfPmzWFjY6O279SpU5GdnY3333+/1CG5qKgorQxDff/992prVT1r9+7duHHjBqysrODl5QUAaNu2Ldq1a4fjx49j48aNJfZ5lZ6odu3aoX379ti+fTt27NhRYrtSqVQlqQDQpUsXeHp64s8//8TJkydV5aIoYu7cuWrDYsV8fHwAAFu2bFEr/+2339SOXRWGDh0KMzMz/PDDD4iKilKVFxQU4LPPPqvSWIi0iXOciAj+/v7YvXs3/vjjD7Ro0QKDBw9GVlYWdu7ciaSkJCxbtgz16tVT26dp06ZwcHDA48eP4eDggKZNm6q29ejRA19++SVSU1NVSdmz/u///g+nT5/G1q1b8ffff6N3795wdnbG48ePcePGDZw5cwb/+9//XnmBxAMHDmDixIlo0KABOnfuDGdnZ2RmZiIyMhInTpyAgYEB1q5dC7lcrtrnp59+Qvfu3fHvf/8b//3vf9GxY0fk5ubi6tWriIiIQFJSUoXj2b5
2022-02-19 10:24:54 +01:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
2022-02-19 10:24:54 +01:00
]
},
"metadata": {},
2022-02-19 10:24:54 +01:00
"output_type": "display_data"
}
],
"source": [
"# extra code this cell plots power scheduling with staircase=True or False\n",
"\n",
"initial_learning_rate = 0.01\n",
"decay_rate = 1.0\n",
"decay_steps = 10_000\n",
"\n",
"steps = np.arange(100_000)\n",
"lrs = initial_learning_rate / (1 + decay_rate * steps / decay_steps)\n",
"lrs2 = initial_learning_rate / (1 + decay_rate * np.floor(steps / decay_steps))\n",
"\n",
"plt.plot(steps, lrs, \"-\", label=\"staircase=False\")\n",
"plt.plot(steps, lrs2, \"-\", label=\"staircase=True\")\n",
"plt.axis([0, steps.max(), 0, 0.0105])\n",
"plt.xlabel(\"Step\")\n",
"plt.ylabel(\"Learning Rate\")\n",
"plt.title(\"Power Scheduling\", fontsize=14)\n",
"plt.legend()\n",
"plt.grid(True)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Exponential Scheduling"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```python\n",
"learning_rate = initial_learning_rate * decay_rate ** (step / decay_steps)\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
"lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(\n",
" initial_learning_rate=0.01,\n",
" decay_steps=20_000,\n",
" decay_rate=0.1,\n",
" staircase=False\n",
")\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.6916 - accuracy: 0.7632 - val_loss: 0.5030 - val_accuracy: 0.8254\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4832 - accuracy: 0.8311 - val_loss: 0.4601 - val_accuracy: 0.8358\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4372 - accuracy: 0.8449 - val_loss: 0.4256 - val_accuracy: 0.8524\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.4131 - accuracy: 0.8546 - val_loss: 0.4037 - val_accuracy: 0.8568\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3952 - accuracy: 0.8596 - val_loss: 0.3950 - val_accuracy: 0.8598\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3825 - accuracy: 0.8640 - val_loss: 0.4010 - val_accuracy: 0.8584\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3739 - accuracy: 0.8667 - val_loss: 0.3851 - val_accuracy: 0.8650\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.3664 - accuracy: 0.8696 - val_loss: 0.3811 - val_accuracy: 0.8616\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.3606 - accuracy: 0.8720 - val_loss: 0.3749 - val_accuracy: 0.8662\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.3555 - accuracy: 0.8743 - val_loss: 0.3706 - val_accuracy: 0.8662\n"
]
}
],
"source": [
"history_exponential_scheduling = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAk4AAAHNCAYAAADolfQeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAAB99klEQVR4nO3dd1wU194G8GcLLB2kF2n23iB2xQpGTTTRRE1CNJZ7DdGgpFhiYklRY/TFEjUmRjRFzVVjiWjEAliwIsZuVBRFEBGlCizsvH8gGze74LIs7ALP9/Ph3nD27MxvZ0Qe55w5IxIEQQARERERPZfY0AUQERER1RQMTkRERERaYnAiIiIi0hKDExEREZGWGJyIiIiItMTgRERERKQlBiciIiIiLTE4EREREWmJwYmIiIhISwxORFRr+fj4wMfHp1LbmDNnDkQiEaKjo/VSU2Xo4/PoSl/HQdNniIiIgEgkQkRERKW2TVQdGJyIqsmtW7cgEonK/WrXrp2hy6xRxowZA5FIhFu3bhm6FCVBEPDzzz+jT58+cHBwgKmpKVxcXNC+fXuEhIQgJibG0CUSUSVIDV0AUV3TsGFDvPXWWxpfc3V1reZqarcDBw5U+z7Hjh2LiIgI1KtXD4MHD4a7uzvS09Nx7do1rF27FllZWQgICKj2uozZK6+8gs6dO8PNzc3QpRA9F4MTUTVr1KgR5syZY+gy6oSGDRtW6/4OHz6MiIgItGvXDjExMbCxsVF5/fHjx7h06VK11lQT2NrawtbW1tBlEGmFQ3VERuqLL76ASCTC5MmT1V4rnW8ydepUtbbo6Gh8//33aNmyJczMzODl5YUZM2YgPz9f437++OMP9O7dG7a2tjA3N0e7du0QHh6O4uJilX6lQ41jxozBzZs3MXz4cNSrVw+Wlpbo168fzp07p3H7aWlpmDp1Kho1agSZTAZHR0cMGzYMFy5cUOtbOv8lNzcXYWFh8PDwgEwmQ5s2bbBlyxa1vuvXrwcA+Pr6Koc7e/Xqpba9Z927dw+zZ89G586d4ezsDJlMBh8fH4SEhCAtLU3jZ9BWXFwcAGD06NFqoQkA7Ozs0LVrV7X2wsJCLF26FB07doS1tTWsrKzQokULhIWF4dGjR2r9tTk+z257yZIl6NChAywtLWFtbY0ePXpg586dGvvfuXMHo0aNgr29PaysrBAQEIDY2FiNfcubmxQdHQ2RSKTVPxLK2k7p+Xzw4AHGjh0LZ2dnmJubo3PnzmXOtfrrr78wcOBAWFtbw9bWFgMHDsSFCxeMcliXaiZecSIyUjNnzkRUVBRWrFiBwMBAvPTSSwCAo0eP4osvvkCbNm2wYMECtfctXrwY0dHRGDFiBAYPHozIyEgsWLAAZ8+exZ49eyASiZR9ly5diilTpsDe3h5vvPEGLC0tsWvXLkydOhWHDx/Gli1bVPoDJQGqU6dOaNGiBcaOHYsbN25gx44d6N27Ny5fvgwXFxdl3xs3bqBXr15ITk5GYGAghg4dirS0NGzduhV//vknDhw4gE6dOqlsXy6XIzAwEBkZGXj11VeRl5eHTZs24fXXX8fevXsRGBgIAJgyZQoiIiJw7tw5hIaGws7ODgCeO3k6NjYWixcvRt++fdGpUyeYmJjg7NmzWLVqFf7880/Ex8frfPXD3t4eAHD9+nWt35Ofn4+goCDExsaicePGeOeddyCTyfD3339j9erVePvtt1GvXj1lf22PDwAUFBRgwIABiI6ORvv27TFu3DjI5XLs3r0bQ4YMwfLlyzFp0iRl/5SUFHTp0gXJyckICgpChw4dcPnyZfTv3x+9e/fW6ZhU1uPHj9GtWzfY2NjgzTffRFpaGjZv3oygoCCcOXMGrVq1UvY9d+4cevTogby8PLz66qto1KgRzpw5g+7du6Nt27YGqZ9qIYGIqkViYqIAQGjYsKEwe/ZsjV979uxReU9SUpJQr149wdHRUbh3757w+PFjwcfHRzA3NxcuXryo0nf27NkCAMHMzEy4cOGCsl0ulwv9+/cXAAgbNmxQtt+4cUOQSqWCs7OzkJSUpGwvKCgQAgICBADCTz/9pFY/AGHBggUq+541a5YAQJg/f75Ke9euXQWpVCrs27dPpf3q1auCtbW10Lp1a5V2b29vAYAwZMgQoaCgQNm+f/9+AYAQFBSk0n/06NECACExMVHteJduz9vbW6Xt/v37QnZ2tlrf9evXCwCEL774QqW99LgeOnRI4z6elZSUJFhbWwtisVh4++23hd9//13l2Gry0UcfCQCE4OBgoaioSOW1x48fq9Ra0eMzc+ZMAYAwZ84cQaFQKNuzsrIEf39/wdTUVEhOTla2lx7Pfx+D7777Tnnunz0O69atEwAI69atU/tchw4dEgAIs2fPVmnXdE7K2k7pPkNCQoTi4mJl+w8//CAAEP773/+q9O/evbsAQPjf//6n0l56Dsv7s0KkLQYnomrybPAo6ys0NFTtfVu2bBEACP369RNGjhwpABBWrlyp1q/0l8OECRPUXjt16pQAQOjbt6+ybd68eQIAYeHChWr94+Li1PqX1u/r66vyS+zZ11599VVlW3x8vABAGDdunMbjERYWJgAQzp8/r2wrDQY3b95U6+/t7S3Y29urtOkSnMqiUCgEGxsboVevXirtFQlOgiAIe/fuFTw9PVXOq5OTk/D6668LBw4cUOlbVFQk2NjYCLa2tkJGRsZzt12R41NcXCzUq1dPaNSokUpoKrVz504BgLB8+XJBEEoCs5mZmeDs7Cw8efJEpW9xcbHQpEkTgwQnS0tLtaArl8sFqVQqdOjQQdl269YtAYDQvn17tVpyc3MFe3t7BifSCw7VEVWzoKAg7N27V+v+w4YNw/jx4/HDDz8AAIYMGYJ33323zP49evRQa/P394e5uTkSEhKUbWfPngUAlTlBpTp37qzWv1Tbtm0hFqtOj6xfvz6AkmGVUsePHwcApKamapzncuXKFeX/PzvcYmdnB19fX7X+9evXV84hqqxt27bhu+++Q3x8PB49eqQyn+vevXuV2nZQUBBu3ryJ6OhoxMbG4syZMzhy5Ah+++03/Pbbb5gxYwa++uorACWfPSsrC/369VMZjiuPtsfn6tWrePToEdzd3TF37ly1/g8ePFDWUNo/Pz8fffr0gZmZmUpfsViMrl274tq1a9odBD1q3LgxrKysVNqkUilcXFxU/ryVzrHTNIfMwsICbdu2xaFDh6q0VqobGJyIaoBXX31VGZzee++9cvs6OzuX2Z6cnKz8PisrCwBU5iSV17+Upvk/UmnJXyXPBpCMjAwAwO7du7F79+4y683NzX3u9kv3oVAoytyOthYvXowPP/wQTk5OCAwMRP369WFubg4ACA8PR0FBQaX3IZVK0a9fP/Tr1w8AUFRUhIiICLz77ruYP38+hg8fjg4dOih/8Xt4eGi9bW2PT+nxv3jxIi5evFjm9kqPf2ZmJoCy//yU9eekqpX3eZ/981b659nJyUljf0PVT7UPgxORkcvIyMB//vMfWFlZQS6XY9KkSYiPj4elpaXG/mXdGZaWlqbyS6j0rq/79+/D29tbY39Nd4Zpq/S9/56AbEhFRUX4/PPP4e7ujoSEBJVfsoIg4Ouvv66S/UqlUowfPx6HDx/Ghg0bcOjQIXTo0EE5oV1TQK2s0uM/bNiwMu+4e1bpn42y/vzcv39fra30ymNRUZHaa6VBrLqUft7SK2n/pql+Il1wOQIiIzdhwgTcvXsXK1aswIIFC3Dt2jWEhoaW2f/w4cNqbadPn8aTJ09UViZv3749AGi8rfvkyZNq/Suq9G45fQ2vaSKRSABAbemEsqSnpyMzMxOdO3dWuzJReoyq0r/DbtOmTWFjY4NTp05pXHagMpo3bw4bGxucPn0acrn8uf2bNm0KMzMznD59Wm3pCoVCgWPHjqm9p3R4UVPwKx0Kri6ld81pqjMvL6/M5TKIKorBiciIff/999i2bRtGjBiB0aNHIzQ0FEFBQVi7dm2ZVxF++uknlaGZoqIizJw5E0DJ+kKl3njjDUilUixZskRlXo9cLsf06dMBlDzSRFcdO3ZEp06
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell plots exponential scheduling\n",
"\n",
"initial_learning_rate = 0.01\n",
"decay_rate = 0.1\n",
"decay_steps = 20_000\n",
"\n",
"steps = np.arange(100_000)\n",
"lrs = initial_learning_rate * decay_rate ** (steps / decay_steps)\n",
"lrs2 = initial_learning_rate * decay_rate ** np.floor(steps / decay_steps)\n",
"\n",
"plt.plot(steps, lrs, \"-\", label=\"staircase=False\")\n",
"plt.plot(steps, lrs2, \"-\", label=\"staircase=True\")\n",
"plt.axis([0, steps.max(), 0, 0.0105])\n",
"plt.xlabel(\"Step\")\n",
"plt.ylabel(\"Learning Rate\")\n",
"plt.title(\"Exponential Scheduling\", fontsize=14)\n",
"plt.legend()\n",
"plt.grid(True)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Keras also provides a `LearningRateScheduler` callback class that lets you define your own scheduling function. Let's see how you could use it to implement exponential decay. Note that in this case the learning rate only changes at each epoch, not at each step:"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"def exponential_decay_fn(epoch):\n",
" return 0.01 * 0.1 ** (epoch / 20)"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"def exponential_decay(lr0, s):\n",
" def exponential_decay_fn(epoch):\n",
" return lr0 * 0.1 ** (epoch / s)\n",
" return exponential_decay_fn\n",
"\n",
"exponential_decay_fn = exponential_decay(lr0=0.01, s=20)"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"# extra code build and compile a model for Fashion MNIST\n",
"\n",
"tf.random.set_seed(42)\n",
"model = build_model()\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {
"scrolled": true
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.6905 - accuracy: 0.7643 - val_loss: 0.4814 - val_accuracy: 0.8330 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 2/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4672 - accuracy: 0.8357 - val_loss: 0.4488 - val_accuracy: 0.8374 - lr: 0.0089\n",
2022-02-19 10:24:54 +01:00
"Epoch 3/25\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.4212 - accuracy: 0.8503 - val_loss: 0.4118 - val_accuracy: 0.8532 - lr: 0.0079\n",
2022-02-19 10:24:54 +01:00
"Epoch 4/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3975 - accuracy: 0.8593 - val_loss: 0.3884 - val_accuracy: 0.8636 - lr: 0.0071\n",
2022-02-19 10:24:54 +01:00
"Epoch 5/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3781 - accuracy: 0.8657 - val_loss: 0.3772 - val_accuracy: 0.8642 - lr: 0.0063\n",
2022-02-19 10:24:54 +01:00
"Epoch 6/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3634 - accuracy: 0.8710 - val_loss: 0.3779 - val_accuracy: 0.8662 - lr: 0.0056\n",
2022-02-19 10:24:54 +01:00
"Epoch 7/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3530 - accuracy: 0.8744 - val_loss: 0.3674 - val_accuracy: 0.8652 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 8/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3437 - accuracy: 0.8771 - val_loss: 0.3616 - val_accuracy: 0.8686 - lr: 0.0045\n",
2022-02-19 10:24:54 +01:00
"Epoch 9/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3359 - accuracy: 0.8801 - val_loss: 0.3509 - val_accuracy: 0.8728 - lr: 0.0040\n",
2022-02-19 10:24:54 +01:00
"Epoch 10/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3290 - accuracy: 0.8826 - val_loss: 0.3504 - val_accuracy: 0.8720 - lr: 0.0035\n",
2022-02-19 10:24:54 +01:00
"Epoch 11/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3236 - accuracy: 0.8844 - val_loss: 0.3458 - val_accuracy: 0.8736 - lr: 0.0032\n",
2022-02-19 10:24:54 +01:00
"Epoch 12/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3186 - accuracy: 0.8869 - val_loss: 0.3459 - val_accuracy: 0.8752 - lr: 0.0028\n",
2022-02-19 10:24:54 +01:00
"Epoch 13/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3147 - accuracy: 0.8878 - val_loss: 0.3359 - val_accuracy: 0.8770 - lr: 0.0025\n",
2022-02-19 10:24:54 +01:00
"Epoch 14/25\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.3109 - accuracy: 0.8890 - val_loss: 0.3404 - val_accuracy: 0.8762 - lr: 0.0022\n",
2022-02-19 10:24:54 +01:00
"Epoch 15/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3076 - accuracy: 0.8902 - val_loss: 0.3398 - val_accuracy: 0.8790 - lr: 0.0020\n",
2022-02-19 10:24:54 +01:00
"Epoch 16/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3043 - accuracy: 0.8915 - val_loss: 0.3331 - val_accuracy: 0.8784 - lr: 0.0018\n",
2022-02-19 10:24:54 +01:00
"Epoch 17/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3020 - accuracy: 0.8924 - val_loss: 0.3363 - val_accuracy: 0.8774 - lr: 0.0016\n",
2022-02-19 10:24:54 +01:00
"Epoch 18/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2998 - accuracy: 0.8927 - val_loss: 0.3356 - val_accuracy: 0.8778 - lr: 0.0014\n",
2022-02-19 10:24:54 +01:00
"Epoch 19/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2979 - accuracy: 0.8935 - val_loss: 0.3309 - val_accuracy: 0.8796 - lr: 0.0013\n",
2022-02-19 10:24:54 +01:00
"Epoch 20/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2961 - accuracy: 0.8940 - val_loss: 0.3308 - val_accuracy: 0.8782 - lr: 0.0011\n",
2022-02-19 10:24:54 +01:00
"Epoch 21/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2944 - accuracy: 0.8951 - val_loss: 0.3286 - val_accuracy: 0.8802 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 22/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2930 - accuracy: 0.8953 - val_loss: 0.3313 - val_accuracy: 0.8804 - lr: 8.9125e-04\n",
2022-02-19 10:24:54 +01:00
"Epoch 23/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2916 - accuracy: 0.8957 - val_loss: 0.3285 - val_accuracy: 0.8796 - lr: 7.9433e-04\n",
2022-02-19 10:24:54 +01:00
"Epoch 24/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2904 - accuracy: 0.8961 - val_loss: 0.3313 - val_accuracy: 0.8786 - lr: 7.0795e-04\n",
2022-02-19 10:24:54 +01:00
"Epoch 25/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2896 - accuracy: 0.8962 - val_loss: 0.3296 - val_accuracy: 0.8812 - lr: 6.3096e-04\n"
2022-02-19 10:24:54 +01:00
]
}
],
"source": [
"n_epochs = 20\n",
"\n",
2021-10-17 04:04:08 +02:00
"lr_scheduler = tf.keras.callbacks.LearningRateScheduler(exponential_decay_fn)\n",
"history = model.fit(X_train, y_train, epochs=n_epochs,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[lr_scheduler])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Alternatively, the schedule function can take the current learning rate as a second argument:"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"def exponential_decay_fn(epoch, lr):\n",
" return lr * 0.1 ** (1 / 20)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Extra material**: if you want to use a custom scheduling function that updates the learning rate at each iteration rather than at each epoch, you can write your own callback class like this:"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"K = tf.keras.backend\n",
"\n",
2021-10-17 04:04:08 +02:00
"class ExponentialDecay(tf.keras.callbacks.Callback):\n",
" def __init__(self, n_steps=40_000):\n",
" super().__init__()\n",
" self.n_steps = n_steps\n",
"\n",
" def on_batch_begin(self, batch, logs=None):\n",
" # Note: the `batch` argument is reset at each epoch\n",
" lr = K.get_value(self.model.optimizer.learning_rate)\n",
" new_learning_rate = lr * 0.1 ** (1 / self.n_steps)\n",
" K.set_value(self.model.optimizer.learning_rate, new_learning_rate)\n",
"\n",
" def on_epoch_end(self, epoch, logs=None):\n",
" logs = logs or {}\n",
" logs['lr'] = K.get_value(self.model.optimizer.learning_rate)"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"lr0 = 0.01\n",
"model = build_model()\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=lr0)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.6947 - accuracy: 0.7635 - val_loss: 0.5014 - val_accuracy: 0.8224 - lr: 0.0091\n",
2022-02-19 10:24:54 +01:00
"Epoch 2/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4718 - accuracy: 0.8349 - val_loss: 0.4530 - val_accuracy: 0.8382 - lr: 0.0083\n",
2022-02-19 10:24:54 +01:00
"Epoch 3/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4255 - accuracy: 0.8500 - val_loss: 0.4216 - val_accuracy: 0.8526 - lr: 0.0076\n",
2022-02-19 10:24:54 +01:00
"Epoch 4/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4025 - accuracy: 0.8587 - val_loss: 0.3954 - val_accuracy: 0.8618 - lr: 0.0069\n",
2022-02-19 10:24:54 +01:00
"Epoch 5/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3840 - accuracy: 0.8643 - val_loss: 0.3847 - val_accuracy: 0.8612 - lr: 0.0063\n",
2022-02-19 10:24:54 +01:00
"Epoch 6/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3696 - accuracy: 0.8689 - val_loss: 0.3908 - val_accuracy: 0.8558 - lr: 0.0058\n",
2022-02-19 10:24:54 +01:00
"Epoch 7/25\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.3590 - accuracy: 0.8722 - val_loss: 0.3744 - val_accuracy: 0.8670 - lr: 0.0052\n",
2022-02-19 10:24:54 +01:00
"Epoch 8/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3498 - accuracy: 0.8749 - val_loss: 0.3754 - val_accuracy: 0.8640 - lr: 0.0048\n",
2022-02-19 10:24:54 +01:00
"Epoch 9/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3415 - accuracy: 0.8783 - val_loss: 0.3592 - val_accuracy: 0.8700 - lr: 0.0044\n",
2022-02-19 10:24:54 +01:00
"Epoch 10/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3340 - accuracy: 0.8803 - val_loss: 0.3575 - val_accuracy: 0.8724 - lr: 0.0040\n",
2022-02-19 10:24:54 +01:00
"Epoch 11/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3281 - accuracy: 0.8833 - val_loss: 0.3573 - val_accuracy: 0.8718 - lr: 0.0036\n",
2022-02-19 10:24:54 +01:00
"Epoch 12/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3228 - accuracy: 0.8847 - val_loss: 0.3579 - val_accuracy: 0.8688 - lr: 0.0033\n",
2022-02-19 10:24:54 +01:00
"Epoch 13/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3182 - accuracy: 0.8865 - val_loss: 0.3421 - val_accuracy: 0.8756 - lr: 0.0030\n",
2022-02-19 10:24:54 +01:00
"Epoch 14/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3138 - accuracy: 0.8882 - val_loss: 0.3468 - val_accuracy: 0.8766 - lr: 0.0028\n",
2022-02-19 10:24:54 +01:00
"Epoch 15/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3101 - accuracy: 0.8889 - val_loss: 0.3471 - val_accuracy: 0.8766 - lr: 0.0025\n",
2022-02-19 10:24:54 +01:00
"Epoch 16/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3064 - accuracy: 0.8898 - val_loss: 0.3386 - val_accuracy: 0.8752 - lr: 0.0023\n",
2022-02-19 10:24:54 +01:00
"Epoch 17/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3035 - accuracy: 0.8903 - val_loss: 0.3417 - val_accuracy: 0.8758 - lr: 0.0021\n",
2022-02-19 10:24:54 +01:00
"Epoch 18/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3005 - accuracy: 0.8919 - val_loss: 0.3398 - val_accuracy: 0.8768 - lr: 0.0019\n",
2022-02-19 10:24:54 +01:00
"Epoch 19/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2983 - accuracy: 0.8929 - val_loss: 0.3357 - val_accuracy: 0.8766 - lr: 0.0017\n",
2022-02-19 10:24:54 +01:00
"Epoch 20/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2959 - accuracy: 0.8939 - val_loss: 0.3370 - val_accuracy: 0.8752 - lr: 0.0016\n",
2022-02-19 10:24:54 +01:00
"Epoch 21/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2940 - accuracy: 0.8938 - val_loss: 0.3346 - val_accuracy: 0.8782 - lr: 0.0014\n",
2022-02-19 10:24:54 +01:00
"Epoch 22/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2917 - accuracy: 0.8949 - val_loss: 0.3361 - val_accuracy: 0.8766 - lr: 0.0013\n",
2022-02-19 10:24:54 +01:00
"Epoch 23/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2902 - accuracy: 0.8955 - val_loss: 0.3349 - val_accuracy: 0.8796 - lr: 0.0012\n",
2022-02-19 10:24:54 +01:00
"Epoch 24/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2884 - accuracy: 0.8959 - val_loss: 0.3364 - val_accuracy: 0.8796 - lr: 0.0011\n",
2022-02-19 10:24:54 +01:00
"Epoch 25/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2871 - accuracy: 0.8969 - val_loss: 0.3352 - val_accuracy: 0.8802 - lr: 1.0000e-03\n"
2022-02-19 10:24:54 +01:00
]
}
],
"source": [
"import math\n",
"\n",
"batch_size = 32\n",
"n_steps = n_epochs * math.ceil(len(X_train) / batch_size)\n",
"exp_decay = ExponentialDecay(n_steps)\n",
"history = model.fit(X_train, y_train, epochs=n_epochs,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[exp_decay])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Piecewise Constant Scheduling"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [],
"source": [
"lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(\n",
" boundaries=[50_000, 80_000],\n",
" values=[0.01, 0.005, 0.001]\n",
")\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule)"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.6942 - accuracy: 0.7617 - val_loss: 0.4892 - val_accuracy: 0.8318\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4751 - accuracy: 0.8340 - val_loss: 0.4603 - val_accuracy: 0.8346\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4280 - accuracy: 0.8500 - val_loss: 0.4245 - val_accuracy: 0.8542\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4035 - accuracy: 0.8581 - val_loss: 0.3867 - val_accuracy: 0.8626\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3828 - accuracy: 0.8650 - val_loss: 0.3827 - val_accuracy: 0.8634\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3665 - accuracy: 0.8700 - val_loss: 0.3880 - val_accuracy: 0.8608\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.3539 - accuracy: 0.8730 - val_loss: 0.3669 - val_accuracy: 0.8688\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3423 - accuracy: 0.8773 - val_loss: 0.3583 - val_accuracy: 0.8708\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3322 - accuracy: 0.8807 - val_loss: 0.3447 - val_accuracy: 0.8758\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3218 - accuracy: 0.8832 - val_loss: 0.3488 - val_accuracy: 0.8716\n"
]
}
],
"source": [
"history_piecewise_scheduling = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {
"scrolled": true
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAk4AAAHNCAYAAADolfQeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAABTFUlEQVR4nO3deVxU1f8/8NfIMoAKKiiIsrokRLmAKcqikbgvpUlqpLn0JVJR+pjikssnJXP5kCmSRi4tSqWlFZqoiBtuLGauH01FCURcQCVhgPP7wx/34zQDXkZkBng9Hw8fOWfec++ZOYO8uvfccxVCCAEiIiIieqJ6+u4AERERUU3B4EREREQkE4MTERERkUwMTkREREQyMTgRERERycTgRERERCQTgxMRERGRTAxORERERDIxOBERERHJxOBEBMDZ2RnOzs767sYz06NHDygUCn13g2qZefPmQaFQYN++fdW+7ytXrkChUGDMmDFPtZ3y3oNCoUCPHj2eattUOzE4Ua1U9o/q439MTU3h4OCAkSNH4vfff9d3F+u0CxcuYNKkSXj++edhaWkJpVIJR0dHDBs2DFu2bEFpaam+u6ihqn5R62r9+vVQKBRYv359pV/7xx9/YPTo0XB2doZSqYSVlRVat26N1157DZ9++il45y0i+Yz13QGiZ6lVq1Z48803AQD379/HkSNHsGnTJmzduhV79+5Ft27dAAB79uzRZzefuY0bN6KgoEDf3QAALFu2DNOnT0dpaSl8fHzQq1cvWFhY4Nq1a9i9eze2bNmCsWPHIjY2Vt9drRUSEhIwYMAAFBcXIyAgAK+++ioA4M8//8ShQ4fw448/4r333oOxMX8dPO7s2bOwsLDQdzfIAPEnhWq11q1bY968eWpts2fPxsKFCzFr1iwkJiYCeBSwajNHR0d9dwEAsGbNGvzrX/+Cs7MztmzZgk6dOqk9X1xcjA0bNuDAgQN66mHt8+6776KkpAS7d+9Gz5491Z4TQmDXrl0wMjLSU+8MV7t27fTdBTJQPFVHdc6kSZMAAMePH5faypvjJITAl19+ie7du8PS0hIWFhbw8vLCl19+qXXbQghs2LABfn5+aNSoESwsLNCmTRuEhIQgIyNDrfbevXuYO3cunn/+eZibm6NRo0bo06cPDh48qFY3ZcoUKBQKpKenq7X3798fCoUC48ePV2vfsWMHFAoFFi9eLLVpm+NUWlqKL774Ai+99BKaNGkCCwsLODs7Y8iQIdi/f7/Ge9u/fz8GDhwIGxsbKJVKtGnTBrNnz5Z9JCsvLw/Tpk2Dqakpfv31V43QBADGxsYYN24cPv/8c7X2goICzJs3D+3atYOZmRmaNGmC/v374/DhwxrbeHzOynfffYdOnTrB3NwczZs3x+TJk/H3339rvGbLli3w9/dHs2bNYGZmBgcHB/Tp0wc//fQTgEenyVxcXAAAGzZsUDsFXDY35q+//sLcuXPRtWtXNGvWDEqlEs7OzggNDUVOTo7GPseMGQOFQoErV64gOjoabm5uMDMzg5OTE+bPn692unLMmDF4++23AQBvv/222v4rkpOTg0uXLsHDw0MjNAGP5vH07t1b63YOHDiAV199Fba2tlAqlXBwcMBrr72m8f0sI/ezBir3XSopKcHixYvRunVrmJmZoXXr1oiMjCz3dG5Fc5MqM5dR23YqM2ZlCgoK8MEHH8DBwQFmZmbw8PDA2rVrsW/fPigUCo3/sSPDxyNOVOfInSQthMCbb76Jb7/9Fm3btsXIkSNhamqKhIQEjBs3DmfOnMHSpUvV6keMGIG4uDi0aNECI0aMgKWlJa5cuYK4uDj06dNHOvJz+/Zt+Pn54fTp0/D19UXv3r2Rl5eHbdu2oWfPnvj+++8xZMgQAEDPnj3x6aefIjExER06dADw6JdJ2S+wsqNmZcp+kWv7Rfm4iIgIfPLJJ2jVqhVGjhyJhg0bIjMzEwcOHMDevXvh5+cn1cbExCA0NBSNGzfGwIED0bRpUxw/fhwLFy5EYmIiEhMTYWpqWuH+vv/+e+Tn52PkyJFwd3evsFapVEp/LywsREBAAI4cOYJOnTphypQpyMnJQVxcHHbt2oW4uDi89tprGttYtWoVduzYgcGDB6NHjx7YuXMnPvvsM9y6dQvffPONVLd69WqEhoaiefPmePXVV2FtbY2srCwcO3YMP/30E4YMGYIOHTogLCwMn376Kdq3by+NDQDpF/H+/fuxbNkyBAQEoEuXLjAxMUFaWhpWr16N3377DampqbCystLo57Rp07Bv3z4MGDAAgYGB+OmnnzBv3jwUFRVh4cKFAIAhQ4bg7t272LZtGwYPHix9D57EysoKRkZGyMrKwoMHD1C/fn1Zr1u1ahUmTZoEc3NzvPrqq3B0dERmZiYOHjyIH374AT4+Pjp91kDlv0vvvPMOvvzyS7i4uOC9997Dw4cPsXz5cq2hubrIGTPg0c/pgAEDkJiYiPbt22PkyJG4ffs23n//fU48r8kEUS10+fJlAUD07t1b47lZs2YJAKJHjx5Sm5OTk3ByclKrW7NmjQAgxo0bJ1QqldReWFgoBg4cKACIEydOSO2rVq0SAERAQIAoKChQ21ZBQYG4deuW9HjkyJECgPjyyy/V6rKzs4WDg4No2rSp+Pvvv4UQQty5c0fUq1dPDBw4UKo7evSotC8A4urVq9JznTt3Fg0bNhTFxcVSm7+/v/jnj3uTJk1EixYtxIMHD9TaS0tL1fp6+vRpYWxsLDp27KjWLoQQkZGRAoBYunSpeJIxY8YIAOKLL754Yu3jFixYIACIUaNGidLSUqn95MmTQqlUisaNG4v8/Hypfe7cuQKAsLKyEufOnZPaCwoKRNu2bYVCoRCZmZlSe6dOnYSpqanIycnR2Hdubq7097Lv1OjRo7X288aNG+LevXsa7Rs2bBAAxEcffaTWPnr0aAFAuLi4iL/++ktqv3nzpmjUqJFo2LChKCwslNrXrVsnAIh169Zp3X95hgwZIgCIDh06iOjoaJGeni6KiorKrf/999+FkZGRsLe3F5cvX1Z7rrS0VO2zq+xnXdnvUmJiogAg2rdvL+7fvy+1X79+XdjY2GgdDwDC399f63vT9nNe9h4SExOfuJ3KjtkXX3whAIhBgwaJkpISqf3s2bPCzMxMABBz587V2lcyXDxVR7XaxYsXMW/ePMybNw//+te/4OPjg4ULF8LMzAyLFi2q8LUrV65E/fr1sXLlSrWJs6amptL/VW7atElqX7VqFYyMjLB69WqYm5urbcvc3BxNmjQBAOTm5iIuLg4BAQHS6Zcytra2mDZtGm7evIndu3cDABo1aoT27dtj//79KCkpAfDoKNPjh/n37t0LAMjPz0dqaip8fX1lzVsxNTXVmBSsUCikvgLA559/juLiYqxYsUKtHQA++OADNG3aVO1zKE92djYAoGXLlk+sfdz69ethYmKCjz/+WO1o4YsvvogxY8bgzp072LZtm8brwsLC8Nxzz0mPzc3NMWLECAghkJKSolZrYmICExMTjW1YW1vL7mezZs3QoEEDjfbg4GBYWlpK4/lPc+bMQfPmzaXHNjY2GDx4MO7du4fz58/L3n951q5di/79+yM9PR2hoaHo0KEDGjRogO7du2PFihUap9NiYmJQUlKCjz76SOO0lkKhgL29vcY+5H7Wlf0ubdy4EQDw4Ycfqh0ta9GiBcLCwir/YVQRuWP29ddfAwD+/e9/o169//26bdeuHUaPHl19HaYqxVN1VKtdunQJ8+fPB/Dol6OtrS1GjhyJGTNm4IUXXij3dQUFBTh16hTs7e3x8ccfazyvUqkAAOfOnQMAPHjwAGfOnEHr1q3Rpk2bCvt0/PhxlJSU4OHDh1rnN/z3v/+Vtj1gwAAAj067paWlITU1FZ07d5YO/fv4+MDOzg6JiYkYM2aMFK6edJoOAIYPH46YmBh4eHggKCgI/v7+8Pb21jidc+TIEQDAzp07tf7yNzExkT6Hqpafn48///wTbm5uWgNXjx498PnnnyM9PV26erKMtjlUZdu4e/eu1DZ8+HDMmDEDHh4eeOONN9CjRw/
2022-02-19 10:24:54 +01:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
2022-02-19 10:24:54 +01:00
]
},
"metadata": {},
2022-02-19 10:24:54 +01:00
"output_type": "display_data"
}
],
"source": [
"# extra code this cell plots piecewise constant scheduling\n",
"\n",
"boundaries = [50_000, 80_000]\n",
"values = [0.01, 0.005, 0.001]\n",
"\n",
"steps = np.arange(100_000)\n",
"\n",
"lrs = np.full(len(steps), values[0])\n",
"for boundary, value in zip(boundaries, values[1:]):\n",
" lrs[boundary:] = value\n",
"\n",
"plt.plot(steps, lrs, \"-\")\n",
"plt.axis([0, steps.max(), 0, 0.0105])\n",
"plt.xlabel(\"Step\")\n",
"plt.ylabel(\"Learning Rate\")\n",
"plt.title(\"Piecewise Constant Scheduling\", fontsize=14)\n",
"plt.grid(True)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Just like we did with exponential scheduling, we could also implement piecewise constant scheduling manually:"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
"def piecewise_constant_fn(epoch):\n",
" if epoch < 5:\n",
" return 0.01\n",
" elif epoch < 15:\n",
" return 0.005\n",
" else:\n",
" return 0.001"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
"# extra code this cell demonstrates a more general way to define\n",
"# piecewise constant scheduling.\n",
"\n",
"def piecewise_constant(boundaries, values):\n",
" boundaries = np.array([0] + boundaries)\n",
" values = np.array(values)\n",
" def piecewise_constant_fn(epoch):\n",
" return values[(boundaries > epoch).argmax() - 1]\n",
" return piecewise_constant_fn\n",
"\n",
"piecewise_constant_fn = piecewise_constant([5, 15], [0.01, 0.005, 0.001])"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
"1719/1719 [==============================] - 5s 2ms/step - loss: 0.5433 - accuracy: 0.8087 - val_loss: 0.4586 - val_accuracy: 0.8288 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 2/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4487 - accuracy: 0.8439 - val_loss: 0.4608 - val_accuracy: 0.8350 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 3/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4263 - accuracy: 0.8502 - val_loss: 0.4234 - val_accuracy: 0.8568 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 4/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4241 - accuracy: 0.8537 - val_loss: 0.4359 - val_accuracy: 0.8490 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 5/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.4080 - accuracy: 0.8584 - val_loss: 0.4165 - val_accuracy: 0.8560 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 6/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3544 - accuracy: 0.8738 - val_loss: 0.3830 - val_accuracy: 0.8662 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 7/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3464 - accuracy: 0.8761 - val_loss: 0.4026 - val_accuracy: 0.8652 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 8/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3426 - accuracy: 0.8772 - val_loss: 0.4212 - val_accuracy: 0.8544 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 9/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3417 - accuracy: 0.8793 - val_loss: 0.4116 - val_accuracy: 0.8612 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 10/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3339 - accuracy: 0.8804 - val_loss: 0.4090 - val_accuracy: 0.8618 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 11/25\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.3309 - accuracy: 0.8819 - val_loss: 0.4033 - val_accuracy: 0.8746 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 12/25\n",
"1719/1719 [==============================] - 5s 3ms/step - loss: 0.3270 - accuracy: 0.8826 - val_loss: 0.4518 - val_accuracy: 0.8630 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 13/25\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.3270 - accuracy: 0.8837 - val_loss: 0.3714 - val_accuracy: 0.8674 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 14/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3247 - accuracy: 0.8844 - val_loss: 0.4026 - val_accuracy: 0.8652 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 15/25\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.3204 - accuracy: 0.8852 - val_loss: 0.3993 - val_accuracy: 0.8724 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 16/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2859 - accuracy: 0.8963 - val_loss: 0.3930 - val_accuracy: 0.8736 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 17/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2781 - accuracy: 0.8978 - val_loss: 0.4021 - val_accuracy: 0.8714 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 18/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2743 - accuracy: 0.8984 - val_loss: 0.3955 - val_accuracy: 0.8754 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 19/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2704 - accuracy: 0.8999 - val_loss: 0.4015 - val_accuracy: 0.8756 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 20/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2683 - accuracy: 0.9015 - val_loss: 0.4161 - val_accuracy: 0.8756 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 21/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2655 - accuracy: 0.9020 - val_loss: 0.4207 - val_accuracy: 0.8740 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 22/25\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.2646 - accuracy: 0.9020 - val_loss: 0.4497 - val_accuracy: 0.8746 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 23/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2626 - accuracy: 0.9032 - val_loss: 0.4429 - val_accuracy: 0.8762 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 24/25\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.2608 - accuracy: 0.9038 - val_loss: 0.4566 - val_accuracy: 0.8748 - lr: 0.0010\n",
2022-02-19 10:24:54 +01:00
"Epoch 25/25\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.2587 - accuracy: 0.9038 - val_loss: 0.4726 - val_accuracy: 0.8770 - lr: 0.0010\n"
2022-02-19 10:24:54 +01:00
]
}
],
"source": [
"# extra code use a tf.keras.callbacks.LearningRateScheduler like earlier\n",
"\n",
"n_epochs = 25\n",
"\n",
"lr_scheduler = tf.keras.callbacks.LearningRateScheduler(piecewise_constant_fn)\n",
"\n",
"model = build_model()\n",
"optimizer = tf.keras.optimizers.Nadam(learning_rate=lr0)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, y_train, epochs=n_epochs,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[lr_scheduler])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We've looked at `InverseTimeDecay`, `ExponentialDecay`, and `PiecewiseConstantDecay`. A few more schedulers are available in `tf.keras.optimizers.schedules`, here is the full list:"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"• CosineDecay A LearningRateSchedule that uses a cosine decay with optional warmup.\n",
"• CosineDecayRestarts A LearningRateSchedule that uses a cosine decay schedule with restarts.\n",
"• ExponentialDecay A LearningRateSchedule that uses an exponential decay schedule.\n",
"• InverseTimeDecay A LearningRateSchedule that uses an inverse time decay schedule.\n",
"• LearningRateSchedule The learning rate schedule base class.\n",
"• PiecewiseConstantDecay A LearningRateSchedule that uses a piecewise constant decay schedule.\n",
"• PolynomialDecay A LearningRateSchedule that uses a polynomial decay schedule.\n"
]
2022-02-19 10:24:54 +01:00
}
],
"source": [
"for name in sorted(dir(tf.keras.optimizers.schedules)):\n",
" if name[0] == name[0].lower(): # must start with capital letter\n",
" continue\n",
" scheduler_class = getattr(tf.keras.optimizers.schedules, name)\n",
" print(f\"• {name} {scheduler_class.__doc__.splitlines()[0]}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Performance Scheduling"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [],
"source": [
"# extra code build and compile the model\n",
"\n",
"model = build_model()\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=lr0)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.6807 - accuracy: 0.7679 - val_loss: 0.4814 - val_accuracy: 0.8310 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 2/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4659 - accuracy: 0.8343 - val_loss: 0.4615 - val_accuracy: 0.8306 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 3/25\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.4201 - accuracy: 0.8505 - val_loss: 0.4199 - val_accuracy: 0.8490 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 4/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3957 - accuracy: 0.8590 - val_loss: 0.3845 - val_accuracy: 0.8614 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 5/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3754 - accuracy: 0.8658 - val_loss: 0.3742 - val_accuracy: 0.8614 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 6/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3588 - accuracy: 0.8709 - val_loss: 0.3853 - val_accuracy: 0.8628 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 7/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3469 - accuracy: 0.8740 - val_loss: 0.3627 - val_accuracy: 0.8690 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 8/25\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.3346 - accuracy: 0.8785 - val_loss: 0.3574 - val_accuracy: 0.8680 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 9/25\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.3244 - accuracy: 0.8828 - val_loss: 0.3410 - val_accuracy: 0.8748 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 10/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3149 - accuracy: 0.8850 - val_loss: 0.3410 - val_accuracy: 0.8720 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 11/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.3074 - accuracy: 0.8879 - val_loss: 0.3629 - val_accuracy: 0.8678 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 12/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2990 - accuracy: 0.8920 - val_loss: 0.3379 - val_accuracy: 0.8746 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 13/25\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.2929 - accuracy: 0.8938 - val_loss: 0.3223 - val_accuracy: 0.8808 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 14/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2867 - accuracy: 0.8947 - val_loss: 0.3405 - val_accuracy: 0.8754 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 15/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2807 - accuracy: 0.8972 - val_loss: 0.3480 - val_accuracy: 0.8730 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 16/25\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.2743 - accuracy: 0.8998 - val_loss: 0.3350 - val_accuracy: 0.8766 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 17/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2694 - accuracy: 0.9019 - val_loss: 0.3421 - val_accuracy: 0.8764 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 18/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2631 - accuracy: 0.9032 - val_loss: 0.3360 - val_accuracy: 0.8772 - lr: 0.0100\n",
2022-02-19 10:24:54 +01:00
"Epoch 19/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2445 - accuracy: 0.9110 - val_loss: 0.3162 - val_accuracy: 0.8874 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 20/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2410 - accuracy: 0.9131 - val_loss: 0.3221 - val_accuracy: 0.8812 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 21/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2380 - accuracy: 0.9137 - val_loss: 0.3166 - val_accuracy: 0.8828 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 22/25\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.2351 - accuracy: 0.9148 - val_loss: 0.3146 - val_accuracy: 0.8854 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 23/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2330 - accuracy: 0.9160 - val_loss: 0.3191 - val_accuracy: 0.8836 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 24/25\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.2300 - accuracy: 0.9161 - val_loss: 0.3175 - val_accuracy: 0.8878 - lr: 0.0050\n",
2022-02-19 10:24:54 +01:00
"Epoch 25/25\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.2276 - accuracy: 0.9174 - val_loss: 0.3205 - val_accuracy: 0.8868 - lr: 0.0050\n"
2022-02-19 10:24:54 +01:00
]
}
],
"source": [
2021-10-17 04:04:08 +02:00
"lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5)\n",
"history = model.fit(X_train, y_train, epochs=n_epochs,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[lr_scheduler])"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcgAAAEbCAYAAABeCxRrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABM20lEQVR4nO2dd5hUVdKH34JhSCoyYkBRQMS4q4iIcRVUFFlXxIS45lVEBcHVNawJRXfVNaEoiIrKigHFwKcYkTGsCVQyIogJGMGMCAoz1PdH3ZE7Pd3Tt2d6unum632e+/QN55xb99LMr885dapEVXEcx3EcpyINsm2A4ziO4+QiLpCO4ziOEwcXSMdxHMeJgwuk4ziO48TBBdJxHMdx4uAC6TiO4zhxcIF06jQislJETsu2HfUNEflcRC7Kth2Ok01cIJ1aR0QeFBENtlIR+VJERopIy2zblg5EpFvwbK0SXB8aev51IrJURMaJyNaZtjWw57SQPSoiJSIyXkTa17DNlem003GyjQukkyleBVoD7YAzgb8Ad2fToAwzH3v+NkBf4I/A+CzasyqwZ0vgRKATMFFEGmbRJsfJKVwgnUzxm6p+raqLVfVl4HHg0HABETldROaKyK8i8omIXCAiDULXtxOR4uD6fBE5IqZ+u6BH1CXmvIrIsaHjLYMe3HciskpEpotI99D1v4jIB8F9PhOR60WksIbPXxo8/1JVfRO4F9hbRDaqqpKIHC0is0TkNxH5SkQuFxEJXf9cRK4QkXtEZIWILBaRf0SwRwN7SlR1CnAN8AdguwR2/F1EZorILyKyRETuE5GNg2vdgAeA5qFe6dDgWqGI3BjY9YuITBWRw0LtNhSR+4P3vFpEFojIxTH/7g+KyHMx9gwVkdkRntNxqk1Btg1w8g8R2RboCawNnTsLuBYYBHyA/bG+NygzIviD+TTwA7AP0AwYDjRO8d7NgdeB5UAfYAmwW+j6YcA4YDDwBrANMCq4T1rm5ERkC+BooCzYEpXbA3gCuC6waU/gHmAFcGeo6AXA1cB/gMOBO0TkLVV9JwWzVgefjRJcXwcMARYBbYP73wmcDLwdXPsX0CEoXz7c+kBw7kRgMdAL+D8R2VNVZ2A/0pcAxwPfAF2B0cB3wP0p2O846UdVffOtVjfgQaAU+6O5GtBguyBU5kvg5Jh6Q4C5wf6hmJhsE7q+f9DOacFxu+C4S0w7Chwb7J8F/Ay0SmDrG8CVMeeOCmyXBHW6BfdI1ObQwPaV2NBm+fMPT/LexgGvxWlrcej4c+DRmDILgCuqaPc0YGXouA3wDvAVUBhq96Iq2ugJ/AY0iNdmcK4DJqzbxJx/Bri7irZvAF6N+f48F+c9zM72d9u3+r15D9LJFG8A/YGmmEh1AO4AEJFNga2Be0RkZKhOAVA+nLgTsERVvwxdfw/7A5wKuwMzVfXbBNf3ALqKyCWhcw0Cu7cASlK8XzmfYr2nxkBv4Bjgn0nq7AQ8H3PuLeBqEdlIVVcE52bGlFkKbJak7eaBU41gvfEPgaNVdU28wiJyEHBZYFMLoCFQiL2TpQnu0Tlof25oVBjsHbwWansANi/dFnvPjYAvktjvOLWOC6STKVap6sJg/3wRmQJcifUEyuebBmDDdfGQBOfDlItleI4udsgwWTsNsPm4J+Jc+yaCDYlYE3r+OSLSEbgL63klQrCeZjzC59fGuZbMv2AV5pizDlimqr8kNEKkLSbU9wJXYcOfnYFHMZFMRIPAlj3j2Lg6aLsvcDs2fP02Nnx8Hjb8Xc46Kv+7JRoKdpy04QLpZItrgBdEZLSqLhWRJUAHVR2boPxcYCsR2VpVvwrOdaWiEJQLWOvQuU4x7XwInCQirRL0Ij8EdgyJWW0xDJgvIneq6gcJyszFhpHD7I8Nsf5cw/trCs/YBRPCC1S1DCDWQQpYg/Uqw3yECdsWao5A8dgfeE9VR5SfEJEOMWW+ofK/Y+yx46Qd92J1soKqFgNzgCuCU0OBiwPP1R1E5A8icoqIXBZcfxX4GBgrIp1EZB/gNmxus7zN1cC7wCUisouI7AvcHHPrRzAHnWdE5E8i0l5Ejgx5sV4LnCgi1wY27Cgix4rITREe6w+BbeEt7v8xVV0ETMSEMhG3AAcGHpvbi8hfgQuBKLakkwXY34ohwfvqh80Ph/kcaCIiPUSklYg0U9VPsHnUB4N3uK2IdBGRi0Tk6KDeJ0BnETlcRDqKyJXAgTFtvwbsLiJniHkyXwzsV0vP6jjryfYkqG/1fyOOk0Vw/kTM0aNtcNwP68H9inmrvgWcECq/PeaB+hv2R/tIzPHltFCZnYD/YUOIs4A/EXLSCcq0wZaZ/BiU+wjoFrp+KPBmcG0FMA0YWMXzdWO9403stgEJHEqAfYMy+1bR9tHBc6zBnGguJ+QsRBxnGqAYGFFFm6cR41ATp0yFdoHzMW/T1cBkzOtUgXahMiOBb4PzQ4NzjYLnXxQ8w9fYD4M9guuFmLfqD8G/x/3YMO7nMfYMxeZ/f8LWz/4r3jv1zbd0bqKaaIrDcRzHcfIXH2J1HMdxnDi4QDqO4zhOHFwgHcdxHCcOLpCO4ziOE4e8XAfZoEEDbdq0abbNyDnWrVtHgwb+mymMv5P4+HuJT31/L6tWrVJVrb8PGENeCmRhYSG//JIwcEjeUlxcTLdu3bJtRk7h7yQ+/l7iU9/fi4isTl6q/pA3vwQcx3EcJxVcIB3HcRwnDi6QjuM4jhMHF0jHcRzHiYMLpOM4juPEIaMCKUJPEeaLsFCES+NcFxHuCK7PFKFz6NoYEZaLMDumTpEIr4iwIPhsmcyOtr+tZc+tv2bcuGh2jxsH7dpBgwb2WV/rHXTQgXXCzkzU83dSdb1cfy+OkxYyFRUdtCHop6DbghaCzgDdOaZML9AXQAV0b9D3QtcOAO0MOjumzk2glwb7l4LemMyWPUBHcK42a6b68MNaJQ8/rNqsmSqs37xe/a5XF2z0ernJlClTsm1CrQL8olnOsJHJLWPZPETYBxiqymHB8WUm0Pw7VOYeoFiVR4Pj+UA3VUqC43bAc6r8IVTn9zIitA7q71CVLV1E9A2asi2L+LXFFpx/fuKyd9wBP/1U+XyLFni9elqvLtiYz/XatoXPP09cL5vkwTrIVaraPNt2ZIpMCuSxQE9VzgyOTwb2UmVgqMxzwA2qvBUcTwYuUWVacNyOygL5oyobh45/UK08zCpCf6A/wB7IHm9RyP2cyUBGIJLYbns98Qqo16un9eqCjflcT0R57bXXE1fMIitXrmSDDTbIthm1Rvfu3fNKIDPWVQU9DvS+0PHJoHfGlHkedP/Q8WTQPULH7eIMsf4Yc/xDMlv2CMZrfqGpdmlTEjuKUIG2bbXCEE/51rZtldW8Xh2uVxds9Hq5Sd4PsUJPhfkKCxUuraLcngplWp7IHHZQmB7aVigMCa4NVVgSutarShvSuGXkJvZedR/Ql0LHl4FeFlPmHtB+oeP5oK1Dx/EE8vcyoK1B5yezpVwgV1OoHx98bpVfiLoy7+L10levLtjo9XKTvBZIaKjwqcK2CoUKMxR2TlDuNYVJvwtk5etfK7QNjocqXJTwvrWpWxm7EVoAugi0PeuddHaJKfNnKjrpvB9zPZ5A/oeKTjo3JbNlj/D/uE6dkn4pHn7YfrWK2GfU/6B1r966OmJn7dfzd5KsXubeS8uW9l+1TZvcFkfVvBfIfRReCh1fpnBZnHJDFM5TeDCBQB6q8L/Qcf0XSHu32gv0E8yb9fLg3ADQAcG+gN4VXJ8F2iVU91HQEtC1oItB/xac3wQbil0QfBYls2PzggJ79NdfT+GrUf+p7/+5q4O/k/hk8r2MG2f/XefPz9gtq019/760gt8UpoW2/rpeyI5VuC90fLLCCA3//YWtFF4PeomJBHKMwsDQ8VCFzxVmBtdaVqpTS1tGs3moMgmYFHNuVGhfgfMS1O2X4Px3wMGp2LGiYUMoLIRHHoEDDkilquM4GaaoyD6//z67djjwLZSi2iXB5fieWBW5HbgE1bK4HloihcCRYKscAkYCw4K2hgG3AGekZHg1yctIOusAeveGJ56ANWu
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell plots performance scheduling\n",
"\n",
"plt.plot(history.epoch, history.history[\"lr\"], \"bo-\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Learning Rate\", color='b')\n",
"plt.tick_params('y', colors='b')\n",
"plt.gca().set_xlim(0, n_epochs - 1)\n",
"plt.grid(True)\n",
"\n",
"ax2 = plt.gca().twinx()\n",
"ax2.plot(history.epoch, history.history[\"val_loss\"], \"r^-\")\n",
"ax2.set_ylabel('Validation Loss', color='r')\n",
"ax2.tick_params('y', colors='r')\n",
"\n",
"plt.title(\"Reduce LR on Plateau\", fontsize=14)\n",
"plt.show()"
]
},
2019-05-05 06:42:08 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1Cycle scheduling"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `ExponentialLearningRate` custom callback updates the learning rate during training, at the end of each batch. It multiplies it by a constant `factor`. It also saves the learning rate and loss at each batch. Since `logs[\"loss\"]` is actually the mean loss since the start of the epoch, and we want to save the batch loss instead, we must compute the mean times the number of batches since the beginning of the epoch to get the total loss so far, then we subtract the total loss at the previous batch to get the current batch's loss."
]
},
2019-05-05 06:42:08 +02:00
{
"cell_type": "code",
"execution_count": 90,
2019-05-05 06:42:08 +02:00
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"K = tf.keras.backend\n",
2019-05-05 06:42:08 +02:00
"\n",
2021-10-17 04:04:08 +02:00
"class ExponentialLearningRate(tf.keras.callbacks.Callback):\n",
2019-05-05 06:42:08 +02:00
" def __init__(self, factor):\n",
" self.factor = factor\n",
" self.rates = []\n",
" self.losses = []\n",
"\n",
" def on_epoch_begin(self, epoch, logs=None):\n",
" self.sum_of_epoch_losses = 0\n",
"\n",
" def on_batch_end(self, batch, logs=None):\n",
" mean_epoch_loss = logs[\"loss\"] # the epoch's mean loss so far \n",
" new_sum_of_epoch_losses = mean_epoch_loss * (batch + 1)\n",
" batch_loss = new_sum_of_epoch_losses - self.sum_of_epoch_losses\n",
" self.sum_of_epoch_losses = new_sum_of_epoch_losses\n",
" self.rates.append(K.get_value(self.model.optimizer.learning_rate))\n",
" self.losses.append(batch_loss)\n",
" K.set_value(self.model.optimizer.learning_rate,\n",
" self.model.optimizer.learning_rate * self.factor)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `find_learning_rate()` function trains the model using the `ExponentialLearningRate` callback, and it returns the learning rates and corresponding batch losses. At the end, it restores the model and its optimizer to their initial state."
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [],
"source": [
"def find_learning_rate(model, X, y, epochs=1, batch_size=32, min_rate=1e-4,\n",
" max_rate=1):\n",
2019-05-05 06:42:08 +02:00
" init_weights = model.get_weights()\n",
" iterations = math.ceil(len(X) / batch_size) * epochs\n",
" factor = (max_rate / min_rate) ** (1 / iterations)\n",
" init_lr = K.get_value(model.optimizer.learning_rate)\n",
" K.set_value(model.optimizer.learning_rate, min_rate)\n",
2019-05-05 06:42:08 +02:00
" exp_lr = ExponentialLearningRate(factor)\n",
" history = model.fit(X, y, epochs=epochs, batch_size=batch_size,\n",
" callbacks=[exp_lr])\n",
" K.set_value(model.optimizer.learning_rate, init_lr)\n",
2019-05-05 06:42:08 +02:00
" model.set_weights(init_weights)\n",
" return exp_lr.rates, exp_lr.losses"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `plot_lr_vs_loss()` function plots the learning rates vs the losses. The optimal learning rate to use as the maximum learning rate in 1cycle is near the bottom of the curve."
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [],
"source": [
2019-05-05 06:42:08 +02:00
"def plot_lr_vs_loss(rates, losses):\n",
" plt.plot(rates, losses, \"b\")\n",
2019-05-05 06:42:08 +02:00
" plt.gca().set_xscale('log')\n",
" max_loss = losses[0] + min(losses)\n",
" plt.hlines(min(losses), min(rates), max(rates), color=\"k\")\n",
" plt.axis([min(rates), max(rates), 0, max_loss])\n",
2019-05-05 06:42:08 +02:00
" plt.xlabel(\"Learning rate\")\n",
" plt.ylabel(\"Loss\")\n",
" plt.grid()"
2019-05-05 06:42:08 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's build a simple Fashion MNIST model and compile it:"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
"model = build_model()\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's find the optimal max learning rate for 1cycle:"
]
},
2019-05-05 06:42:08 +02:00
{
"cell_type": "code",
"execution_count": 94,
2019-05-05 06:42:08 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"430/430 [==============================] - 1s 1ms/step - loss: 1.7725 - accuracy: 0.4122\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAEOCAYAAACNY7BQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABAZUlEQVR4nO2dd5gc1bG331qlVRaKIKEECmQQQQiTBDY52aRLtMGAjAEb30sONsaYZAwPl2CCsT6wScYYc4UQwWANiCBAAiQQEkpIKKEcUFhJqz3fHzWH7untSavZ2Zndep9nnk6ne870zvZvqupUHXHOYRiGYRieiobugGEYhlFamDAYhmEYKZgwGIZhGCmYMBiGYRgpmDAYhmEYKZgwGIZhGCk0b+gObC2dOnVyAwYMaOhuNBrWrVtH27ZtG7objQa7n4Wl3O7nzJmwejW0bw+DBgX7P/kEunSBpUuhVy/Ydtvi923ixInLnHPd4o6VvTD06NGDCRMmNHQ3Gg2JRILhw4c3dDcaDXY/C0u53c/jj4eXX4b994d//zvY37YtnHEG3H8/XHIJXH998fsmInPTHTNXkmEYRj3h84ejecTOQfPkz/KamuL2KRdMGAzDMOoJ/9CPPvxragJh2LKluH3KBRMGwzCMeiKTxdCsma6bMBiGYTQhvKUQFYaaGhBRcTBXkmEYRhPCC0L04e8cVFToq0lbDCJSKSIfisgkEZkiIjfHtBERuU9EZorIZBHZu1j9MwzDKDS5WAylKAzFHK66ETjcObdWRFoA74jIK8658aE2xwADk6/9gYeSS8MwjLIjU4zBXEmAU9YmN1skX9HJIE4C/ppsOx7oJCLbFauPhmEYhSRuVJIXiVJ2JRU1wU1EmgETgQHAg865DyJNegHzQtvzk/sWRa4zAhgB0K1bNxKJRH11ucmxdu1au58FxO5nYSm3+7ly5V5AJ1avXk0i8QngRWI4c+d+hXPbM3fuYhKJmQ3Yy9oUVRicc1uAvUSkE/AvEdnNOfd5qInEnRZznUeBRwEGDx7syikTstQpt8zSUsfuZ2Ept/vZoYMu27fv+F2/vYWwww79adUKevbcnuHDt2+YDqahQUYlOedWAQng6Mih+UDv0Pb2wMLi9MowDKOwxI1K8usipetKKuaopG5JSwERaQ38AJgWaTYK+HFydNIwYLVzbhGGYRhlSNyopHCMoVSDz8V0JW0HPJGMM1QAzznnRovIxQDOuYeBMcCxwExgPXB+EftnGIZRUOJGJZWDxVA0YXDOTQaGxOx/OLTugEuL1SfDMIz6JNuopFLNY7DMZ8MwjHoim8VQqq4kEwbDMIx6IlOMoZRdSSYMhmEY9UTcqCRzJRmGYTRhzJVkGIZhpJBtuKq5kgzDMJoY2RLczJVkGIbRxCjXBDcTBsMwjHqiXBPcTBgMwzDqiUwJbhZ8NgzDaILEWQwWfDYMw2jCxMUYLPhsGIbRhMklwc1cSYZhGE2IbBaDuZIMwzCaGNliDOZKMgzDaGLEjUqykhiGYRhNGBuVZBiGYaRgo5IMwzCMFDKNSjJXUpnw6quwcmVD98IwjMaCVVctccaNg8ceS398zRo45hg4/vji9ckwjMaNzcdQQsyeDUcfDUuWBPsOPRQuuggWLIg/Z9UqXb73Xr13zzCMJkK2BDezGIrIY4/Ba6/BHXcE+/wf4y9/CfatWQObN+u6FwbDMIxCYcHnEmLhQl2OHq3Ldev0jwAwcWLQrmNHOOUUXTdhMAyj0OSS4GaupCLxySe6nDEDVq+GadOCP4YPLvvtl17S5erVwfnffqvLhQvhhhtKU9ENwyh9siW4mSupSGzcCF98AYMH6/bs2TB1qq7vvDOsWKHr69ennhe2GGbP1uUhh8BttwXnG4Zh5EMmi8FcSYCI9BaRsSIyVUSmiMjlMW2Gi8hqEfk0+fpNtuuuXt0iZfubb6C6Go44QrdnzYJFi3R9zz0DYfBWgScsDAsWqHDMmqXbURGJY9Qo6NsXNmzI3tYwjKaBTe2ZnWrgCufczsAw4FIR2SWm3Tjn3F7J1++yXXTJksqUG+sf8HvvrctZs2D5cmjRAnr3VpHo0iU11hA+D2DZMnVDedauzf7h3nsPvv4avvoqdf/rr6cfCWUYRuMmblSSuZJCOOcWOec+Tq5/C0wFem39dWHp0mDbP+D79IGuXdUttGKFikHnznpsxQq49dbgnC1bUmMMS5fC4sXBdtS6iGPePF1+/XWwb/16OOooHTqbjbffhsmTs7czDKN8yMViaNLCEEZE+gFDgA9iDh8gIpNE5BUR2TWX6911F9xyi/6y9w/4Tp1gxx3hyy9VCDp3DoQB4IPQO69YoYKy3XZqWRRKGD77TJdh6yPMDTfAm2/q+qGHqqvLMIzGQ7kmuDUv9huKSDvgn8CvnHNrIoc/Bvo659aKyLHAi8DAmGuMAEbo1j7cfbeuvf76UpYsqQTaM23aeLbffnteemk7Bg5cS/PmjkWLFgCqNeE/xujRHzFjRj9atmxDhw7NmTx5BatXrwd2BODcc+Gppxax556refrpPjz22ARatkz9a86YsT/QmnHj5tKr19dMntyRpUsrgUF07FhFIjE+pX1NDdx223Buuw2OPPIbYFsAEokEX3zRgfXrm7Hvvpnrc3zxRQcmTerImWfOy9guH9auXUsikSjY9Zo6dj8LS7ndz+rqg4FmVFdvIZEYB8CXX7YH9mHKlM9YtGgbNm7sQSLxboP2sxbOuaK9gBbAa8D/5Nh+DtA1c5t9nOpx6mvpUudeeinYPukk5958s3Y7cO6uu3Q5bJhze+zh3IknOnfllfFtwbl33nHuxhudmzfPOeecq652rnlzPXb66c61a6fru++uy/79tV1NjXMjRzq3YIH2L+7aNTXB+pYtLiPNmmm7b7/N3C4fxo4dW7iLGXY/C0y53c/WrfV/tLIy2PfBB7pv9GjnLr/cuY4dG6ZvwASX5rlazFFJAvwFmOqcuydNm22T7RCRoaira3ku1+/RI3W7Y0d1zzRrpttRV1KYq67S5eLF0K1b4Erq0ydIjAtz773w+9/DJZcE51VX6/qrrwbBau9KWrRIH/XTpsFPfwq77ZbqqgoTDoKffHL6dvPnB+v+fQzDKC2siF52DgTOBQ4PDUc9VkQuFpGLk21OBT4XkUnAfcAZSWXLyjnnpG63aAHt28OuyShFly4qFpn4/e9ThWHbbaGyUo8NGwavvAK9esHzz+u+xYs1fuDjC127apmNMJ07Q1UV3H8//Oc/um/lSnjqqfg+TJkSrP/f/8G//lW7zYwZOsLKf6E+/TTz5zIMo2HINiqpyQefnXPvOOfEObeHC4ajjnHOPeycezjZ5gHn3K7OuT2dc8Occ1lL2nXvXsWf/6yJaHfdVfv4zjvrsnNnfdC3agU33RQc79JFl+edB2edpUHrmTN1mGnXrkEtpbPP1tFFBxwQnPvhhzBoEBx7rG6fdFLt9//JT3R5+eXw5JNq2bRrBw89pPt32AGOPDJo74PR++2ny+Uhe+nrr+Hjj/kupuJ5773UXySGYZQGmSyGUg4+l33mc6dOm7nwQmjZEg47rPbxvn11uWkTtG6tv97/+7+D4wMGpLbbYYfgWGVl4CLq2lWXp59e+z180typpwb7fv1rXR5zTLBv/HgYPhyOOy61mutrrwVJeK+/rsvHHlMLx7uSxo3TPu6zDzzySOr7P/mkWiRxjB2rSX+eKVM0O9wwjPrHpvYsAXr2rL1v33112a5dsK9Dh2C9f39d9uuny8su03yCf/4ztTqrtyxOOCH+vSsrU4Xpd7/TP/hBB8HA0Liq7t1TxcJft0cP2H57FYpmzVSwevRQYbjrLrj22vj3veYaGDIE/v731P1PPgnf+x4cfrgKEei1dtsNfvGL+GsZhlE4tmwJHvrmSmpAunevve/UU9VP/8tfBvvCAWUvCN5iaN0adt9dA7/hB7q3GCorNZv66ad1+/JkYY9mzdRNdeut8OKLuq+iQq8Xjid07qxi4WnePOiTdyn94AfQpo1+niVL4Oqr088T0bOnTi40fryKg5+MaORIeP99XZ80SZfLlun
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2019-05-05 06:42:08 +02:00
"source": [
"batch_size = 128\n",
"rates, losses = find_learning_rate(model, X_train, y_train, epochs=1,\n",
" batch_size=batch_size)\n",
"plot_lr_vs_loss(rates, losses)"
2019-05-05 06:42:08 +02:00
]
},
{
"cell_type": "markdown",
2019-05-05 06:42:08 +02:00
"metadata": {},
"source": [
"Looks like the max learning rate to use for 1cycle is around 10<sup>1</sup>."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `OneCycleScheduler` custom callback updates the learning rate at the beginning of each batch. It applies the logic described in the book: increase the learning rate linearly during about half of training, then reduce it linearly back to the initial learning rate, and lastly reduce it down to close to zero linearly for the very last part of training."
2019-05-05 06:42:08 +02:00
]
},
{
"cell_type": "code",
"execution_count": 95,
2019-05-05 06:42:08 +02:00
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"class OneCycleScheduler(tf.keras.callbacks.Callback):\n",
" def __init__(self, iterations, max_lr=1e-3, start_lr=None,\n",
" last_iterations=None, last_lr=None):\n",
2019-05-05 06:42:08 +02:00
" self.iterations = iterations\n",
" self.max_lr = max_lr\n",
" self.start_lr = start_lr or max_lr / 10\n",
2019-05-05 06:42:08 +02:00
" self.last_iterations = last_iterations or iterations // 10 + 1\n",
" self.half_iteration = (iterations - self.last_iterations) // 2\n",
" self.last_lr = last_lr or self.start_lr / 1000\n",
2019-05-05 06:42:08 +02:00
" self.iteration = 0\n",
"\n",
" def _interpolate(self, iter1, iter2, lr1, lr2):\n",
" return (lr2 - lr1) * (self.iteration - iter1) / (iter2 - iter1) + lr1\n",
"\n",
2019-05-05 06:42:08 +02:00
" def on_batch_begin(self, batch, logs):\n",
" if self.iteration < self.half_iteration:\n",
" lr = self._interpolate(0, self.half_iteration, self.start_lr,\n",
" self.max_lr)\n",
2019-05-05 06:42:08 +02:00
" elif self.iteration < 2 * self.half_iteration:\n",
" lr = self._interpolate(self.half_iteration, 2 * self.half_iteration,\n",
" self.max_lr, self.start_lr)\n",
2019-05-05 06:42:08 +02:00
" else:\n",
" lr = self._interpolate(2 * self.half_iteration, self.iterations,\n",
" self.start_lr, self.last_lr)\n",
2019-05-05 06:42:08 +02:00
" self.iteration += 1\n",
" K.set_value(self.model.optimizer.learning_rate, lr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's build and compile a simple Fashion MNIST model, then train it using the `OneCycleScheduler` callback:"
2019-05-05 06:42:08 +02:00
]
},
{
"cell_type": "code",
"execution_count": 96,
2019-05-05 06:42:08 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
"430/430 [==============================] - 1s 2ms/step - loss: 0.9502 - accuracy: 0.6913 - val_loss: 0.6003 - val_accuracy: 0.7874\n",
"Epoch 2/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.5695 - accuracy: 0.8025 - val_loss: 0.4918 - val_accuracy: 0.8248\n",
"Epoch 3/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.4954 - accuracy: 0.8252 - val_loss: 0.4762 - val_accuracy: 0.8264\n",
"Epoch 4/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.4515 - accuracy: 0.8402 - val_loss: 0.4261 - val_accuracy: 0.8478\n",
"Epoch 5/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.4225 - accuracy: 0.8492 - val_loss: 0.4066 - val_accuracy: 0.8486\n",
"Epoch 6/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3958 - accuracy: 0.8571 - val_loss: 0.4787 - val_accuracy: 0.8224\n",
"Epoch 7/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3787 - accuracy: 0.8626 - val_loss: 0.3917 - val_accuracy: 0.8566\n",
"Epoch 8/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3630 - accuracy: 0.8683 - val_loss: 0.4719 - val_accuracy: 0.8296\n",
"Epoch 9/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3512 - accuracy: 0.8724 - val_loss: 0.3673 - val_accuracy: 0.8652\n",
"Epoch 10/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3360 - accuracy: 0.8766 - val_loss: 0.4957 - val_accuracy: 0.8466\n",
"Epoch 11/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3287 - accuracy: 0.8786 - val_loss: 0.4187 - val_accuracy: 0.8370\n",
"Epoch 12/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3173 - accuracy: 0.8815 - val_loss: 0.3425 - val_accuracy: 0.8728\n",
"Epoch 13/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2961 - accuracy: 0.8910 - val_loss: 0.3217 - val_accuracy: 0.8792\n",
"Epoch 14/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2818 - accuracy: 0.8958 - val_loss: 0.3734 - val_accuracy: 0.8692\n",
"Epoch 15/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2675 - accuracy: 0.9003 - val_loss: 0.3261 - val_accuracy: 0.8844\n",
"Epoch 16/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2558 - accuracy: 0.9055 - val_loss: 0.3205 - val_accuracy: 0.8820\n",
"Epoch 17/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2464 - accuracy: 0.9091 - val_loss: 0.3089 - val_accuracy: 0.8894\n",
"Epoch 18/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2368 - accuracy: 0.9115 - val_loss: 0.3130 - val_accuracy: 0.8870\n",
"Epoch 19/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2292 - accuracy: 0.9145 - val_loss: 0.3078 - val_accuracy: 0.8854\n",
"Epoch 20/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2205 - accuracy: 0.9186 - val_loss: 0.3092 - val_accuracy: 0.8886\n",
"Epoch 21/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2138 - accuracy: 0.9209 - val_loss: 0.3022 - val_accuracy: 0.8914\n",
"Epoch 22/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2073 - accuracy: 0.9232 - val_loss: 0.3054 - val_accuracy: 0.8914\n",
"Epoch 23/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2020 - accuracy: 0.9261 - val_loss: 0.3026 - val_accuracy: 0.8896\n",
"Epoch 24/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.1989 - accuracy: 0.9273 - val_loss: 0.3020 - val_accuracy: 0.8922\n",
"Epoch 25/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.1967 - accuracy: 0.9276 - val_loss: 0.3016 - val_accuracy: 0.8920\n"
]
}
],
2019-05-05 06:42:08 +02:00
"source": [
"model = build_model()\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(),\n",
" metrics=[\"accuracy\"])\n",
2019-05-05 06:42:08 +02:00
"n_epochs = 25\n",
"onecycle = OneCycleScheduler(math.ceil(len(X_train) / batch_size) * n_epochs,\n",
" max_lr=0.1)\n",
"history = model.fit(X_train, y_train, epochs=n_epochs, batch_size=batch_size,\n",
" validation_data=(X_valid, y_valid),\n",
2019-05-05 06:42:08 +02:00
" callbacks=[onecycle])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Avoiding Overfitting Through Regularization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## $\\ell_1$ and $\\ell_2$ regularization"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [],
"source": [
"layer = tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\",\n",
" kernel_regularizer=tf.keras.regularizers.l2(0.01))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or use `l1(0.1)` for <sub>1</sub> regularization with a factor of 0.1, or `l1_l2(0.1, 0.01)` for both <sub>1</sub> and <sub>2</sub> regularization, with factors 0.1 and 0.01 respectively."
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42) # extra code for reproducibility"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
2021-10-17 04:04:08 +02:00
"RegularizedDense = partial(tf.keras.layers.Dense,\n",
" activation=\"relu\",\n",
" kernel_initializer=\"he_normal\",\n",
2021-10-17 04:04:08 +02:00
" kernel_regularizer=tf.keras.regularizers.l2(0.01))\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" RegularizedDense(100),\n",
" RegularizedDense(100),\n",
" RegularizedDense(10, activation=\"softmax\")\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"1719/1719 [==============================] - 2s 878us/step - loss: 3.1224 - accuracy: 0.7748 - val_loss: 1.8602 - val_accuracy: 0.8264\n",
"Epoch 2/2\n",
"1719/1719 [==============================] - 1s 814us/step - loss: 1.4263 - accuracy: 0.8159 - val_loss: 1.1269 - val_accuracy: 0.8182\n"
]
}
],
"source": [
"# extra code compile and train the model\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.02)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, y_train, epochs=2,\n",
" validation_data=(X_valid, y_valid))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dropout"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [],
2019-02-28 12:48:06 +01:00
"source": [
"tf.random.set_seed(42) # extra code for reproducibility"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 102,
2019-02-28 12:48:06 +01:00
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dropout(rate=0.2),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dropout(rate=0.2),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dropout(rate=0.2),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 103,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.6703 - accuracy: 0.7536 - val_loss: 0.4498 - val_accuracy: 0.8342\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 996us/step - loss: 0.5103 - accuracy: 0.8136 - val_loss: 0.4401 - val_accuracy: 0.8296\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 998us/step - loss: 0.4712 - accuracy: 0.8263 - val_loss: 0.3806 - val_accuracy: 0.8554\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 977us/step - loss: 0.4488 - accuracy: 0.8337 - val_loss: 0.3711 - val_accuracy: 0.8608\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4342 - accuracy: 0.8409 - val_loss: 0.3672 - val_accuracy: 0.8606\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 983us/step - loss: 0.4245 - accuracy: 0.8427 - val_loss: 0.3706 - val_accuracy: 0.8600\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 995us/step - loss: 0.4131 - accuracy: 0.8467 - val_loss: 0.3582 - val_accuracy: 0.8650\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 959us/step - loss: 0.4074 - accuracy: 0.8484 - val_loss: 0.3478 - val_accuracy: 0.8708\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 997us/step - loss: 0.4024 - accuracy: 0.8533 - val_loss: 0.3556 - val_accuracy: 0.8690\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 998us/step - loss: 0.3903 - accuracy: 0.8552 - val_loss: 0.3453 - val_accuracy: 0.8732\n"
]
}
],
2019-02-28 12:48:06 +01:00
"source": [
"# extra code compile and train the model\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, y_train, epochs=10,\n",
" validation_data=(X_valid, y_valid))"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "markdown",
2019-02-28 12:48:06 +01:00
"metadata": {},
"source": [
"The training accuracy looks like it's lower than the validation accuracy, but that's just because dropout is only active during training. If we evaluate the model on the training set after training (i.e., with dropout turned off), we get the \"real\" training accuracy, which is very slightly higher than the validation accuracy and the test accuracy:"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 104,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1719/1719 [==============================] - 1s 578us/step - loss: 0.3082 - accuracy: 0.8849\n"
]
},
{
"data": {
"text/plain": [
"[0.30816400051116943, 0.8849090933799744]"
]
},
"execution_count": 104,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
"model.evaluate(X_train, y_train)"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 105,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"313/313 [==============================] - 0s 588us/step - loss: 0.3629 - accuracy: 0.8700\n"
]
},
{
"data": {
"text/plain": [
"[0.3628920316696167, 0.8700000047683716]"
]
},
"execution_count": 105,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
"model.evaluate(X_test, y_test)"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "markdown",
2019-02-28 12:48:06 +01:00
"metadata": {},
"source": [
"**Note**: make sure to use `AlphaDropout` instead of `Dropout` if you want to build a self-normalizing neural net using SELU."
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "markdown",
2019-02-28 12:48:06 +01:00
"metadata": {},
"source": [
"## MC Dropout"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 106,
2019-02-28 12:48:06 +01:00
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42) # extra code for reproducibility"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 107,
2019-02-28 12:48:06 +01:00
"metadata": {},
"outputs": [],
"source": [
"y_probas = np.stack([model(X_test, training=True)\n",
" for sample in range(100)])\n",
"y_proba = y_probas.mean(axis=0)"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 108,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0. , 0. , 0. , 0. , 0. , 0.024, 0. , 0.132, 0. ,\n",
" 0.844]], dtype=float32)"
]
},
"execution_count": 108,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
"model.predict(X_test[:1]).round(3)"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 109,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0. , 0. , 0. , 0. , 0. , 0.067, 0. , 0.209, 0.001,\n",
" 0.723], dtype=float32)"
]
},
"execution_count": 109,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
"y_proba[0].round(3)"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 110,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0. , 0. , 0. , 0.001, 0. , 0.096, 0. , 0.162, 0.001,\n",
" 0.183], dtype=float32)"
]
},
"execution_count": 110,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
"y_std = y_probas.std(axis=0)\n",
"y_std[0].round(3)"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 111,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.8717"
]
},
"execution_count": 111,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
"y_pred = y_proba.argmax(axis=1)\n",
"accuracy = (y_pred == y_test).sum() / len(y_test)\n",
2019-02-28 12:48:06 +01:00
"accuracy"
]
},
{
"cell_type": "code",
"execution_count": 112,
2019-02-28 12:48:06 +01:00
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"class MCDropout(tf.keras.layers.Dropout):\n",
" def call(self, inputs, training=None):\n",
2019-02-28 12:48:06 +01:00
" return super().call(inputs, training=True)"
]
},
{
"cell_type": "code",
"execution_count": 113,
2019-02-28 12:48:06 +01:00
"metadata": {},
"outputs": [],
"source": [
"# extra code shows how to convert Dropout to MCDropout in a Sequential model\n",
"Dropout = tf.keras.layers.Dropout\n",
2021-10-17 04:04:08 +02:00
"mc_model = tf.keras.Sequential([\n",
" MCDropout(layer.rate) if isinstance(layer, Dropout) else layer\n",
2019-02-28 12:48:06 +01:00
" for layer in model.layers\n",
"])\n",
"mc_model.set_weights(model.get_weights())"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 114,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_25\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"flatten_22 (Flatten) (None, 784) 0 \n",
"_________________________________________________________________\n",
"mc_dropout (MCDropout) (None, 784) 0 \n",
"_________________________________________________________________\n",
"dense_89 (Dense) (None, 100) 78500 \n",
"_________________________________________________________________\n",
"mc_dropout_1 (MCDropout) (None, 100) 0 \n",
"_________________________________________________________________\n",
"dense_90 (Dense) (None, 100) 10100 \n",
"_________________________________________________________________\n",
"mc_dropout_2 (MCDropout) (None, 100) 0 \n",
"_________________________________________________________________\n",
"dense_91 (Dense) (None, 10) 1010 \n",
"=================================================================\n",
"Total params: 89,610\n",
"Trainable params: 89,610\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
2019-02-28 12:48:06 +01:00
"source": [
"mc_model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can use the model with MC Dropout:"
]
},
{
"cell_type": "code",
"execution_count": 115,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0. , 0. , 0. , 0. , 0. , 0.07, 0. , 0.17, 0. , 0.76]],\n",
" dtype=float32)"
]
},
"execution_count": 115,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
"# extra code shows that the model works without retraining\n",
"tf.random.set_seed(42)\n",
"np.mean([mc_model.predict(X_test[:1])\n",
" for sample in range(100)], axis=0).round(2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Max norm"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
"outputs": [],
"source": [
"dense = tf.keras.layers.Dense(\n",
" 100, activation=\"relu\", kernel_initializer=\"he_normal\",\n",
" kernel_constraint=tf.keras.constraints.max_norm(1.))"
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.5500 - accuracy: 0.8015 - val_loss: 0.4510 - val_accuracy: 0.8242\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 960us/step - loss: 0.4089 - accuracy: 0.8499 - val_loss: 0.3956 - val_accuracy: 0.8504\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 974us/step - loss: 0.3777 - accuracy: 0.8604 - val_loss: 0.3693 - val_accuracy: 0.8680\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 943us/step - loss: 0.3581 - accuracy: 0.8690 - val_loss: 0.3517 - val_accuracy: 0.8716\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 949us/step - loss: 0.3416 - accuracy: 0.8729 - val_loss: 0.3433 - val_accuracy: 0.8682\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 951us/step - loss: 0.3368 - accuracy: 0.8756 - val_loss: 0.4045 - val_accuracy: 0.8582\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 935us/step - loss: 0.3293 - accuracy: 0.8767 - val_loss: 0.4168 - val_accuracy: 0.8476\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 951us/step - loss: 0.3258 - accuracy: 0.8779 - val_loss: 0.3570 - val_accuracy: 0.8674\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 970us/step - loss: 0.3269 - accuracy: 0.8787 - val_loss: 0.3702 - val_accuracy: 0.8578\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 948us/step - loss: 0.3169 - accuracy: 0.8809 - val_loss: 0.3907 - val_accuracy: 0.8578\n"
]
}
],
"source": [
"# extra code shows how to apply max norm to every hidden layer in a model\n",
"\n",
2021-10-17 04:04:08 +02:00
"MaxNormDense = partial(tf.keras.layers.Dense,\n",
" activation=\"relu\", kernel_initializer=\"he_normal\",\n",
2021-10-17 04:04:08 +02:00
" kernel_constraint=tf.keras.constraints.max_norm(1.))\n",
"\n",
"tf.random.set_seed(42)\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" MaxNormDense(100),\n",
" MaxNormDense(100),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, y_train, epochs=10,\n",
" validation_data=(X_valid, y_valid))"
]
},
{
"cell_type": "markdown",
2020-04-06 09:13:12 +02:00
"metadata": {},
"source": [
2019-02-28 12:48:06 +01:00
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. to 7."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. Glorot initialization and He initialization were designed to make the output standard deviation as close as possible to the input standard deviation, at least at the beginning of training. This reduces the vanishing/exploding gradients problem.\n",
"2. No, all weights should be sampled independently; they should not all have the same initial value. One important goal of sampling weights randomly is to break symmetry: if all the weights have the same initial value, even if that value is not zero, then symmetry is not broken (i.e., all neurons in a given layer are equivalent), and backpropagation will be unable to break it. Concretely, this means that all the neurons in any given layer will always have the same weights. It's like having just one neuron per layer, and much slower. It is virtually impossible for such a configuration to converge to a good solution.\n",
"3. It is perfectly fine to initialize the bias terms to zero. Some people like to initialize them just like weights, and that's OK too; it does not make much difference.\n",
"4. ReLU is usually a good default for the hidden layers, as it is fast and yields good results. Its ability to output precisely zero can also be useful in some cases (e.g., see Chapter 17). Moreover, it can sometimes benefit from optimized implementations as well as from hardware acceleration. The leaky ReLU variants of ReLU can improve the model's quality without hindering its speed too much compared to ReLU. For large neural nets and more complex problems, GLU, Swish and Mish can give you a slightly higher quality model, but they have a computational cost. The hyperbolic tangent (tanh) can be useful in the output layer if you need to output a number in a fixed range (by default between 1 and 1), but nowadays it is not used much in hidden layers, except in recurrent nets. The sigmoid activation function is also useful in the output layer when you need to estimate a probability (e.g., for binary classification), but it is rarely used in hidden layers (there are exceptions—for example, for the coding layer of variational autoencoders; see Chapter 17). The softplus activation function is useful in the output layer when you need to ensure that the output will always be positive. The softmax activation function is useful in the output layer to estimate probabilities for mutually exclusive classes, but it is rarely (if ever) used in hidden layers.\n",
"5. If you set the `momentum` hyperparameter too close to 1 (e.g., 0.99999) when using an `SGD` optimizer, then the algorithm will likely pick up a lot of speed, hopefully moving roughly toward the global minimum, but its momentum will carry it right past the minimum. Then it will slow down and come back, accelerate again, overshoot again, and so on. It may oscillate this way many times before converging, so overall it will take much longer to converge than with a smaller `momentum` value.\n",
"6. One way to produce a sparse model (i.e., with most weights equal to zero) is to train the model normally, then zero out tiny weights. For more sparsity, you can apply <sub>1</sub> regularization during training, which pushes the optimizer toward sparsity. A third option is to use the TensorFlow Model Optimization Toolkit.\n",
"7. Yes, dropout does slow down training, in general roughly by a factor of two. However, it has no impact on inference speed since it is only turned on during training. MC Dropout is exactly like dropout during training, but it is still active during inference, so each inference is slowed down slightly. More importantly, when using MC Dropout you generally want to run inference 10 times or more to get better predictions. This means that making predictions is slowed down by a factor of 10 or more."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. Deep Learning on CIFAR10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### a.\n",
"*Exercise: Build a DNN with 20 hidden layers of 100 neurons each (that's too many, but it's the point of this exercise). Use He initialization and the Swish activation function.*"
]
},
{
"cell_type": "code",
"execution_count": 118,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100,\n",
" activation=\"swish\",\n",
" kernel_initializer=\"he_normal\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### b.\n",
2021-10-17 04:04:08 +02:00
"*Exercise: Using Nadam optimization and early stopping, train the network on the CIFAR10 dataset. You can load it with `tf.keras.datasets.cifar10.load_data()`. The dataset is composed of 60,000 32 × 32pixel color images (50,000 for training, 10,000 for testing) with 10 classes, so you'll need a softmax output layer with 10 neurons. Remember to search for the right learning rate each time you change the model's architecture or hyperparameters.*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's add the output layer to the model:"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's use a Nadam optimizer with a learning rate of 5e-5. I tried learning rates 1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3 and 1e-2, and I compared their learning curves for 10 epochs each (using the TensorBoard callback, below). The learning rates 3e-5 and 1e-4 were pretty good, so I tried 5e-5, which turned out to be slightly better."
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"optimizer = tf.keras.optimizers.Nadam(learning_rate=5e-5)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"Let's load the CIFAR10 dataset. We also want to use early stopping, so we need a validation set. Let's use the first 5,000 images of the original training set as the validation set:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"cifar10 = tf.keras.datasets.cifar10.load_data()\n",
"(X_train_full, y_train_full), (X_test, y_test) = cifar10\n",
"\n",
"X_train = X_train_full[5000:]\n",
"y_train = y_train_full[5000:]\n",
"X_valid = X_train_full[:5000]\n",
"y_valid = y_train_full[:5000]"
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"Now we can create the callbacks we need and train the model:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 122,
"metadata": {},
"outputs": [],
2016-09-27 23:31:21 +02:00
"source": [
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=20,\n",
" restore_best_weights=True)\n",
"model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\"my_cifar10_model\",\n",
" save_best_only=True)\n",
"run_index = 1 # increment every time you train the model\n",
"run_logdir = Path() / \"my_cifar10_logs\" / f\"run_{run_index:03d}\"\n",
2021-10-17 04:04:08 +02:00
"tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)\n",
"callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe id=\"tensorboard-frame-d05c16b556c70d97\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
" </iframe>\n",
" <script>\n",
" (function() {\n",
" const frame = document.getElementById(\"tensorboard-frame-d05c16b556c70d97\");\n",
" const url = new URL(\"/\", window.location);\n",
" const port = 6006;\n",
" if (port) {\n",
" url.port = port;\n",
" }\n",
" frame.src = url;\n",
" })();\n",
" </script>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
2022-02-19 10:24:54 +01:00
}
],
"source": [
"%load_ext tensorboard\n",
"%tensorboard --logdir=./my_cifar10_logs"
]
2016-09-27 23:31:21 +02:00
},
{
"cell_type": "code",
"execution_count": 124,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"1404/1407 [============================>.] - ETA: 0s - loss: 4.0493 - accuracy: 0.1598INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 17s 10ms/step - loss: 4.0462 - accuracy: 0.1597 - val_loss: 2.1441 - val_accuracy: 0.2036\n",
"Epoch 2/100\n",
"1407/1407 [==============================] - ETA: 0s - loss: 2.0667 - accuracy: 0.2320INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 2.0667 - accuracy: 0.2320 - val_loss: 2.0134 - val_accuracy: 0.2472\n",
"Epoch 3/100\n",
"1407/1407 [==============================] - ETA: 0s - loss: 1.9472 - accuracy: 0.2819INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.9472 - accuracy: 0.2819 - val_loss: 1.9427 - val_accuracy: 0.2796\n",
"Epoch 4/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.8636 - accuracy: 0.3182INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.8637 - accuracy: 0.3182 - val_loss: 1.8934 - val_accuracy: 0.3222\n",
"Epoch 5/100\n",
"1402/1407 [============================>.] - ETA: 0s - loss: 1.7975 - accuracy: 0.3464INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.7974 - accuracy: 0.3465 - val_loss: 1.8389 - val_accuracy: 0.3284\n",
"Epoch 6/100\n",
"1407/1407 [==============================] - 9s 7ms/step - loss: 1.7446 - accuracy: 0.3664 - val_loss: 2.0006 - val_accuracy: 0.3030\n",
"Epoch 7/100\n",
"1407/1407 [==============================] - ETA: 0s - loss: 1.6974 - accuracy: 0.3852INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.6974 - accuracy: 0.3852 - val_loss: 1.7075 - val_accuracy: 0.3738\n",
"Epoch 8/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.6605 - accuracy: 0.3984INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.6604 - accuracy: 0.3984 - val_loss: 1.6788 - val_accuracy: 0.3836\n",
"Epoch 9/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.6322 - accuracy: 0.4114INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.6321 - accuracy: 0.4114 - val_loss: 1.6477 - val_accuracy: 0.4014\n",
"Epoch 10/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.6065 - accuracy: 0.4205 - val_loss: 1.6623 - val_accuracy: 0.3980\n",
"Epoch 11/100\n",
"1401/1407 [============================>.] - ETA: 0s - loss: 1.5843 - accuracy: 0.4287INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.5845 - accuracy: 0.4285 - val_loss: 1.6032 - val_accuracy: 0.4198\n",
"Epoch 12/100\n",
"1407/1407 [==============================] - 9s 6ms/step - loss: 1.5634 - accuracy: 0.4367 - val_loss: 1.6063 - val_accuracy: 0.4258\n",
"Epoch 13/100\n",
"1401/1407 [============================>.] - ETA: 0s - loss: 1.5443 - accuracy: 0.4420INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"<<47 more lines>>\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.3247 - accuracy: 0.5256 - val_loss: 1.5130 - val_accuracy: 0.4616\n",
"Epoch 33/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.3164 - accuracy: 0.5286 - val_loss: 1.5284 - val_accuracy: 0.4686\n",
"Epoch 34/100\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 1.3091 - accuracy: 0.5303 - val_loss: 1.5208 - val_accuracy: 0.4682\n",
"Epoch 35/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.3026 - accuracy: 0.5319 - val_loss: 1.5479 - val_accuracy: 0.4604\n",
"Epoch 36/100\n",
"1407/1407 [==============================] - 11s 8ms/step - loss: 1.2930 - accuracy: 0.5378 - val_loss: 1.5443 - val_accuracy: 0.4580\n",
"Epoch 37/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.2833 - accuracy: 0.5406 - val_loss: 1.5165 - val_accuracy: 0.4710\n",
"Epoch 38/100\n",
"1407/1407 [==============================] - 11s 8ms/step - loss: 1.2763 - accuracy: 0.5433 - val_loss: 1.5345 - val_accuracy: 0.4672\n",
"Epoch 39/100\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 1.2687 - accuracy: 0.5437 - val_loss: 1.5162 - val_accuracy: 0.4712\n",
"Epoch 40/100\n",
"1407/1407 [==============================] - 11s 7ms/step - loss: 1.2623 - accuracy: 0.5490 - val_loss: 1.5717 - val_accuracy: 0.4566\n",
"Epoch 41/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.2580 - accuracy: 0.5467 - val_loss: 1.5296 - val_accuracy: 0.4738\n",
"Epoch 42/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.2469 - accuracy: 0.5532 - val_loss: 1.5179 - val_accuracy: 0.4690\n",
"Epoch 43/100\n",
"1407/1407 [==============================] - 11s 8ms/step - loss: 1.2404 - accuracy: 0.5542 - val_loss: 1.5542 - val_accuracy: 0.4566\n",
"Epoch 44/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.2292 - accuracy: 0.5605 - val_loss: 1.5536 - val_accuracy: 0.4608\n",
"Epoch 45/100\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 1.2276 - accuracy: 0.5606 - val_loss: 1.5522 - val_accuracy: 0.4624\n",
"Epoch 46/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.2200 - accuracy: 0.5637 - val_loss: 1.5339 - val_accuracy: 0.4794\n",
"Epoch 47/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.2080 - accuracy: 0.5677 - val_loss: 1.5451 - val_accuracy: 0.4688\n",
"Epoch 48/100\n",
"1407/1407 [==============================] - 15s 10ms/step - loss: 1.2050 - accuracy: 0.5675 - val_loss: 1.5209 - val_accuracy: 0.4770\n",
"Epoch 49/100\n",
"1407/1407 [==============================] - 10s 7ms/step - loss: 1.1947 - accuracy: 0.5718 - val_loss: 1.5435 - val_accuracy: 0.4736\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fb9f02fc070>"
]
},
"execution_count": 124,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(X_train, y_train, epochs=100,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=callbacks)"
]
2016-09-27 23:31:21 +02:00
},
{
"cell_type": "code",
"execution_count": 125,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"157/157 [==============================] - 0s 2ms/step - loss: 1.5062 - accuracy: 0.4676\n"
]
},
{
"data": {
"text/plain": [
"[1.5061508417129517, 0.4675999879837036]"
]
},
"execution_count": 125,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"model.evaluate(X_valid, y_valid)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The model with the lowest validation loss gets about 46.8% accuracy on the validation set. It took 29 epochs to reach the lowest validation loss, with roughly 10 seconds per epoch on my laptop (without a GPU). Let's see if we can improve the model using Batch Normalization."
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"### c.\n",
"*Exercise: Now try adding Batch Normalization and compare the learning curves: Is it converging faster than before? Does it produce a better model? How does it affect training speed?*"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The code below is very similar to the code above, with a few changes:\n",
"\n",
"* I added a BN layer after every Dense layer (before the activation function), except for the output layer.\n",
"* I changed the learning rate to 5e-4. I experimented with 1e-5, 3e-5, 5e-5, 1e-4, 3e-4, 5e-4, 1e-3 and 3e-3, and I chose the one with the best validation performance after 20 epochs.\n",
"* I renamed the run directories to run_bn_* and the model file name to `my_cifar10_bn_model`."
]
2016-09-27 23:31:21 +02:00
},
{
"cell_type": "code",
"execution_count": 126,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 2.0377 - accuracy: 0.2523INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 32s 18ms/step - loss: 2.0374 - accuracy: 0.2525 - val_loss: 1.8766 - val_accuracy: 0.3154\n",
"Epoch 2/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.7874 - accuracy: 0.3542 - val_loss: 1.8784 - val_accuracy: 0.3268\n",
"Epoch 3/100\n",
"1407/1407 [==============================] - 20s 15ms/step - loss: 1.6806 - accuracy: 0.3969 - val_loss: 1.9764 - val_accuracy: 0.3252\n",
"Epoch 4/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 1.6111 - accuracy: 0.4229INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 24s 17ms/step - loss: 1.6112 - accuracy: 0.4228 - val_loss: 1.7087 - val_accuracy: 0.3750\n",
"Epoch 5/100\n",
"1402/1407 [============================>.] - ETA: 0s - loss: 1.5520 - accuracy: 0.4478INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 21s 15ms/step - loss: 1.5521 - accuracy: 0.4476 - val_loss: 1.6272 - val_accuracy: 0.4176\n",
"Epoch 6/100\n",
"1406/1407 [============================>.] - ETA: 0s - loss: 1.5030 - accuracy: 0.4659INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 23s 16ms/step - loss: 1.5030 - accuracy: 0.4660 - val_loss: 1.5401 - val_accuracy: 0.4452\n",
"Epoch 7/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.4559 - accuracy: 0.4812 - val_loss: 1.6990 - val_accuracy: 0.3952\n",
"Epoch 8/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 1.4169 - accuracy: 0.4987INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 21s 15ms/step - loss: 1.4168 - accuracy: 0.4987 - val_loss: 1.5078 - val_accuracy: 0.4652\n",
"Epoch 9/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.3863 - accuracy: 0.5123 - val_loss: 1.5513 - val_accuracy: 0.4470\n",
"Epoch 10/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.3514 - accuracy: 0.5216 - val_loss: 1.5208 - val_accuracy: 0.4562\n",
"Epoch 11/100\n",
"1407/1407 [==============================] - 16s 12ms/step - loss: 1.3220 - accuracy: 0.5314 - val_loss: 1.7301 - val_accuracy: 0.4206\n",
"Epoch 12/100\n",
"1404/1407 [============================>.] - ETA: 0s - loss: 1.2933 - accuracy: 0.5410INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 25s 18ms/step - loss: 1.2931 - accuracy: 0.5410 - val_loss: 1.4909 - val_accuracy: 0.4734\n",
"Epoch 13/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.2702 - accuracy: 0.5490 - val_loss: 1.5256 - val_accuracy: 0.4636\n",
"Epoch 14/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.2424 - accuracy: 0.5591 - val_loss: 1.5569 - val_accuracy: 0.4624\n",
"Epoch 15/100\n",
"<<12 more lines>>\n",
"Epoch 21/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.1174 - accuracy: 0.6066 - val_loss: 1.5241 - val_accuracy: 0.4828\n",
"Epoch 22/100\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.0978 - accuracy: 0.6128 - val_loss: 1.5313 - val_accuracy: 0.4772\n",
"Epoch 23/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.0844 - accuracy: 0.6198 - val_loss: 1.4993 - val_accuracy: 0.4924\n",
"Epoch 24/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.0677 - accuracy: 0.6244 - val_loss: 1.4622 - val_accuracy: 0.5078\n",
"Epoch 25/100\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.0571 - accuracy: 0.6297 - val_loss: 1.4917 - val_accuracy: 0.4990\n",
"Epoch 26/100\n",
"1407/1407 [==============================] - 19s 14ms/step - loss: 1.0395 - accuracy: 0.6327 - val_loss: 1.4888 - val_accuracy: 0.4896\n",
"Epoch 27/100\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.0298 - accuracy: 0.6370 - val_loss: 1.5358 - val_accuracy: 0.5024\n",
"Epoch 28/100\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.0150 - accuracy: 0.6444 - val_loss: 1.5219 - val_accuracy: 0.5030\n",
"Epoch 29/100\n",
"1407/1407 [==============================] - 16s 12ms/step - loss: 1.0100 - accuracy: 0.6456 - val_loss: 1.4933 - val_accuracy: 0.5098\n",
"Epoch 30/100\n",
"1407/1407 [==============================] - 20s 14ms/step - loss: 0.9956 - accuracy: 0.6492 - val_loss: 1.4756 - val_accuracy: 0.5012\n",
"Epoch 31/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 0.9787 - accuracy: 0.6576 - val_loss: 1.5181 - val_accuracy: 0.4936\n",
"Epoch 32/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 0.9710 - accuracy: 0.6565 - val_loss: 1.7510 - val_accuracy: 0.4568\n",
"Epoch 33/100\n",
"1407/1407 [==============================] - 20s 14ms/step - loss: 0.9613 - accuracy: 0.6628 - val_loss: 1.5576 - val_accuracy: 0.4910\n",
"Epoch 34/100\n",
"1407/1407 [==============================] - 19s 14ms/step - loss: 0.9530 - accuracy: 0.6651 - val_loss: 1.5087 - val_accuracy: 0.5046\n",
"Epoch 35/100\n",
"1407/1407 [==============================] - 19s 13ms/step - loss: 0.9388 - accuracy: 0.6701 - val_loss: 1.5534 - val_accuracy: 0.4950\n",
"Epoch 36/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 0.9331 - accuracy: 0.6743 - val_loss: 1.5033 - val_accuracy: 0.5046\n",
"Epoch 37/100\n",
"1407/1407 [==============================] - 19s 14ms/step - loss: 0.9144 - accuracy: 0.6808 - val_loss: 1.5679 - val_accuracy: 0.5028\n",
"157/157 [==============================] - 0s 2ms/step - loss: 1.4236 - accuracy: 0.5074\n"
]
},
{
"data": {
"text/plain": [
"[1.4236289262771606, 0.5073999762535095]"
]
},
"execution_count": 126,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100, kernel_initializer=\"he_normal\"))\n",
" model.add(tf.keras.layers.BatchNormalization())\n",
" model.add(tf.keras.layers.Activation(\"swish\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"optimizer = tf.keras.optimizers.Nadam(learning_rate=5e-4)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"\n",
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=20,\n",
" restore_best_weights=True)\n",
"model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\"my_cifar10_bn_model\",\n",
" save_best_only=True)\n",
"run_index = 1 # increment every time you train the model\n",
"run_logdir = Path() / \"my_cifar10_logs\" / f\"run_bn_{run_index:03d}\"\n",
2021-10-17 04:04:08 +02:00
"tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)\n",
"callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]\n",
"\n",
"model.fit(X_train, y_train, epochs=100,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=callbacks)\n",
"\n",
"model.evaluate(X_valid, y_valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"* *Is the model converging faster than before?* Much faster! The previous model took 29 epochs to reach the lowest validation loss, while the new model achieved that same loss in just 12 epochs and continued to make progress until the 17th epoch. The BN layers stabilized training and allowed us to use a much larger learning rate, so convergence was faster.\n",
"* *Does BN produce a better model?* Yes! The final model is also much better, with 50.7% validation accuracy instead of 46.7%. It's still not a very good model, but at least it's much better than before (a Convolutional Neural Network would do much better, but that's a different topic, see chapter 14).\n",
"* *How does BN affect training speed?* Although the model converged much faster, each epoch took about 15s instead of 10s, because of the extra computations required by the BN layers. But overall the training time (wall time) to reach the best model was shortened by about 10%."
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"### d.\n",
"*Exercise: Try replacing Batch Normalization with SELU, and make the necessary adjustements to ensure the network self-normalizes (i.e., standardize the input features, use LeCun normal initialization, make sure the DNN contains only a sequence of dense layers, etc.).*"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 127,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 1.9386 - accuracy: 0.3045INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 20s 13ms/step - loss: 1.9385 - accuracy: 0.3046 - val_loss: 1.8175 - val_accuracy: 0.3510\n",
"Epoch 2/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.7241 - accuracy: 0.3869INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.7241 - accuracy: 0.3869 - val_loss: 1.7677 - val_accuracy: 0.3614\n",
"Epoch 3/100\n",
"1407/1407 [==============================] - ETA: 0s - loss: 1.6272 - accuracy: 0.4263INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.6272 - accuracy: 0.4263 - val_loss: 1.6878 - val_accuracy: 0.4054\n",
"Epoch 4/100\n",
"1406/1407 [============================>.] - ETA: 0s - loss: 1.5644 - accuracy: 0.4492INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.5643 - accuracy: 0.4492 - val_loss: 1.6589 - val_accuracy: 0.4304\n",
"Epoch 5/100\n",
"1404/1407 [============================>.] - ETA: 0s - loss: 1.5080 - accuracy: 0.4712INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.5080 - accuracy: 0.4712 - val_loss: 1.5651 - val_accuracy: 0.4538\n",
"Epoch 6/100\n",
"1404/1407 [============================>.] - ETA: 0s - loss: 1.4611 - accuracy: 0.4873INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.4613 - accuracy: 0.4872 - val_loss: 1.5305 - val_accuracy: 0.4678\n",
"Epoch 7/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.4174 - accuracy: 0.5077 - val_loss: 1.5346 - val_accuracy: 0.4558\n",
"Epoch 8/100\n",
"1406/1407 [============================>.] - ETA: 0s - loss: 1.3781 - accuracy: 0.5175INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.3781 - accuracy: 0.5175 - val_loss: 1.4773 - val_accuracy: 0.4882\n",
"Epoch 9/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.3413 - accuracy: 0.5345 - val_loss: 1.5021 - val_accuracy: 0.4764\n",
"Epoch 10/100\n",
"1407/1407 [==============================] - 15s 10ms/step - loss: 1.3182 - accuracy: 0.5422 - val_loss: 1.5709 - val_accuracy: 0.4762\n",
"Epoch 11/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.2832 - accuracy: 0.5571 - val_loss: 1.5345 - val_accuracy: 0.4868\n",
"Epoch 12/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.2557 - accuracy: 0.5667 - val_loss: 1.5024 - val_accuracy: 0.4900\n",
"Epoch 13/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.2373 - accuracy: 0.5710 - val_loss: 1.5114 - val_accuracy: 0.5028\n",
"Epoch 14/100\n",
"1404/1407 [============================>.] - ETA: 0s - loss: 1.2071 - accuracy: 0.5846INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.2073 - accuracy: 0.5847 - val_loss: 1.4608 - val_accuracy: 0.5026\n",
"Epoch 15/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.1843 - accuracy: 0.5940 - val_loss: 1.4962 - val_accuracy: 0.5038\n",
"Epoch 16/100\n",
"1407/1407 [==============================] - 16s 12ms/step - loss: 1.1617 - accuracy: 0.6026 - val_loss: 1.5255 - val_accuracy: 0.5062\n",
"Epoch 17/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.1452 - accuracy: 0.6084 - val_loss: 1.5057 - val_accuracy: 0.5036\n",
"Epoch 18/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.1297 - accuracy: 0.6145 - val_loss: 1.5097 - val_accuracy: 0.5010\n",
"Epoch 19/100\n",
"1407/1407 [==============================] - 16s 12ms/step - loss: 1.1004 - accuracy: 0.6245 - val_loss: 1.5218 - val_accuracy: 0.5014\n",
"Epoch 20/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0971 - accuracy: 0.6304 - val_loss: 1.5253 - val_accuracy: 0.5090\n",
"Epoch 21/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.0670 - accuracy: 0.6345 - val_loss: 1.5006 - val_accuracy: 0.5034\n",
"Epoch 22/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.0544 - accuracy: 0.6407 - val_loss: 1.5244 - val_accuracy: 0.5010\n",
"Epoch 23/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0338 - accuracy: 0.6502 - val_loss: 1.5355 - val_accuracy: 0.5096\n",
"Epoch 24/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.0281 - accuracy: 0.6514 - val_loss: 1.5257 - val_accuracy: 0.5164\n",
"Epoch 25/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.4097 - accuracy: 0.6478 - val_loss: 1.8203 - val_accuracy: 0.3514\n",
"Epoch 26/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.3733 - accuracy: 0.5157 - val_loss: 1.5600 - val_accuracy: 0.4664\n",
"Epoch 27/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.2032 - accuracy: 0.5814 - val_loss: 1.5367 - val_accuracy: 0.4944\n",
"Epoch 28/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.1291 - accuracy: 0.6121 - val_loss: 1.5333 - val_accuracy: 0.4852\n",
"Epoch 29/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0734 - accuracy: 0.6317 - val_loss: 1.5475 - val_accuracy: 0.5032\n",
"Epoch 30/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0294 - accuracy: 0.6469 - val_loss: 1.5400 - val_accuracy: 0.5052\n",
"Epoch 31/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0081 - accuracy: 0.6605 - val_loss: 1.5617 - val_accuracy: 0.4856\n",
"Epoch 32/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0109 - accuracy: 0.6603 - val_loss: 1.5727 - val_accuracy: 0.5124\n",
"Epoch 33/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 0.9646 - accuracy: 0.6762 - val_loss: 1.5333 - val_accuracy: 0.5174\n",
"Epoch 34/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 0.9597 - accuracy: 0.6789 - val_loss: 1.5601 - val_accuracy: 0.5016\n",
"157/157 [==============================] - 0s 1ms/step - loss: 1.4608 - accuracy: 0.5026\n"
]
},
{
"data": {
"text/plain": [
"[1.4607702493667603, 0.5026000142097473]"
]
},
"execution_count": 127,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100,\n",
" kernel_initializer=\"lecun_normal\",\n",
" activation=\"selu\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"optimizer = tf.keras.optimizers.Nadam(learning_rate=7e-4)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"\n",
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(\n",
" patience=20, restore_best_weights=True)\n",
"model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\n",
" \"my_cifar10_selu_model\", save_best_only=True)\n",
"run_index = 1 # increment every time you train the model\n",
"run_logdir = Path() / \"my_cifar10_logs\" / f\"run_selu_{run_index:03d}\"\n",
2021-10-17 04:04:08 +02:00
"tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)\n",
"callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]\n",
"\n",
"X_means = X_train.mean(axis=0)\n",
"X_stds = X_train.std(axis=0)\n",
"X_train_scaled = (X_train - X_means) / X_stds\n",
"X_valid_scaled = (X_valid - X_means) / X_stds\n",
"X_test_scaled = (X_test - X_means) / X_stds\n",
"\n",
"model.fit(X_train_scaled, y_train, epochs=100,\n",
" validation_data=(X_valid_scaled, y_valid),\n",
" callbacks=callbacks)\n",
"\n",
"model.evaluate(X_valid_scaled, y_valid)"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
"This model reached the first model's validation loss in just 8 epochs. After 14 epochs, it reached its lowest validation loss, with about 50.3% accuracy, which is better than the original model (46.7%), but not quite as good as the model using batch normalization (50.7%). Each epoch took only 9 seconds. So it's the fastest model to train so far."
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
"### e.\n",
"*Exercise: Try regularizing the model with alpha dropout. Then, without retraining your model, see if you can achieve better accuracy using MC Dropout.*"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
"execution_count": 128,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.8953 - accuracy: 0.3240INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 18s 11ms/step - loss: 1.8950 - accuracy: 0.3239 - val_loss: 1.7556 - val_accuracy: 0.3812\n",
"Epoch 2/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 1.6618 - accuracy: 0.4129INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.6618 - accuracy: 0.4130 - val_loss: 1.6563 - val_accuracy: 0.4114\n",
"Epoch 3/100\n",
"1402/1407 [============================>.] - ETA: 0s - loss: 1.5772 - accuracy: 0.4431INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.5770 - accuracy: 0.4432 - val_loss: 1.6507 - val_accuracy: 0.4232\n",
"Epoch 4/100\n",
"1406/1407 [============================>.] - ETA: 0s - loss: 1.5081 - accuracy: 0.4673INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 15s 10ms/step - loss: 1.5081 - accuracy: 0.4672 - val_loss: 1.5892 - val_accuracy: 0.4566\n",
"Epoch 5/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 1.4560 - accuracy: 0.4902INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.4561 - accuracy: 0.4902 - val_loss: 1.5382 - val_accuracy: 0.4696\n",
"Epoch 6/100\n",
"1401/1407 [============================>.] - ETA: 0s - loss: 1.4095 - accuracy: 0.5050INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.4094 - accuracy: 0.5050 - val_loss: 1.5236 - val_accuracy: 0.4818\n",
"Epoch 7/100\n",
"1401/1407 [============================>.] - ETA: 0s - loss: 1.3634 - accuracy: 0.5234INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.3636 - accuracy: 0.5232 - val_loss: 1.5139 - val_accuracy: 0.4840\n",
"Epoch 8/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.3297 - accuracy: 0.5377INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.3296 - accuracy: 0.5378 - val_loss: 1.4780 - val_accuracy: 0.4982\n",
"Epoch 9/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.2907 - accuracy: 0.5485 - val_loss: 1.5151 - val_accuracy: 0.4854\n",
"Epoch 10/100\n",
"1407/1407 [==============================] - 13s 10ms/step - loss: 1.2559 - accuracy: 0.5646 - val_loss: 1.4980 - val_accuracy: 0.4976\n",
"Epoch 11/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.2221 - accuracy: 0.5767 - val_loss: 1.5199 - val_accuracy: 0.4990\n",
"Epoch 12/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.1960 - accuracy: 0.5870 - val_loss: 1.5167 - val_accuracy: 0.5030\n",
"Epoch 13/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.1684 - accuracy: 0.5955 - val_loss: 1.5815 - val_accuracy: 0.5014\n",
"Epoch 14/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.1463 - accuracy: 0.6025 - val_loss: 1.5427 - val_accuracy: 0.5112\n",
"Epoch 15/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.1125 - accuracy: 0.6169 - val_loss: 1.5868 - val_accuracy: 0.5212\n",
"Epoch 16/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.0854 - accuracy: 0.6243 - val_loss: 1.6234 - val_accuracy: 0.5090\n",
"Epoch 17/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0668 - accuracy: 0.6328 - val_loss: 1.6162 - val_accuracy: 0.5072\n",
"Epoch 18/100\n",
"1407/1407 [==============================] - 15s 10ms/step - loss: 1.0440 - accuracy: 0.6442 - val_loss: 1.5748 - val_accuracy: 0.5162\n",
"Epoch 19/100\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 1.0272 - accuracy: 0.6477 - val_loss: 1.6518 - val_accuracy: 0.5200\n",
"Epoch 20/100\n",
"1407/1407 [==============================] - 13s 10ms/step - loss: 1.0007 - accuracy: 0.6594 - val_loss: 1.6224 - val_accuracy: 0.5186\n",
"Epoch 21/100\n",
"1407/1407 [==============================] - 15s 10ms/step - loss: 0.9824 - accuracy: 0.6639 - val_loss: 1.6972 - val_accuracy: 0.5136\n",
"Epoch 22/100\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 0.9660 - accuracy: 0.6714 - val_loss: 1.7210 - val_accuracy: 0.5278\n",
"Epoch 23/100\n",
"1407/1407 [==============================] - 13s 10ms/step - loss: 0.9472 - accuracy: 0.6780 - val_loss: 1.6436 - val_accuracy: 0.5006\n",
"Epoch 24/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 0.9314 - accuracy: 0.6819 - val_loss: 1.7059 - val_accuracy: 0.5160\n",
"Epoch 25/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 0.9172 - accuracy: 0.6888 - val_loss: 1.6926 - val_accuracy: 0.5200\n",
"Epoch 26/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 0.8990 - accuracy: 0.6947 - val_loss: 1.7705 - val_accuracy: 0.5148\n",
"Epoch 27/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 0.8758 - accuracy: 0.7028 - val_loss: 1.7023 - val_accuracy: 0.5198\n",
"Epoch 28/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 0.8622 - accuracy: 0.7090 - val_loss: 1.7567 - val_accuracy: 0.5184\n",
"157/157 [==============================] - 0s 1ms/step - loss: 1.4780 - accuracy: 0.4982\n"
]
},
{
"data": {
"text/plain": [
"[1.4779616594314575, 0.498199999332428]"
]
},
"execution_count": 128,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100,\n",
" kernel_initializer=\"lecun_normal\",\n",
" activation=\"selu\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.AlphaDropout(rate=0.1))\n",
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"optimizer = tf.keras.optimizers.Nadam(learning_rate=5e-4)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"\n",
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(\n",
" patience=20, restore_best_weights=True)\n",
"model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\n",
" \"my_cifar10_alpha_dropout_model\", save_best_only=True)\n",
"run_index = 1 # increment every time you train the model\n",
"run_logdir = Path() / \"my_cifar10_logs\" / f\"run_alpha_dropout_{run_index:03d}\"\n",
2021-10-17 04:04:08 +02:00
"tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)\n",
"callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]\n",
"\n",
"X_means = X_train.mean(axis=0)\n",
"X_stds = X_train.std(axis=0)\n",
"X_train_scaled = (X_train - X_means) / X_stds\n",
"X_valid_scaled = (X_valid - X_means) / X_stds\n",
"X_test_scaled = (X_test - X_means) / X_stds\n",
"\n",
"model.fit(X_train_scaled, y_train, epochs=100,\n",
" validation_data=(X_valid_scaled, y_valid),\n",
" callbacks=callbacks)\n",
"\n",
"model.evaluate(X_valid_scaled, y_valid)"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
"The model reaches 48.1% accuracy on the validation set. That's worse than without dropout (50.3%). With an extensive hyperparameter search, it might be possible to do better (I tried dropout rates of 5%, 10%, 20% and 40%, and learning rates 1e-4, 3e-4, 5e-4, and 1e-3), but probably not much better in this case."
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
"Let's use MC Dropout now. We will need the `MCAlphaDropout` class we used earlier, so let's just copy it here for convenience:"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
"execution_count": 129,
"metadata": {},
2017-06-14 09:09:23 +02:00
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"class MCAlphaDropout(tf.keras.layers.AlphaDropout):\n",
" def call(self, inputs):\n",
" return super().call(inputs, training=True)"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's create a new model, identical to the one we just trained (with the same weights), but with `MCAlphaDropout` dropout layers instead of `AlphaDropout` layers:"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "code",
"execution_count": 130,
"metadata": {},
2017-06-14 09:09:23 +02:00
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"mc_model = tf.keras.Sequential([\n",
" (\n",
" MCAlphaDropout(layer.rate)\n",
" if isinstance(layer, tf.keras.layers.AlphaDropout)\n",
" else layer\n",
" )\n",
" for layer in model.layers\n",
"])"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
"Then let's add a couple utility functions. The first will run the model many times (10 by default) and it will return the mean predicted class probabilities. The second will use these mean probabilities to predict the most likely class for each instance:"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
"execution_count": 131,
"metadata": {},
2017-06-14 09:09:23 +02:00
"outputs": [],
"source": [
"def mc_dropout_predict_probas(mc_model, X, n_samples=10):\n",
" Y_probas = [mc_model.predict(X) for sample in range(n_samples)]\n",
" return np.mean(Y_probas, axis=0)\n",
"\n",
"def mc_dropout_predict_classes(mc_model, X, n_samples=10):\n",
" Y_probas = mc_dropout_predict_probas(mc_model, X, n_samples)\n",
" return Y_probas.argmax(axis=1)"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's make predictions for all the instances in the validation set, and compute the accuracy:"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "code",
"execution_count": 132,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.4984"
]
},
"execution_count": 132,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-14 09:09:23 +02:00
"source": [
"tf.random.set_seed(42)\n",
"\n",
"y_pred = mc_dropout_predict_classes(mc_model, X_valid_scaled)\n",
"accuracy = (y_pred == y_valid[:, 0]).mean()\n",
"accuracy"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
"We get back to roughly the accuracy of the model without dropout in this case (about 50.3% accuracy).\n",
"\n",
"So the best model we got in this exercise is the Batch Normalization model."
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
"### f.\n",
"*Exercise: Retrain your model using 1cycle scheduling and see if it improves training speed and model accuracy.*"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
"execution_count": 133,
"metadata": {},
2017-06-14 09:09:23 +02:00
"outputs": [],
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100,\n",
" kernel_initializer=\"lecun_normal\",\n",
" activation=\"selu\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.AlphaDropout(rate=0.1))\n",
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
"\n",
"optimizer = tf.keras.optimizers.SGD()\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
"execution_count": 134,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"352/352 [==============================] - 3s 8ms/step - loss: nan - accuracy: 0.1706\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEOCAYAAACKDawAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAu0klEQVR4nO3deXxU5bkH8N8TguwQE5IQdhDBIIoIolcqm0VxrVtdqN66VCrKtVbrVtuqvS612l4XXFFBiytaXBBQSwkoWhBEWYxBEFBZw04WAiTv/eOZ13NmMjOZCZMzkzm/7+eTz2xn5rxzkjznPc+7iTEGRESU/jKSXQAiIvIGAz4RkU8w4BMR+QQDPhGRTzDgExH5BAM+EZFPZCa7ANFkZWWZXr16JbsYaaO8vBytWrVKdjHSAo9lYiXzeJaUAJWVQHU1kJkJ9O+flGIkzOLFi7caY3LDvZbSAT8/Px+LFi1KdjHSRlFREYYPH57sYqQFHsvESubxHDYMKC4GSkuBrCygsYccEVkX6TWmdIjI14wBMnwSCX3yNYmIwqupAZo00fvpPvEAAz4R+ZoxDPhERL7gTukw4BMRpTGmdIiIfII1fCIin2AOn4jIJ9wpnXTHgE9EvsaUDhGRTzClQ0TkE+ylQ0TkE0zpEBH5BFM6REQ+wZQOEZFPMKVDROQTnB6ZiMgnmNIhIvIJNtoSEfkEc/hERD7BlA4RkU8wpUNE5BOs4RMR+QS7ZRIR+QRTOkREPlFTwxo+EZEv7NsHNG+e7FJ4gwGfiHytqooBn4jIFxjwiYh8wBhg716gRYtkl8QbDPhE5Fv79+stAz4RUZqrqtJbpnSIiNIcAz4RkU/s3au3TOkQEaU51vCJiHzCBvxmzZJbDq8w4BORbzHgNzARaSIiS0Rkutf7JiJyszl8BvyG8xsAxUnYLxFRkNAc/hVXJK8sXsj0cmci0hnAGQDuBXCjl/smIgrlTulUVQGZnkZE73n99R4GcAuANpE2EJGxAMYCQG5uLoqKijwpmB+UlZXxeCYIj2ViJet4LlqUA+AoLFu2CPv2lXm+f695FvBF5EwAW4wxi0VkeKTtjDHPAHgGAPr06WOGD4+4KcWpqKgIPJ6JwWOZWMk6nqWlejtkyCD06+f57j3nZQ5/CICzRWQtgFcBjBSRKR7un4goiG20ZT/8BDPG3G6M6WyM6Q7gYgD/NsZc6tX+iYhCsVsmEZFP+C3gJ6VN2hhTBKAoGfsmIrL8FvBZwyci32IOn4jIJ2wN/5BDklsOrzDgE5FvVVVpsBdJdkm8wYBPRL5VVeWf/D3AgE9EPrZ3LwM+EZEvVFX5p8EWYMAnIh9jSoeIyCcY8ImIfII5fCIin2AOn4jIJ5jSISLyCQZ8IiKfYA6fiMgn9u5lDp+IyBf27AHatUt2KbzDgE9EvrVrF9C2bbJL4R0GfCLypQMHgIoKBnwiorS3Z4/eMuATEaW53bv1lgGfiCjNMeATEfkEAz4RkU8w4BMR+YQN+OyHT0SU5ljDb2SM0R8iongx4DcyZ54JtGyZ7FIQUWO0ezcgArRqleySeCcz2QU4GDNmJLsERNRY7d4NtGkDZDTqam98Uvqr7tuXgexsYM0aYMcOYMqU8NsxrUNE8fLbPDpAigf8qqoM7NgBlJQAkyYBl10GfP+9vnbggLOdHSJNRBSr3bsZ8FNKdbUA0DPx6tX63MaNert+vbPdjh3O/TvvBF57zaMCElHK2bEDuPtuoLo6+nYM+CmmpqZ2wN+0SW/XrXO2cwf8CROA55/3qIBElHLefRe46y5g6dLo223f7q8++ECKB3x3Df/bb/W5zZv1du1aZzsb8Kuq9Jf4zTfelZGIUsuWLXpru12GU1GhJ4Sjj/amTKmiUQT87dudAB+thu9+rarKmzISUWopLdXbaG178+cD+/cDI0d6U6ZU0SgC/ooV+ssBnKBuG28BPSHU1Div1dQ4VwRuX3/NHj1E6S6WGv6cOUBmJvCTn3hTplSR0gHf5vCXLHGesymdrVuBrl31/tVXAxdc4DToAsDKlcGf9e67QGEh8MorDVhgIko6G/Cj1fAXLAAGDABat/amTKnCs4AvIs1FZKGIfCkiK0Tk7rreY2v4P/ygj7t2dWrx27YB3bsDTZro4w8+CO658803zlUBADz5pN5+8snBfhMiSmWxpHRKS4GOHb0pTyrxsoZfBWCkMaY/gGMAjBaRE6K9wQZ8a9Cg4Bp++/aavgGA8nKgqEiHSnfooIO0DjkEePppfc+sWbrdnDnAffcBe/cm8qsRUaqIJaWzYwdw6KHelCeVeBbwjSoLPGwa+ImaUXcH/Pz82jX8nJzgnPw77wB5ecCDDwJffqnPzZ6tDb7GaErnq6+AO+4Apk1L2FcjohQSWsOfMwcYNy54mx07gOxsb8uVCjydS0dEmgBYDKAXgMeNMQvCbDMWwFh9NPDH59u124OKii0oKzsM06d/hK1bh6C8/HsA3QAAmZk12LcvA61bl6FTp0W44YaOePjh3ti0qRSzZ28CcBT69fsexcVdAADPPrsZBQXFDfp9U01ZWRmKioqSXYy0wGOZWIk6npWVGaioGAoAKCnZiKKiEjz55GGYOrULBg78DJs2tcDxx29Defkw7Ny5BkVF6+r4xDRjjPH8B0AWgDkA+kXfbqBp0UInQT7nHGPeflvvz5yptw89ZCdINuaii/T2xBPNj04+2ZgTTjDm2Wf1tZISY+67z5jzzzcmK8uYfftMVI88YsykSdG3SaQZM4yprGy4z58zZ07DfbjP8FgmVqKO55o1Tkz4+c/1uauuch7n5hqzaZM+fvzxhOwy5QBYZCLE1KT00jHG7ARQBGB0Xdva6Y87dwb699f7//633ubkAC+/DDz0EPDii9ow++CDznsLCrTnjs3pdekC3H47MGYMsHNn3Q24f/878MwzcXyxg7B8OXD66cD48d7sjygd2f91wMnh79rlvLZzp3bjBpjDb1AikisiWYH7LQD8FMDX0d7Tps0BnH++3s/O1hx+Vpbm5QFttL3kEuCmm7SB9pprgBNPdN5vA/7mzToNaosW+vyoUbr9u+9G3ndlJfDdd9rf/803gc8/r9/3jpX945w3r2H3Q5TONmzQ23btnBz+zp16u3Wr9tyzvfkY8BtWAYA5IrIUwGcAPjTGTI/6hoLKH4N027baA6d/fyf45uTUscMCYN8+nW0zL895vk0bYPjw6AF/9Wq9MNywAbjiCr0yaEj2j9JdQyGi+BQHmuUGDXICvq3h28bcNWv01o+Ntl720llqjBlgjDnaGNPPGPPn2N6nt/aXY9M6QN0B3/az/fLL4IAP6GpZK1fqL7+8HPjVr4CBA7UXz9y5wFNP6XY1NfqH8/HHevKoj/vv155B0djLTPvHSUTx++orTf927Fg7pbNtm97agO/HGn7Kr3j1pz/p4KoxY/Tx2WcDjz6q99u3j/7eggK93bgRGDw4+DU7pHrBAr0CeP55TRcdc0zwgC2rogJYuDD+odjPPgv8/vd6/957I29nA77dF5duJIpfcTHQt69mBEJr+Ha6ZNbwU1hOjjaeNmumj92THWVlRX+vDfhA7Rp+v35A8+bAZ59p0D/qKG04HTtWXwvnyiu1T2883nxTb0WcQWLhuAM+Z/skil9NjRPw27TRGr4xta+a7TxbdcWPdJTyAT+UiI6avfHGutei7NRJG2cBIDc3+LWmTXUujYULNeXTv79eBk6Y4Myjba8gWrUC7rlHA/Frr8WX2ikp0VtjdL6fG24Iv5293AR0kjciis933+nVcWGhBvwDB7SWHzqqfu1avQKw07L4SaML+ABw6qnA3/5W93YtWwK33KL3beB3GzxYc/MbNgTPiy2ijairVukfRu/emoMfNEhPBnl5wCmnBE/BvH69pm/ctfjKSv3jOuoofTxtmtOl1M0YDfj2KsQu9kJEsbPdrPv3d1aysvNwuW3Z4s90DtBIA3487rwT+L//A37969qvnX66c9/dGAxot6527TTI27x99+7Af/6jl4gffgi89JLWIMaP14aiq6/W9JC1apUG8xEjnOfcE7wBerWQkaGzeHb
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-14 09:09:23 +02:00
"source": [
"batch_size = 128\n",
"rates, losses = find_learning_rate(model, X_train_scaled, y_train, epochs=1,\n",
" batch_size=batch_size)\n",
"plot_lr_vs_loss(rates, losses)"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
"execution_count": 135,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100,\n",
" kernel_initializer=\"lecun_normal\",\n",
" activation=\"selu\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.AlphaDropout(rate=0.1))\n",
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
"\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=2e-2)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "code",
"execution_count": 136,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 2.0559 - accuracy: 0.2839 - val_loss: 1.7917 - val_accuracy: 0.3768\n",
"Epoch 2/15\n",
"352/352 [==============================] - 3s 8ms/step - loss: 1.7596 - accuracy: 0.3797 - val_loss: 1.6566 - val_accuracy: 0.4258\n",
"Epoch 3/15\n",
"352/352 [==============================] - 3s 8ms/step - loss: 1.6199 - accuracy: 0.4247 - val_loss: 1.6395 - val_accuracy: 0.4260\n",
"Epoch 4/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.5451 - accuracy: 0.4524 - val_loss: 1.6202 - val_accuracy: 0.4408\n",
"Epoch 5/15\n",
"352/352 [==============================] - 3s 8ms/step - loss: 1.4952 - accuracy: 0.4691 - val_loss: 1.5981 - val_accuracy: 0.4488\n",
"Epoch 6/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.4541 - accuracy: 0.4842 - val_loss: 1.5720 - val_accuracy: 0.4490\n",
"Epoch 7/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.4171 - accuracy: 0.4967 - val_loss: 1.6035 - val_accuracy: 0.4470\n",
"Epoch 8/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.3497 - accuracy: 0.5194 - val_loss: 1.4918 - val_accuracy: 0.4864\n",
"Epoch 9/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.2788 - accuracy: 0.5459 - val_loss: 1.5597 - val_accuracy: 0.4672\n",
"Epoch 10/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.2070 - accuracy: 0.5707 - val_loss: 1.5845 - val_accuracy: 0.4864\n",
"Epoch 11/15\n",
"352/352 [==============================] - 3s 10ms/step - loss: 1.1433 - accuracy: 0.5926 - val_loss: 1.5293 - val_accuracy: 0.4998\n",
"Epoch 12/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.0745 - accuracy: 0.6182 - val_loss: 1.5118 - val_accuracy: 0.5072\n",
"Epoch 13/15\n",
"352/352 [==============================] - 3s 10ms/step - loss: 1.0030 - accuracy: 0.6413 - val_loss: 1.5388 - val_accuracy: 0.5204\n",
"Epoch 14/15\n",
"352/352 [==============================] - 3s 10ms/step - loss: 0.9388 - accuracy: 0.6654 - val_loss: 1.5547 - val_accuracy: 0.5210\n",
"Epoch 15/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 0.8989 - accuracy: 0.6805 - val_loss: 1.5835 - val_accuracy: 0.5242\n"
]
}
],
"source": [
"n_epochs = 15\n",
"n_iterations = math.ceil(len(X_train_scaled) / batch_size) * n_epochs\n",
"onecycle = OneCycleScheduler(n_iterations, max_lr=0.05)\n",
"history = model.fit(X_train_scaled, y_train, epochs=n_epochs, batch_size=batch_size,\n",
" validation_data=(X_valid_scaled, y_valid),\n",
" callbacks=[onecycle])"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
"One cycle allowed us to train the model in just 15 epochs, each taking only 2 seconds (thanks to the larger batch size). This is several times faster than the fastest model we trained so far. Moreover, we improved the model's performance (from 50.7% to 52.0%)."
]
2016-09-27 23:31:21 +02:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
2016-09-27 23:31:21 +02:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
2016-09-27 23:31:21 +02:00
},
"nav_menu": {
"height": "360px",
"width": "416px"
},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
2020-04-06 09:13:12 +02:00
"nbformat_minor": 4
2016-09-27 23:31:21 +02:00
}