handson-ml/11_training_deep_neural_net...

4540 lines
566 KiB
Plaintext
Raw Normal View History

2016-09-27 23:31:21 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"**Chapter 11 Training Deep Neural Networks**"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"_This notebook contains all the sample code and solutions to the exercises in chapter 11._"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/11_training_deep_neural_networks.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/11_training_deep_neural_networks.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
" </td>\n",
"</table>"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 11:03:20 +01:00
"This project requires Python 3.7 or above:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"import sys\n",
2016-09-27 23:31:21 +02:00
"\n",
2022-02-19 11:03:20 +01:00
"assert sys.version_info >= (3, 7)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"It also requires Scikit-Learn ≥ 1.0.1:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"import sklearn\n",
2016-09-27 23:31:21 +02:00
"\n",
"assert sklearn.__version__ >= \"1.0.1\""
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And TensorFlow ≥ 2.6:"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"assert tf.__version__ >= \"2.6.0\""
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"As we did in previous chapters, let's define the default font sizes to make the figures prettier:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"plt.rc('font', size=14)\n",
"plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n",
"plt.rc('xtick', labelsize=10)\n",
"plt.rc('ytick', labelsize=10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And let's create the `images/deep` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"IMAGES_PATH = Path() / \"images\" / \"deep\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Vanishing/Exploding Gradients Problem"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABNlUlEQVR4nO3de3yO9f/A8ddn56PzDHNYclw5HxIxpxWRYyo5ppIi9ZUvRSchfSUmyuEnEXKMnJVkkiijEWE5m83ZZrOT3ffn98c1y2xj7N6ue9v7+Xhcj+2+Pp/7+rx3ue2967o+B6W1RgghhMhrDmYHIIQQonCSBCSEEMIUkoCEEEKYQhKQEEIIU0gCEkIIYQpJQEIIIUxhkwSklJqrlLqglDqQRXkvpdT+1O03pVQdW7QrhBAi/7LVFdA8oN0dyk8AgVrr2sBYYLaN2hVCCJFPOdniIFrrX5RS/nco/+2Wl7uA8rZoVwghRP5lxjOgF4GNJrQrhBDCjtjkCii7lFKtMBLQY3eoMxAYCODu7t6gQoUKeRRd9litVhwcpO/G3ch5urszZ86gtaZixYpmh2L38vLzdMN6g9MJpwGo6F4RZwfnPGnXFuzx/114ePglrbVPZmV5loCUUrWBOUB7rfXlrOpprWeT+oyoYcOGOjQ0NI8izJ6QkBBatmxpdhh2T87T3bVs2ZLo6GjCwsLMDsXu5dXn6XzceR6Z8wjFkouxrf82Hir9UK63aUv2+P9OKXUqq7I8SZVKqYrASqCP1jo8L9oUQoh7VdKjJB2qduCH3j/ku+STH9nkCkgptRhoCZRSSkUAHwDOAFrrmcD7QEngS6UUQIrWuqEt2hZCiJy6mnCVxJREynqX5YsOX5gdTqFhq15wPe9S/hLwki3aEkIIW4pNiuXJb58kNimWsEFhODnk6aPxQk3OtBCi0Eq4kUDnJZ3ZfXY3K55ZIcknj8nZFkIUSsmWZHos70HIyRAWdF1AlxpdzA6p0JEEJIQolN7f+j7r/1nPzA4z6VW7l9nhFEr5PgFdu3aNCxcucOPGjTxpr2jRohw6dChP2srPCuN5cnZ2pnTp0hQpUsTsUEQ2DG86nId8HqJPnT5mh1Jo5esEdO3aNc6fP4+fnx/u7u6k9rDLVbGxsXh7e+d6O/ldYTtPWmsSEhI4e/YsgCQhO6W1Zv6++fR8uCelPEpJ8jGZfQ2ZvUcXLlzAz88PDw+PPEk+QmRFKYWHhwd+fn5cuHDB7HBEFt7f+j4vrH6Bb/Z9Y3YognyegG7cuIG7u7vZYQiRxt3dPc9uB4t7M3HHRMZtH8dL9V7ipfoyKsQe5OsEBMiVj7Ar8nm0T1/u/pKRP42k58M9mdlxpvw72Yl8n4CEEOJOriRc4b2t7/FUtaeY32U+jg6OZockUuXrTghCCHE3JdxLsGPADvyL+ePsmH9mti4M5ApICFEgbTq6if/9+j8AapSqgZuTm8kRidtJAhL5TkREBK+//jqPPvpoWg/IkydPmh2WsCO/nPqFrku7suTgEhJuJJgdjsiCJCCR7xw9epRly5ZRvHhxmjdvbnY4ws7sPrubjt92xL+YPz/2/hF3Z+kpa68kAYl8p0WLFpw/f54NGzbQo0cPs8MRduTAhQO0W9SOUh6l+KnPT/h4ZroQp7ATkoBEvmNvSw4L+xF2LgxPZ09+6vsTfkX8zA5H3IX0ghNC5HsWqwVHB0d61+5N1xpd8XTxNDskkQ3yp6QQIl87F3eOBrMbsOnoJgBJPvmIJKB87vXXX+epp57Kdv0pU6ZQu3ZtrFZrLkYlRN64knCFxxc8ztErRyniKhPA5jeSgPKxY8eOMWvWLD744INsv2fQoEFcuHCB+fPn52JkQuS+2KRY2i9qz5HLR1j93GqaVmhqdkjiHkkCyseCg4OpU6cODRs2zPZ73N3d6du3L5MmTcrFyITIXYkpiTy1+Cn2RO5heY/ltKncxuyQxH2QBGSHrl+/zsiRI6lSpQouLi4opdJtn332GUlJSSxcuJDnn38+3Xt37NiRof7N7dVXXwXgueee4++//+a3334z48cTIsecHZypWaomC7ouoFP1TmaHI+6T9IKzM1prunXrxo4dOxg9ejQNGzZk586djBkzBn9/f3r27MmTTz7Jrl27iI6OzjAQs0aNGuzcuTPdvk8++YSNGzfyzDPPAFC3bl2KFCnCpk2baNo089sWWmssFstd41VK4eiY95M7rlixAoA9e/YAsHHjRnx8fPDx8SEwMDDP4xF5I8WawqX4S5TxKsOMjjPMDkfklNY6xxswF7gAHMiiXAGfA0eB/UD97By3QYMG+k7+/vvvO5bnhmvXruXq8b/44gutlNI//vhjuv1du3bVpUqV0larVWut9SeffKKVUjopKemOx3vvvfe0m5ub3rBhQ7r9jz32mA4KCsryfVu3btXAXbfAwMBM35/b5+le48lL2f1cBgYG6jp16uRuMAXE1q1btcVq0f1W9dMVp1TU0QnRZodkl7Zu3Wp2CBkAoTqL3/G2ugKaB0wHslpmsD1QNXV7BJiR+lXc5uuvvyYoKIigoKB0+2vUqMGaNWvS1jGJjIykSJEiuLi4ZHmst99+m88//5y1a9fStm3bdGU+Pj6Eh4dn+d4GDRqwe/fuu8Zr1rLbxudaFBZaa4ZuHMr8ffMZ03IMRd2Kmh2SsAGbJCCt9S9KKf87VOkMfJOaDXcppYoppcpqraNs0f6tcn+dqez9wr2f34/nz58nNDSUKVOmZCiLioqiTJkyaa8TExNxdXXN8lhvvfUWs2bNYsOGDbRs2TJDubu7OwkJWU/S6OXlRd26de8asyzsJfLCnBNz+PbMtwx/dDjvtXjP7HCEjeTVMyA/4MwtryNS990xAR05ciTDL89nnnmG1157jfj4eM6dO5fJtCzVbRBuzsXFxXH27NkM+ytUqICHhwfXrl0jKir9j79//34AypYtS3R0NOfPnwfAYrGwbt06Hn/8cZKTk3FxccHDw4OrV69y5MiRdMeoXLkyw4YNY968ecyePZuyZcumq1OlShUcHR05d+4cXl5eGd5fvbpx/latWkX37t3v+nM2atSIBQsWAODo6EiVKlUAuHjxIpGRkenqOjs7U7lyZQBOnz6dIQG6urri7+8PwMmTJ0lKSkpX7u7uTsWKFQE4fvx4hqWvPT09KV++PGBMWHr7Myxvb2/KlSsHQHh4eIarqKJFi6Yl+dvPC0Dx4sUpXbo0FouFo0ePZigvWbIkpUqVwmKxZJr0X331VZ599lnOnDlDnz59CAsLIyUlJa3uW2+9xVNPPcWRI0d45ZVXMrz/3XffpW3btoSFhfHmm29mKP/4449p2rQpv/32G6NGjcpQHhwcTN26dfnpp58YN25chvJZs2ZRvXp11q5dy2effZahfMGCBVSoUIGlS5cyY0bG5y8rVqygVKlSzJs3j3nz5mUo37BhAx4eHnz55ZcsW7YsQ3lISAgAkyZNYt26dWn7o8pGEV4jnFcavMLEoImMGzeOLVu2pHtvyZIl+e677wB45513MjwHLV++PAsXLgTgzTffJCwsLF15tWrVmD17NgADBw7McHegbt26BAcHA9C7d28iIiLSlT/66KNMmDABgO7du3P58uV05W3atOG994zE2b59+wyf/Y4dOzJ8+HCATD87t/7ee/LJJzOU9+/fn/79+xMTE5Otz97tcvuzl5W8SkCZ/Zmc6TWCUmogMBCMX1jR0dHpysPDwwkJCSExMREXFxdSUlLSlZ85E0HRokWxWCyZJoBixYpRpEgRbty4kSEBAJQoUQIvLy+SkpLSEsCtihcvjre3N4mJiVy4cCFDuY+PD+7u7sTHJ2SIDYwebhaLhfj4+AzlHh4eAOzbt48WLVqklc+ePZuYmBi6d+9OXFwczs7O+Pv7c+PGDSIiItJ+aWqteeWVV1i5ciWLFi3iwQcfzNBGXFwcDg4OnDp1ioCAgAzlsbGxgJGIlixZkq7MwcGB0qVLAxATE0NCQgKenp5px9BaExsbS5EihWNA4IEDB9K9TkxMJDY2lqSkpAyfW4CDBw8SEhLChQsXiI6OJiUlBa11Wt2
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 111\n",
"\n",
"import numpy as np\n",
"\n",
"def sigmoid(z):\n",
" return 1 / (1 + np.exp(-z))\n",
"\n",
"z = np.linspace(-5, 5, 200)\n",
"\n",
2016-09-27 23:31:21 +02:00
"plt.plot([-5, 5], [0, 0], 'k-')\n",
"plt.plot([-5, 5], [1, 1], 'k--')\n",
"plt.plot([0, 0], [-0.2, 1.2], 'k-')\n",
"plt.plot([-5, 5], [-3/4, 7/4], 'g--')\n",
"plt.plot(z, sigmoid(z), \"b-\", linewidth=2,\n",
" label=r\"$\\sigma(z) = \\dfrac{1}{1+e^{-z}}$\")\n",
2016-09-27 23:31:21 +02:00
"props = dict(facecolor='black', shrink=0.1)\n",
"plt.annotate('Saturating', xytext=(3.5, 0.7), xy=(5, 1), arrowprops=props,\n",
" fontsize=14, ha=\"center\")\n",
"plt.annotate('Saturating', xytext=(-3.5, 0.3), xy=(-5, 0), arrowprops=props,\n",
" fontsize=14, ha=\"center\")\n",
"plt.annotate('Linear', xytext=(2, 0.2), xy=(0, 0.5), arrowprops=props,\n",
" fontsize=14, ha=\"center\")\n",
"plt.grid(True)\n",
"plt.axis([-5, 5, -0.2, 1.2])\n",
"plt.xlabel(\"$z$\")\n",
"plt.legend(loc=\"upper left\", fontsize=16)\n",
2016-09-27 23:31:21 +02:00
"\n",
"save_fig(\"sigmoid_saturation_plot\")\n",
2016-09-27 23:31:21 +02:00
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Xavier and He Initialization"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"dense = tf.keras.layers.Dense(50, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\")"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"he_avg_init = tf.keras.initializers.VarianceScaling(scale=2., mode=\"fan_avg\",\n",
" distribution=\"uniform\")\n",
"dense = tf.keras.layers.Dense(50, activation=\"sigmoid\",\n",
" kernel_initializer=he_avg_init)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Nonsaturating Activation Functions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Leaky ReLU"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAADkCAYAAADeiPCXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAg9ElEQVR4nO3deXhU9fk28PuZJEASEkBCggEFDLuILBGoCA2kUMrqqwgIBamWaHHtJQj059JXARdc3rpc1pZaqlB/gtVWLVqhJULQBhITQJIQliIFQ1LAGLLQJDPP+wdkZJjJZCaznDMz9+e65oI5y/c888zJ3DkzJ2dEVUFERBRsFqMLICKiyMQAIiIiQzCAiIjIEAwgIiIyBAOIiIgMwQAiIiJDRBux0aSkJO3Zs6cRm3arpqYG8fHxRpcRMtgvzx04cABWqxUDBw40upSQEan7V00NUFoK2GxAp05Ar16AiCfrmbNf+fn5p1S1i6t5hgRQz549kZeXZ8Sm3crOzkZGRobRZYQM9stzGRkZqKysNOV+b1aRuH/985/AxInnw2f2bGD9eiDaw1dps/ZLRL5qbh7fgiMiMoGm8Dl71vvwCVUMICIig0Vi+AAMICIiQ0Vq+AAMICIiw0Ry+AAMICIiQ0R6+AAGnQXXkqqqKlRUVKChoSGo2+3QoQOKi4uDus1Qxn557rHHHoOqhlW/YmJikJycjMTERKNLCTkMn/NM95CrqqpQXl6Obt26ITY2FuLJCfB+cvbsWSQkJARte6GO/fKcxWJBY2MjBgwYYHQpfqGqqKurw4kTJwCAIeQFhs93TPcWXEVFBbp164a4uLighg8ReU5EEBcXh27duqGiosLockIGw8eR6QKooaEBsbGxRpdBRB6IjY0N+lvloYrh48x0AQSARz5EIYI/q55h+LhmygAiIgoXDJ/mMYCIiAKE4eMeA4iIKAAYPi1jAIWYZcuWYcKECUaX4VfffPMNUlJScPjwYY+WnzlzJp5//vkAVxV83vYh0MK1z8HA8PEMA8jPxo8fjwULFgRs/MLCQgwZMsSnMcaPHw8RgYggJiYGffr0wW9/+1uvx5k6darbx9q5c2esWbPGafojjzyCHj162O+vXr0akydPRlpamkfbfeyxx7By5Up8++23XtdsZt72IdDCtc+BxvDxHAPIzwoKCjB8+PCAjb9nzx4MHTrUpzEKCgqwevVqlJWV4dChQ5gzZw7uvPNOFBQUeDXO3r17m32sR48exZkzZ5Cenu40Ly8vz75ebW0t1q5dizvuuMPj7V5zzTW46qqrsH79eq/qNbPW9CHQwrHPgcbw8Q4DyI8OHz6MysrKZl+UT5w4gQULFqBz587o2LEjbr75ZpSXlzsss3LlSgwePBjt27dHly5dsHDhQtTV1QEATp48ifLycvsRUE1NDebMmYNhw4bh6NGjAIDu3bs7vW2yb98+tGvXDkVFRfYaJ02ahK5du6JHjx648847oarYv3+/x/W29Fjz8/MhIhg2bJjLeU3BtHnzZlgsFowePdphmWeeecZ+lHbx7dFHHwUATJ8+HW+99ZbLbfvi+PHjEBG8/fbbGD9+POLi4nDttdeipKQEeXl5GDt2LOLi4jBixAgcO3bMvp675w0APv74YwwdOhRfffXdd3Pdf//9SEtLQ3l5ebN9AM73KzMzE7Gxsejduze2b9+OjRs32pdtqVfN8WS9QPU5HDF8WkFVg34bPny4NqeoqKjZeYFWVVXl0/pvv/22WiwWPXv2rNO8I0eOaHJysi5fvlyLioq0oKBAx44dqzfeeKPDco899pjm5OTo0aNHdcuWLZqamqqrV69WVdXNmzdrbGysNjY2aklJiQ4cOFDnzZuntbW19vVnzpypc+bMcRhz/Pjxevfdd9trTExM1MbGRlVVLSsr0zlz5qjFYtH9+/d7XK+7x6qqumLFCu3bt6/T9KNHjyoA/fjjj1VV9b777tMJEyY4LVdVVaVlZWX224MPPqhdu3bVgwcPqqrqRx99pDExMQ6PvcmqVas0Pj7e7W379u0u6/7ggw8UgH7/+9/X7Oxs3bt3r/br109Hjhyp48aN0x07dmhhYaH26tVL77//fvt67p43VdXi4mIdOHCg/vSnP1VV1TVr1miXLl20tLTUbR927dqlsbGx+sQTT2hpaanOmzdPMzIydPDgwfr3v//do141x5P13PW5SaB+Zrdt2xaQcQPh889VExJUAdXZs1UbGoJfg1n7BSBPm8mCkAggwJibtx566CHt37+/y3kTJ07UFStWOEzbsmWLJiQkuB1z0aJFumDBAlVVXb16tY4YMULfeecdveyyy/SFF15wWv65557TtLQ0+/333ntPO3XqpKdOnbLXaLFYND4+XmNjYxWAtmnTxmmslup96KGHXAbMxevPnTvXafo777yjAOz1zJgxw/74mvPUU09pamqqlpSU2Kft2bNHAeihQ4eclj99+rQePHjQ7a25F9SVK1dqhw4d9OTJk/Zp99xzjyYlJdlrVlVduHChzpo1q9maL37eVFVLSkr0tdde0+joaH3yySe1ffv2umvXLvv85vowZswYh+389a9/VYvFohkZGS6366pXnmhuPXd9bhLpAWSG8FE1b78YQEEKoMzMTJ03b57T9K+++koBaGxsrMNv4e3atdPLLrvMvtyxY8f03nvv1UGDBmmnTp00Pj5eo6Oj9Re/+IWqqs6aNUs7deqkiYmJmp2d7bKGnTt3KgA9ffq0njt3TtPS0hzCJTMzU7OysvTgwYOan5+vP/zhD+1HR97Um5mZ6fYFOCkpSZ9//nmn6StWrNAePXrY70+cOFGzsrKaHWf16tWampqqBw4ccJheWlqqAHTfvn3NrtsaM2fOdHoOp0+f7lTj+PHjddmyZara8vOmej6AvvzyS/3e976nUVFRunnzZofxXPWhrKxMATi8sGzZskUB6I4dO5xqb65XLXG3nid9juQAMkv4qJq3X+4CKCTeoVQNznZ8vbpzQUEBHn74YafphYWFSExMRH5+vtO8Nm3aAABOnz6N6667DmPHjsWzzz6L7t27IyoqCtddd539M5/CwkLcdNNN+OMf/4jTp0+7rGH48OFo06YN8vLyUFBQgOjoaNx9990ONc6fPx+9e/cGALz22mvo1asX7rzzTlxzzTUe11tQUIAlS5a4rOH48eM4deoUBg0a5DRvy5YtDp9zJCUl4ZtvvnE5zqpVq/DrX/8an376qb3eJmfOnAEAdOnSxWm91atXY/Xq1S7HbPLRRx9hzJgxTtP37NmD++67z2FaQUEBHn/8caflsrKyPHremuTm5mLPnj1QVaSkpDjMc9WHpq9uuO666+zTDhw4gH79+uGGG25wWNZdr9xpaT13fY50/MzHD5pLJk9vANoB2AVgD4D9AP5vS+uE42dAR44cUQAuj0w2b96sUVFRzX5eoqr6hz/8QTt06KA2m80+bd26dQpAS0tLtaamRi0Wi+7atUs3bNig8fHxmp+f73KskSNH6j333KMJCQn6wQcfONV46XrDhg3TpUuXelxv0ziX/hZ/6fy//OUvDtNzc3NVRBymr1mzRq+++mqnMR5//HG94oormn3rZ+3atZqamupyXmvfgquurlaLxaI5OTkOYwHQwsJC+7Rjx44pAD1w4ECLz1uTP//5z5qQkKC/+93v9KabbtKJEyc6bNtVH959910VEXutVVVVevnll+vQoUO96lVzPFnPXZ+bROIRkJmOfJqYtV8I5FtwAARA+wv/jwGQC2CUu3XCMYA2bdqkAHTnzp26b98++62oqEjPnDmjSUlJeuONN+oXX3yhhw4d0k8++UQXL16sVqtVVVU//PBDjYqK0nfffVcPHjyoL774onbt2lUTEhLUZrPpZ599plFRUfYXo4cfflhTU1P1+PHjTrU88MADKiJOL3KbNm1Si8Xi9OK7bNky7d27t/1+S/U2PdYtW7Y4PVZVVZvNpgMGDNBBgwbpJ598onv27NHXX39dU1NTddq0aQ7b3rt3r1osFofPV1auXKmdO3fWnTt3OnxIXldXZ1/mtttu09tvv701T1WzPvvsM6c
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"# extra code this cell generates and saves Figure 112\n",
"\n",
"def leaky_relu(z, alpha):\n",
" return np.maximum(alpha * z, z)\n",
"\n",
"z = np.linspace(-5, 5, 200)\n",
"plt.plot(z, leaky_relu(z, 0.1), \"b-\", linewidth=2, label=r\"$LeakyReLU(z) = max(\\alpha z, z)$\")\n",
"plt.plot([-5, 5], [0, 0], 'k-')\n",
"plt.plot([0, 0], [-1, 3.7], 'k-')\n",
"plt.grid(True)\n",
"props = dict(facecolor='black', shrink=0.1)\n",
"plt.annotate('Leak', xytext=(-3.5, 0.5), xy=(-5, -0.3), arrowprops=props,\n",
" fontsize=14, ha=\"center\")\n",
"plt.xlabel(\"$z$\")\n",
"plt.axis([-5, 5, -1, 3.7])\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.legend()\n",
"\n",
"save_fig(\"leaky_relu_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"leaky_relu = tf.keras.layers.LeakyReLU(alpha=0.2) # defaults to alpha=0.3\n",
"dense = tf.keras.layers.Dense(50, activation=leaky_relu,\n",
" kernel_initializer=\"he_normal\")"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-12-16 11:22:41.636848: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
],
"source": [
"model = tf.keras.models.Sequential([\n",
" # [...] # more layers\n",
" tf.keras.layers.Dense(50, kernel_initializer=\"he_normal\"), # no activation\n",
" tf.keras.layers.LeakyReLU(alpha=0.2), # activation as a separate layer\n",
" # [...] # more layers\n",
"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ELU"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Implementing ELU in TensorFlow is trivial, just specify the activation function when building each layer, and use He initialization:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"dense = tf.keras.layers.Dense(50, activation=\"elu\",\n",
" kernel_initializer=\"he_normal\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"### SELU"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By default, the SELU hyperparameters (`scale` and `alpha`) are tuned in such a way that the mean output of each neuron remains close to 0, and the standard deviation remains close to 1 (assuming the inputs are standardized with mean 0 and standard deviation 1 too, and other constraints are respected, as explained in the book). Using this activation function, even a 1,000 layer deep neural network preserves roughly mean 0 and standard deviation 1 across all layers, avoiding the exploding/vanishing gradients problem:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAD+CAYAAAB8xdFqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA5cElEQVR4nO3deXgUVfbw8e/NDmQjBEKIbIogIMgmIo4QDJuIuKCijAqjA4KM44bj6wY6DKg4MuMAIowO4k9Fx3EQVEBhIICKQoCIrBJWCSAQCCQkIUuf949KmjTphIR0Ut3J+TxPPemuulV96nanT1fVrXuNiKCUUkpVNz+7A1BKKVU7aQJSSillC01ASimlbKEJSCmllC00ASmllLKFJiCllFK2CLDjRaOjo6VFixZ2vHSZzpw5Q7169ewOw2dofZXfzp07KSgooF27dnaH4jO88vPlcICfX4lZKSmQkQEBAdCmDYSEVH9oXllfwIYNG46LSEN3y2xJQC1atCApKcmOly5TYmIi8fHxdofhM7S+yi8+Pp709HSv/Nx7K6/7fE2YAAsXwpIlEBMDQHY2DBkCmzZZs1asALt+Y3hdfRUyxuwvbZmeglNKqQuZOhUmTYKtW2HDBuBc8lm+3P7k46s0ASmlVFnefBOefhqMgXffhUGDNPl4SKUTkDEmxBizzhjzozFmqzHmJU8EppRStps3D8aNsx7PmgX33qvJx4M8cQ3oLHCDiGQaYwKBb4wxS0Tkew9sWyml7PHJJ/DAA9bjv/4VHnpIk4+HVToBidWbaWbh08DCSXs4VUr5trVrrSZuL74ITz6pyacKGE/0hm2M8Qc2AK2AmSLytJsyo4HRADExMV0/+uijUrcXFBRESEgIxphKx1YRIlLtr+nLtL7K7+TJk4gIUVFRdofiM6rq8yUi5OTkkJube6GCRK1bx4nu3Tmb689zz13Jhg1R1K+fy7RpybRokeXx2CojMzOT0NBQu8MooU+fPhtEpJu7ZR5JQM6NGRMJLAAeEZEtpZXr1q2blNYc9cCBAxhjiImJITAwsFq/4DIyMggLC6u21/N1Wl/lt3PnTvLz82nfvr3dofiMqvh8iQh5eXn8+uuviAjNmjVzLbBxIzRtCg3P3bbiK0c+XtwMu9QE5NFWcCKSDiQCAy92G2fOnCEuLo6goCD9da2U8ihjDEFBQcTFxXHmzBnXhcnJkJAAvXvDsWOA7yQfX+WJVnANC498MMbUAfoCOyoVlJ+2DldKVZ0S3zHbt0P//pCeDm3bQv36mnyqgSdawcUC8wqvA/kB/xaRLzywXaWUqnp79kDfvtZRz8CB8OGHZOcFaPKpBp5oBbcZ6OyBWJRSqnodPGiddjt0yDr19umnZDuCNflUEz3XpZSqnRwO68hn3z7o3h0+/5xsU1eTTzXSBFQDnTx5kpiYGHbv3l2u8nfccQfTpk2r4qjKVtGYi/Tu3ZtOnTrRqVMnAgICWLt2bRVFWFJ5623kyJE89NBDzucOh4OHHnqIBg0aYIwhMTGxCqP0DiNHjmTw4MF2h+HKzw8GD4aOHWHJErIDwjT5VDcRqfapa9euUppt27aVuqyqnT59ulLrjxgxQrBuwnWZrrnmGufym266qdT1e/fuLePGjSsxf+7cuVKvXr1yxzF+/HgZOXJkuctv3rxZ6tevL+np6eVeR6Ty9VVcRWM+30svvSQPP/ywx+IREVm1apXcfPPN0qRJEwFk7ty5LsvLW2/p6emybt062bJli4iIfP755xIYGCjffvutHD58WM6ePevRuCti5syZ0qJFCwkODpYuXbrI6tWrq+R1LvTZd8eTny93tm3bJuJwiGRkSFaWSN++IiASEyOydWuVvnSVWLlypd0huAUkSSm5QI+APKxv374cPnzYZVq8eHG1vX5WVhZvv/02Dz74YLnX6dChA5deeinvv/9+FUZWuouJubjp06eza9cuZsyY4dG4MjMzufLKK3njjTeoU6dOieXlrbeIiAjCw8Odz1NSUoiNjaVnz540btyYoKAgj8Z9vtOnT5Oenl5i/scff8yjjz7Ks88+y6ZNm+jZsyc33ngjBw4cqNJ4bFVQAPv3Q16e9dwYsv1D9cjHJpqAPCw4OJjGjRu7TNV59/vixYvx8/Pjuuuuc5k/depUjDElpgkTJgAwZMgQ5s+fXyUxbdiwgYSEBOrUqUOrVq1YvXo1//73v50xlhYzQGpqKvfffz8NGjQgMjKSoUOH8uuvvzqXv/fee6xYsYK5c+d6/L6xQYMGMWXKFO64445Sbw0oT70VPwU3cuRIHn/8cecN12UNzHih96wsBQUFfPXVVwwfPpzGjRvz448/ligzbdo0Ro4cyahRo2jbti3Tp08nNjaWWbNmXXD7xYkIU6dO5bLLLqNOnTp06NChzKS8evVqevToQWhoKBEREVxzzTVs2bKlxPY6duxYru1BOevK4bBavB07Zv1F7/OxmyagGmbNmjV07dq1xJfx2LFjXY7KnnzySRo3bsz9998PQPfu3Vm3bh3Z2dkltjllyhRCQ0NLTLGxsc7Ha9ascRvP+vXruf766+nTpw+bN2+mR48eTJw4kcmTJzNp0qQyY967dy9dunQhLi6Ob775hsTERI4fP86YMWMAWLBgAR988AEfffQRAQG2jK1YZr2588YbbzBhwgQuueQSDh8+zPr160ste6H3zJ2tW7fypz/9iWbNmjFs2DDq1avH0qVL6dWrl0u53NxcNmzYQP/+/V3m9+/fn++++65c+1Lk+eef55133mHmzJls27aNZ555hoceeogvv/yyRNn8/HxuueUWfvOb3/Djjz/yww8/8Oijj+Lv719ie6+//voFt1fkgnUlAnv3wqlT1rClzZohosnHbvb819ZgS5cuLdEf07hx43j11Ver5fX3799PbGxsiflhYWHObk1effVV5s+fT2JiIq1atQKgSZMm5OXlcejQIS677DKXdceMGcNdd91VYpvF+56Ki4tzG8+TTz7JzTffzPPPPw/A8OHDufnmm+nVqxc33HBDmTGPGTOGBx98kClTpjjnvfDCC9x+++0APPDAAzRs2JBrrrkGgIkTJ3LbbbeVUTueV1a9uRMREUFYWBj+/v40bty4zLIXes+KpKWl8cEHH/Dee++xefNmBg4cyN///neGDBlCcHCw220fP36cgoICYgpH9iwSExPD8uXLL7gfRc6cOcO0adP4+uuvuf766wFo2bIl69atY+bMmdx0000u5YtOB958883O+rriiivcbq9Tp06EhYWVub1y1ZWI1dLt5Enw94fWrXEE1+HoUU0+dvOJBFR9PfK49jt1Md3k9erVizlz5rjMi4yMrERMFZOdnV3iS6W4l19+mRkzZrBy5Upat27tnF90jcPdL/moqCi3pxEv1FfXkSNHWLNmDStXrnTOCwoKwuFwOI9+Sov5wIEDfP3116xZs4Z//OMfzvkFBQXUrVsXsFrOXcjzzz/P5MmTyyyzcuXKi+5Dq6x685TS3rMi06dP56WXXqJnz57s2rWL5s2bl3vb5x91SgU7AN22bRs5OTkMHDjQZb28vDy3pxejoqIYOXIkAwYMICEhgYSEBO68806aNm16Uds7X4m6EoEDByAtzWr1dvnlOELqkpICOTmafOzmEwnIl9StW7fEL9TyCg8P59SpUyXmp6enExERUa5tREdHl/rFPHnyZN566y1WrVpVIsYTJ04A0LBYJ4xFpkyZ4nIU4s6SJUucv4CLbN++HYCrr77aOW/nzp20adOG3/zmN2XGnJycTHh4OBsKhz8uriIX7R977DHuvffeMsuU6JCyAsqqN08o6z0rMnr0aAIDA3nvvfdo3749t912G/fddx8JCQkup7aKi46Oxt/fnyNHjrjMP3r0aJk/YM7ncDgA+Pzzz0vUY2BgoNt15s6dy2OPPcbSpUtZtGgRzz33HJ999hkDBgxw2V5UVJTL2YTStlfEbV2dOmVd8zEGWrXCUTeUlBQ4fdrKR5p87OUTCciDHXaXye7endu0acPixYtL/ArduHEjbdq0cSm7adMmHnvsMVJTUxk/fjwrVqzg73//O50
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 113\n",
"\n",
"from scipy.special import erfc\n",
"\n",
"# alpha and scale to self normalize with mean 0 and standard deviation 1\n",
"# (see equation 14 in the paper):\n",
"alpha_0_1 = -np.sqrt(2 / np.pi) / (erfc(1 / np.sqrt(2)) * np.exp(1 / 2) - 1)\n",
"scale_0_1 = (\n",
" (1 - erfc(1 / np.sqrt(2)) * np.sqrt(np.e))\n",
" * np.sqrt(2 * np.pi)\n",
" * (\n",
" 2 * erfc(np.sqrt(2)) * np.e ** 2\n",
" + np.pi * erfc(1 / np.sqrt(2)) ** 2 * np.e\n",
" - 2 * (2 + np.pi) * erfc(1 / np.sqrt(2)) * np.sqrt(np.e)\n",
" + np.pi\n",
" + 2\n",
" ) ** (-1 / 2)\n",
")\n",
"\n",
"def elu(z, alpha=1):\n",
" return np.where(z < 0, alpha * (np.exp(z) - 1), z)\n",
"\n",
"def selu(z, scale=scale_0_1, alpha=alpha_0_1):\n",
" return scale * elu(z, alpha)\n",
"\n",
"z = np.linspace(-5, 5, 200)\n",
"plt.plot(z, elu(z), \"b-\", linewidth=2, label=r\"ELU$_\\alpha(z) = \\alpha (e^z - 1)$ if $z < 0$, else $z$\")\n",
"plt.plot(z, selu(z), \"r--\", linewidth=2, label=r\"SELU$(z) = 1.05 \\, $ELU$_{1.67}(z)$\")\n",
"plt.plot([-5, 5], [0, 0], 'k-')\n",
"plt.plot([-5, 5], [-1, -1], 'k:', linewidth=2)\n",
"plt.plot([-5, 5], [-1.758, -1.758], 'k:', linewidth=2)\n",
"plt.plot([0, 0], [-2.2, 3.2], 'k-')\n",
"plt.grid(True)\n",
"plt.axis([-5, 5, -2.2, 3.2])\n",
"plt.xlabel(\"$z$\")\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.legend()\n",
"\n",
"save_fig(\"elu_selu_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Using SELU is straightforward:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"dense = tf.keras.layers.Dense(50, activation=\"selu\",\n",
" kernel_initializer=\"lecun_normal\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Extra material an example of a self-regularized network using SELU**\n",
"\n",
"Let's create a neural net for Fashion MNIST with 100 hidden layers, using the SELU activation function:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42)\n",
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[28, 28]))\n",
"for layer in range(100):\n",
" model.add(tf.keras.layers.Dense(100, activation=\"selu\",\n",
" kernel_initializer=\"lecun_normal\"))\n",
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's train it. Do not forget to scale the inputs to mean 0 and standard deviation 1:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"fashion_mnist = tf.keras.datasets.fashion_mnist.load_data()\n",
"(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist\n",
"X_train, y_train = X_train_full[:-5000], y_train_full[:-5000]\n",
"X_valid, y_valid = X_train_full[-5000:], y_train_full[-5000:]\n",
"X_train, X_valid, X_test = X_train / 255, X_valid / 255, X_test / 255"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"class_names = [\"T-shirt/top\", \"Trouser\", \"Pullover\", \"Dress\", \"Coat\",\n",
" \"Sandal\", \"Shirt\", \"Sneaker\", \"Bag\", \"Ankle boot\"]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"pixel_means = X_train.mean(axis=0, keepdims=True)\n",
"pixel_stds = X_train.std(axis=0, keepdims=True)\n",
"X_train_scaled = (X_train - pixel_means) / pixel_stds\n",
"X_valid_scaled = (X_valid - pixel_means) / pixel_stds\n",
"X_test_scaled = (X_test - pixel_means) / pixel_stds"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-12-16 11:22:44.499697: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"1719/1719 [==============================] - 13s 7ms/step - loss: 1.3735 - accuracy: 0.4548 - val_loss: 0.9599 - val_accuracy: 0.6444\n",
"Epoch 2/5\n",
"1719/1719 [==============================] - 12s 7ms/step - loss: 0.7783 - accuracy: 0.7073 - val_loss: 0.6529 - val_accuracy: 0.7664\n",
"Epoch 3/5\n",
"1719/1719 [==============================] - 12s 7ms/step - loss: 0.6462 - accuracy: 0.7611 - val_loss: 0.6048 - val_accuracy: 0.7748\n",
"Epoch 4/5\n",
"1719/1719 [==============================] - 11s 6ms/step - loss: 0.5821 - accuracy: 0.7863 - val_loss: 0.5737 - val_accuracy: 0.7944\n",
"Epoch 5/5\n",
"1719/1719 [==============================] - 12s 7ms/step - loss: 0.5401 - accuracy: 0.8041 - val_loss: 0.5333 - val_accuracy: 0.8046\n"
]
}
],
"source": [
"history = model.fit(X_train_scaled, y_train, epochs=5,\n",
" validation_data=(X_valid_scaled, y_valid))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The network managed to learn, despite how deep it is. Now look at what happens if we try to use the ReLU activation function instead:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[28, 28]))\n",
"for layer in range(100):\n",
" model.add(tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"))\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"1719/1719 [==============================] - 12s 6ms/step - loss: 1.6932 - accuracy: 0.3071 - val_loss: 1.2058 - val_accuracy: 0.5106\n",
"Epoch 2/5\n",
"1719/1719 [==============================] - 11s 6ms/step - loss: 1.1132 - accuracy: 0.5297 - val_loss: 0.9682 - val_accuracy: 0.5718\n",
"Epoch 3/5\n",
"1719/1719 [==============================] - 10s 6ms/step - loss: 0.9480 - accuracy: 0.6117 - val_loss: 1.0552 - val_accuracy: 0.5102\n",
"Epoch 4/5\n",
"1719/1719 [==============================] - 10s 6ms/step - loss: 0.9763 - accuracy: 0.6003 - val_loss: 0.7764 - val_accuracy: 0.7070\n",
"Epoch 5/5\n",
"1719/1719 [==============================] - 11s 6ms/step - loss: 0.7892 - accuracy: 0.6875 - val_loss: 0.7485 - val_accuracy: 0.7054\n"
]
}
],
"source": [
"history = model.fit(X_train_scaled, y_train, epochs=5,\n",
" validation_data=(X_valid_scaled, y_valid))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not great at all, we suffered from the vanishing/exploding gradients problem."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### GELU, Swish and Mish"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAADsCAYAAAAy23L6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABbDElEQVR4nO2dd3hURdfAf5PeSCMkEECKIUiXYoiAEHpVFLChIIIiCIIKSrGAH0XBV1+VgIKgqPBSFUSqoCRUKUoLHamBQEhICOltvj9udrObbJINKZsyv+e5z+6de+7Mmbt377kzc+aMkFKiUCgUCkVpY2VpBRQKhUJROVEGSKFQKBQWQRkghUKhUFgEZYAUCoVCYRGUAVIoFAqFRVAGSKFQKBQWocgGSAhRWwixUwhxWghxUggx3oSMEEJ8JYS4IIQ4LoRoVdRyFQqFQlG+sSmGPNKBCVLKf4QQVYC/hRDbpZSnDGR6Aw2ytrbA11mfCoVCoaikFLkFJKWMkFL+k/X9HnAaqJlDrD/wo9T4C3AXQtQoatkKhUKhKL8URwtIjxCiLtASOJDjUE3gmsF+eFZahIk8RgIjARwcHFo/8MADxalimSczMxMrq8o1NKfqXPG5du0aUkoq2/8ZKt9vDXDu3LkoKWW1guSKzQAJIVyAn4E3pZRxOQ+bOMVkDCAp5SJgEUDDhg3l2bNni0vFckFISAhBQUGWVqNUUXWu+AQFBREbG8vRo0ctrUqpU9l+awAhxBVz5IrFLAshbNGMz3Ip5S8mRMKB2gb7tYAbxVG2QqFQKMoOUVHmyxaHF5wAlgCnpZSf5yG2ARia5Q0XCNyVUubqflMoFApF+eXqVWjZ0nz54uiCaw8MAU4IIY5mpU0FHgCQUn4DbAb6ABeARODlYihXoVAoFGWEyEjo3h3Cw80/p8gGSEq5B9NjPIYyEhhT1LIUCoVCUba4GHORLadDWTL2Zc6dgxYt4Ngx884tVi84hUKhUFQeIu5F0O3H7lyKvQjOt/Dzm8y2bVC9unnni7K8IF1+XnCZmZmEh4eTkJBQylqVLMnJyTg4OFhajVJF1bnic/PmTaSU1KhhPP3P1tYWb29vXF1dLaRZyVORveCG/PISy078qO2kOxD6zEk6NquPEOJvKWWbgs4vty2gqKgohBA0bNiwQvnY37t3jypVqlhajVJF1bniY2VlRXp6Oo0aNdKnSSlJSkri+vXrABXaCFVEMjMh/ddgENegzi6CO66hY7P6hcqj3D65Y2Nj8fHxqVDGR6GoTAghcHJyombNmkRGRlpaHUUhkBLeegtW/lgFp3WbmR/4O2O69yt0PuW2BZSRkYGtra2l1VAoFEXE0dGRtLQ0S6uhKAQzZsBXX4GdHfz6swPdunW5r3zKrQEC7Q1KoVCUb9T/uHwgpWTKH1NIOdGPL6Z1wMoKVqyAbt3uP0/Vf6VQKBSKAvko9CPm7J3DF9E9oMFmFi2CAQOKlqcyQAqFQqHIl+jEaL7c+422Y5tEq5dWMmJE0fNVBkihUCgU+XL2aFWS5u+BmLrUz+jF/smLiyVfZYAUCoVCkSdhYdC3L6RE+PF84j6OvbcWO2u73IJ//glnzhQqb2WAFADExMTg4+PDv//+a5b8oEGD+PzzvGLPlj0qev0UiuJGSsmVK9CzJ8TGwpNPwo8LauBi75xb+K+/4PHHoUMHuHzZ7DKUAbIAt27d4q233qJBgwY4ODjg7e1Nu3btmDdvHvHx8QAMGzYMIUSuLTAwUJ/PsGHD6Ncvb9/7oKAgxo4dmyt96dKluLi4GKXNnj2bPn368OCDD5pVh2nTpjFz5kzu3r1rlnxJcvjwYYQQXM7nxi/P9VMoSpujN48SsLA9nfuHc+MGdOqkebzZmPKbPnkS+vSBxETo1w8KseigMkClzOXLl2nVqhVbt25lxowZ/PPPP/z5559MnDiRP/74g82bN+tlu3XrRkREhNFmeLy4SExMZPHixYwoxKhis2bNqF+/PsuWLSt2fczl1KlTDB06lEGDBgHa9Ro+fDg5wzeV1/opFJbgwp0L9PypF4dv7edS5/Y81P4cv/4KJiNHXb4MPXpATAw88QQsXgyFCA5QYQyQEJbZCsvo0aOxsrLi8OHDPPfcczRu3JimTZsyYMAA1q9fz9NPP62Xtbe3p3r16kabp6dnMV41jc2bN2NlZUX79u2N0ufOnWuyFfbhhx8C8MQTT7BixYpi16egcgF++eUXWrRoQUJCAm+++SYAU6ZMITo6mmbNmrF+/foyWz+Foixz9PpJbsffAcDK6S7BC5NwczMheOuWtv6Crom0cmUeTaS8qTAGqDxw584dtm3bxpgxY3B2NtGPimUm5e3evZvWrVvnKnv06NFGra8JEyZQvXp1hg4dCkBAQAAHDx4kKSkpV56zZ8/GxcUl32337t0m9Smo3NTUVF577TV69uzJzz//TIcOHQDo2rUrv/76K926dWPkyJH62fUlUT+FoiKSkQGrZ/RHLv8NkeTF//ptpGuTFrkF09K0brcLF7QV6H79FRwdC11euY6EYEgZDuqt5/z580gpadiwoVF6rVq1iI2NBeDZZ59lyZIlAGzdujXXWM2YMWOYM2dOsep15cqVXFGKAapUqaIPmDlnzhxWrFhBSEgIfn5+APj6+pKWlsaNGzdyja2MGjWKZ555Jt9ya9asaTK9oHJPnDhBVFQUQ4YMMXn+kCFD2LJlCydOnKBVq1YlUj+FoqIhJYwbB2vWgKtrT7YMukS7Ni6mhW1t4fXX4T//ga1bMd1EKpgKY4DKM7t37yYjI4ORI0eSnJysT+/YsSOLFi0yknV3dy/28pOSkvDx8cnz+Mcff0xwcDA7d+7E399fn+6Y9cZjqoXg6elZ5O7CvMrVLSGSV2tRF6BWJ1cS9VMoKgqpGalYCStmzbBhwQKwt4cNG8jb+OgYMQJefFE74T5RXXCliJ+fH0IIzuTwla9Xrx5+fn44OTkZpTs5OeHn52e0eXl5mV2eq6urSS+u2NhY3AzeWLy8vIiJiTGZx6xZs1iwYAGhoaFGD2fQuhQBqlWrluu8onTBFVRu8+bNqVq1KsuXLzd57rJly/Dy8qJZs2YlVj+FoiKQkZnB0HVDaTVnINNnJmFlpQ3ldOpkQjgzEyZMgOPHs9OKYHxAtYBKlapVq9KjRw+Cg4N54403cnWvFTcNGzZk8+bNSCmNWgv//POPUTdgy5YtWbp0aa7zZ8yYwbfffktISIjJLqiwsDB8fX1Nti6K0gVXULl2dnYsWLCAwYMH8+yzz/LYY48BEBoayttvv83WrVtZuXIldnZ2JVY/haK8I6XkjS1vsOrkKi3hxd4Et9/Ck0+aGMuRUjM+X3wBq1ZpYz/FsKCiMkClzIIFC2jfvj2tW7dm+vTptGjRAhsbG/7++2+OHTtG586d9bIpKSncvHnT6Hxra2ujN/K4uDiOHj1qJOPu7k7dunUZPXq03ti9+uqrODg4sHnzZlasWMGvv/6ql+/ZsyeTJk0iOjqaqlWrAlrL4Msvv2TDhg04Ozvr9XB3d9ev5Ll792569eplsp732wVnTrkAzzzzDI0aNeKTTz5h7ty5gDZ3p2PHjhw7dozGjRuXaP0UiorAnYjsRQAfrd+MUSPyMCqzZ2vGx9YWvvuuWIwPoFnBsrr5+/vLvDh16lSex8o6ERERcty4cfLBBx+UdnZ20tnZWbZp00bOnj1bXr9+XUop5UsvvSSBXFvNmjX1+eQlM3DgQL3MwYMHZY8ePaS3t7d0dXWVAQEBct26dbl0CgwMlMHBwVJKKTMzM6Wrq6vJvHfs2CGllDIpKUm6urrK/fv3F/l6xMXFmV2uKQ4dOiQBeenSpTxlLFk/U+jqXFk4c+aMDAsLy/N4ef4/F8TOnTstrYJJDh6U0tlZStrPkQ9NHSzTMzJMC379tZQgpRBSrlplVt7AYWnGM97iRia/raIaoPyw1INpy5Yt0t/fX6anp5slHxwcLLt3714sZZdGnS1ZP1MoA2RMRf0/S1k2DdCZM1J6eWkW4MUXpUxPzzQtuHKlZnhAym++KTjjjAwpk5PNNkDF4oQghPhOCBE
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 114\n",
"\n",
"def swish(z, beta=1):\n",
" return z * sigmoid(beta * z)\n",
"\n",
"def approx_gelu(z):\n",
" return swish(z, beta=1.702)\n",
"\n",
"def softplus(z):\n",
" return np.log(1 + np.exp(z))\n",
"\n",
"def mish(z):\n",
" return z * np.tanh(softplus(z))\n",
"\n",
"z = np.linspace(-4, 2, 200)\n",
"\n",
"beta = 0.6\n",
"plt.plot(z, approx_gelu(z), \"b-\", linewidth=2,\n",
" label=r\"GELU$(z) = z\\,\\Phi(z)$\")\n",
"plt.plot(z, swish(z), \"r--\", linewidth=2,\n",
" label=r\"Swish$(z) = z\\,\\sigma(z)$\")\n",
"plt.plot(z, swish(z, beta), \"r:\", linewidth=2,\n",
" label=fr\"Swish$_{{\\beta={beta}}}(z)=z\\,\\sigma({beta}\\,z)$\")\n",
"plt.plot(z, mish(z), \"g:\", linewidth=3,\n",
" label=fr\"Mish$(z) = z\\,\\tanh($softplus$(z))$\")\n",
"plt.plot([-4, 2], [0, 0], 'k-')\n",
"plt.plot([0, 0], [-2.2, 3.2], 'k-')\n",
"plt.grid(True)\n",
"plt.axis([-4, 2, -1, 2])\n",
"plt.gca().set_aspect(\"equal\")\n",
"plt.xlabel(\"$z$\")\n",
"plt.legend(loc=\"upper left\")\n",
"\n",
"save_fig(\"gelu_swish_mish_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Batch Normalization"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"# extra code - clear the name counters and set the random seed\n",
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)"
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Dense(300, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"flatten (Flatten) (None, 784) 0 \n",
"_________________________________________________________________\n",
"batch_normalization (BatchNo (None, 784) 3136 \n",
"_________________________________________________________________\n",
"dense (Dense) (None, 300) 235500 \n",
"_________________________________________________________________\n",
"batch_normalization_1 (Batch (None, 300) 1200 \n",
"_________________________________________________________________\n",
"dense_1 (Dense) (None, 100) 30100 \n",
"_________________________________________________________________\n",
"batch_normalization_2 (Batch (None, 100) 400 \n",
"_________________________________________________________________\n",
"dense_2 (Dense) (None, 10) 1010 \n",
"=================================================================\n",
"Total params: 271,346\n",
"Trainable params: 268,978\n",
"Non-trainable params: 2,368\n",
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[('batch_normalization/gamma:0', True),\n",
" ('batch_normalization/beta:0', True),\n",
" ('batch_normalization/moving_mean:0', False),\n",
" ('batch_normalization/moving_variance:0', False)]"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"[(var.name, var.trainable) for var in model.layers[1].variables]"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"1719/1719 [==============================] - 3s 2ms/step - loss: 0.5559 - accuracy: 0.8094 - val_loss: 0.4016 - val_accuracy: 0.8558\n",
"Epoch 2/2\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.4083 - accuracy: 0.8561 - val_loss: 0.3676 - val_accuracy: 0.8650\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fa5d11505b0>"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code just show that the model works! 😊\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"sgd\",\n",
" metrics=\"accuracy\")\n",
"model.fit(X_train, y_train, epochs=2, validation_data=(X_valid, y_valid))"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2019-03-25 05:03:44 +01:00
"Sometimes applying BN before the activation function works better (there's a debate on this topic). Moreover, the layer before a `BatchNormalization` layer does not need to have bias terms, since the `BatchNormalization` layer some as well, it would be a waste of parameters, so you can set `use_bias=False` when creating those layers:"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"# extra code - clear the name counters and set the random seed\n",
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(300, kernel_initializer=\"he_normal\", use_bias=False),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Activation(\"relu\"),\n",
" tf.keras.layers.Dense(100, kernel_initializer=\"he_normal\", use_bias=False),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Activation(\"relu\"),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.6063 - accuracy: 0.7993 - val_loss: 0.4296 - val_accuracy: 0.8418\n",
"Epoch 2/2\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4275 - accuracy: 0.8500 - val_loss: 0.3752 - val_accuracy: 0.8646\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fa5fdd309d0>"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code just show that the model works! 😊\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"sgd\",\n",
" metrics=\"accuracy\")\n",
"model.fit(X_train, y_train, epochs=2, validation_data=(X_valid, y_valid))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Gradient Clipping"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"All `tf.keras.optimizers` accept `clipnorm` or `clipvalue` arguments:"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(clipvalue=1.0)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(clipnorm=1.0)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Reusing Pretrained Layers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Reusing a Keras model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's split the fashion MNIST training set in two:\n",
"* `X_train_A`: all images of all items except for T-shirts/tops and pullovers (classes 0 and 2).\n",
"* `X_train_B`: a much smaller training set of just the first 200 images of T-shirts/tops and pullovers.\n",
"\n",
"The validation set and the test set are also split this way, but without restricting the number of images.\n",
"\n",
"We will train a model on set A (classification task with 8 classes), and try to reuse it to tackle set B (binary classification). We hope to transfer a little bit of knowledge from task A to task B, since classes in set A (trousers, dresses, coats, sandals, shirts, sneakers, bags, and ankle boots) are somewhat similar to classes in set B (T-shirts/tops and pullovers). However, since we are using `Dense` layers, only patterns that occur at the same location can be reused (in contrast, convolutional layers will transfer much better, since learned patterns can be detected anywhere on the image, as we will see in the chapter 14)."
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"1376/1376 [==============================] - 1s 908us/step - loss: 1.1385 - accuracy: 0.6260 - val_loss: 0.7101 - val_accuracy: 0.7603\n",
"Epoch 2/20\n",
"1376/1376 [==============================] - 1s 869us/step - loss: 0.6221 - accuracy: 0.7911 - val_loss: 0.5293 - val_accuracy: 0.8315\n",
"Epoch 3/20\n",
"1376/1376 [==============================] - 1s 852us/step - loss: 0.5016 - accuracy: 0.8394 - val_loss: 0.4515 - val_accuracy: 0.8581\n",
"Epoch 4/20\n",
"1376/1376 [==============================] - 1s 852us/step - loss: 0.4381 - accuracy: 0.8583 - val_loss: 0.4055 - val_accuracy: 0.8669\n",
"Epoch 5/20\n",
"1376/1376 [==============================] - 1s 844us/step - loss: 0.3979 - accuracy: 0.8692 - val_loss: 0.3748 - val_accuracy: 0.8706\n",
"Epoch 6/20\n",
"1376/1376 [==============================] - 1s 882us/step - loss: 0.3693 - accuracy: 0.8782 - val_loss: 0.3538 - val_accuracy: 0.8787\n",
"Epoch 7/20\n",
"1376/1376 [==============================] - 1s 863us/step - loss: 0.3487 - accuracy: 0.8825 - val_loss: 0.3376 - val_accuracy: 0.8834\n",
"Epoch 8/20\n",
"1376/1376 [==============================] - 2s 1ms/step - loss: 0.3324 - accuracy: 0.8879 - val_loss: 0.3315 - val_accuracy: 0.8847\n",
"Epoch 9/20\n",
"1376/1376 [==============================] - 1s 1ms/step - loss: 0.3198 - accuracy: 0.8920 - val_loss: 0.3174 - val_accuracy: 0.8879\n",
"Epoch 10/20\n",
"1376/1376 [==============================] - 2s 1ms/step - loss: 0.3088 - accuracy: 0.8947 - val_loss: 0.3118 - val_accuracy: 0.8904\n",
"Epoch 11/20\n",
"1376/1376 [==============================] - 1s 1ms/step - loss: 0.2994 - accuracy: 0.8979 - val_loss: 0.3039 - val_accuracy: 0.8925\n",
"Epoch 12/20\n",
"1376/1376 [==============================] - 1s 837us/step - loss: 0.2918 - accuracy: 0.8999 - val_loss: 0.2998 - val_accuracy: 0.8952\n",
"Epoch 13/20\n",
"1376/1376 [==============================] - 1s 840us/step - loss: 0.2852 - accuracy: 0.9016 - val_loss: 0.2932 - val_accuracy: 0.8980\n",
"Epoch 14/20\n",
"1376/1376 [==============================] - 1s 799us/step - loss: 0.2788 - accuracy: 0.9034 - val_loss: 0.2865 - val_accuracy: 0.8990\n",
"Epoch 15/20\n",
"1376/1376 [==============================] - 1s 922us/step - loss: 0.2736 - accuracy: 0.9052 - val_loss: 0.2824 - val_accuracy: 0.9015\n",
"Epoch 16/20\n",
"1376/1376 [==============================] - 1s 835us/step - loss: 0.2686 - accuracy: 0.9068 - val_loss: 0.2796 - val_accuracy: 0.9015\n",
"Epoch 17/20\n",
"1376/1376 [==============================] - 1s 863us/step - loss: 0.2641 - accuracy: 0.9085 - val_loss: 0.2748 - val_accuracy: 0.9015\n",
"Epoch 18/20\n",
"1376/1376 [==============================] - 1s 913us/step - loss: 0.2596 - accuracy: 0.9101 - val_loss: 0.2729 - val_accuracy: 0.9037\n",
"Epoch 19/20\n",
"1376/1376 [==============================] - 1s 909us/step - loss: 0.2558 - accuracy: 0.9119 - val_loss: 0.2715 - val_accuracy: 0.9040\n",
"Epoch 20/20\n",
"1376/1376 [==============================] - 1s 859us/step - loss: 0.2520 - accuracy: 0.9125 - val_loss: 0.2728 - val_accuracy: 0.9027\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-12-15 16:22:23.274500: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_model_A/assets\n"
]
}
],
"source": [
"# extra code split Fashion MNIST into tasks A and B, then train and save\n",
"# model A to \"my_model_A\".\n",
"\n",
"pos_class_id = class_names.index(\"Pullover\")\n",
"neg_class_id = class_names.index(\"T-shirt/top\")\n",
"\n",
"def split_dataset(X, y):\n",
" y_for_B = (y == pos_class_id) | (y == neg_class_id)\n",
" y_A = y[~y_for_B]\n",
" y_B = (y[y_for_B] == pos_class_id).astype(np.float32)\n",
" old_class_ids = list(set(range(10)) - set([neg_class_id, pos_class_id]))\n",
" for old_class_id, new_class_id in zip(old_class_ids, range(8)):\n",
" y_A[y_A == old_class_id] = new_class_id # reorder class ids for A\n",
" return ((X[~y_for_B], y_A), (X[y_for_B], y_B))\n",
"\n",
"(X_train_A, y_train_A), (X_train_B, y_train_B) = split_dataset(X_train, y_train)\n",
"(X_valid_A, y_valid_A), (X_valid_B, y_valid_B) = split_dataset(X_valid, y_valid)\n",
"(X_test_A, y_test_A), (X_test_B, y_test_B) = split_dataset(X_test, y_test)\n",
"X_train_B = X_train_B[:200]\n",
"y_train_B = y_train_B[:200]\n",
"\n",
"tf.random.set_seed(42)\n",
"\n",
"model_A = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(8, activation=\"softmax\")\n",
"])\n",
"\n",
"model_A.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
" metrics=[\"accuracy\"])\n",
"history = model_A.fit(X_train_A, y_train_A, epochs=20,\n",
" validation_data=(X_valid_A, y_valid_A))\n",
"model_A.save(\"my_model_A\")"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"7/7 [==============================] - 0s 20ms/step - loss: 0.7167 - accuracy: 0.5450 - val_loss: 0.7052 - val_accuracy: 0.5272\n",
"Epoch 2/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.6805 - accuracy: 0.5800 - val_loss: 0.6758 - val_accuracy: 0.6004\n",
"Epoch 3/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.6532 - accuracy: 0.6650 - val_loss: 0.6530 - val_accuracy: 0.6746\n",
"Epoch 4/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.6289 - accuracy: 0.7150 - val_loss: 0.6317 - val_accuracy: 0.7517\n",
"Epoch 5/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.6079 - accuracy: 0.7800 - val_loss: 0.6105 - val_accuracy: 0.8091\n",
"Epoch 6/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.5866 - accuracy: 0.8400 - val_loss: 0.5913 - val_accuracy: 0.8447\n",
"Epoch 7/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.5670 - accuracy: 0.8850 - val_loss: 0.5728 - val_accuracy: 0.8833\n",
"Epoch 8/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.5499 - accuracy: 0.8900 - val_loss: 0.5571 - val_accuracy: 0.8971\n",
"Epoch 9/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.5331 - accuracy: 0.9150 - val_loss: 0.5427 - val_accuracy: 0.9050\n",
"Epoch 10/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.5180 - accuracy: 0.9250 - val_loss: 0.5290 - val_accuracy: 0.9080\n",
"Epoch 11/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.5038 - accuracy: 0.9350 - val_loss: 0.5160 - val_accuracy: 0.9189\n",
"Epoch 12/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4903 - accuracy: 0.9350 - val_loss: 0.5032 - val_accuracy: 0.9228\n",
"Epoch 13/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.4770 - accuracy: 0.9400 - val_loss: 0.4925 - val_accuracy: 0.9228\n",
"Epoch 14/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4656 - accuracy: 0.9450 - val_loss: 0.4817 - val_accuracy: 0.9258\n",
"Epoch 15/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4546 - accuracy: 0.9550 - val_loss: 0.4708 - val_accuracy: 0.9298\n",
"Epoch 16/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4435 - accuracy: 0.9550 - val_loss: 0.4608 - val_accuracy: 0.9318\n",
"Epoch 17/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4330 - accuracy: 0.9600 - val_loss: 0.4510 - val_accuracy: 0.9337\n",
"Epoch 18/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4226 - accuracy: 0.9600 - val_loss: 0.4406 - val_accuracy: 0.9367\n",
"Epoch 19/20\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4119 - accuracy: 0.9600 - val_loss: 0.4311 - val_accuracy: 0.9377\n",
"Epoch 20/20\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.4025 - accuracy: 0.9600 - val_loss: 0.4225 - val_accuracy: 0.9367\n",
"63/63 [==============================] - 0s 728us/step - loss: 0.4317 - accuracy: 0.9185\n"
]
},
{
"data": {
"text/plain": [
"[0.43168652057647705, 0.9185000061988831]"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code train and evaluate model B, without reusing model A\n",
"\n",
"tf.random.set_seed(42)\n",
"model_B = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(1, activation=\"sigmoid\")\n",
"])\n",
"\n",
"model_B.compile(loss=\"binary_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
" metrics=[\"accuracy\"])\n",
"history = model_B.fit(X_train_B, y_train_B, epochs=20,\n",
" validation_data=(X_valid_B, y_valid_B))\n",
"model_B.evaluate(X_test_B, y_test_B)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Model B reaches 91.85% accuracy on the test set. Now let's try reusing the pretrained model A."
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"model_A = tf.keras.models.load_model(\"my_model_A\")\n",
"model_B_on_A = tf.keras.Sequential(model_A.layers[:-1])\n",
"model_B_on_A.add(tf.keras.layers.Dense(1, activation=\"sigmoid\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that `model_B_on_A` and `model_A` actually share layers now, so when we train one, it will update both models. If we want to avoid that, we need to build `model_B_on_A` on top of a *clone* of `model_A`:"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42) # extra code ensure reproducibility"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"model_A_clone = tf.keras.models.clone_model(model_A)\n",
"model_A_clone.set_weights(model_A.get_weights())"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"# extra code creating model_B_on_A just like in the previous cell\n",
"model_B_on_A = tf.keras.Sequential(model_A_clone.layers[:-1])\n",
"model_B_on_A.add(tf.keras.layers.Dense(1, activation=\"sigmoid\"))"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"for layer in model_B_on_A.layers[:-1]:\n",
" layer.trainable = False\n",
"\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)\n",
"model_B_on_A.compile(loss=\"binary_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/4\n",
"7/7 [==============================] - 0s 23ms/step - loss: 1.7893 - accuracy: 0.5550 - val_loss: 1.3324 - val_accuracy: 0.5084\n",
"Epoch 2/4\n",
"7/7 [==============================] - 0s 7ms/step - loss: 1.1235 - accuracy: 0.5350 - val_loss: 0.9199 - val_accuracy: 0.4807\n",
"Epoch 3/4\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.8836 - accuracy: 0.5000 - val_loss: 0.8266 - val_accuracy: 0.4837\n",
"Epoch 4/4\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.8202 - accuracy: 0.5250 - val_loss: 0.7795 - val_accuracy: 0.4985\n",
"Epoch 1/16\n",
"7/7 [==============================] - 0s 21ms/step - loss: 0.7348 - accuracy: 0.6050 - val_loss: 0.6372 - val_accuracy: 0.6914\n",
"Epoch 2/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.6055 - accuracy: 0.7600 - val_loss: 0.5283 - val_accuracy: 0.8229\n",
"Epoch 3/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.4992 - accuracy: 0.8400 - val_loss: 0.4742 - val_accuracy: 0.8180\n",
"Epoch 4/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.4297 - accuracy: 0.8700 - val_loss: 0.4212 - val_accuracy: 0.8773\n",
"Epoch 5/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.3825 - accuracy: 0.9050 - val_loss: 0.3797 - val_accuracy: 0.9031\n",
"Epoch 6/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.3438 - accuracy: 0.9250 - val_loss: 0.3534 - val_accuracy: 0.9149\n",
"Epoch 7/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.3148 - accuracy: 0.9500 - val_loss: 0.3384 - val_accuracy: 0.9001\n",
"Epoch 8/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.3012 - accuracy: 0.9450 - val_loss: 0.3179 - val_accuracy: 0.9209\n",
"Epoch 9/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2767 - accuracy: 0.9650 - val_loss: 0.3043 - val_accuracy: 0.9298\n",
"Epoch 10/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2623 - accuracy: 0.9550 - val_loss: 0.2929 - val_accuracy: 0.9308\n",
"Epoch 11/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2512 - accuracy: 0.9600 - val_loss: 0.2830 - val_accuracy: 0.9327\n",
"Epoch 12/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2397 - accuracy: 0.9600 - val_loss: 0.2744 - val_accuracy: 0.9318\n",
"Epoch 13/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2295 - accuracy: 0.9600 - val_loss: 0.2675 - val_accuracy: 0.9327\n",
"Epoch 14/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2225 - accuracy: 0.9600 - val_loss: 0.2598 - val_accuracy: 0.9347\n",
"Epoch 15/16\n",
"7/7 [==============================] - 0s 6ms/step - loss: 0.2147 - accuracy: 0.9600 - val_loss: 0.2542 - val_accuracy: 0.9357\n",
"Epoch 16/16\n",
"7/7 [==============================] - 0s 7ms/step - loss: 0.2077 - accuracy: 0.9600 - val_loss: 0.2492 - val_accuracy: 0.9377\n"
]
}
],
"source": [
"history = model_B_on_A.fit(X_train_B, y_train_B, epochs=4,\n",
" validation_data=(X_valid_B, y_valid_B))\n",
"\n",
"for layer in model_B_on_A.layers[:-1]:\n",
" layer.trainable = True\n",
"\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)\n",
"model_B_on_A.compile(loss=\"binary_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model_B_on_A.fit(X_train_B, y_train_B, epochs=16,\n",
" validation_data=(X_valid_B, y_valid_B))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"So, what's the final verdict?"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"63/63 [==============================] - 0s 667us/step - loss: 0.2546 - accuracy: 0.9385\n"
]
},
{
"data": {
"text/plain": [
"[0.2546142041683197, 0.9384999871253967]"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_B_on_A.evaluate(X_test_B, y_test_B)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Great! We got a bit of transfer: the model's accuracy went up 2 percentage points, from 91.85% to 93.85%. This means the error rate dropped by almost 25%:"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.24539877300613477"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"1 - (100 - 93.85) / (100 - 91.85)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Faster Optimizers"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"# extra code a little function to test an optimizer on Fashion MNIST\n",
"\n",
"def build_model(seed=42):\n",
" tf.random.set_seed(seed)\n",
" return tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
" ])\n",
"\n",
"def build_and_train_model(optimizer):\n",
" model = build_model()\n",
" model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
" return model.fit(X_train, y_train, epochs=10,\n",
" validation_data=(X_valid, y_valid))"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9)"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.6877 - accuracy: 0.7677 - val_loss: 0.4960 - val_accuracy: 0.8172\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 948us/step - loss: 0.4619 - accuracy: 0.8378 - val_loss: 0.4421 - val_accuracy: 0.8404\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 1s 868us/step - loss: 0.4179 - accuracy: 0.8525 - val_loss: 0.4188 - val_accuracy: 0.8538\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 1s 866us/step - loss: 0.3902 - accuracy: 0.8621 - val_loss: 0.3814 - val_accuracy: 0.8604\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 1s 869us/step - loss: 0.3686 - accuracy: 0.8691 - val_loss: 0.3665 - val_accuracy: 0.8656\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 925us/step - loss: 0.3553 - accuracy: 0.8732 - val_loss: 0.3643 - val_accuracy: 0.8720\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 908us/step - loss: 0.3385 - accuracy: 0.8778 - val_loss: 0.3611 - val_accuracy: 0.8684\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 926us/step - loss: 0.3297 - accuracy: 0.8796 - val_loss: 0.3490 - val_accuracy: 0.8726\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 893us/step - loss: 0.3200 - accuracy: 0.8850 - val_loss: 0.3625 - val_accuracy: 0.8666\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 886us/step - loss: 0.3097 - accuracy: 0.8881 - val_loss: 0.3656 - val_accuracy: 0.8672\n"
]
}
],
"source": [
"history_sgd = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Momentum optimization"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9)"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 941us/step - loss: 0.6877 - accuracy: 0.7677 - val_loss: 0.4960 - val_accuracy: 0.8172\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 878us/step - loss: 0.4619 - accuracy: 0.8378 - val_loss: 0.4421 - val_accuracy: 0.8404\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 898us/step - loss: 0.4179 - accuracy: 0.8525 - val_loss: 0.4188 - val_accuracy: 0.8538\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 934us/step - loss: 0.3902 - accuracy: 0.8621 - val_loss: 0.3814 - val_accuracy: 0.8604\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 910us/step - loss: 0.3686 - accuracy: 0.8691 - val_loss: 0.3665 - val_accuracy: 0.8656\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 913us/step - loss: 0.3553 - accuracy: 0.8732 - val_loss: 0.3643 - val_accuracy: 0.8720\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 893us/step - loss: 0.3385 - accuracy: 0.8778 - val_loss: 0.3611 - val_accuracy: 0.8684\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 968us/step - loss: 0.3297 - accuracy: 0.8796 - val_loss: 0.3490 - val_accuracy: 0.8726\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 913us/step - loss: 0.3200 - accuracy: 0.8850 - val_loss: 0.3625 - val_accuracy: 0.8666\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 1s 858us/step - loss: 0.3097 - accuracy: 0.8881 - val_loss: 0.3656 - val_accuracy: 0.8672\n"
]
}
],
"source": [
"history_momentum = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Nesterov Accelerated Gradient"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9,\n",
" nesterov=True)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 907us/step - loss: 0.6777 - accuracy: 0.7711 - val_loss: 0.4796 - val_accuracy: 0.8260\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 898us/step - loss: 0.4570 - accuracy: 0.8398 - val_loss: 0.4358 - val_accuracy: 0.8396\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 1s 872us/step - loss: 0.4140 - accuracy: 0.8537 - val_loss: 0.4013 - val_accuracy: 0.8566\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 902us/step - loss: 0.3882 - accuracy: 0.8629 - val_loss: 0.3802 - val_accuracy: 0.8616\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 913us/step - loss: 0.3666 - accuracy: 0.8703 - val_loss: 0.3689 - val_accuracy: 0.8638\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 882us/step - loss: 0.3531 - accuracy: 0.8732 - val_loss: 0.3681 - val_accuracy: 0.8688\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 958us/step - loss: 0.3375 - accuracy: 0.8784 - val_loss: 0.3658 - val_accuracy: 0.8670\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 895us/step - loss: 0.3278 - accuracy: 0.8815 - val_loss: 0.3598 - val_accuracy: 0.8682\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 878us/step - loss: 0.3183 - accuracy: 0.8855 - val_loss: 0.3472 - val_accuracy: 0.8720\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 921us/step - loss: 0.3081 - accuracy: 0.8891 - val_loss: 0.3624 - val_accuracy: 0.8708\n"
]
}
],
"source": [
"history_nesterov = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## AdaGrad"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.001)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 1.0003 - accuracy: 0.6822 - val_loss: 0.6876 - val_accuracy: 0.7744\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 912us/step - loss: 0.6389 - accuracy: 0.7904 - val_loss: 0.5837 - val_accuracy: 0.8048\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 930us/step - loss: 0.5682 - accuracy: 0.8105 - val_loss: 0.5379 - val_accuracy: 0.8154\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 878us/step - loss: 0.5316 - accuracy: 0.8215 - val_loss: 0.5135 - val_accuracy: 0.8244\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 1s 855us/step - loss: 0.5076 - accuracy: 0.8295 - val_loss: 0.4937 - val_accuracy: 0.8288\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 1s 868us/step - loss: 0.4905 - accuracy: 0.8338 - val_loss: 0.4821 - val_accuracy: 0.8312\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 940us/step - loss: 0.4776 - accuracy: 0.8371 - val_loss: 0.4705 - val_accuracy: 0.8348\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 966us/step - loss: 0.4674 - accuracy: 0.8409 - val_loss: 0.4611 - val_accuracy: 0.8362\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 892us/step - loss: 0.4587 - accuracy: 0.8435 - val_loss: 0.4548 - val_accuracy: 0.8406\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 873us/step - loss: 0.4511 - accuracy: 0.8458 - val_loss: 0.4469 - val_accuracy: 0.8424\n"
]
}
],
"source": [
"history_adagrad = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## RMSProp"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.RMSprop(learning_rate=0.001, rho=0.9)"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.5138 - accuracy: 0.8135 - val_loss: 0.4413 - val_accuracy: 0.8338\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 942us/step - loss: 0.3932 - accuracy: 0.8590 - val_loss: 0.4518 - val_accuracy: 0.8370\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 948us/step - loss: 0.3711 - accuracy: 0.8692 - val_loss: 0.3914 - val_accuracy: 0.8686\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 949us/step - loss: 0.3643 - accuracy: 0.8735 - val_loss: 0.4176 - val_accuracy: 0.8644\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 970us/step - loss: 0.3578 - accuracy: 0.8769 - val_loss: 0.3874 - val_accuracy: 0.8696\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3561 - accuracy: 0.8775 - val_loss: 0.4650 - val_accuracy: 0.8590\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3528 - accuracy: 0.8783 - val_loss: 0.4122 - val_accuracy: 0.8774\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 989us/step - loss: 0.3491 - accuracy: 0.8811 - val_loss: 0.5151 - val_accuracy: 0.8586\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3479 - accuracy: 0.8829 - val_loss: 0.4457 - val_accuracy: 0.8856\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 1000us/step - loss: 0.3437 - accuracy: 0.8830 - val_loss: 0.4781 - val_accuracy: 0.8636\n"
]
}
],
"source": [
"history_rmsprop = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Adam Optimization"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9,\n",
" beta_2=0.999)"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4949 - accuracy: 0.8220 - val_loss: 0.4110 - val_accuracy: 0.8428\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3727 - accuracy: 0.8637 - val_loss: 0.4153 - val_accuracy: 0.8370\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3372 - accuracy: 0.8756 - val_loss: 0.3600 - val_accuracy: 0.8708\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3126 - accuracy: 0.8833 - val_loss: 0.3498 - val_accuracy: 0.8760\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2965 - accuracy: 0.8901 - val_loss: 0.3264 - val_accuracy: 0.8794\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2821 - accuracy: 0.8947 - val_loss: 0.3295 - val_accuracy: 0.8782\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2672 - accuracy: 0.8993 - val_loss: 0.3473 - val_accuracy: 0.8790\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2587 - accuracy: 0.9020 - val_loss: 0.3230 - val_accuracy: 0.8818\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2500 - accuracy: 0.9057 - val_loss: 0.3676 - val_accuracy: 0.8744\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2428 - accuracy: 0.9073 - val_loss: 0.3879 - val_accuracy: 0.8696\n"
]
}
],
"source": [
"history_adam = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Adamax Optimization**"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Adamax(learning_rate=0.001, beta_1=0.9,\n",
" beta_2=0.999)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.5327 - accuracy: 0.8151 - val_loss: 0.4402 - val_accuracy: 0.8340\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 935us/step - loss: 0.3950 - accuracy: 0.8591 - val_loss: 0.3907 - val_accuracy: 0.8512\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 933us/step - loss: 0.3563 - accuracy: 0.8715 - val_loss: 0.3730 - val_accuracy: 0.8676\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 942us/step - loss: 0.3335 - accuracy: 0.8797 - val_loss: 0.3453 - val_accuracy: 0.8738\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 993us/step - loss: 0.3129 - accuracy: 0.8853 - val_loss: 0.3270 - val_accuracy: 0.8792\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 926us/step - loss: 0.2986 - accuracy: 0.8913 - val_loss: 0.3396 - val_accuracy: 0.8772\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 939us/step - loss: 0.2854 - accuracy: 0.8949 - val_loss: 0.3390 - val_accuracy: 0.8770\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 949us/step - loss: 0.2757 - accuracy: 0.8984 - val_loss: 0.3147 - val_accuracy: 0.8854\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 952us/step - loss: 0.2662 - accuracy: 0.9020 - val_loss: 0.3341 - val_accuracy: 0.8760\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 957us/step - loss: 0.2542 - accuracy: 0.9063 - val_loss: 0.3282 - val_accuracy: 0.8780\n"
]
}
],
"source": [
"history_adamax = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"**Nadam Optimization**"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"optimizer = tf.keras.optimizers.Nadam(learning_rate=0.001, beta_1=0.9,\n",
" beta_2=0.999)"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.4826 - accuracy: 0.8284 - val_loss: 0.4092 - val_accuracy: 0.8456\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3610 - accuracy: 0.8667 - val_loss: 0.3893 - val_accuracy: 0.8592\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3270 - accuracy: 0.8784 - val_loss: 0.3653 - val_accuracy: 0.8712\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3049 - accuracy: 0.8874 - val_loss: 0.3444 - val_accuracy: 0.8726\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2897 - accuracy: 0.8905 - val_loss: 0.3174 - val_accuracy: 0.8810\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2753 - accuracy: 0.8981 - val_loss: 0.3389 - val_accuracy: 0.8830\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2652 - accuracy: 0.9000 - val_loss: 0.3725 - val_accuracy: 0.8734\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2563 - accuracy: 0.9034 - val_loss: 0.3229 - val_accuracy: 0.8828\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2463 - accuracy: 0.9079 - val_loss: 0.3353 - val_accuracy: 0.8818\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2402 - accuracy: 0.9091 - val_loss: 0.3813 - val_accuracy: 0.8740\n"
]
}
],
"source": [
"history_nadam = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**AdamW Optimization**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"On Colab or Kaggle, we need to install the TensorFlow-Addons library:"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"if \"google.colab\" in sys.modules or \"kaggle_secrets\" in sys.modules:\n",
" %pip install -q -U tensorflow-addons"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import tensorflow_addons as tfa\n",
"\n",
"optimizer = tfa.optimizers.AdamW(weight_decay=1e-5, learning_rate=0.001,\n",
" beta_1=0.9, beta_2=0.999)"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.4945 - accuracy: 0.8220 - val_loss: 0.4203 - val_accuracy: 0.8424\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3735 - accuracy: 0.8629 - val_loss: 0.4014 - val_accuracy: 0.8474\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3391 - accuracy: 0.8753 - val_loss: 0.3347 - val_accuracy: 0.8760\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3155 - accuracy: 0.8827 - val_loss: 0.3441 - val_accuracy: 0.8720\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2989 - accuracy: 0.8892 - val_loss: 0.3218 - val_accuracy: 0.8786\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2862 - accuracy: 0.8931 - val_loss: 0.3423 - val_accuracy: 0.8814\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2738 - accuracy: 0.8970 - val_loss: 0.3593 - val_accuracy: 0.8764\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2648 - accuracy: 0.8993 - val_loss: 0.3263 - val_accuracy: 0.8856\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2583 - accuracy: 0.9035 - val_loss: 0.3642 - val_accuracy: 0.8680\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2483 - accuracy: 0.9054 - val_loss: 0.3696 - val_accuracy: 0.8702\n"
]
}
],
"source": [
"history_adamw = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAtgAAAHoCAYAAABzQZg1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAD0SElEQVR4nOzdd5gUVdbA4d+t6jzdk4c0wBBmQJAgoIiCSjChILqgIihmXOPnrqirmLMsxtXVlaC4LiYUs5hRMYEkRVByhsl5Ond9f3RPzzQzYKMwifM+zzx0Vd2qvl2Oevpy6hxlGAZCCCGEEEKIA0Nr7AkIIYQQQgjRkkiALYQQQgghxAEkAbYQQgghhBAHkATYQgghhBBCHEASYAshhBBCCHEASYAthBBCCCHEAdSgAbZS6lSl1G9KqfVKqX/Uc/xGpdSKyM8qpVRQKZXakHMUQgghhBDiz1ANVQdbKaUDa4GTgO3AEuA8wzBW72X8aOBvhmEMb5AJCiGEEEIIcQA05Ar2QGC9YRgbDcPwAa8AY/Yx/jzg5QaZmRBCCCGEEAdIQwbYmcC2WtvbI/vqUEo5gFOBNxpgXkIIIYQQQhwwpgZ8L1XPvr3lp4wGvjEMo6jeCyk1GZgMYLPZBnTs2PHAzLAZ0Coq0QsLATBsVgKtW8d9bigUQtPkudbfI/cpPnKf4if3Kj5yn+In9yo+cp/iI/cpfmvXri0wDCPj98Y1ZIC9HehQa7s9sHMvY8ezj/QQwzCeA54D6N69u/Hbb78dqDk2ef68PNYff0J4w2ym23ffoTsT4jp34cKFDB069OBNroWQ+xQfuU/xk3sVH7lP8ZN7FR+5T/GR+xQ/pdSWeMY15NeVJUCOUqqzUspCOIh+Z89BSqkk4ATg7QacW7NhbtUKa48e4Q2/n6rvv2vcCQkhhBBCiBgNFmAbhhEArgE+AtYArxmG8YtS6q9Kqb/WGnoW8LFhGJUNNbfmxnnccdHXFV8vasSZCCGEEEKIPTVkigiGYXwAfLDHvmf32H4BeKHhZtX8OI8bQuFzzwFQ8fVXGIaBUvWluAshhBBCiIYmGe3NkP2II9CcTgACO3fh27ChkWckhBBCCCGqNf8Au4Ea5TQlymwm4dhjo9uSJiKEEEII0XQ0aIrIwWCv2MKviz/hsIEn1Xu8rKyMvLw8/H5/A8/s4ApNuoDgmDMA2Gm1krdmze+ek5SUxJo4xrUkCQkJtG/fXsoPCSGEEKLBNPsA20QA95ePQz0BdllZGbm5uWRmZmK321tUnnLI78dbXZ5QKWzduqF0fZ/nlJeX43K5GmB2TUMoFGLHjh0UFBTQqlWrxp6OEEIIIQ4RLWJZr2/FN2xbt7LO/ry8PDIzM3E4HC0quAbQzGY0qy28YRiEKqXoyp40TaN169aUlpY29lSEEEIIcQhpEQG2pgx2Lni0zn6/34/dbm+EGTUMzeWMvg5VVDTiTJous9lMIBBo7GkIIYQQ4hDSIgJsgL4F71OYu73O/pa2cl1bdSURgGB5BcYh+MDn72nJ//yFEEII0TQ1+wDbgxUAm/Kz9r3HG3cyDUxzOFCRh/cMvw/D52vkGQkhhBBCiOYfYJuTo68P2/YK7sryxptMA1OahpZQK02k/ND57EIIIYQQTVWzD7CLTB6WmcIVIlIo56f3n2nkGR0Y+fn5XHXVVXTq1Amr1Urr1q0ZMWIEn3zySXTMxo0b+ettU+l+8skk9+9Px8MPZ9iwYcyZMwdfrdVspRRKKRITE3E4HHTp0oUJEyawaJHUzxZCCCGEONCafZk+d8jNC1m96L/hcwAy18wmGPg7uql5f7SxY8dSVVXFrFmzyM7OJi8vjy+//JLCwkIAfvzxR0aMGEGPww7jkVtuoXvnzlR5vWz0eJgZOWfw4MHR682YMYOhQ4diNpvZuHEjc+bM4fjjj+fhhx/mxhtvbKyPKYQQQgjR4jTvKDTiW7ayRXOSFaqgvbGL5Z/9j36nXNjY0/rDSkpK+Prrr/nkk08YMWIEAFlZWRx11FEAGIbBhRdeSE5ODt9+9x2+DRswvF4AjsrKYsLEiXUeeExOTqZ169a4XC6ysrIYNmwY7dq145ZbbuGss84iOzu7YT+kEEIIIUQL1exTRAC8IR/PZPaLbtuXNO80EafTidPp5J133sHj8dQ5vmLFClavXs2UKVPQNA3dWTsPO1yuL57qGTfccAOhUIi33nrrgM1dCCGEEOJQ1yJWsAG+shVSgplk/BwWWMOvP3wMiR3qjOv0j/cbYXZhmx86Pa5xJpOJF154gcsvv5znnnuOfv36MXjwYM4++2yOPvpo1q5dC0D37t0B0FwuSjdvJnvECFAKlOLWW2/l1ltv3ef7pKWl0apVKzZu3PjnPpgQQgghhIhq9ivYJhX+jlAeqOCp1jWr2O4vH2+kGR0YY8eOZefOnbz77ruMHDmSb7/9lkGDBvHAAw/UGas5HLhcLr6fN4/vX3+ddm3bxjzkuC+GYUitaCGEEEKIA6jZB9iJemL09WfJXvyR130rvyXgb951oW02GyeddBJ33HEH3377LZdeeil33XUXnTp1AuDXX38FwuX6zC4XXTt2pGvHjljifMCzoKCA/Px8unTpcrA+ghBCCCHEIafZp4gkaAmk2lIp8hRR4Cvm2ZTeXFv8M5oyCLrL6oyPN02jKerZsyeBQIDDDjuMHj16MG3aNM455xx0XUdzOglG6mAboVBc13vkkUfQNI0xY8YczGkLIYQQQhxSmv0KtkJxfo/zo9sftLJSHV5aQu5muYpdWFjI8OHDeemll/jpp5/YtGkTr7/+OtOmTWPEiBEkJSXxwgsvsGHDBo455hjefvtt1u/axa8bN/L8G2+wY9cuNC32H21JSQm5ubls3bqVL774gosuuoiHH36Yhx56SCqICCGEEEIcQM1+BRvgnO7nMPPnmVQFqtjuy+MVZxcmVGxEYeApycWZUfdhx6bM6XQyaNAgnnjiCdavX4/X6yUzM5MJEyZw2223ATBw4ECWLVvGgw8+yLXXXsvu3buxW6306taNO6+9liuuujrmmpdffjkAVquVtm3bMmjQIBYuXMjxxx/f4J9PCCGEEKIlaxEBdpI1ibO7nc2c1XMAmNc2gwnrwpUxbP4iQsF2aLremFPcL1arlQceeKDeBxpry87OZtasWdFt/65dBCKNaEy1Vu6ra2KXl5fjcrkOwoyFEEIIIUS1Zp8iUu2Cnhdg0sLfF9YFdvGptTUAJkJUleY15tQajFarHnYwUg9bCCGEEEI0rBYTYLdOaM2oLqOi2y9kdo6+tngK63Q2bIm0hIRwHWzA8HkJxVmqTwghhBBCHDgtJsAGuPjwi1GEA8yVxk58KvzxLPipKitszKk1CKVp4SA7IhSpKiKEEEIIIRpOiwqwuyR3YViHYdHtct0cfa1X5h8Sq9h6rRzrUIWkiQghhBBCNLQWFWADXNL7kuhrHyG8kRVtGx48lXXrYrc0MXnYlZVx18QWQgghhBAHRosLsPtm9GVA6wEAGBjkme3RY0ZFbmNNq8FoVivKYglvhEKEqqoad0JCCCGEEIeYFhdgA1zSq2YVu1wLEoi8tgcr8XpafsCp11rFDkk1ESGEEEKIBtUiA+zjMo8jJyUHCNeAzjPZgHCBDX/p7sacWoPQauVhByvkQUchhBBCiIbUIgNspRQXH35xdLtUJ9o+3REoa5bt0/dHTLk+r5TrE0IIIYRoSC0ywAY4tfOp6Fq4e2PICJGvWwHQVLh9ektWp1yfVBMRQgghhGgwLTbANmtmnOaaXOQSk0Z1kb5w+/Rg40wsDhdddBFKKS677LI6x2666SaUUowaNaqeM2vE5GE3YoCtlGLevHmN9v5CCCGEEA2txQbYAHaTPbqKHTCCFGnhutjNoX16hw4dePXVV6msrIzuCwQC/Pe//6Vjx46/e762Rz1sKdcnhBBCCNEwWnSArSmNVFtqdLvQZIquYjf19ul9+vQhJyeH1157Lbrv/fffx2azMXTo0Oi+UCjEvffeS4cOHbBarfTu3Zu3334bZbGgzGa27NiB/fDDefnFFxk5ciR2u51+/frx008/sWr
"text/plain": [
"<Figure size 864x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAtgAAAHoCAYAAABzQZg1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAADxsklEQVR4nOzdd3hUVf7H8fednslMei+kQICEXkQURIqKBdsiFixgw13XtmsXe5dVd9Xdn+4KKqxrRbGLlaig0ntCD6GkkN6n398fkwwZkkCAkPp9PQ8PmXvPvXPmEjKfnPnecxRVVRFCCCGEEEK0DU1Hd0AIIYQQQojuRAK2EEIIIYQQbUgCthBCCCGEEG1IArYQQgghhBBtSAK2EEIIIYQQbUgCthBCCCGEEG2oXQO2oihnK4qyVVGUHYqi3NfM/rsVRVlX/2eToihuRVHC2rOPQgghhBBCHA+lvebBVhRFC2wDzgT2ASuBK1RVzWqh/fnAX1RVndguHRRCCCGEEKINtOcI9ihgh6qqu1RVdQDvARcepv0VwLvt0jMhhBBCCCHaSHsG7Hhgb6PH++q3NaEoihk4G/ioHfolhBBCCCFEm9G143MpzWxrqT7lfGCZqqqlzZ5IUWYBswBMJtOIXr16tU0PuzmPx4NGI/e1Holcp9aR69R6cq1aR65T68m1ah25Tq0j16n1tm3bVqyqauSR2rVnwN4HJDZ6nADktdD2cg5THqKq6n+A/wD069dP3bp1a1v1sVvLzMxk/PjxHd2NTk+uU+vIdWo9uVatI9ep9eRatY5cp9aR69R6iqLktqZde/66shJIUxQlRVEUA94Q/dmhjRRFCQZOBz5tx74JIYQQQgjRJtptBFtVVZeiKLcA3wBa4A1VVTcrivLH+v2v1Te9GPhWVdWa9uqbEEIIIYQQbaU9S0RQVfUr4KtDtr12yOO3gLfar1dCCCGEEEK0HaloF0IIIYQQog11/YDdTgvlCCGEEEII0RrtWiJyIugKCnCVlaELDW12f2VlJQcOHMDpdLZzzzqf4OBgsrOzO7ob7SowMJCEhASZfkgIIYQQ7abLB2zF4WTPzGvp9dabTUJ2ZWUlhYWFxMfHExAQgKI0NxV3z1FVVYXVau3obrQbj8fD/v37KS4uJioqqqO7I4QQQogeolsM69m3bmXPddfjLi/3237gwAHi4+Mxm809Plz3RBqNhujoaCoqKjq6K0IIIYToQbp+wK7PzfbsbG/IbhSmnE4nAQEBHdQx0Rno9XpcLldHd0MIIYQQPUiXD9jusHCoH522ZWWx5/obcFdW+vbLyHXPJv/+QgghhGhvXT5geyyBxD75hO+xbdMm9txwI+6qqg7slRBCCCGE6Km6fMAGCJk6lZjHH/M9tm3YwN4bbkT1eDqwV0IIIYQQoifqFgEbIPTSS4l59BHf47r163GXlqK63R3Yq2NXVFTEzTffTHJyMkajkejoaCZNmsR3333na7Nr1y5uuOEGkpKSMBqNxMXFMWHCBObPn4/D4fC1UxQFRVEICgrCbDaTmprK9OnTWbp0aUe8NCGEEEKIbq3LT9PXWOjll6O63RQ+8SQAqsOBIzcXQ1ISilbbwb07OlOnTqW2tpZ58+bRp08fDhw4wE8//URJSQkAq1atYtKkSaSnp/PKK6/Qv39/amtryc7O5vXXX6dPnz6MGTPGd77XX3+d8ePHo9fr2bVrF/Pnz2fcuHE899xz3H333R31MoUQQgghup1uFbABwq68EjwqhU89BYCnthZH7h4MSb26TMguLy/nl19+4bvvvmPSpEkAJCUlcdJJJwGgqiozZswgLS2NX3/91W8RlaFDh3LFFVegHrLCZUhICNHR0VitVpKSkpgwYQJxcXHcf//9XHzxxfTp06f9XqAQQgghRDfWbUpEGgu7+iqi77/P99hTW4Njz54uU5NtsViwWCx89tln2Gy2JvvXrVtHVlYWd911V4srFLZm9ow777wTj8fDJ598crxdFkIIIYQQ9brdCHaDsBkzyFu50vfYU+MN2f1e29xhfdr97HmtaqfT6Xjrrbe48cYb+c9//sOwYcMYM2YM06ZN4+STT2bbtm0A9OvXz3dMRUUF8fHxvscPPPAADzzwwGGfJzw8nKioKHbt2nUMr0YIIYQQQjSnW45gN9BaLOiio32PPdXVHdibozN16lTy8vL4/PPPOeecc/j1118ZPXo0Tz/9dLPtrVYr69atY926dcTFxfnd5Hg4qqrKXNFCCCGEEG2oWwdsAH1kJLqo6CM37IRMJhNnnnkmDz/8ML/++ivXX389jz76KMnJyQBs2bLF11aj0dCnTx/69OmDwWBo1fmLi4spKioiNTX1RHRfCCGEEKJH6rYlIo3poyIBFdeBA2RflQSAxmrFkJiI0kINc2eUkZGBy+Wif//+pKenM2fOHC699FK0x3jz5gsvvIBGo+HCCy9s454KIYQQQvRcPSJgA+ijokAFV9EBADxVVTj37kXfCUN2SUkJ06ZN47rrrmPw4MFYrVZWrVrFnDlzmDRpEsHBwbz11lucccYZnHLKKcyePZv09HTcbjfLli1j3759TUJ3eXk5hYWFlJWVsXPnTubPn8+CBQuYM2eOzCAihBBCCNGGekzABtA1jGQXFQF4l1Pfuw99YkKnCtkWi4XRo0fz0ksvsWPHDux2O/Hx8UyfPp0HH3wQgFGjRrFmzRqeeeYZbr31VgoKCggICGDw4ME89dRT3HDDDX7nvPHGGwEwGo3ExsYyevRoMjMzGTduXLu/PiGEEEKI7qxHBWxFUdA1jGQXN4TsSti3D31C5wnZRqORp59+usUbGhv06dOHefPmHfF8DXNiV1VVYbVa26SPQgghhBCieZ0jUbYjRVHQRUehi4jwbXNXVuLct6/LzJMthBBCCCE6rx4XsKEhZEejCz8kZO/f32QFRCGEEEIIIY5GjwzYUB+yY6LRhYf7trkrKrwj2RKyhRBCCCHEMeqxARsaQnZMMyFbRrKFEEIIIcSx6dEBGxqF7LAw3zZ3RbmUiwghhBBCiGPS4wM21Ifs2Fi0jUN2eTnO/XkSsoUQQgghxFGRgF1PURT0sbFoQ0N929zlZTjzJGQLIYQQQojWk4DdiKIo6OPi/EN2WRnOvHwJ2UIIIYQQolUkYB/CF7JDQnzb3GWlOPMlZAshhBBCiCOTgN0MRVHQx8f7h+zSUlwSsoUQQgghxBFIwG6BL2QHB/u2uUpLcRUUnPCQPXPmTBRF4YYbbmiy75577kFRFKZMmXJC+9BWFEVh4cKFHd0NIYQQQoh2IwH7MBRFQZ+Q4B+yS0raJWQnJiby/vvvU1NTc/C5XS7++9//0qtXrxP63EIIIYQQ4thJwD4CX8gOOiRkFxae0JA9ePBg0tLS+OCDD3zbvvzyS0wmE+PHj/dt83g8PPHEEyQmJmI0Ghk0aBCffvqpb//u3btRFIX33nuPc845h4CAAIYNG8aGDRvYtGkTp556KoGBgYwdO5acnBy/Pnz++eeMGDECk8lESkoKs2fPxuFw+PYnJyfz5JNPctNNNxEUFERCQgJ/+9vf/PYDTJs2DUVRfI8fffRRBg4c6Pdcb731FhaLxfe4oc38+fNJTk7GYrFw7bXX4nA4+L//+z8SExMJDw/nr3/9Kx6P55ivsxBCCCFEW5OA3QrekB2PNijIt81VXHzCQ/b111/PG2+84Xv8xhtvcO2116Ioim/bSy+9xN/+9jeee+45Nm7cyMUXX8wf/vAH1q1b53euRx55hDvuuIO1a9cSEhLC9OnTufXWW3nqqadYsWIFNpuN2267zdf+m2++4corr+SWW25h8+bNvPHGGyxcuJAHHnjA77x///vfGTRoEGvWrOHee+/lnnvu4bfffgNg5cqVALz++uvk5+f7HrfW7t27+fTTT/niiy/46KOP+PDDD7nwwgtZuXIl3377LXPnzuWVV15h0aJFR3VeIYQQQogTSdfRHWh3jwYfuU0zFMBw3M9dcVTNp0+fzl133cX27duxWq0sXryYV155hYcfftjX5vnnn+euu+5i+vTpADz++OP8/PPPPP/887z99tu+dn/
"text/plain": [
"<Figure size 864x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code visualize the learning curves of all the optimizers\n",
"\n",
"for loss in (\"loss\", \"val_loss\"):\n",
" plt.figure(figsize=(12, 8))\n",
" opt_names = \"SGD Momentum Nesterov AdaGrad RMSProp Adam Adamax Nadam AdamW\"\n",
" for history, opt_name in zip((history_sgd, history_momentum, history_nesterov,\n",
" history_adagrad, history_rmsprop, history_adam,\n",
" history_adamax, history_nadam, history_adamw),\n",
" opt_names.split()):\n",
" plt.plot(history.history[loss], label=f\"{opt_name}\", linewidth=3)\n",
"\n",
" plt.grid()\n",
" plt.xlabel(\"Epochs\")\n",
" plt.ylabel({\"loss\": \"Training loss\", \"val_loss\": \"Validation loss\"}[loss])\n",
" plt.legend(loc=\"upper left\")\n",
" plt.axis([0, 9, 0.1, 0.7])\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learning Rate Scheduling"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Power Scheduling"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```lr = lr0 / (1 + steps / s)**c```\n",
"* Keras uses `c=1` and `s = 1 / decay`"
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, decay=1e-4)"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 915us/step - loss: 0.6818 - accuracy: 0.7678 - val_loss: 0.4840 - val_accuracy: 0.8276\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 877us/step - loss: 0.4702 - accuracy: 0.8361 - val_loss: 0.4421 - val_accuracy: 0.8398\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 886us/step - loss: 0.4242 - accuracy: 0.8491 - val_loss: 0.4110 - val_accuracy: 0.8534\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 880us/step - loss: 0.4012 - accuracy: 0.8580 - val_loss: 0.3900 - val_accuracy: 0.8574\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 894us/step - loss: 0.3821 - accuracy: 0.8636 - val_loss: 0.3835 - val_accuracy: 0.8626\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 927us/step - loss: 0.3685 - accuracy: 0.8687 - val_loss: 0.3836 - val_accuracy: 0.8614\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 1s 854us/step - loss: 0.3580 - accuracy: 0.8706 - val_loss: 0.3709 - val_accuracy: 0.8646\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 1s 851us/step - loss: 0.3490 - accuracy: 0.8756 - val_loss: 0.3736 - val_accuracy: 0.8614\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 1s 852us/step - loss: 0.3413 - accuracy: 0.8786 - val_loss: 0.3536 - val_accuracy: 0.8698\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 1s 844us/step - loss: 0.3343 - accuracy: 0.8801 - val_loss: 0.3546 - val_accuracy: 0.8698\n"
]
}
],
"source": [
"history_power_scheduling = build_and_train_model(optimizer) # extra code"
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZIAAAEbCAYAAADwPQLqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAwgUlEQVR4nO3deXxU9dn38c+VkJCwhp0QVhUCKCiggmtxRbQKj0sf21rXp9RW77u2rlRt3arWpVXvm2ptS63WfalSpVIQU0WLIosgIBCRHVmEAIEASbieP84JDsMkGUgyk2S+79drXpk55/xmrjnGXPx2c3dEREQOVlqyAxARkYZNiURERGpEiURERGpEiURERGpEiURERGpEiURERGpEiUSkHjKzy82suI7e+zMzu+MAyywzsxsqey2pTYlE6i0ze8rMPHyUmtlSM3vIzJonO7bqmFkvM/ubma0ys11mtsbM3jKzQcmOrZYcA/w+2UFI/dAk2QGIVGMK8AMgAzgJ+BPQHPhxMoOqYGYZ7l4afQyYDHwBfAdYDeQBZwBtEx5kHXD3DcmOQeoP1Uikvtvl7l+5+0p3fw54FhgNYGZNzewRM1tnZjvNbLqZnVhR0Mw+MrObI14/G9ZuOoevm5nZbjM7IXxtZnaTmX1hZiVmNs/MLoko3zMs/10zm2pmJcCPYsR8OHAocI27f+juy8Ofd7r7OxHv18rMHjeztWH8C83s/0a+kZmdFjZFbTezd82sV9T5c81sZlj+SzP7tZllRpzvaGZvhN9nuZldGR1s+J0ujDpWZdNVjKYuN7MxZvZyGOvSyHsXXjPUzGaFsc42s7PDcsMr+xxpGJRIpKEpIaidADwA/F/gSmAQMA9428xyw/MFwCkRZb8FbASGh69PAEqBj8PX9wBXAdcA/YH7gD+Y2TlRMdxH0KzTH3g9RowbgD3ABWYWs9ZvZgb8M4zpivC9fg7sjrisKTA2/H7HATnAExHvMYIgsf4vQfK6ErgQuDfiPZ4CDgNOJ0jAlwI9Y8VUC34JvAEcCbwIjDezHmGsLYA3gc+BIcBNwIN1FIckmrvroUe9fBD8EXwz4vWxBIngRYLmrd3ApRHn0wmak+4JX48EigmacHsD24BfA38Iz/8amBw+b06QpE6KiuERYGL4vCfgwPVxxH4NsD38/H8DdwOHR5w/gyDZ9Kuk/OXhZ+VHHPt++J3TwtfvAbdHlRsdfqYBfcL3OCHifA+gHLgj4pgDF0a9zzLghgN47cB9Ea+bADuAS8LXPwI2AdkR13wvLDc82b9retTsoRqJ1HdnmVmxme0E/kPwx/O/CJqOMoAPKi509/Lwmv7hofcJ/lV/DEEt5H2CPpfh4fnhBLUWwjJZBDWa4ooHQV/MoVExfVJd0O4+DuhM8MdyGjAKmGNmPwgvGQSsdfeFVbzNLndfFPF6Tfidc8LXQ4Bbo+J9jiApdgb6ESSrihoX7r48fJ+6MDfic8oIamYdw0N9gc/cvSTi+o/qKA5JMHW2S333HjCGoAlqjYcd2xHNV7GWrw7+iexebGazCJq3DgfeJUg0PcysN0GCuSksU/GPqnOBFVHvVxr1ens8gbv7NmACMMHMbgMmEdRMniGoMVSnLPoto2JNA+4EXo5RdkOcn1HxvtHXZsS6sBrR98n5JlYj9n8raQSUSKS+2+HuhTGOFxI085wILAUws3SCvoTnIq4rIEgk/YBH3H2nmX0E3Mq+/SMLgF1AD3efWttfwt3dzD4HBoeHZgG5ZtavmlpJVWYBfSu5P5jZQoI/5McAH4bHugNdoi7dAORGlOsU+bqWLAQuNbPsiFrJsbX8GZIkSiTSILn7djN7HLjfzDYCXwI/Azqx7/yGAuB6glrErIhjtwLvVtRw3H2bmT0EPBR2hL8HtACGAXvc/cl4YzOzowhqCs8QJKjdBJ3qVwLPh5e9Q9C086qZ/QxYTNAp3tzdX4/zo+4C3jSz5cBLBDWYI4Bj3f0md19kZm8TDBgYQ9AH9NvwZ6SpwDVm9iFB/8m9wM54v2+cniUYzPBHM7uXIJn9IjynmkoDpz4SachuJvgD+hdgDjAQOMvd10Zc8z7BH6r3wz4UCJq40vmmf6TC7cAdwA3AfIK5IBcQJKkDsYqglvRLYHoY2/XAQwT9O7j7HoLBAB8AfyP4F/ujQOb+bxebu08CziGocX0cPm5h36a5y8P4pwL/IKitLYt6q+vDeAuAVwjm6qyPN444Yy0maDY8HJhNMGLrjvB0bSctSTBz1z8GRCTxzGwU8Hego7tvTHY8cvDUtCUiCWFmlxHUfFYSNME9AvxDSaThS2jTlpmdZWaLzKzQzG6Jcd7M7LHw/FwzGxxxbryZrTezz6LKtDWzyWa2JPzZJhHfRUQOWCeCfqNFwDiCCZmXVFlCGoSENW2FI2oWE0zEWgXMAL7r7gsirjmboA35bGAo8Ki7Dw3PnUww0eppdz8ioswDwCZ3vz9MTm3cfe+yGCIiUrcSWSM5Fih096Xuvht4gWCSVqRRBInC3X06kFMxX8Dd3yOYGRttFPDX8PlfCddhEhGRxEhkH0keQdtohVUEtY7qrskD1lK5ThWjdNx9rZl1jHVROPxxDEBadqshTVp/c1nPVhq8BrBnzx7S0nQvIumexKb7Eltjvy+LFy/e6O4doo8nMpHEmmUb3a4WzzUHJZwH8CRA09zennvZIwDk5WTzwS2n1sZHNHgFBQUMHz482WHUK7onsem+xNbY70s4Z2k/iUydq4BuEa+7sv+aP/FcE21dRfNX+DPu8e/ZGencOCI/3stFRCSGRCaSGUBvC3aOywQuJliHKNIEgmUUzMyGAVuiJpfFMgG4LHx+GcEy1nH5+Rl9GD0oL97LRUQkhoQlknA10GsJFq5bCLzk7vPN7Gozuzq8bCLBOPNC4I/ATyrKm9nzBAvu5VuwfelV4an7gTPMbAnBiLD7q4ulW8s0MtPTWLV5Ry19OxGR1JXQCYnuPpEgWUQeeyLiuRPs4xCr7HcrOf41cNqBxJFu8O0ju/DyzFX8/Mx8WmcfzEKnIiICKbzW1hUn9GTH7nJemrGy+otFRKRSKZtIjshrzdBebXnqw2WUle9JdjgiIg1WyiYSgCtP7MXqohImL1iX7FBERBqslE4kp/frRLe22Yz/4EBXCRcRkQopnUjS04zLj+/FjGWbmbuqKNnhiIg0SCmdSAC+c3RXWjRtwl8+WJbsUEREGqSUTyQtszK46OiuvDl3Deu2aqM2EZEDlfKJBODy43tStsf52/SYy8iIiEgVlEiAHu2ac3q/Tjz70Qp2lpZXX0BERPZSIgldeUIvNm3fzRtzVic7FBGRBkWJJDTskLb0y23F+GnLSNSukSIijYESScjMuPKEnixat40Pv/g62eGIiDQYSiQRzj2yC+1bZDJ+miYoiojES4kkQlZGOt8f2oN3Pl/Plxu3JzscEZEGQYkkyveHdSczPY2ntGyKiEhclEiidGyZxbnhXiVbSkqTHY6ISL2nRBKD9ioREYmfEkkM2qtERCR+SiSV0F4lIiLxUSKphPYqERGJjxJJJbRXiYhIfJRIqqC9SkREqqdEUgXtVSIiUj0lkmpcfnxPSsud0x/+N71ueYsT7p/K67O1QrCISIUmyQ6gvpu9oog0g227ygBYXVTC2NfmATB6UF4yQxMRqRdUI6nGg5MWsSdqVfmS0nIenLQoOQGJiNQzSiTVWFNUckDHRURSjRJJNbrkZB/QcRGRVKNEUo0bR+STnZG+z7GmTdK4cUR+kiISEalf1NlejYoO9QcnLdrbnJXfqYU62kVEQkokcRg9KG9v4hj3biEPTlrEtCUbObF3+yRHJiKSfGraOkBXndiL7m2bccc/5lOqlYFFRJRIDlRWRjq3f7s/heuL+euHy5IdjohI0imRHITT+3XkW3068OiUJWzYtivZ4YiIJJUSyUEwM355bn92lpXz4KTPkx2OiEhSJTSRmNlZZrbIzArN7JYY583MHgvPzzWzwdWVNbOjzGy6mc0xs0/M7NhEfJdDO7TgyhN68dInq5izsigRHykiUi8lLJGYWTowDhgJ9Ae+a2b9oy4
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell plots power scheduling\n",
"\n",
"import math\n",
"\n",
"learning_rate = 0.01\n",
"decay = 1e-4\n",
"batch_size = 32\n",
"n_steps_per_epoch = math.ceil(len(X_train) / batch_size)\n",
"n_epochs = 25\n",
"\n",
"epochs = np.arange(n_epochs)\n",
"lrs = learning_rate / (1 + decay * epochs * n_steps_per_epoch)\n",
"\n",
"plt.plot(epochs, lrs, \"o-\")\n",
"plt.axis([0, n_epochs - 1, 0, 0.01])\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Learning Rate\")\n",
"plt.title(\"Power Scheduling\", fontsize=14)\n",
"plt.grid(True)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Exponential Scheduling"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```lr = lr0 * 0.1 ** (epoch / s)```"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
"def exponential_decay_fn(epoch):\n",
" return 0.01 * 0.1 ** (epoch / 20)"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
"def exponential_decay(lr0, s):\n",
" def exponential_decay_fn(epoch):\n",
" return lr0 * 0.1 ** (epoch / s)\n",
" return exponential_decay_fn\n",
"\n",
"exponential_decay_fn = exponential_decay(lr0=0.01, s=20)"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"# extra code build and compile a model for Fashion MNIST\n",
"\n",
"tf.random.set_seed(42)\n",
"model = build_model()\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
"1719/1719 [==============================] - 2s 913us/step - loss: 0.6795 - accuracy: 0.7682 - val_loss: 0.4782 - val_accuracy: 0.8314\n",
"Epoch 2/25\n",
"1719/1719 [==============================] - 1s 869us/step - loss: 0.4656 - accuracy: 0.8376 - val_loss: 0.4407 - val_accuracy: 0.8404\n",
"Epoch 3/25\n",
"1719/1719 [==============================] - 2s 873us/step - loss: 0.4194 - accuracy: 0.8505 - val_loss: 0.4118 - val_accuracy: 0.8520\n",
"Epoch 4/25\n",
"1719/1719 [==============================] - 1s 857us/step - loss: 0.3953 - accuracy: 0.8601 - val_loss: 0.3850 - val_accuracy: 0.8616\n",
"Epoch 5/25\n",
"1719/1719 [==============================] - 1s 868us/step - loss: 0.3754 - accuracy: 0.8667 - val_loss: 0.3769 - val_accuracy: 0.8620\n",
"Epoch 6/25\n",
"1719/1719 [==============================] - 2s 878us/step - loss: 0.3611 - accuracy: 0.8717 - val_loss: 0.3782 - val_accuracy: 0.8630\n",
"Epoch 7/25\n",
"1719/1719 [==============================] - 2s 895us/step - loss: 0.3501 - accuracy: 0.8743 - val_loss: 0.3665 - val_accuracy: 0.8690\n",
"Epoch 8/25\n",
"1719/1719 [==============================] - 2s 883us/step - loss: 0.3407 - accuracy: 0.8785 - val_loss: 0.3694 - val_accuracy: 0.8638\n",
"Epoch 9/25\n",
"1719/1719 [==============================] - 2s 879us/step - loss: 0.3328 - accuracy: 0.8814 - val_loss: 0.3477 - val_accuracy: 0.8708\n",
"Epoch 10/25\n",
"1719/1719 [==============================] - 2s 895us/step - loss: 0.3259 - accuracy: 0.8834 - val_loss: 0.3495 - val_accuracy: 0.8728\n",
"Epoch 11/25\n",
"1719/1719 [==============================] - 2s 884us/step - loss: 0.3200 - accuracy: 0.8855 - val_loss: 0.3483 - val_accuracy: 0.8722\n",
"Epoch 12/25\n",
"1719/1719 [==============================] - 2s 903us/step - loss: 0.3148 - accuracy: 0.8877 - val_loss: 0.3459 - val_accuracy: 0.8772\n",
"Epoch 13/25\n",
"1719/1719 [==============================] - 2s 894us/step - loss: 0.3107 - accuracy: 0.8892 - val_loss: 0.3366 - val_accuracy: 0.8766\n",
"Epoch 14/25\n",
"1719/1719 [==============================] - 2s 885us/step - loss: 0.3068 - accuracy: 0.8904 - val_loss: 0.3409 - val_accuracy: 0.8772\n",
"Epoch 15/25\n",
"1719/1719 [==============================] - 1s 853us/step - loss: 0.3034 - accuracy: 0.8921 - val_loss: 0.3404 - val_accuracy: 0.8766\n",
"Epoch 16/25\n",
"1719/1719 [==============================] - 2s 884us/step - loss: 0.3000 - accuracy: 0.8934 - val_loss: 0.3332 - val_accuracy: 0.8774\n",
"Epoch 17/25\n",
"1719/1719 [==============================] - 2s 887us/step - loss: 0.2978 - accuracy: 0.8933 - val_loss: 0.3342 - val_accuracy: 0.8788\n",
"Epoch 18/25\n",
"1719/1719 [==============================] - 2s 890us/step - loss: 0.2953 - accuracy: 0.8945 - val_loss: 0.3323 - val_accuracy: 0.8770\n",
"Epoch 19/25\n",
"1719/1719 [==============================] - 2s 918us/step - loss: 0.2934 - accuracy: 0.8951 - val_loss: 0.3291 - val_accuracy: 0.8774\n",
"Epoch 20/25\n",
"1719/1719 [==============================] - 2s 923us/step - loss: 0.2915 - accuracy: 0.8966 - val_loss: 0.3292 - val_accuracy: 0.8776\n",
"Epoch 21/25\n",
"1719/1719 [==============================] - 2s 888us/step - loss: 0.2899 - accuracy: 0.8968 - val_loss: 0.3273 - val_accuracy: 0.8766\n",
"Epoch 22/25\n",
"1719/1719 [==============================] - 2s 874us/step - loss: 0.2882 - accuracy: 0.8975 - val_loss: 0.3298 - val_accuracy: 0.8790\n",
"Epoch 23/25\n",
"1719/1719 [==============================] - 2s 879us/step - loss: 0.2870 - accuracy: 0.8978 - val_loss: 0.3287 - val_accuracy: 0.8780\n",
"Epoch 24/25\n",
"1719/1719 [==============================] - 2s 890us/step - loss: 0.2857 - accuracy: 0.8986 - val_loss: 0.3288 - val_accuracy: 0.8786\n",
"Epoch 25/25\n",
"1719/1719 [==============================] - 2s 877us/step - loss: 0.2849 - accuracy: 0.8985 - val_loss: 0.3285 - val_accuracy: 0.8792\n"
]
}
],
"source": [
2021-10-17 04:04:08 +02:00
"lr_scheduler = tf.keras.callbacks.LearningRateScheduler(exponential_decay_fn)\n",
"history = model.fit(X_train, y_train, epochs=n_epochs,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[lr_scheduler])"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZIAAAEbCAYAAADwPQLqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAyaUlEQVR4nO3deXxU1fnH8c+TQCBhC0vYAoosIogKbriLWhXUirVo9dfWrdVatdraulerVVuVulRrRWytWrUWN6SKIkqDVkVBkEV2EDRh3w0J+/P7497gOMwkA8nMJJnv+/Wa18zce+6d5x6GeXLPufccc3dERET2VFa6AxARkbpNiURERKpFiURERKpFiURERKpFiURERKpFiURERKpFiUSkhpnZRWZWupvbFJnZX5IVU/gZi8zsN0nY7xAz2637CKLraE/qTGoPJRKpMWb2lJl5jMeEdMeWLOHxDYla/G+gaxI+66dmNsXMSs1svZlNM7O7avpz0iQpdSap0SDdAUi98w7w46hlW9IRSLq4ezlQXpP7NLNLgIeBXwHvAjnA/sCRNfk56ZKMOpPU0RmJ1LTN7r4s6rEGwMyON7OtZjagorCZXW5mG8ysa/i+yMyGmdmfzWxt+BhqZlkR27Q0s6fDdeVm9o6Z7R+x/qLwr/aTzGyGmW00s/+a2T6RgZrZd83sUzPbZGZfmNndZpYTsX6Rmf3WzB4PYyw2s+si14cvXwzPTBZFfn5EuW5m9pqZLQtjmWxmZ+xmvZ4JvOLuj7v7fHef6e4vuvu1Ucd0upl9HNbLajP7j5k1jijSON7xhNu3MLPhZrbCzL42s/FmdmhUmQvMbLGZlZnZ60C7qPW3m9mMqGWVNl3FqLPbw3+788xsQRjLSDNrE1GmgZk9GPE9edDMHjOzoqqrU2qSEomkjLuPB4YC/zSzVma2H3A/8At3XxhR9IcE380jgZ8BlwG/jFj/FNAfGAwcDpQBb5lZbkSZRsBNwCXhfvKBYRUrzexU4DngLwR/2V8CDAH+EBX2r4DpwMHAvcB9ZlZxFnBY+Hwp0CHifbSmwJvAycBBwMvAK+HxJ2oZcHhFwo3FzAYCrwFjgUOAE4DxfPv/edzjMTMD3gAKgTOAfsB7wDgz6xCW6U9Q/8OBvsB/gN/vxnHsji7AD4DvAaeE8dwdsf43wEXAT4EjCI7z/5IUi1TG3fXQo0YeBD8w24DSqMe9EWUaAhOBV4DJwL+j9lEEzAUsYtlvgeLwdQ/AgeMi1rcA1gM/Dd9fFJbpGVHmhwRNbFnh+/eAW6M++6wwXgvfLwL+FVVmHvDbiPcODIkqcxFQWkVdTYjaTxHwl0rKdwA+Cj9vHvAscAHQMKLMB8ALleyj0uMBTgyPPzeqzGfA9eHr54GxUev/FvyU7Hx/OzCjsjpJ4P3twCagRcSyW4D5Ee+XAjdGvDdgNlCU7v8LmfbQGYnUtPcI/lKNfAytWOnuWwn+ajwDaEtwxhFtgoe/DKGPgEIzaw70AnaEyyr2uZ7gr+zeEdtsdvc5Ee+XECSx/PD9IcAtYRNYadis8jzQBGgfsd20qNiWhHEnzMyamNl9ZjYzbIIpBQ4F9kp0H+6+1N2PBA4AHiL40Xwc+MTM8sJi/Qj6TypT2fEcAuQBK6PqpQ/QLSzTi4i6D0W/rymLw3/bXWI1sxYE/06fVKwMvzMTkxSLVEKd7VLTytx9fhVlKpoh8oECYN1u7N8qWReZfLbFWZcV8XwH8GKM/ayMeL01xn529w+wPwEDCZpi5hE0xT1D0GG+W9x9BjADeNTMjgHeB84lOBtMRGXHkwUsB46Nsd2G8Lmy+q+wI0a5hgnGFymRutfw5bWAzkgkpcysC0G/xJUEbfnPmVn0HzT9w/b6CkcAS9x9AzCTb/pPKvbZnOAv9Zm7EcpkYD8POq6jH9FJqDJbgewqyhwDPOPuL7v7NKCYb/7Cr46K420aPk8BTqrG/iYTdJzviFEnKyI+84io7aLfrwTaRf0b9q1GXLsIz1SWEfSRATv7eOL1U0kS6YxEalojM2sftWy7u680s2yCtv3x7v64mb1E0CT1O+DWiPIdgYfM7K8ECeI64C4Ad59nZq8Bj5vZZQRnM3cT/MX8/G7E+XvgdTNbDIwgOIPpAxzu7tfvxn4WASeZ2XiC5rS1McrMBb4Xxr2V4HgbxygXl5k9RtC0M44gEXUg6DsqA94Oi90N/MfM5hPUhRF0Uj/u7mUJfMw7BP0sr5nZ9QT9De0Jzqbecff3CS5B/tDMbgJeAgYQdIZHKgJaATeb2Qthmeh7bWrCn4HrzWwuQYL7GUG9LE3CZ0kldEYiNe07BP+RIx9TwnU3A92BnwC4+2rgQuDGsJmmwnMEf+V/DDwB/B14MGL9xQRt46PC5zxgoAf3IiTE3ccApxNc2fRJ+LgR+DLxQwXg1+E+vuKb44x2LbCCoBnqTYKO9vd383PGElypNoIgMb0aLj/Z3ecCuPtogh/1QWEs48PYdiTyAWEfw2kEyeoJYE74eT0JkhjuPoHg3+/nBP0tZxN0jEfuZ1a4/rKwzMnsejVcTfgT8E/gHwR1CkG9bErCZ0klKq5OEakVwnsAZrj7VemOReoeM5sMfODuv0h3LJlETVsiUieZ2d7AqQRnXg0IzoAOCp8lhZRIRKSu2kFwL81Qgmb6mcAgd5+U1qgykJq2RESkWtTZLiIi1ZKRTVv5+fnevXv3dIdR62zcuJEmTZqkO4xaRXUSm+oltvpeL59++ukqdy+IXp6RiaRdu3ZMmqRm1GhFRUUMGDAg3WHUKqqT2FQvsdX3egnvu9qFmrZERKRalEhERKRalEhERKRalEhERKRalEhERKRalEhERKRalEhERKRalEhERKRalEhERKRalEhERKRalEhERKRalEhERKRalEhERKRalEhERKRalEhERKRaUppIzGygmc0xs/lmdmOM9WZmD4frp5nZwRHrnjSzFWY2I2qbVmY21szmhc8tq4pj0YYdHH3POEZOKamZAxMRyWApSyRmlg08CgwCegPnm1nvqGKDgB7h4zLgsYh1TwEDY+z6RuBdd+8BvBu+r1LJunJuemW6komISDWl8ozkcGC+uy909y3AC8DgqDKDgWc8MAHIN7MOAO7+HrAmxn4HA0+Hr58Gzko0oPKt2xk6Zs7uHYWIiHxLKqfaLQS+inhfDPRPoEwhsLSS/bZz96UA7r7UzNrGKmRmlxGc5ZDT/pv52kvWlVNUVJTYEdRzpaWlqosoqpPYVC+xZWq9pDKRWIxlvgdl9oi7DweGAzTq0GPnPgvzc+v1HMu7o77PN70nVCexqV5iy9R6SWXTVjHQOeJ9J2DJHpSJtryi+St8XpFoQA2zjetO7ZlocRERiSGViWQi0MPM9jGzHOA8YFRUmVHABeHVW0cA6yuarSoxCrgwfH0h8FoiweRkZ5FlcGS31okfgYiI7CJlicTdtwFXAWOAWcAId//czC43s8vDYqOBhcB84AngiortzexfwEdATzMrNrOfhKvuAU42s3nAyeH7SnVpnsVbvzwWMG4dOQP3Gmk9ExHJSKnsI8HdRxMki8hlwyJeO3BlnG3Pj7N8NXDS7sbStaApvzp5X+55czZvzljGaQd02N1diIgIGX5n+0+P2Yc+hc257bUZrN24Jd3hiIjUSRmdSBpkZ3Hf9w9iXdlW7nxjZrrDERGpkzI6kQD07ticnw/oxiuTSyiak/AFXyIiEsr4RAJw1Ynd6VbQhFtenUHp5m3pDkdEpE5RIgEaNcjmviEHsmR9OUPfmp3ucERE6hQlktAhe7fiwiO78MyExUxcFGtILxERiUWJJMJ1p/akY4tcbnh5Gpu2bk93OCIidYISSYQmjRrwx7MPYOHKjTz87rx0hyMiUicokUQ5bt8ChhzSicffW8iMkvXpDkdEpNZTIonht6f3omVeDte/NI2t23ekOxwRkVpNiSSG/Lwc7hy8PzOXbuCJ9xemOxwRkVpNiSSOQQd0YFCf9jz0zjwWrCxNdzgiIrWWEkkl7hi8P40bZHHjy9PYsUMjBIuIxJLS0X/rmrbNGnPrGb257qVp9LtzLBvKt9I
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell plots exponential scheduling\n",
"\n",
"plt.plot(history.epoch, history.history[\"lr\"], \"o-\")\n",
"plt.axis([0, n_epochs - 1, 0, 0.011])\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Learning Rate\")\n",
"plt.title(\"Exponential Scheduling\", fontsize=14)\n",
"plt.grid(True)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The schedule function can take the current learning rate as a second argument:"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
"def exponential_decay_fn(epoch, lr):\n",
" return lr * 0.1 ** (1 / 20)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Extra material**: if you want to update the learning rate at each iteration rather than at each epoch, you can write your own callback class:"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"K = tf.keras.backend\n",
"\n",
2021-10-17 04:04:08 +02:00
"class ExponentialDecay(tf.keras.callbacks.Callback):\n",
" def __init__(self, n_steps=40_000):\n",
" super().__init__()\n",
" self.n_steps = n_steps\n",
"\n",
" def on_batch_begin(self, batch, logs=None):\n",
" # Note: the `batch` argument is reset at each epoch\n",
" lr = K.get_value(self.model.optimizer.learning_rate)\n",
" new_learning_rate = lr * 0.1 ** (1 / self.n_steps)\n",
" K.set_value(self.model.optimizer.learning_rate, new_learning_rate)\n",
"\n",
" def on_epoch_end(self, epoch, logs=None):\n",
" logs = logs or {}\n",
" logs['lr'] = K.get_value(self.model.optimizer.learning_rate)"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
"lr0 = 0.01\n",
"model = build_model()\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=lr0)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.6804 - accuracy: 0.7679 - val_loss: 0.4803 - val_accuracy: 0.8276\n",
"Epoch 2/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4683 - accuracy: 0.8361 - val_loss: 0.4410 - val_accuracy: 0.8412\n",
"Epoch 3/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4216 - accuracy: 0.8493 - val_loss: 0.4108 - val_accuracy: 0.8536\n",
"Epoch 4/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3974 - accuracy: 0.8591 - val_loss: 0.3858 - val_accuracy: 0.8584\n",
"Epoch 5/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3770 - accuracy: 0.8657 - val_loss: 0.3784 - val_accuracy: 0.8624\n",
"Epoch 6/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3625 - accuracy: 0.8713 - val_loss: 0.3784 - val_accuracy: 0.8626\n",
"Epoch 7/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3512 - accuracy: 0.8736 - val_loss: 0.3662 - val_accuracy: 0.8674\n",
"Epoch 8/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3414 - accuracy: 0.8779 - val_loss: 0.3699 - val_accuracy: 0.8638\n",
"Epoch 9/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3333 - accuracy: 0.8810 - val_loss: 0.3470 - val_accuracy: 0.8714\n",
"Epoch 10/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3260 - accuracy: 0.8827 - val_loss: 0.3463 - val_accuracy: 0.8718\n",
"Epoch 11/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3197 - accuracy: 0.8852 - val_loss: 0.3509 - val_accuracy: 0.8718\n",
"Epoch 12/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3140 - accuracy: 0.8877 - val_loss: 0.3463 - val_accuracy: 0.8764\n",
"Epoch 13/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3093 - accuracy: 0.8893 - val_loss: 0.3345 - val_accuracy: 0.8762\n",
"Epoch 14/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3049 - accuracy: 0.8907 - val_loss: 0.3397 - val_accuracy: 0.8778\n",
"Epoch 15/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3010 - accuracy: 0.8925 - val_loss: 0.3400 - val_accuracy: 0.8788\n",
"Epoch 16/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2969 - accuracy: 0.8934 - val_loss: 0.3318 - val_accuracy: 0.8792\n",
"Epoch 17/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2942 - accuracy: 0.8939 - val_loss: 0.3337 - val_accuracy: 0.8780\n",
"Epoch 18/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2913 - accuracy: 0.8960 - val_loss: 0.3290 - val_accuracy: 0.8766\n",
"Epoch 19/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2889 - accuracy: 0.8962 - val_loss: 0.3264 - val_accuracy: 0.8778\n",
"Epoch 20/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2865 - accuracy: 0.8970 - val_loss: 0.3262 - val_accuracy: 0.8794\n",
"Epoch 21/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2845 - accuracy: 0.8980 - val_loss: 0.3226 - val_accuracy: 0.8798\n",
"Epoch 22/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2822 - accuracy: 0.8985 - val_loss: 0.3262 - val_accuracy: 0.8814\n",
"Epoch 23/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2807 - accuracy: 0.8998 - val_loss: 0.3254 - val_accuracy: 0.8790\n",
"Epoch 24/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2788 - accuracy: 0.9006 - val_loss: 0.3258 - val_accuracy: 0.8816\n",
"Epoch 25/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2776 - accuracy: 0.9008 - val_loss: 0.3249 - val_accuracy: 0.8808\n"
]
}
],
"source": [
"n_epochs = 25\n",
"batch_size = 32\n",
"n_steps = n_epochs * math.ceil(len(X_train) / batch_size)\n",
"exp_decay = ExponentialDecay(n_steps)\n",
"history = model.fit(X_train, y_train, epochs=n_epochs,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[exp_decay])"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {
"scrolled": true
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZIAAAEbCAYAAADwPQLqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA8FUlEQVR4nO3deXwU9fnA8c+TE5JAICGQcIYzyCFyX4pBRQG1aMX7QDwobamt1npU21qt/VmPqlQrXlStJ1ZURBQQjQpyhFuukAABwhXOQAgEQp7fHzPBdc2xuXZzPO/Xa17ZnfnO7DPf7O6z8/3OfEdUFWOMMaaiggIdgDHGmNrNEokxxphKsURijDGmUiyRGGOMqRRLJMYYYyrFEokxxphKsURiahURuVlEcsu5ToqIPFddMbmvkSkid1fDdseKSLnO0feuo4rUWWWIyEMiMtVfr1fM66uIjA3A65ZZzyIySURm+Csmf7FEUkuIyGvuB8R7WhTo2KpLCV8I7wEdquG1bhORFSKSKyI5IrJaRP5W1a8TINVSZ8URkebAXUCtrjs3Ga6phk2/DPQTkXOqYdsBExLoAEy5fAHc6DXvRCACCRRVPQYcq8ptisgtwGTgTmAeEAZ0BwZX5esESnXUWSluA5ao6ubqfiERCVXVk9X9OlVJVfNF5G3gDuDbQMdTVeyIpHbJV9XdXtMBABE5V0ROikhyUWERmSgih0Wkg/s8RUSmiMizInLQnZ4QkSCPdZqKyOvusmMi8oWIdPdYfrP7q/18EVkjIkdF5CsRae8ZqIhcKiLLROS4iGwRkUdFJMxjeaaIPCgiL7oxZonIHzyXuw/fd49MMj1f36NcRxH5WER2u7EsF5FLylmvPwOmq+qLqpqhqutU9X1Vvctrny4WkcVuvewXkU9EpIFHkQYl7Y+7frSIvCQi2SJyRES+FpF+XmVuEpGtIpInIjOBFl7Lf/JLuawmlWLq7CH3f3eNiGxyY/lIRJp5lAkRkac93idPi8gLIpJSRl1eB/yo6cbH912YiPzDrbejIpIqIhd5LE923wejRWSJiJwALqJk8SLyqVuPW0XkBq+YHhORNPd/mSkijxf9L0XkZuAvQHf54cj/ZndZY7cedrnv7fUicrXXtkv9bLj18zMRiSijLmsPVbWpFkzAa8DMMsr8HdgOxABdgaPAOI/lKcAR4F/u8quAHOAujzIfAxuAYUBPnDf9dqChu/xm4CTO0dEA4ExgBTDbYxsXAYeB8UBHYDiQBjzpUSYT2A9MAjoBvwEUGOwuj3Of3wbEA3Eer5/rsZ1ewEQ31k7AAzhHaV299vu5UuptCrAR6FBKmZFAAU6TTTd3v+8GInzcHwHmA5+69dYJeMStpwS3zECg0N2HLsAv3G2qRxwPAWu8YvOuk7KePwTkAh+6+zEY2Aq86FHmPuAgcAWQBDzrvldSSqmjGDf+IV7zUyj7ffcWsAjnfdfBrccTQC93ebJbn98DF7pl4kqIQ916+4Vbjw+4cfXzKPMnYCiQCIwGtgGPuMsaAk/ifA7i3amh+z9cAKxz3w8dgFHA5b5+NtxyEcAp4PxAf69U2fdToAOwycd/lJNICtwvAM/pHx5lQoFUYDqwHHjPaxspOF+Y4jHvQSDLfdzZ/RAO81ge7X7ob3Of3+yWSfIoc737oQ9yn38D/MnrtS9z4xX3eSbwjleZdOBBj+cKjPUqczMeX4ol1NUir+2kUHoiSQAWuq+XDrwJ3ASEepRZALxbyjZK3R/gPHf/G3qVWQnc4z5+G5jrtfwVqieRHAeiPeY9AGR4PN8F3OfxXHC+WFNKqYOz3DpsX873XUecL/q2Xut9BPzbfZzsbvsKHz4rCrzsNe8L4M1S1pnotf/F1fMIN84zStjGzZTx2fCYfwC4tax9qS2TNW3VLt/gfFg9pyeKFqrTXnwdcAnQHOcXmbdF6r6TXQuBViLSGDgD54Oy0GObOTi/Art5rJOvqmkez3fiJLEm7vO+wANuE1iu26zyNhCJ8+uuyGqv2Ha6cftMRCLdZol1bpNJLtAPaOvrNlR1l6oOxjmqeQbnS/NFYIlH80NvnP6T0pS2P31xfonu9aqXHjhfpODU/0KvbXg/rypb3f/tT2IVkWic/9OSooXueya1jG02dP8eL2ZZae+7Pjh1vs6rbi7mh7opsrSMGDy37/389HtYnLPh5rtNornA05T9nukN7FLV9aWUKeuzUeQYP9RXrWed7bVLnqpmlFFmEE7fVxOc5qFD5di+lLLM80ugoIRlQR5//wq8X8x29no89u4oVcrfb/ckTjPD3ThHAHnAGzgd5uWiqmuANcDzInI2TmfoVThHg74obX+CgD1AcWfrHHb/llb/RQqLKRfqY3yefKn78g4Nvs/92xTniMZXQe5r9S8mLu+TBI6WM6afEJFBwLs479E7cT4jP8N5L5W6qg+bL+uzUSSGH38WajU7IqlDRCQReA74NTAXeEtEvH8sDBQRzw/EIGCnqh7GafsNwuNsJfcXY093ma+W4/RRZBQzeX/QSnMSCC6jzNnAG6r6gaquBrL46a/Yiija3yj37wrg/EpsbzlOx3lhMXWS7fGag7zW836+F2jh9T88qxJx/YR7pLIbp50fAPf1+pex6iacpNitmGWlve9W4HxJxxdTNzsquBvF1WPRkcRQYIeqPqKqqaqaDrTzKn+Cn773lgMJInJGBWMCnBNEgAbu9uoEOyKpXcJFJN5r3ilV3SsiwTht+1+r6osi8j+cJqm/4HQsFmkJPCMi/8ZJEH/APedfVdNF5GPgRRGZgPNL7VGcL4e3yxHnw8BMEdkKTMP5ldYDGKCq95RjO5nA+SLyNU6TwcFiymwELnfjPomzvw2KKVciEXkBpwniS5xElIDThp8HzHGLPQp8IiIZOHUhOJ2+L6pqng8v8wVOP8vHInIPP3TkjgS+UNVvcU5B/k5E7gf+h9MvcLnXdlJwfs3+UUTedctUx8V3zwL3iMhGnAT3C5x6KfFIQ1ULReQLnOT+P6/Fpb3vNorIW8BrIvJ7nC/YGJx926yq0ysQ/89FJBWnvsbi/AgY6C7biNOsdj1Ok9dFwLVe62cC7USkD05H/BGcps3FwAcicqe7nU5ApKp+VI7YzsHZr/Ty71bNZEcktcsFOB9kz2mFu+yPOG/qWwFUdT8wDrjPbaYp8hbOL63FOBdHvYrTPlxkPE7b+Az3bwQwUp1rEXyiqrNx2reHu9tYgnMW0DbfdxWA37vb2M4P++ntLiAbpxnqM5yO9vKenz8X50tmGs6Xw4fu/BGquhFAVWfhfKmPcmP52o2t0JcXcPsHRuMkq5dxzmKbhnNG1E63zCKc/98vcfpbfo7T6eu5nfXu8glumRE4Z+tVtSeB/wL/walTcOqluP4PTy8BV7s/bDz58r77D/A4TpKdiXMG19YKxv8Qzhlnq3Hqa7yqpgKo6ic4fYvP8EMd/tlr/Q+AWTjJYy9wraoW4vz/F+D8aFuPk3DL24x6LU4d1BlFZ9CYesC9BmCNqk4KdCym9hGR5cACVf1NGeUW4pxt9V/3eQr2vgNARHrgJKcuXic71GrWtGWM+QkRaYfT5PM1zvfEBJxrdib4sPovcM5wMj/VEripLiURsERijCleIc61NE/gNIGvA0apapmn37onPXifCm0AVZ1Tdqnax5q2jDHGVIp1thtjjKmUetm01aRJE+3UqVOgwyjW0aNHiYyMDHQYxbLYKsZiqxiLrWKqM7Zly5btU9W4nywI9BgtgZi6dOmiNdVXX30V6BBKZLFVjMVWMRZbxVRnbMBStbG2jDHGVDVLJMYYYyrFEokxxphKsURijDGmUiyRGGOMqRRLJMYYYyrFEokxxphKsURijDGmUiyRGGOMqRRLJMYYYyrFEokxxphKsURijDGmUiyRGGOMqRRLJMYYYyrFEokxxphK8WsiEZGRIpImIhkicl8xy0VEJrvLV4tIH49lU0UkW0TWeK0TIyJzRSTd/dvUH/tijDHG4bdEIiLBwPPAKKAbcK2IdPMqNgro7E4TgBc8lr0GjCxm0/cB81S1MzDPfV6qQrtNvTHGVBl/HpE
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"n_steps = n_epochs * math.ceil(len(X_train) / batch_size)\n",
"steps = np.arange(n_steps)\n",
"decay_rate = 0.1\n",
"lrs = lr0 * decay_rate ** (steps / n_steps)\n",
"\n",
"plt.plot(steps, lrs, \"-\", linewidth=2)\n",
"plt.axis([0, n_steps - 1, 0, lr0 * 1.1])\n",
"plt.xlabel(\"Batch\")\n",
"plt.ylabel(\"Learning Rate\")\n",
"plt.title(\"Exponential Scheduling (per batch)\", fontsize=14)\n",
"plt.grid(True)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Piecewise Constant Scheduling"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [],
"source": [
"def piecewise_constant_fn(epoch):\n",
" if epoch < 5:\n",
" return 0.01\n",
" elif epoch < 15:\n",
" return 0.005\n",
" else:\n",
" return 0.001"
]
},
{
"cell_type": "code",
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
"# extra code this cell demonstrates a more general way to define\n",
"# piecewise constant scheduling.\n",
"\n",
"def piecewise_constant(boundaries, values):\n",
" boundaries = np.array([0] + boundaries)\n",
" values = np.array(values)\n",
" def piecewise_constant_fn(epoch):\n",
" return values[(boundaries > epoch).argmax() - 1]\n",
" return piecewise_constant_fn\n",
"\n",
"piecewise_constant_fn = piecewise_constant([5, 15], [0.01, 0.005, 0.001])"
]
},
{
"cell_type": "code",
"execution_count": 82,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
"1719/1719 [==============================] - 3s 1ms/step - loss: 0.5745 - accuracy: 0.7963 - val_loss: 0.4856 - val_accuracy: 0.8256\n",
"Epoch 2/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4472 - accuracy: 0.8424 - val_loss: 0.4418 - val_accuracy: 0.8372\n",
"Epoch 3/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4216 - accuracy: 0.8505 - val_loss: 0.4162 - val_accuracy: 0.8588\n",
"Epoch 4/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4104 - accuracy: 0.8569 - val_loss: 0.4027 - val_accuracy: 0.8592\n",
"Epoch 5/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3949 - accuracy: 0.8614 - val_loss: 0.4276 - val_accuracy: 0.8510\n",
"Epoch 6/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3483 - accuracy: 0.8747 - val_loss: 0.3907 - val_accuracy: 0.8676\n",
"Epoch 7/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3336 - accuracy: 0.8783 - val_loss: 0.3981 - val_accuracy: 0.8628\n",
"Epoch 8/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3261 - accuracy: 0.8815 - val_loss: 0.4098 - val_accuracy: 0.8574\n",
"Epoch 9/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3241 - accuracy: 0.8818 - val_loss: 0.4197 - val_accuracy: 0.8514\n",
"Epoch 10/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3183 - accuracy: 0.8833 - val_loss: 0.3668 - val_accuracy: 0.8736\n",
"Epoch 11/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3147 - accuracy: 0.8856 - val_loss: 0.3936 - val_accuracy: 0.8698\n",
"Epoch 12/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3111 - accuracy: 0.8864 - val_loss: 0.3854 - val_accuracy: 0.8702\n",
"Epoch 13/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3117 - accuracy: 0.8868 - val_loss: 0.4126 - val_accuracy: 0.8718\n",
"Epoch 14/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3052 - accuracy: 0.8896 - val_loss: 0.3997 - val_accuracy: 0.8722\n",
"Epoch 15/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.3018 - accuracy: 0.8905 - val_loss: 0.4556 - val_accuracy: 0.8678\n",
"Epoch 16/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2640 - accuracy: 0.9022 - val_loss: 0.4124 - val_accuracy: 0.8782\n",
"Epoch 17/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2567 - accuracy: 0.9040 - val_loss: 0.4121 - val_accuracy: 0.8804\n",
"Epoch 18/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2525 - accuracy: 0.9053 - val_loss: 0.4173 - val_accuracy: 0.8776\n",
"Epoch 19/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2497 - accuracy: 0.9065 - val_loss: 0.4279 - val_accuracy: 0.8816\n",
"Epoch 20/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2478 - accuracy: 0.9066 - val_loss: 0.4330 - val_accuracy: 0.8774\n",
"Epoch 21/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2452 - accuracy: 0.9086 - val_loss: 0.4400 - val_accuracy: 0.8816\n",
"Epoch 22/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2430 - accuracy: 0.9099 - val_loss: 0.4522 - val_accuracy: 0.8796\n",
"Epoch 23/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2405 - accuracy: 0.9096 - val_loss: 0.4538 - val_accuracy: 0.8812\n",
"Epoch 24/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2390 - accuracy: 0.9102 - val_loss: 0.4929 - val_accuracy: 0.8830\n",
"Epoch 25/25\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.2377 - accuracy: 0.9105 - val_loss: 0.4720 - val_accuracy: 0.8802\n"
]
}
],
"source": [
"# extra code use a tf.keras.callbacks.LearningRateScheduler like earlier\n",
"\n",
"n_epochs = 25\n",
"\n",
"lr_scheduler = tf.keras.callbacks.LearningRateScheduler(piecewise_constant_fn)\n",
"\n",
"model = build_model()\n",
"optimizer = tf.keras.optimizers.Nadam(learning_rate=lr0)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, y_train, epochs=n_epochs,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[lr_scheduler])"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZIAAAEbCAYAAADwPQLqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAp40lEQVR4nO3deZxcVZ338c+3O1unCXToBDp0oDtK7ACCgAgog0ZRWXQGXPCBEVlcIj4woz6Kwui44DDgNo4oghnNACMKOIMYNYqCtowKympYIwECZCELmJClyfp7/ri3k0qlquumq6sqXfV9v1716rrLufWr09X163vPuecoIjAzMxusploHYGZmw5sTiZmZlcWJxMzMyuJEYmZmZXEiMTOzsjiRmJlZWZxIGpyksyWtqXUcA5EUkt5Z6zgsG0lXS/ppBY47If0sTN+JMt1pmSMKLdvQcCKpc+kfdaSPjZKekPQVSa3pLjcAL6lljBlMAn5SyReQNE7SFyQ9LKlP0lJJvZJOl1SVv5NKfsntzLElvU7SbZJWSFon6XFJ10nafajjqoFnSD5P99c4jroyotYBWFXcCrwHGAkcC3wHaAU+FBF9QF8NYyspIp6t5PEltQG/A8YDnwb+BGwA/gb4Z+AOYEElY9hVSDoQ+AVwFfARYC2wP3AKMLpmgQ2RiNgMVPTz1JAiwo86fgBXAz/NW/cfwJL0+dnAmrztfwvcA7wIPAlcAozK2T4K+FfgKWA98ATwjznbDwR+BqwGlgE/ADrSbQcAkbM8luRL++c55T8APJazHMA7c5Y/k/PazwLX5mwT8AngcZIE+QBwRok6+hbJF+bkAtvGAGPS5+OBa4C/pse+FTgoZ9+zgTXAccCD6TF/A0zJ2Wdf4MfA88A64FHgtJz3mfvoTde/CvglsAJ4gSTpvTovzgBmAD9MX/eJ3Pdd7NgF3u9HgIUZPlfTgNnAqvQ93wEcnPuZAz4MLErr6z+BsTvze0rfd//n8D7gLWns09Pt09PlCTllutN1R2Rc7j/GccAf09/J3cDhebG8F3g63f4T4P8CUeu/713l4UtbjamP5OxkB5KOB64DvgkcRPIH9E6SxNHvGuBM4P+RJIb3ASvT8pOA20m+SI8E3gjsBsyW1BQRjwBLSf6AAY4h+TL6G0n9Z8jTgd4i8b0D+DjJH/JU4K0kZxD9/iWN5zyShHYp8G1JbylyvCbgNOC6iFiYvz0iXoyIF9PFq4GjgJPT97YO+IWklpwio4GLSOrt1UAbyX/3/b5FkjxfT1K/HyGtu/SYACeQXH55e7o8DvgvkrPJI0kuy8yRNCEv3M+QJKlXkFyynCWpq8Sx8z0LTJT0+iLbkbQPSTIL4E3A4cAVQHPObscCLyf5/f8f4G0kiaXfgL+n9NLrz0gS4hHAhcBXisU0BC5NX+Nw4DngOklKY3k1yVn8FcChJAn08xWMZfipdSbzo7IP8s5ISL5QVgA3pMtnk3NGQpIE/jnvGKeQ/Ncpki/vAE4o8noXA7flrRufljkyXb4B+Hb6/BLgSpJLR69O1y0E3p1TfusZCUnymgeMLPDarSRJ8ti89f8OzCkS717p8T9aoh773/drc9btQZIE359TlwH05OzzbpIzrqZ0eS7w2SKv0U3Of8sDxCJgCTuecVyaszyCJNGdsZPHbiY5ewiShP+TtM4n5uxzCckZ4agix7iapC1iRM66/wBuzfp7Ijm7WgnslrP9DCp3RnJ8zjGOSddNTpd/APwiL9aZ+Ixk68NnJI3hBElrJL1IcgniduAfiuz7SuBT6f5r0h5d3yf54+8ADgO2kFyyKVb+tXnln0m3vTT92cu2M5Lp6bF+C0yXNBXopMgZCcmlmzHAk5K+K+lUSf3X7g9Mt/0i7/U/lPPa+VRkfb4DSN73Hf0rImIVySWZA3P2Wx8R83KWF5Oc/bWly18HPi3pDkn/IumVpV5Y0l6Svi3pL5JWkVwy3AvYL2/XuTmxbQKWp/tlFhGbI+IcYDLJmd/TwAXAo5IOSnc7DPhdRGwY4FAPpzH0W5wTS5bf0wHA3IjI7VF4B5UzN+f54vRnf7zT2P6sF5LLYJZyY3tjuJ3kP7yNwOKI2DjAvk0kp+0/LLBtOaW/eJtILkl8vMC2penPXuBbadI4Il1uBU4nOVuaHxGLCh08Ip6R1ENyTfuNwFeBz0o6im29EP+W5AswV7H3vJzkGv4BJd7XQO87dwjtTUW2NQFExHcl3QKcRBL/HyRdGhGfG+D41wB7Ax8lOXNbD9xG0laVK/89BoPsmZnW/38B/yXp08BfSBLK2WRLvgPFkuX3lOU1thTYt+Al2wxy493ud5YeP7CifEbSGNZFxPyIeKpEEgG4F5iW7p//2JRubyK5xl+s/EHAUwXKrwaIbe0knyJJGstIzkqOIbnm3jtQgJG0W/wsIj5K0iB7UFr2YZIv2a4Cr/1UkWNtIbnU9m5Jk/O3SxojaUx67CaSdo/+bbsDB6fbMouIhRExMyLeRdKuMSPd1P8ffnNekb8BvpG+54dIzkgm7cxrDnDsLPH+leRS2m7pqntJ2rTyE1lWWX5PDwMH53RTBzg67zjL05+5dXHoIGMayCNsa2Pql7/c0JxILN/FwN9LuljSyyVNk/ROSV8CiIjHgBuB70h6h6Qpko6V9J60/BUkbQc3SDpK0kskvVHSTEnjcl7ntyTXvH+THncByRfD2xkgkaQ3UL5f0sGSpgDnkPw3+ViaqL4CfEXSeyXtL+lQSedKmlHsmMA/kfxn/EdJ50g6KC37HpJeQx3p+/4xSYPwsZIOBr5H0ovq+xnrFklfl3RCWi+HkjR+9yeiZSRtB8dL2lvSHun6vwBnSDpQ0quA69mWGLIqduz8+D4o6UpJb5b00rQuvkiSMG9Od/sWSVK5UdKr0ro6PX0/JWX8PX2f5OxuVhrDm0j+8cg1n+Sy6eckvUzSm0m6bw+1y4E3S7pA0lRJ7yPpPGD9at1I40dlHxTo/pu3/Wx27P77ZuB/SRprXyDpDnl+zvbRwJdIunauJ+nCmbt9KvDfbOsmOw/4Btt3IT6XHbv1Xp2u68yLJ7ex/RSSa+UrSbq53gW8NWdfkbT/9P/Xuxz4FfCmEvW0B0kj8qMk3U2XkSS009jWUJ6p+2/ecaeT0yCc1sNj6WssJ0kKnTn7v58kqW1mW/ffV5Bck+9L6/o9JL3iPleojnLWLQA+PtCxC9TDYel77O+W+xxwJ/CevP0OAuaQdMJYDfwBeHmxzxzwOeDBnfk9kfSQuzfd/meSS2FbG9vTfV5D0outL/1c9HcR3tnG9qIN9um695IkrT6SDggfA/pq/fe9qzyUVpKZmWUk6WvAGyPi4FrHsitwY7uZWQmSLiA5Y1pD0kniXJJLogY+IzEzK0XSDSSXwfYgGe3h28DXw1+ggBOJmZmVyb22zMysLA3ZRtLW1hb7779/rcPY5axdu5bW1tbSOzYQ10lhrpfC6r1e7rnnnhURMTF/fUMmkr333pu777671mHscnp7e5k+fXqtw9iluE4Kc70UVu/1Iqngjb2+tGVmZmVxIjEzs7I4kZiZWVmcSMzMrCxOJGZmVhYnEjMzK4sTiZmZlcWJxMzMyuJEYmZmZXEiMTOzsjiRmJlZWZxIzMysLE4kZmZWFicSMzMrixOJmZmVparzkUg6Afg60Ax8JyIuy9uudPtJwDrg7Ii4N902C3grsCwiXp5TZk/gBqAbWAC8KyL+OlAcC17YwjGX/ZoLju/hlMM6S8Z9832L+PIt81i8so992lrqrpyZWTmqdkYiqRm4AjgROBA4XdKBebudCExNHzOAK3O2XQ2cUODQFwK3RcRU4LZ0uaRFK/u46KYHuPm+RQPud/N9i7jopgdYtLKPqMNyZmblquYZyZHA/Ih4AkDS9cDJwMM5+5wMXBsRAdwpqU3SpIhYEhG3S+oucNyTgenp82uAXuCTWQLq27iZf/rRA/xu/oqi+8x5YAl9GzcP23JfvmWez0rMrKKqmUg6gWdylhcCR2XYpxNYMsBx946IJQARsUTSXoV2kjSD5CyHUR3b5mtft2Ezv3mo+H/t6zZEkfX
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell plots piecewise constant scheduling\n",
"\n",
"plt.plot(history.epoch, history.history[\"lr\"], \"o-\")\n",
"plt.axis([0, n_epochs - 1, 0, 0.011])\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Learning Rate\")\n",
"plt.title(\"Piecewise Constant Scheduling\", fontsize=14)\n",
"plt.grid(True)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Performance Scheduling"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
"# extra code build and compile the model\n",
"\n",
"model = build_model()\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=lr0)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "code",
"execution_count": 85,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
"1719/1719 [==============================] - 2s 889us/step - loss: 0.6795 - accuracy: 0.7682 - val_loss: 0.4782 - val_accuracy: 0.8314\n",
"Epoch 2/25\n",
"1719/1719 [==============================] - 1s 841us/step - loss: 0.4670 - accuracy: 0.8368 - val_loss: 0.4440 - val_accuracy: 0.8396\n",
"Epoch 3/25\n",
"1719/1719 [==============================] - 1s 843us/step - loss: 0.4191 - accuracy: 0.8506 - val_loss: 0.4155 - val_accuracy: 0.8522\n",
"Epoch 4/25\n",
"1719/1719 [==============================] - 1s 842us/step - loss: 0.3934 - accuracy: 0.8601 - val_loss: 0.3799 - val_accuracy: 0.8606\n",
"Epoch 5/25\n",
"1719/1719 [==============================] - 1s 830us/step - loss: 0.3708 - accuracy: 0.8681 - val_loss: 0.3664 - val_accuracy: 0.8670\n",
"Epoch 6/25\n",
"1719/1719 [==============================] - 1s 824us/step - loss: 0.3541 - accuracy: 0.8730 - val_loss: 0.3752 - val_accuracy: 0.8654\n",
"Epoch 7/25\n",
"1719/1719 [==============================] - 1s 831us/step - loss: 0.3416 - accuracy: 0.8762 - val_loss: 0.3545 - val_accuracy: 0.8716\n",
"Epoch 8/25\n",
"1719/1719 [==============================] - 1s 832us/step - loss: 0.3300 - accuracy: 0.8807 - val_loss: 0.3597 - val_accuracy: 0.8678\n",
"Epoch 9/25\n",
"1719/1719 [==============================] - 1s 817us/step - loss: 0.3206 - accuracy: 0.8845 - val_loss: 0.3323 - val_accuracy: 0.8804\n",
"Epoch 10/25\n",
"1719/1719 [==============================] - 1s 859us/step - loss: 0.3108 - accuracy: 0.8869 - val_loss: 0.3406 - val_accuracy: 0.8766\n",
"Epoch 11/25\n",
"1719/1719 [==============================] - 1s 849us/step - loss: 0.3033 - accuracy: 0.8893 - val_loss: 0.3551 - val_accuracy: 0.8696\n",
"Epoch 12/25\n",
"1719/1719 [==============================] - 1s 822us/step - loss: 0.2954 - accuracy: 0.8931 - val_loss: 0.3324 - val_accuracy: 0.8810\n",
"Epoch 13/25\n",
"1719/1719 [==============================] - 1s 796us/step - loss: 0.2893 - accuracy: 0.8956 - val_loss: 0.3159 - val_accuracy: 0.8810\n",
"Epoch 14/25\n",
"1719/1719 [==============================] - 1s 806us/step - loss: 0.2826 - accuracy: 0.8969 - val_loss: 0.3435 - val_accuracy: 0.8792\n",
"Epoch 15/25\n",
"1719/1719 [==============================] - 1s 830us/step - loss: 0.2762 - accuracy: 0.8995 - val_loss: 0.3470 - val_accuracy: 0.8792\n",
"Epoch 16/25\n",
"1719/1719 [==============================] - 1s 813us/step - loss: 0.2701 - accuracy: 0.9019 - val_loss: 0.3276 - val_accuracy: 0.8794\n",
"Epoch 17/25\n",
"1719/1719 [==============================] - 1s 821us/step - loss: 0.2656 - accuracy: 0.9025 - val_loss: 0.3334 - val_accuracy: 0.8796\n",
"Epoch 18/25\n",
"1719/1719 [==============================] - 1s 814us/step - loss: 0.2608 - accuracy: 0.9040 - val_loss: 0.3246 - val_accuracy: 0.8844\n",
"Epoch 19/25\n",
"1719/1719 [==============================] - 1s 816us/step - loss: 0.2417 - accuracy: 0.9120 - val_loss: 0.3155 - val_accuracy: 0.8798\n",
"Epoch 20/25\n",
"1719/1719 [==============================] - 1s 821us/step - loss: 0.2387 - accuracy: 0.9135 - val_loss: 0.3149 - val_accuracy: 0.8826\n",
"Epoch 21/25\n",
"1719/1719 [==============================] - 1s 825us/step - loss: 0.2359 - accuracy: 0.9143 - val_loss: 0.3076 - val_accuracy: 0.8844\n",
"Epoch 22/25\n",
"1719/1719 [==============================] - 1s 824us/step - loss: 0.2328 - accuracy: 0.9148 - val_loss: 0.3156 - val_accuracy: 0.8854\n",
"Epoch 23/25\n",
"1719/1719 [==============================] - 1s 821us/step - loss: 0.2304 - accuracy: 0.9157 - val_loss: 0.3225 - val_accuracy: 0.8808\n",
"Epoch 24/25\n",
"1719/1719 [==============================] - 1s 827us/step - loss: 0.2274 - accuracy: 0.9176 - val_loss: 0.3158 - val_accuracy: 0.8834\n",
"Epoch 25/25\n",
"1719/1719 [==============================] - 1s 825us/step - loss: 0.2249 - accuracy: 0.9183 - val_loss: 0.3144 - val_accuracy: 0.8852\n"
]
}
],
"source": [
2021-10-17 04:04:08 +02:00
"lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5)\n",
"history = model.fit(X_train, y_train, epochs=n_epochs,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=[lr_scheduler])"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcgAAAEbCAYAAABeCxRrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABM20lEQVR4nO2dd5hUVdKH34JhSCoyYkBRQMS4q4iIcRVUFFlXxIS45lVEBcHVNawJRXfVNaEoiIrKigHFwKcYkTGsCVQyIogJGMGMCAoz1PdH3ZE7Pd3Tt2d6unum632e+/QN55xb99LMr885dapEVXEcx3EcpyINsm2A4ziO4+QiLpCO4ziOEwcXSMdxHMeJgwuk4ziO48TBBdJxHMdx4uAC6TiO4zhxcIF06jQislJETsu2HfUNEflcRC7Kth2Ok01cIJ1aR0QeFBENtlIR+VJERopIy2zblg5EpFvwbK0SXB8aev51IrJURMaJyNaZtjWw57SQPSoiJSIyXkTa17DNlem003GyjQukkyleBVoD7YAzgb8Ad2fToAwzH3v+NkBf4I/A+CzasyqwZ0vgRKATMFFEGmbRJsfJKVwgnUzxm6p+raqLVfVl4HHg0HABETldROaKyK8i8omIXCAiDULXtxOR4uD6fBE5IqZ+u6BH1CXmvIrIsaHjLYMe3HciskpEpotI99D1v4jIB8F9PhOR60WksIbPXxo8/1JVfRO4F9hbRDaqqpKIHC0is0TkNxH5SkQuFxEJXf9cRK4QkXtEZIWILBaRf0SwRwN7SlR1CnAN8AdguwR2/F1EZorILyKyRETuE5GNg2vdgAeA5qFe6dDgWqGI3BjY9YuITBWRw0LtNhSR+4P3vFpEFojIxTH/7g+KyHMx9gwVkdkRntNxqk1Btg1w8g8R2RboCawNnTsLuBYYBHyA/bG+NygzIviD+TTwA7AP0AwYDjRO8d7NgdeB5UAfYAmwW+j6YcA4YDDwBrANMCq4T1rm5ERkC+BooCzYEpXbA3gCuC6waU/gHmAFcGeo6AXA1cB/gMOBO0TkLVV9JwWzVgefjRJcXwcMARYBbYP73wmcDLwdXPsX0CEoXz7c+kBw7kRgMdAL+D8R2VNVZ2A/0pcAxwPfAF2B0cB3wP0p2O846UdVffOtVjfgQaAU+6O5GtBguyBU5kvg5Jh6Q4C5wf6hmJhsE7q+f9DOacFxu+C4S0w7Chwb7J8F/Ay0SmDrG8CVMeeOCmyXBHW6BfdI1ObQwPaV2NBm+fMPT/LexgGvxWlrcej4c+DRmDILgCuqaPc0YGXouA3wDvAVUBhq96Iq2ugJ/AY0iNdmcK4DJqzbxJx/Bri7irZvAF6N+f48F+c9zM72d9u3+r15D9LJFG8A/YGmmEh1AO4AEJFNga2Be0RkZKhOAVA+nLgTsERVvwxdfw/7A5wKuwMzVfXbBNf3ALqKyCWhcw0Cu7cASlK8XzmfYr2nxkBv4Bjgn0nq7AQ8H3PuLeBqEdlIVVcE52bGlFkKbJak7eaBU41gvfEPgaNVdU28wiJyEHBZYFMLoCFQiL2TpQnu0Tlof25oVBjsHbwWansANi/dFnvPjYAvktjvOLWOC6STKVap6sJg/3wRmQJcifUEyuebBmDDdfGQBOfDlItleI4udsgwWTsNsPm4J+Jc+yaCDYlYE3r+OSLSEbgL63klQrCeZjzC59fGuZbMv2AV5pizDlimqr8kNEKkLSbU9wJXYcOfnYFHMZFMRIPAlj3j2Lg6aLsvcDs2fP02Nnx8Hjb8Xc46Kv+7JRoKdpy04QLpZItrgBdEZLSqLhWRJUAHVR2boPxcYCsR2VpVvwrOdaWiEJQLWOvQuU4x7XwInCQirRL0Ij8EdgyJWW0xDJgvIneq6gcJyszFhpHD7I8Nsf5cw/trCs/YBRPCC1S1DCDWQQpYg/Uqw3yECdsWao5A8dgfeE9VR5SfEJEOMWW+ofK/Y+yx46Qd92J1soKqFgNzgCuCU0OBiwPP1R1E5A8icoqIXBZcfxX4GBgrIp1EZB/gNmxus7zN1cC7wCUisouI7AvcHHPrRzAHnWdE5E8i0l5Ejgx5sV4LnCgi1wY27Cgix4rITREe6w+BbeEt7v8xVV0ETMSEMhG3AAcGHpvbi8hfgQuBKLakkwXY34ohwfvqh80Ph/kcaCIiPUSklYg0U9VPsHnUB4N3uK2IdBGRi0Tk6KDeJ0BnETlcRDqKyJXAgTFtvwbsLiJniHkyXwzsV0vP6jjryfYkqG/1fyOOk0Vw/kTM0aNtcNwP68H9inmrvgWcECq/PeaB+hv2R/tIzPHltFCZnYD/YUOIs4A/EXLSCcq0wZaZ/BiU+wjoFrp+KPBmcG0FMA0YWMXzdWO9403stgEJHEqAfYMy+1bR9tHBc6zBnGguJ+QsRBxnGqAYGFFFm6cR41ATp0yFdoHzMW/T1cBkzOtUgXahMiOBb4PzQ4NzjYLnXxQ8w9fYD4M9guuFmLfqD8G/x/3YMO7nMfYMxeZ/f8LWz/4r3jv1zbd0bqKaaIrDcRzHcfIXH2J1HMdxnDi4QDqO4zhOHFwgHcdxHCcOLpCO4ziOE4e8XAfZoEEDbdq0abbNyDnWrVtHgwb+mymMv5P4+HuJT31/L6tWrVJVrb8PGENeCmRhYSG//JIwcEjeUlxcTLdu3bJtRk7h7yQ+/l7iU9/fi4isTl6q/pA3vwQcx3EcJxVcIB3HcRwnDi6QjuM4jhMHF0jHcRzHiYMLpOM4juPEIaMCKUJPEeaLsFCES+NcFxHuCK7PFKFz6NoYEZaLMDumTpEIr4iwIPhsmcyOtr+tZc+tv2bcuGh2jxsH7dpBgwb2WV/rHXTQgXXCzkzU83dSdb1cfy+OkxYyFRUdtCHop6DbghaCzgDdOaZML9AXQAV0b9D3QtcOAO0MOjumzk2glwb7l4LemMyWPUBHcK42a6b68MNaJQ8/rNqsmSqs37xe/a5XF2z0ernJlClTsm1CrQL8olnOsJHJLWPZPETYBxiqymHB8WUm0Pw7VOYeoFiVR4Pj+UA3VUqC43bAc6r8IVTn9zIitA7q71CVLV1E9A2asi2L+LXFFpx/fuKyd9wBP/1U+XyLFni9elqvLtiYz/XatoXPP09cL5vkwTrIVaraPNt2ZIpMCuSxQE9VzgyOTwb2UmVgqMxzwA2qvBUcTwYuUWVacNyOygL5oyobh45/UK08zCpCf6A/wB7IHm9RyP2cyUBGIJLYbns98Qqo16un9eqCjflcT0R57bXXE1fMIitXrmSDDTbIthm1Rvfu3fNKIDPWVQU9DvS+0PHJoHfGlHkedP/Q8WTQPULH7eIMsf4Yc/xDMlv2CMZrfqGpdmlTEjuKUIG2bbXCEE/51rZtldW8Xh2uVxds9Hq5Sd4PsUJPhfkKCxUuraLcngplWp7IHHZQmB7aVigMCa4NVVgSutarShvSuGXkJvZedR/Ql0LHl4FeFlPmHtB+oeP5oK1Dx/EE8vcyoK1B5yezpVwgV1OoHx98bpVfiLoy7+L10levLtjo9XKTvBZIaKjwqcK2CoUKMxR2TlDuNYVJvwtk5etfK7QNjocqXJTwvrWpWxm7EVoAugi0PeuddHaJKfNnKjrpvB9zPZ5A/oeKTjo3JbNlj/D/uE6dkn4pHn7YfrWK2GfU/6B1r966OmJn7dfzd5KsXubeS8uW9l+1TZvcFkfVvBfIfRReCh1fpnBZnHJDFM5TeDCBQB6q8L/Qcf0XSHu32gv0E8yb9fLg3ADQAcG+gN4VXJ8F2iVU91HQEtC1oItB/xac3wQbil0QfBYls2PzggJ79NdfT+GrUf+p7/+5q4O/k/hk8r2MG2f/XefPz9gtq019/760gt8UpoW2/rpeyI5VuC90fLLCCA3//YWtFF4PeomJBHKMwsDQ8VCFzxVmBtdaVqpTS1tGs3moMgmYFHNuVGhfgfMS1O2X4Px3wMGp2LGiYUMoLIRHHoEDDkilquM4GaaoyD6//z67djjwLZSi2iXB5fieWBW5HbgE1bK4HloihcCRYKscAkYCw4K2hgG3AGekZHg1yctIOusAeveGJ56ANWu
"text/plain": [
"<Figure size 432x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell plots performance scheduling\n",
"\n",
"plt.plot(history.epoch, history.history[\"lr\"], \"bo-\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Learning Rate\", color='b')\n",
"plt.tick_params('y', colors='b')\n",
"plt.gca().set_xlim(0, n_epochs - 1)\n",
"plt.grid(True)\n",
"\n",
"ax2 = plt.gca().twinx()\n",
"ax2.plot(history.epoch, history.history[\"val_loss\"], \"r^-\")\n",
"ax2.set_ylabel('Validation Loss', color='r')\n",
"ax2.tick_params('y', colors='r')\n",
"\n",
"plt.title(\"Reduce LR on Plateau\", fontsize=14)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### tf.keras schedulers"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"\n",
"batch_size = 32\n",
"n_epochs = 25\n",
"n_steps = n_epochs * math.ceil(len(X_train) / batch_size)\n",
"scheduled_learning_rate = tf.keras.optimizers.schedules.ExponentialDecay(\n",
" initial_learning_rate=0.01, decay_steps=n_steps, decay_rate=0.1)\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=scheduled_learning_rate)"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 864us/step - loss: 0.6808 - accuracy: 0.7683 - val_loss: 0.4806 - val_accuracy: 0.8268\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 1s 812us/step - loss: 0.4686 - accuracy: 0.8359 - val_loss: 0.4420 - val_accuracy: 0.8408\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 1s 809us/step - loss: 0.4221 - accuracy: 0.8494 - val_loss: 0.4108 - val_accuracy: 0.8530\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 1s 828us/step - loss: 0.3976 - accuracy: 0.8592 - val_loss: 0.3867 - val_accuracy: 0.8582\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 1s 825us/step - loss: 0.3775 - accuracy: 0.8655 - val_loss: 0.3784 - val_accuracy: 0.8620\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 1s 817us/step - loss: 0.3633 - accuracy: 0.8705 - val_loss: 0.3796 - val_accuracy: 0.8624\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 1s 843us/step - loss: 0.3518 - accuracy: 0.8737 - val_loss: 0.3662 - val_accuracy: 0.8662\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 1s 805us/step - loss: 0.3422 - accuracy: 0.8779 - val_loss: 0.3707 - val_accuracy: 0.8628\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 1s 821us/step - loss: 0.3339 - accuracy: 0.8809 - val_loss: 0.3475 - val_accuracy: 0.8696\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 1s 829us/step - loss: 0.3266 - accuracy: 0.8826 - val_loss: 0.3473 - val_accuracy: 0.8710\n"
]
}
],
"source": [
"# extra code build and train the model\n",
"model = build_and_train_model(optimizer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For piecewise constant scheduling, try this:"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [],
"source": [
"# extra code shows how to use PiecewiseConstantDecay\n",
"scheduled_learning_rate = tf.keras.optimizers.schedules.PiecewiseConstantDecay(\n",
2019-05-05 06:42:08 +02:00
" boundaries=[5. * n_steps_per_epoch, 15. * n_steps_per_epoch],\n",
" values=[0.01, 0.005, 0.001])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1Cycle scheduling"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `ExponentialLearningRate` custom callback updates the learning rate during training, at the end of each batch. It multiplies it by a constant `factor`. It also saves the learning rate and loss at each batch. Since `logs[\"loss\"]` is actually the mean loss since the start of the epoch, and we want to save the batch loss instead, we must compute the mean times the number of batches since the beginning of the epoch to get the total loss so far, then we subtract the total loss at the previous batch to get the current batch's loss."
]
},
2019-05-05 06:42:08 +02:00
{
"cell_type": "code",
"execution_count": 90,
2019-05-05 06:42:08 +02:00
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"K = tf.keras.backend\n",
2019-05-05 06:42:08 +02:00
"\n",
2021-10-17 04:04:08 +02:00
"class ExponentialLearningRate(tf.keras.callbacks.Callback):\n",
2019-05-05 06:42:08 +02:00
" def __init__(self, factor):\n",
" self.factor = factor\n",
" self.rates = []\n",
" self.losses = []\n",
"\n",
" def on_epoch_begin(self, epoch, logs=None):\n",
" self.sum_of_epoch_losses = 0\n",
"\n",
" def on_batch_end(self, batch, logs=None):\n",
" mean_epoch_loss = logs[\"loss\"] # the epoch's mean loss so far \n",
" new_sum_of_epoch_losses = mean_epoch_loss * (batch + 1)\n",
" batch_loss = new_sum_of_epoch_losses - self.sum_of_epoch_losses\n",
" self.sum_of_epoch_losses = new_sum_of_epoch_losses\n",
" self.rates.append(K.get_value(self.model.optimizer.learning_rate))\n",
" self.losses.append(batch_loss)\n",
" K.set_value(self.model.optimizer.learning_rate,\n",
" self.model.optimizer.learning_rate * self.factor)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `find_learning_rate()` function trains the model using the `ExponentialLearningRate` callback, and it returns the learning rates and corresponding batch losses. At the end, it restores the model and its optimizer to their initial state."
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [],
"source": [
"def find_learning_rate(model, X, y, epochs=1, batch_size=32, min_rate=1e-4,\n",
" max_rate=1):\n",
2019-05-05 06:42:08 +02:00
" init_weights = model.get_weights()\n",
" iterations = math.ceil(len(X) / batch_size) * epochs\n",
" factor = (max_rate / min_rate) ** (1 / iterations)\n",
" init_lr = K.get_value(model.optimizer.learning_rate)\n",
" K.set_value(model.optimizer.learning_rate, min_rate)\n",
2019-05-05 06:42:08 +02:00
" exp_lr = ExponentialLearningRate(factor)\n",
" history = model.fit(X, y, epochs=epochs, batch_size=batch_size,\n",
" callbacks=[exp_lr])\n",
" K.set_value(model.optimizer.learning_rate, init_lr)\n",
2019-05-05 06:42:08 +02:00
" model.set_weights(init_weights)\n",
" return exp_lr.rates, exp_lr.losses"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `plot_lr_vs_loss()` function plots the learning rates vs the losses. The optimal learning rate to use as the maximum learning rate in 1cycle is near the bottom of the curve."
]
},
{
"cell_type": "code",
"execution_count": 92,
"metadata": {},
"outputs": [],
"source": [
2019-05-05 06:42:08 +02:00
"def plot_lr_vs_loss(rates, losses):\n",
" plt.plot(rates, losses, \"b\")\n",
2019-05-05 06:42:08 +02:00
" plt.gca().set_xscale('log')\n",
" max_loss = losses[0] + min(losses)\n",
" plt.hlines(min(losses), min(rates), max(rates), color=\"k\")\n",
" plt.axis([min(rates), max(rates), 0, max_loss])\n",
2019-05-05 06:42:08 +02:00
" plt.xlabel(\"Learning rate\")\n",
" plt.ylabel(\"Loss\")\n",
" plt.grid()"
2019-05-05 06:42:08 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's build a simple Fashion MNIST model and compile it:"
]
},
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
"model = build_model()\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n",
" metrics=[\"accuracy\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's find the optimal max learning rate for 1cycle:"
]
},
2019-05-05 06:42:08 +02:00
{
"cell_type": "code",
"execution_count": 94,
2019-05-05 06:42:08 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"430/430 [==============================] - 1s 1ms/step - loss: 1.7725 - accuracy: 0.4122\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAEOCAYAAACNY7BQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAABAZUlEQVR4nO2dd5gc1bG331qlVRaKIKEECmQQQQiTBDY52aRLtMGAjAEb30sONsaYZAwPl2CCsT6wScYYc4UQwWANiCBAAiQQEkpIKKEcUFhJqz3fHzWH7untSavZ2Zndep9nnk6ne870zvZvqupUHXHOYRiGYRieiobugGEYhlFamDAYhmEYKZgwGIZhGCmYMBiGYRgpmDAYhmEYKZgwGIZhGCk0b+gObC2dOnVyAwYMaOhuNBrWrVtH27ZtG7objQa7n4Wl3O7nzJmwejW0bw+DBgX7P/kEunSBpUuhVy/Ydtvi923ixInLnHPd4o6VvTD06NGDCRMmNHQ3Gg2JRILhw4c3dDcaDXY/C0u53c/jj4eXX4b994d//zvY37YtnHEG3H8/XHIJXH998fsmInPTHTNXkmEYRj3h84ejecTOQfPkz/KamuL2KRdMGAzDMOoJ/9CPPvxragJh2LKluH3KBRMGwzCMeiKTxdCsma6bMBiGYTQhvKUQFYaaGhBRcTBXkmEYRhPCC0L04e8cVFToq0lbDCJSKSIfisgkEZkiIjfHtBERuU9EZorIZBHZu1j9MwzDKDS5WAylKAzFHK66ETjcObdWRFoA74jIK8658aE2xwADk6/9gYeSS8MwjLIjU4zBXEmAU9YmN1skX9HJIE4C/ppsOx7oJCLbFauPhmEYhSRuVJIXiVJ2JRU1wU1EmgETgQHAg865DyJNegHzQtvzk/sWRa4zAhgB0K1bNxKJRH11ucmxdu1au58FxO5nYSm3+7ly5V5AJ1avXk0i8QngRWI4c+d+hXPbM3fuYhKJmQ3Yy9oUVRicc1uAvUSkE/AvEdnNOfd5qInEnRZznUeBRwEGDx7syikTstQpt8zSUsfuZ2Ept/vZoYMu27fv+F2/vYWwww79adUKevbcnuHDt2+YDqahQUYlOedWAQng6Mih+UDv0Pb2wMLi9MowDKOwxI1K8usipetKKuaopG5JSwERaQ38AJgWaTYK+HFydNIwYLVzbhGGYRhlSNyopHCMoVSDz8V0JW0HPJGMM1QAzznnRovIxQDOuYeBMcCxwExgPXB+EftnGIZRUOJGJZWDxVA0YXDOTQaGxOx/OLTugEuL1SfDMIz6JNuopFLNY7DMZ8MwjHoim8VQqq4kEwbDMIx6IlOMoZRdSSYMhmEY9UTcqCRzJRmGYTRhzJVkGIZhpJBtuKq5kgzDMJoY2RLczJVkGIbRxCjXBDcTBsMwjHqiXBPcTBgMwzDqiUwJbhZ8NgzDaILEWQwWfDYMw2jCxMUYLPhsGIbRhMklwc1cSYZhGE2IbBaDuZIMwzCaGNliDOZKMgzDaGLEjUqykhiGYRhNGBuVZBiGYaRgo5IMwzCMFDKNSjJXUpnw6quwcmVD98IwjMaCVVctccaNg8ceS398zRo45hg4/vji9ckwjMaNzcdQQsyeDUcfDUuWBPsOPRQuuggWLIg/Z9UqXb73Xr13zzCMJkK2BDezGIrIY4/Ba6/BHXcE+/wf4y9/CfatWQObN+u6FwbDMIxCYcHnEmLhQl2OHq3Ldev0jwAwcWLQrmNHOOUUXTdhMAyj0OSS4GaupCLxySe6nDEDVq+GadOCP4YPLvvtl17S5erVwfnffqvLhQvhhhtKU9ENwyh9siW4mSupSGzcCF98AYMH6/bs2TB1qq7vvDOsWKHr69ennhe2GGbP1uUhh8BttwXnG4Zh5EMmi8FcSYCI9BaRsSIyVUSmiMjlMW2Gi8hqEfk0+fpNtuuuXt0iZfubb6C6Go44QrdnzYJFi3R9zz0DYfBWgScsDAsWqHDMmqXbURGJY9Qo6NsXNmzI3tYwjKaBTe2ZnWrgCufczsAw4FIR2SWm3Tjn3F7J1++yXXTJksqUG+sf8HvvrctZs2D5cmjRAnr3VpHo0iU11hA+D2DZMnVDedauzf7h3nsPvv4avvoqdf/rr6cfCWUYRuMmblSSuZJCOOcWOec+Tq5/C0wFem39dWHp0mDbP+D79IGuXdUttGKFikHnznpsxQq49dbgnC1bUmMMS5fC4sXBdtS6iGPePF1+/XWwb/16OOooHTqbjbffhsmTs7czDKN8yMViaNLCEEZE+gFDgA9iDh8gIpNE5BUR2TWX6911F9xyi/6y9w/4Tp1gxx3hyy9VCDp3DoQB4IPQO69YoYKy3XZqWRRKGD77TJdh6yPMDTfAm2/q+qGHqqvLMIzGQ7kmuDUv9huKSDvgn8CvnHNrIoc/Bvo659aKyLHAi8DAmGuMAEbo1j7cfbeuvf76UpYsqQTaM23aeLbffnteemk7Bg5cS/PmjkWLFgCqNeE/xujRHzFjRj9atmxDhw7NmTx5BatXrwd2BODcc+Gppxax556refrpPjz22ARatkz9a86YsT/QmnHj5tKr19dMntyRpUsrgUF07FhFIjE+pX1NDdx223Buuw2OPPIbYFsAEokEX3zRgfXrm7Hvvpnrc3zxRQcmTerImWfOy9guH9auXUsikSjY9Zo6dj8LS7ndz+rqg4FmVFdvIZEYB8CXX7YH9mHKlM9YtGgbNm7sQSLxboP2sxbOuaK9gBbAa8D/5Nh+DtA1c5t9nOpx6mvpUudeeinYPukk5958s3Y7cO6uu3Q5bJhze+zh3IknOnfllfFtwbl33nHuxhudmzfPOeecq652rnlzPXb66c61a6fru++uy/79tV1NjXMjRzq3YIH2L+7aNTXB+pYtLiPNmmm7b7/N3C4fxo4dW7iLGXY/C0y53c/WrfV/tLIy2PfBB7pv9GjnLr/cuY4dG6ZvwASX5rlazFFJAvwFmOqcuydNm22T7RCRoaira3ku1+/RI3W7Y0d1zzRrpttRV1KYq67S5eLF0K1b4Erq0ydIjAtz773w+9/DJZcE51VX6/qrrwbBau9KWrRIH/XTpsFPfwq77ZbqqgoTDoKffHL6dvPnB+v+fQzDKC2siF52DgTOBQ4PDUc9VkQuFpGLk21OBT4XkUnAfcAZSWXLyjnnpG63aAHt28OuyShFly4qFpn4/e9ThWHbbaGyUo8NGwavvAK9esHzz+u+xYs1fuDjC127apmNMJ07Q1UV3H8//Oc/um/lSnjqqfg+TJkSrP/f/8G//lW7zYwZOsLKf6E+/TTz5zIMo2HINiqpyQefnXPvOOfEObeHC4ajjnHOPeycezjZ5gHn3K7OuT2dc8Occ1lL2nXvXsWf/6yJaHfdVfv4zjvrsnNnfdC3agU33RQc79JFl+edB2edpUHrmTN1mGnXrkEtpbPP1tFFBxwQnPvhhzBoEBx7rG6fdFLt9//JT3R5+eXw5JNq2bRrBw89pPt32AGOPDJo74PR++2ny+Uhe+nrr+Hjj/kupuJ5773UXySGYZQGmSyGUg4+l33mc6dOm7nwQmjZEg47rPbxvn11uWkTtG6tv97/+7+D4wMGpLbbYYfgWGVl4CLq2lWXp59e+z180typpwb7fv1rXR5zTLBv/HgYPhyOOy61mutrrwVJeK+/rsvHHlMLx7uSxo3TPu6zDzzySOr7P/mkWiRxjB2rSX+eKVM0O9wwjPrHpvYsAXr2rL1v33112a5dsK9Dh2C9f39d9uuny8su03yCf/4ztTqrtyxOOCH+vSsrU4Xpd7/TP/hBB8HA0Liq7t1TxcJft0cP2H57FYpmzVSwevRQYbjrLrj22vj3veYaGDIE/v731P1PPgnf+x4cfrgKEei1dtsNfvGL+GsZhlE4tmwJHvrmSmpAunevve/UU9VP/8tfBvvCAWUvCN5iaN0adt9dA7/hB7q3GCorNZv66ad1+/JkYY9mzdRNdeut8OKLuq+iQq8Xjid07qxi4WnePOiTdyn94AfQpo1+niVL4Oqr088T0bOnTi40fryKg5+MaORIeP99XZ80SZfLlun
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2019-05-05 06:42:08 +02:00
"source": [
"batch_size = 128\n",
"rates, losses = find_learning_rate(model, X_train, y_train, epochs=1,\n",
" batch_size=batch_size)\n",
"plot_lr_vs_loss(rates, losses)"
2019-05-05 06:42:08 +02:00
]
},
{
"cell_type": "markdown",
2019-05-05 06:42:08 +02:00
"metadata": {},
"source": [
"Looks like the max learning rate to use for 1cycle is around 10<sup>1</sup>."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `OneCycleScheduler` custom callback updates the learning rate at the beginning of each batch. It applies the logic described in the book: increase the learning rate linearly during about half of training, then reduce it linearly back to the initial learning rate, and lastly reduce it down to close to zero linearly for the very last part of training."
2019-05-05 06:42:08 +02:00
]
},
{
"cell_type": "code",
"execution_count": 95,
2019-05-05 06:42:08 +02:00
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"class OneCycleScheduler(tf.keras.callbacks.Callback):\n",
" def __init__(self, iterations, max_lr=1e-3, start_lr=None,\n",
" last_iterations=None, last_lr=None):\n",
2019-05-05 06:42:08 +02:00
" self.iterations = iterations\n",
" self.max_lr = max_lr\n",
" self.start_lr = start_lr or max_lr / 10\n",
2019-05-05 06:42:08 +02:00
" self.last_iterations = last_iterations or iterations // 10 + 1\n",
" self.half_iteration = (iterations - self.last_iterations) // 2\n",
" self.last_lr = last_lr or self.start_lr / 1000\n",
2019-05-05 06:42:08 +02:00
" self.iteration = 0\n",
"\n",
" def _interpolate(self, iter1, iter2, lr1, lr2):\n",
" return (lr2 - lr1) * (self.iteration - iter1) / (iter2 - iter1) + lr1\n",
"\n",
2019-05-05 06:42:08 +02:00
" def on_batch_begin(self, batch, logs):\n",
" if self.iteration < self.half_iteration:\n",
" lr = self._interpolate(0, self.half_iteration, self.start_lr,\n",
" self.max_lr)\n",
2019-05-05 06:42:08 +02:00
" elif self.iteration < 2 * self.half_iteration:\n",
" lr = self._interpolate(self.half_iteration, 2 * self.half_iteration,\n",
" self.max_lr, self.start_lr)\n",
2019-05-05 06:42:08 +02:00
" else:\n",
" lr = self._interpolate(2 * self.half_iteration, self.iterations,\n",
" self.start_lr, self.last_lr)\n",
2019-05-05 06:42:08 +02:00
" self.iteration += 1\n",
" K.set_value(self.model.optimizer.learning_rate, lr)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's build and compile a simple Fashion MNIST model, then train it using the `OneCycleScheduler` callback:"
2019-05-05 06:42:08 +02:00
]
},
{
"cell_type": "code",
"execution_count": 96,
2019-05-05 06:42:08 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
"430/430 [==============================] - 1s 2ms/step - loss: 0.9502 - accuracy: 0.6913 - val_loss: 0.6003 - val_accuracy: 0.7874\n",
"Epoch 2/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.5695 - accuracy: 0.8025 - val_loss: 0.4918 - val_accuracy: 0.8248\n",
"Epoch 3/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.4954 - accuracy: 0.8252 - val_loss: 0.4762 - val_accuracy: 0.8264\n",
"Epoch 4/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.4515 - accuracy: 0.8402 - val_loss: 0.4261 - val_accuracy: 0.8478\n",
"Epoch 5/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.4225 - accuracy: 0.8492 - val_loss: 0.4066 - val_accuracy: 0.8486\n",
"Epoch 6/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3958 - accuracy: 0.8571 - val_loss: 0.4787 - val_accuracy: 0.8224\n",
"Epoch 7/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3787 - accuracy: 0.8626 - val_loss: 0.3917 - val_accuracy: 0.8566\n",
"Epoch 8/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3630 - accuracy: 0.8683 - val_loss: 0.4719 - val_accuracy: 0.8296\n",
"Epoch 9/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3512 - accuracy: 0.8724 - val_loss: 0.3673 - val_accuracy: 0.8652\n",
"Epoch 10/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3360 - accuracy: 0.8766 - val_loss: 0.4957 - val_accuracy: 0.8466\n",
"Epoch 11/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3287 - accuracy: 0.8786 - val_loss: 0.4187 - val_accuracy: 0.8370\n",
"Epoch 12/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.3173 - accuracy: 0.8815 - val_loss: 0.3425 - val_accuracy: 0.8728\n",
"Epoch 13/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2961 - accuracy: 0.8910 - val_loss: 0.3217 - val_accuracy: 0.8792\n",
"Epoch 14/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2818 - accuracy: 0.8958 - val_loss: 0.3734 - val_accuracy: 0.8692\n",
"Epoch 15/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2675 - accuracy: 0.9003 - val_loss: 0.3261 - val_accuracy: 0.8844\n",
"Epoch 16/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2558 - accuracy: 0.9055 - val_loss: 0.3205 - val_accuracy: 0.8820\n",
"Epoch 17/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2464 - accuracy: 0.9091 - val_loss: 0.3089 - val_accuracy: 0.8894\n",
"Epoch 18/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2368 - accuracy: 0.9115 - val_loss: 0.3130 - val_accuracy: 0.8870\n",
"Epoch 19/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2292 - accuracy: 0.9145 - val_loss: 0.3078 - val_accuracy: 0.8854\n",
"Epoch 20/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2205 - accuracy: 0.9186 - val_loss: 0.3092 - val_accuracy: 0.8886\n",
"Epoch 21/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2138 - accuracy: 0.9209 - val_loss: 0.3022 - val_accuracy: 0.8914\n",
"Epoch 22/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2073 - accuracy: 0.9232 - val_loss: 0.3054 - val_accuracy: 0.8914\n",
"Epoch 23/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.2020 - accuracy: 0.9261 - val_loss: 0.3026 - val_accuracy: 0.8896\n",
"Epoch 24/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.1989 - accuracy: 0.9273 - val_loss: 0.3020 - val_accuracy: 0.8922\n",
"Epoch 25/25\n",
"430/430 [==============================] - 1s 1ms/step - loss: 0.1967 - accuracy: 0.9276 - val_loss: 0.3016 - val_accuracy: 0.8920\n"
]
}
],
2019-05-05 06:42:08 +02:00
"source": [
"model = build_model()\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=tf.keras.optimizers.SGD(),\n",
" metrics=[\"accuracy\"])\n",
2019-05-05 06:42:08 +02:00
"n_epochs = 25\n",
"onecycle = OneCycleScheduler(math.ceil(len(X_train) / batch_size) * n_epochs,\n",
" max_lr=0.1)\n",
"history = model.fit(X_train, y_train, epochs=n_epochs, batch_size=batch_size,\n",
" validation_data=(X_valid, y_valid),\n",
2019-05-05 06:42:08 +02:00
" callbacks=[onecycle])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Avoiding Overfitting Through Regularization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## $\\ell_1$ and $\\ell_2$ regularization"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [],
"source": [
"layer = tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\",\n",
" kernel_regularizer=tf.keras.regularizers.l2(0.01))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or use `l1(0.1)` for <sub>1</sub> regularization with a factor of 0.1, or `l1_l2(0.1, 0.01)` for both <sub>1</sub> and <sub>2</sub> regularization, with factors 0.1 and 0.01 respectively."
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42) # extra code for reproducibility"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
2021-10-17 04:04:08 +02:00
"RegularizedDense = partial(tf.keras.layers.Dense,\n",
" activation=\"relu\",\n",
" kernel_initializer=\"he_normal\",\n",
2021-10-17 04:04:08 +02:00
" kernel_regularizer=tf.keras.regularizers.l2(0.01))\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" RegularizedDense(100),\n",
" RegularizedDense(100),\n",
" RegularizedDense(10, activation=\"softmax\")\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"1719/1719 [==============================] - 2s 878us/step - loss: 3.1224 - accuracy: 0.7748 - val_loss: 1.8602 - val_accuracy: 0.8264\n",
"Epoch 2/2\n",
"1719/1719 [==============================] - 1s 814us/step - loss: 1.4263 - accuracy: 0.8159 - val_loss: 1.1269 - val_accuracy: 0.8182\n"
]
}
],
"source": [
"# extra code compile and train the model\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.02)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, y_train, epochs=2,\n",
" validation_data=(X_valid, y_valid))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dropout"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [],
2019-02-28 12:48:06 +01:00
"source": [
"tf.random.set_seed(42) # extra code for reproducibility"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 102,
2019-02-28 12:48:06 +01:00
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" tf.keras.layers.Dropout(rate=0.2),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dropout(rate=0.2),\n",
" tf.keras.layers.Dense(100, activation=\"relu\",\n",
" kernel_initializer=\"he_normal\"),\n",
" tf.keras.layers.Dropout(rate=0.2),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 103,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.6703 - accuracy: 0.7536 - val_loss: 0.4498 - val_accuracy: 0.8342\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 996us/step - loss: 0.5103 - accuracy: 0.8136 - val_loss: 0.4401 - val_accuracy: 0.8296\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 998us/step - loss: 0.4712 - accuracy: 0.8263 - val_loss: 0.3806 - val_accuracy: 0.8554\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 977us/step - loss: 0.4488 - accuracy: 0.8337 - val_loss: 0.3711 - val_accuracy: 0.8608\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.4342 - accuracy: 0.8409 - val_loss: 0.3672 - val_accuracy: 0.8606\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 983us/step - loss: 0.4245 - accuracy: 0.8427 - val_loss: 0.3706 - val_accuracy: 0.8600\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 995us/step - loss: 0.4131 - accuracy: 0.8467 - val_loss: 0.3582 - val_accuracy: 0.8650\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 959us/step - loss: 0.4074 - accuracy: 0.8484 - val_loss: 0.3478 - val_accuracy: 0.8708\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 997us/step - loss: 0.4024 - accuracy: 0.8533 - val_loss: 0.3556 - val_accuracy: 0.8690\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 998us/step - loss: 0.3903 - accuracy: 0.8552 - val_loss: 0.3453 - val_accuracy: 0.8732\n"
]
}
],
2019-02-28 12:48:06 +01:00
"source": [
"# extra code compile and train the model\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, y_train, epochs=10,\n",
" validation_data=(X_valid, y_valid))"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "markdown",
2019-02-28 12:48:06 +01:00
"metadata": {},
"source": [
"The training accuracy looks like it's lower than the validation accuracy, but that's just because dropout is only active during training. If we evaluate the model on the training set after training (i.e., with dropout turned off), we get the \"real\" training accuracy, which is very slightly higher than the validation accuracy and the test accuracy:"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 104,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1719/1719 [==============================] - 1s 578us/step - loss: 0.3082 - accuracy: 0.8849\n"
]
},
{
"data": {
"text/plain": [
"[0.30816400051116943, 0.8849090933799744]"
]
},
"execution_count": 104,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
"model.evaluate(X_train, y_train)"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 105,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"313/313 [==============================] - 0s 588us/step - loss: 0.3629 - accuracy: 0.8700\n"
]
},
{
"data": {
"text/plain": [
"[0.3628920316696167, 0.8700000047683716]"
]
},
"execution_count": 105,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
"model.evaluate(X_test, y_test)"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "markdown",
2019-02-28 12:48:06 +01:00
"metadata": {},
"source": [
"**Note**: make sure to use `AlphaDropout` instead of `Dropout` if you want to build a self-normalizing neural net using SELU."
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "markdown",
2019-02-28 12:48:06 +01:00
"metadata": {},
"source": [
"## MC Dropout"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 106,
2019-02-28 12:48:06 +01:00
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42) # extra code for reproducibility"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 107,
2019-02-28 12:48:06 +01:00
"metadata": {},
"outputs": [],
"source": [
"y_probas = np.stack([model(X_test, training=True)\n",
" for sample in range(100)])\n",
"y_proba = y_probas.mean(axis=0)"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 108,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0. , 0. , 0. , 0. , 0. , 0.024, 0. , 0.132, 0. ,\n",
" 0.844]], dtype=float32)"
]
},
"execution_count": 108,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
"model.predict(X_test[:1]).round(3)"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 109,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0. , 0. , 0. , 0. , 0. , 0.067, 0. , 0.209, 0.001,\n",
" 0.723], dtype=float32)"
]
},
"execution_count": 109,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
"y_proba[0].round(3)"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 110,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([0. , 0. , 0. , 0.001, 0. , 0.096, 0. , 0.162, 0.001,\n",
" 0.183], dtype=float32)"
]
},
"execution_count": 110,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
"y_std = y_probas.std(axis=0)\n",
"y_std[0].round(3)"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 111,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"0.8717"
]
},
"execution_count": 111,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
"y_pred = y_proba.argmax(axis=1)\n",
"accuracy = (y_pred == y_test).sum() / len(y_test)\n",
2019-02-28 12:48:06 +01:00
"accuracy"
]
},
{
"cell_type": "code",
"execution_count": 112,
2019-02-28 12:48:06 +01:00
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"class MCDropout(tf.keras.layers.Dropout):\n",
" def call(self, inputs, training=None):\n",
2019-02-28 12:48:06 +01:00
" return super().call(inputs, training=True)"
]
},
{
"cell_type": "code",
"execution_count": 113,
2019-02-28 12:48:06 +01:00
"metadata": {},
"outputs": [],
"source": [
"# extra code shows how to convert Dropout to MCDropout in a Sequential model\n",
"Dropout = tf.keras.layers.Dropout\n",
2021-10-17 04:04:08 +02:00
"mc_model = tf.keras.Sequential([\n",
" MCDropout(layer.rate) if isinstance(layer, Dropout) else layer\n",
2019-02-28 12:48:06 +01:00
" for layer in model.layers\n",
"])\n",
"mc_model.set_weights(model.get_weights())"
2019-02-28 12:48:06 +01:00
]
},
{
"cell_type": "code",
"execution_count": 114,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_25\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"flatten_22 (Flatten) (None, 784) 0 \n",
"_________________________________________________________________\n",
"mc_dropout (MCDropout) (None, 784) 0 \n",
"_________________________________________________________________\n",
"dense_89 (Dense) (None, 100) 78500 \n",
"_________________________________________________________________\n",
"mc_dropout_1 (MCDropout) (None, 100) 0 \n",
"_________________________________________________________________\n",
"dense_90 (Dense) (None, 100) 10100 \n",
"_________________________________________________________________\n",
"mc_dropout_2 (MCDropout) (None, 100) 0 \n",
"_________________________________________________________________\n",
"dense_91 (Dense) (None, 10) 1010 \n",
"=================================================================\n",
"Total params: 89,610\n",
"Trainable params: 89,610\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n"
]
}
],
2019-02-28 12:48:06 +01:00
"source": [
"mc_model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can use the model with MC Dropout:"
]
},
{
"cell_type": "code",
"execution_count": 115,
2019-02-28 12:48:06 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[0. , 0. , 0. , 0. , 0. , 0.07, 0. , 0.17, 0. , 0.76]],\n",
" dtype=float32)"
]
},
"execution_count": 115,
2022-02-19 10:24:54 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-02-28 12:48:06 +01:00
"source": [
"# extra code shows that the model works without retraining\n",
"tf.random.set_seed(42)\n",
"np.mean([mc_model.predict(X_test[:1])\n",
" for sample in range(100)], axis=0).round(2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Max norm"
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
"outputs": [],
"source": [
"dense = tf.keras.layers.Dense(\n",
" 100, activation=\"relu\", kernel_initializer=\"he_normal\",\n",
" kernel_constraint=tf.keras.constraints.max_norm(1.))"
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 2s 1ms/step - loss: 0.5500 - accuracy: 0.8015 - val_loss: 0.4510 - val_accuracy: 0.8242\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 2s 960us/step - loss: 0.4089 - accuracy: 0.8499 - val_loss: 0.3956 - val_accuracy: 0.8504\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 2s 974us/step - loss: 0.3777 - accuracy: 0.8604 - val_loss: 0.3693 - val_accuracy: 0.8680\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 2s 943us/step - loss: 0.3581 - accuracy: 0.8690 - val_loss: 0.3517 - val_accuracy: 0.8716\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 2s 949us/step - loss: 0.3416 - accuracy: 0.8729 - val_loss: 0.3433 - val_accuracy: 0.8682\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 2s 951us/step - loss: 0.3368 - accuracy: 0.8756 - val_loss: 0.4045 - val_accuracy: 0.8582\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 2s 935us/step - loss: 0.3293 - accuracy: 0.8767 - val_loss: 0.4168 - val_accuracy: 0.8476\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 2s 951us/step - loss: 0.3258 - accuracy: 0.8779 - val_loss: 0.3570 - val_accuracy: 0.8674\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 2s 970us/step - loss: 0.3269 - accuracy: 0.8787 - val_loss: 0.3702 - val_accuracy: 0.8578\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 2s 948us/step - loss: 0.3169 - accuracy: 0.8809 - val_loss: 0.3907 - val_accuracy: 0.8578\n"
]
}
],
"source": [
"# extra code shows how to apply max norm to every hidden layer in a model\n",
"\n",
2021-10-17 04:04:08 +02:00
"MaxNormDense = partial(tf.keras.layers.Dense,\n",
" activation=\"relu\", kernel_initializer=\"he_normal\",\n",
2021-10-17 04:04:08 +02:00
" kernel_constraint=tf.keras.constraints.max_norm(1.))\n",
"\n",
"tf.random.set_seed(42)\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=[28, 28]),\n",
" MaxNormDense(100),\n",
" MaxNormDense(100),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, y_train, epochs=10,\n",
" validation_data=(X_valid, y_valid))"
]
},
{
"cell_type": "markdown",
2020-04-06 09:13:12 +02:00
"metadata": {},
"source": [
2019-02-28 12:48:06 +01:00
"# Exercises"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. to 7."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. Glorot initialization and He initialization were designed to make the output standard deviation as close as possible to the input standard deviation, at least at the beginning of training. This reduces the vanishing/exploding gradients problem.\n",
"2. No, all weights should be sampled independently; they should not all have the same initial value. One important goal of sampling weights randomly is to break symmetry: if all the weights have the same initial value, even if that value is not zero, then symmetry is not broken (i.e., all neurons in a given layer are equivalent), and backpropagation will be unable to break it. Concretely, this means that all the neurons in any given layer will always have the same weights. It's like having just one neuron per layer, and much slower. It is virtually impossible for such a configuration to converge to a good solution.\n",
"3. It is perfectly fine to initialize the bias terms to zero. Some people like to initialize them just like weights, and that's OK too; it does not make much difference.\n",
"4. ReLU is usually a good default for the hidden layers, as it is fast and yields good results. Its ability to output precisely zero can also be useful in some cases (e.g., see Chapter 17). Moreover, it can sometimes benefit from optimized implementations as well as from hardware acceleration. The leaky ReLU variants of ReLU can improve the model's quality without hindering its speed too much compared to ReLU. For large neural nets and more complex problems, GLU, Swish and Mish can give you a slightly higher quality model, but they have a computational cost. The hyperbolic tangent (tanh) can be useful in the output layer if you need to output a number in a fixed range (by default between 1 and 1), but nowadays it is not used much in hidden layers, except in recurrent nets. The sigmoid activation function is also useful in the output layer when you need to estimate a probability (e.g., for binary classification), but it is rarely used in hidden layers (there are exceptions—for example, for the coding layer of variational autoencoders; see Chapter 17). The softplus activation function is useful in the output layer when you need to ensure that the output will always be positive. The softmax activation function is useful in the output layer to estimate probabilities for mutually exclusive classes, but it is rarely (if ever) used in hidden layers.\n",
"5. If you set the `momentum` hyperparameter too close to 1 (e.g., 0.99999) when using an `SGD` optimizer, then the algorithm will likely pick up a lot of speed, hopefully moving roughly toward the global minimum, but its momentum will carry it right past the minimum. Then it will slow down and come back, accelerate again, overshoot again, and so on. It may oscillate this way many times before converging, so overall it will take much longer to converge than with a smaller `momentum` value.\n",
"6. One way to produce a sparse model (i.e., with most weights equal to zero) is to train the model normally, then zero out tiny weights. For more sparsity, you can apply <sub>1</sub> regularization during training, which pushes the optimizer toward sparsity. A third option is to use the TensorFlow Model Optimization Toolkit.\n",
"7. Yes, dropout does slow down training, in general roughly by a factor of two. However, it has no impact on inference speed since it is only turned on during training. MC Dropout is exactly like dropout during training, but it is still active during inference, so each inference is slowed down slightly. More importantly, when using MC Dropout you generally want to run inference 10 times or more to get better predictions. This means that making predictions is slowed down by a factor of 10 or more."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8. Deep Learning on CIFAR10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### a.\n",
"*Exercise: Build a DNN with 20 hidden layers of 100 neurons each (that's too many, but it's the point of this exercise). Use He initialization and the Swish activation function.*"
]
},
{
"cell_type": "code",
"execution_count": 118,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100,\n",
" activation=\"swish\",\n",
" kernel_initializer=\"he_normal\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### b.\n",
2021-10-17 04:04:08 +02:00
"*Exercise: Using Nadam optimization and early stopping, train the network on the CIFAR10 dataset. You can load it with `tf.keras.datasets.cifar10.load_data()`. The dataset is composed of 60,000 32 × 32pixel color images (50,000 for training, 10,000 for testing) with 10 classes, so you'll need a softmax output layer with 10 neurons. Remember to search for the right learning rate each time you change the model's architecture or hyperparameters.*"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's add the output layer to the model:"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's use a Nadam optimizer with a learning rate of 5e-5. I tried learning rates 1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3 and 1e-2, and I compared their learning curves for 10 epochs each (using the TensorBoard callback, below). The learning rates 3e-5 and 1e-4 were pretty good, so I tried 5e-5, which turned out to be slightly better."
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"optimizer = tf.keras.optimizers.Nadam(learning_rate=5e-5)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"Let's load the CIFAR10 dataset. We also want to use early stopping, so we need a validation set. Let's use the first 5,000 images of the original training set as the validation set:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
"cifar10 = tf.keras.datasets.cifar10.load_data()\n",
"(X_train_full, y_train_full), (X_test, y_test) = cifar10\n",
"\n",
"X_train = X_train_full[5000:]\n",
"y_train = y_train_full[5000:]\n",
"X_valid = X_train_full[:5000]\n",
"y_valid = y_train_full[:5000]"
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"Now we can create the callbacks we need and train the model:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 122,
"metadata": {},
"outputs": [],
2016-09-27 23:31:21 +02:00
"source": [
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=20,\n",
" restore_best_weights=True)\n",
"model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\"my_cifar10_model\",\n",
" save_best_only=True)\n",
"run_index = 1 # increment every time you train the model\n",
"run_logdir = Path() / \"my_cifar10_logs\" / f\"run_{run_index:03d}\"\n",
2021-10-17 04:04:08 +02:00
"tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)\n",
"callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe id=\"tensorboard-frame-d05c16b556c70d97\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
" </iframe>\n",
" <script>\n",
" (function() {\n",
" const frame = document.getElementById(\"tensorboard-frame-d05c16b556c70d97\");\n",
" const url = new URL(\"/\", window.location);\n",
" const port = 6006;\n",
" if (port) {\n",
" url.port = port;\n",
" }\n",
" frame.src = url;\n",
" })();\n",
" </script>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
2022-02-19 10:24:54 +01:00
}
],
"source": [
"%load_ext tensorboard\n",
"%tensorboard --logdir=./my_cifar10_logs"
]
2016-09-27 23:31:21 +02:00
},
{
"cell_type": "code",
"execution_count": 124,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"1404/1407 [============================>.] - ETA: 0s - loss: 4.0493 - accuracy: 0.1598INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 17s 10ms/step - loss: 4.0462 - accuracy: 0.1597 - val_loss: 2.1441 - val_accuracy: 0.2036\n",
"Epoch 2/100\n",
"1407/1407 [==============================] - ETA: 0s - loss: 2.0667 - accuracy: 0.2320INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 2.0667 - accuracy: 0.2320 - val_loss: 2.0134 - val_accuracy: 0.2472\n",
"Epoch 3/100\n",
"1407/1407 [==============================] - ETA: 0s - loss: 1.9472 - accuracy: 0.2819INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.9472 - accuracy: 0.2819 - val_loss: 1.9427 - val_accuracy: 0.2796\n",
"Epoch 4/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.8636 - accuracy: 0.3182INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.8637 - accuracy: 0.3182 - val_loss: 1.8934 - val_accuracy: 0.3222\n",
"Epoch 5/100\n",
"1402/1407 [============================>.] - ETA: 0s - loss: 1.7975 - accuracy: 0.3464INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.7974 - accuracy: 0.3465 - val_loss: 1.8389 - val_accuracy: 0.3284\n",
"Epoch 6/100\n",
"1407/1407 [==============================] - 9s 7ms/step - loss: 1.7446 - accuracy: 0.3664 - val_loss: 2.0006 - val_accuracy: 0.3030\n",
"Epoch 7/100\n",
"1407/1407 [==============================] - ETA: 0s - loss: 1.6974 - accuracy: 0.3852INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.6974 - accuracy: 0.3852 - val_loss: 1.7075 - val_accuracy: 0.3738\n",
"Epoch 8/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.6605 - accuracy: 0.3984INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.6604 - accuracy: 0.3984 - val_loss: 1.6788 - val_accuracy: 0.3836\n",
"Epoch 9/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.6322 - accuracy: 0.4114INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.6321 - accuracy: 0.4114 - val_loss: 1.6477 - val_accuracy: 0.4014\n",
"Epoch 10/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.6065 - accuracy: 0.4205 - val_loss: 1.6623 - val_accuracy: 0.3980\n",
"Epoch 11/100\n",
"1401/1407 [============================>.] - ETA: 0s - loss: 1.5843 - accuracy: 0.4287INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.5845 - accuracy: 0.4285 - val_loss: 1.6032 - val_accuracy: 0.4198\n",
"Epoch 12/100\n",
"1407/1407 [==============================] - 9s 6ms/step - loss: 1.5634 - accuracy: 0.4367 - val_loss: 1.6063 - val_accuracy: 0.4258\n",
"Epoch 13/100\n",
"1401/1407 [============================>.] - ETA: 0s - loss: 1.5443 - accuracy: 0.4420INFO:tensorflow:Assets written to: my_cifar10_model/assets\n",
"<<47 more lines>>\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.3247 - accuracy: 0.5256 - val_loss: 1.5130 - val_accuracy: 0.4616\n",
"Epoch 33/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.3164 - accuracy: 0.5286 - val_loss: 1.5284 - val_accuracy: 0.4686\n",
"Epoch 34/100\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 1.3091 - accuracy: 0.5303 - val_loss: 1.5208 - val_accuracy: 0.4682\n",
"Epoch 35/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.3026 - accuracy: 0.5319 - val_loss: 1.5479 - val_accuracy: 0.4604\n",
"Epoch 36/100\n",
"1407/1407 [==============================] - 11s 8ms/step - loss: 1.2930 - accuracy: 0.5378 - val_loss: 1.5443 - val_accuracy: 0.4580\n",
"Epoch 37/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.2833 - accuracy: 0.5406 - val_loss: 1.5165 - val_accuracy: 0.4710\n",
"Epoch 38/100\n",
"1407/1407 [==============================] - 11s 8ms/step - loss: 1.2763 - accuracy: 0.5433 - val_loss: 1.5345 - val_accuracy: 0.4672\n",
"Epoch 39/100\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 1.2687 - accuracy: 0.5437 - val_loss: 1.5162 - val_accuracy: 0.4712\n",
"Epoch 40/100\n",
"1407/1407 [==============================] - 11s 7ms/step - loss: 1.2623 - accuracy: 0.5490 - val_loss: 1.5717 - val_accuracy: 0.4566\n",
"Epoch 41/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.2580 - accuracy: 0.5467 - val_loss: 1.5296 - val_accuracy: 0.4738\n",
"Epoch 42/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.2469 - accuracy: 0.5532 - val_loss: 1.5179 - val_accuracy: 0.4690\n",
"Epoch 43/100\n",
"1407/1407 [==============================] - 11s 8ms/step - loss: 1.2404 - accuracy: 0.5542 - val_loss: 1.5542 - val_accuracy: 0.4566\n",
"Epoch 44/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.2292 - accuracy: 0.5605 - val_loss: 1.5536 - val_accuracy: 0.4608\n",
"Epoch 45/100\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 1.2276 - accuracy: 0.5606 - val_loss: 1.5522 - val_accuracy: 0.4624\n",
"Epoch 46/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.2200 - accuracy: 0.5637 - val_loss: 1.5339 - val_accuracy: 0.4794\n",
"Epoch 47/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.2080 - accuracy: 0.5677 - val_loss: 1.5451 - val_accuracy: 0.4688\n",
"Epoch 48/100\n",
"1407/1407 [==============================] - 15s 10ms/step - loss: 1.2050 - accuracy: 0.5675 - val_loss: 1.5209 - val_accuracy: 0.4770\n",
"Epoch 49/100\n",
"1407/1407 [==============================] - 10s 7ms/step - loss: 1.1947 - accuracy: 0.5718 - val_loss: 1.5435 - val_accuracy: 0.4736\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fb9f02fc070>"
]
},
"execution_count": 124,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.fit(X_train, y_train, epochs=100,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=callbacks)"
]
2016-09-27 23:31:21 +02:00
},
{
"cell_type": "code",
"execution_count": 125,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"157/157 [==============================] - 0s 2ms/step - loss: 1.5062 - accuracy: 0.4676\n"
]
},
{
"data": {
"text/plain": [
"[1.5061508417129517, 0.4675999879837036]"
]
},
"execution_count": 125,
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
"model.evaluate(X_valid, y_valid)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The model with the lowest validation loss gets about 46.8% accuracy on the validation set. It took 29 epochs to reach the lowest validation loss, with roughly 10 seconds per epoch on my laptop (without a GPU). Let's see if we can improve the model using Batch Normalization."
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"### c.\n",
"*Exercise: Now try adding Batch Normalization and compare the learning curves: Is it converging faster than before? Does it produce a better model? How does it affect training speed?*"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The code below is very similar to the code above, with a few changes:\n",
"\n",
"* I added a BN layer after every Dense layer (before the activation function), except for the output layer.\n",
"* I changed the learning rate to 5e-4. I experimented with 1e-5, 3e-5, 5e-5, 1e-4, 3e-4, 5e-4, 1e-3 and 3e-3, and I chose the one with the best validation performance after 20 epochs.\n",
"* I renamed the run directories to run_bn_* and the model file name to `my_cifar10_bn_model`."
]
2016-09-27 23:31:21 +02:00
},
{
"cell_type": "code",
"execution_count": 126,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 2.0377 - accuracy: 0.2523INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 32s 18ms/step - loss: 2.0374 - accuracy: 0.2525 - val_loss: 1.8766 - val_accuracy: 0.3154\n",
"Epoch 2/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.7874 - accuracy: 0.3542 - val_loss: 1.8784 - val_accuracy: 0.3268\n",
"Epoch 3/100\n",
"1407/1407 [==============================] - 20s 15ms/step - loss: 1.6806 - accuracy: 0.3969 - val_loss: 1.9764 - val_accuracy: 0.3252\n",
"Epoch 4/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 1.6111 - accuracy: 0.4229INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 24s 17ms/step - loss: 1.6112 - accuracy: 0.4228 - val_loss: 1.7087 - val_accuracy: 0.3750\n",
"Epoch 5/100\n",
"1402/1407 [============================>.] - ETA: 0s - loss: 1.5520 - accuracy: 0.4478INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 21s 15ms/step - loss: 1.5521 - accuracy: 0.4476 - val_loss: 1.6272 - val_accuracy: 0.4176\n",
"Epoch 6/100\n",
"1406/1407 [============================>.] - ETA: 0s - loss: 1.5030 - accuracy: 0.4659INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 23s 16ms/step - loss: 1.5030 - accuracy: 0.4660 - val_loss: 1.5401 - val_accuracy: 0.4452\n",
"Epoch 7/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.4559 - accuracy: 0.4812 - val_loss: 1.6990 - val_accuracy: 0.3952\n",
"Epoch 8/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 1.4169 - accuracy: 0.4987INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 21s 15ms/step - loss: 1.4168 - accuracy: 0.4987 - val_loss: 1.5078 - val_accuracy: 0.4652\n",
"Epoch 9/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.3863 - accuracy: 0.5123 - val_loss: 1.5513 - val_accuracy: 0.4470\n",
"Epoch 10/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.3514 - accuracy: 0.5216 - val_loss: 1.5208 - val_accuracy: 0.4562\n",
"Epoch 11/100\n",
"1407/1407 [==============================] - 16s 12ms/step - loss: 1.3220 - accuracy: 0.5314 - val_loss: 1.7301 - val_accuracy: 0.4206\n",
"Epoch 12/100\n",
"1404/1407 [============================>.] - ETA: 0s - loss: 1.2933 - accuracy: 0.5410INFO:tensorflow:Assets written to: my_cifar10_bn_model/assets\n",
"1407/1407 [==============================] - 25s 18ms/step - loss: 1.2931 - accuracy: 0.5410 - val_loss: 1.4909 - val_accuracy: 0.4734\n",
"Epoch 13/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.2702 - accuracy: 0.5490 - val_loss: 1.5256 - val_accuracy: 0.4636\n",
"Epoch 14/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.2424 - accuracy: 0.5591 - val_loss: 1.5569 - val_accuracy: 0.4624\n",
"Epoch 15/100\n",
"<<12 more lines>>\n",
"Epoch 21/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.1174 - accuracy: 0.6066 - val_loss: 1.5241 - val_accuracy: 0.4828\n",
"Epoch 22/100\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.0978 - accuracy: 0.6128 - val_loss: 1.5313 - val_accuracy: 0.4772\n",
"Epoch 23/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.0844 - accuracy: 0.6198 - val_loss: 1.4993 - val_accuracy: 0.4924\n",
"Epoch 24/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.0677 - accuracy: 0.6244 - val_loss: 1.4622 - val_accuracy: 0.5078\n",
"Epoch 25/100\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.0571 - accuracy: 0.6297 - val_loss: 1.4917 - val_accuracy: 0.4990\n",
"Epoch 26/100\n",
"1407/1407 [==============================] - 19s 14ms/step - loss: 1.0395 - accuracy: 0.6327 - val_loss: 1.4888 - val_accuracy: 0.4896\n",
"Epoch 27/100\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.0298 - accuracy: 0.6370 - val_loss: 1.5358 - val_accuracy: 0.5024\n",
"Epoch 28/100\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.0150 - accuracy: 0.6444 - val_loss: 1.5219 - val_accuracy: 0.5030\n",
"Epoch 29/100\n",
"1407/1407 [==============================] - 16s 12ms/step - loss: 1.0100 - accuracy: 0.6456 - val_loss: 1.4933 - val_accuracy: 0.5098\n",
"Epoch 30/100\n",
"1407/1407 [==============================] - 20s 14ms/step - loss: 0.9956 - accuracy: 0.6492 - val_loss: 1.4756 - val_accuracy: 0.5012\n",
"Epoch 31/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 0.9787 - accuracy: 0.6576 - val_loss: 1.5181 - val_accuracy: 0.4936\n",
"Epoch 32/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 0.9710 - accuracy: 0.6565 - val_loss: 1.7510 - val_accuracy: 0.4568\n",
"Epoch 33/100\n",
"1407/1407 [==============================] - 20s 14ms/step - loss: 0.9613 - accuracy: 0.6628 - val_loss: 1.5576 - val_accuracy: 0.4910\n",
"Epoch 34/100\n",
"1407/1407 [==============================] - 19s 14ms/step - loss: 0.9530 - accuracy: 0.6651 - val_loss: 1.5087 - val_accuracy: 0.5046\n",
"Epoch 35/100\n",
"1407/1407 [==============================] - 19s 13ms/step - loss: 0.9388 - accuracy: 0.6701 - val_loss: 1.5534 - val_accuracy: 0.4950\n",
"Epoch 36/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 0.9331 - accuracy: 0.6743 - val_loss: 1.5033 - val_accuracy: 0.5046\n",
"Epoch 37/100\n",
"1407/1407 [==============================] - 19s 14ms/step - loss: 0.9144 - accuracy: 0.6808 - val_loss: 1.5679 - val_accuracy: 0.5028\n",
"157/157 [==============================] - 0s 2ms/step - loss: 1.4236 - accuracy: 0.5074\n"
]
},
{
"data": {
"text/plain": [
"[1.4236289262771606, 0.5073999762535095]"
]
},
"execution_count": 126,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100, kernel_initializer=\"he_normal\"))\n",
" model.add(tf.keras.layers.BatchNormalization())\n",
" model.add(tf.keras.layers.Activation(\"swish\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"optimizer = tf.keras.optimizers.Nadam(learning_rate=5e-4)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"\n",
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(patience=20,\n",
" restore_best_weights=True)\n",
"model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\"my_cifar10_bn_model\",\n",
" save_best_only=True)\n",
"run_index = 1 # increment every time you train the model\n",
"run_logdir = Path() / \"my_cifar10_logs\" / f\"run_bn_{run_index:03d}\"\n",
2021-10-17 04:04:08 +02:00
"tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)\n",
"callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]\n",
"\n",
"model.fit(X_train, y_train, epochs=100,\n",
" validation_data=(X_valid, y_valid),\n",
" callbacks=callbacks)\n",
"\n",
"model.evaluate(X_valid, y_valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"* *Is the model converging faster than before?* Much faster! The previous model took 29 epochs to reach the lowest validation loss, while the new model achieved that same loss in just 12 epochs and continued to make progress until the 17th epoch. The BN layers stabilized training and allowed us to use a much larger learning rate, so convergence was faster.\n",
"* *Does BN produce a better model?* Yes! The final model is also much better, with 50.7% validation accuracy instead of 46.7%. It's still not a very good model, but at least it's much better than before (a Convolutional Neural Network would do much better, but that's a different topic, see chapter 14).\n",
"* *How does BN affect training speed?* Although the model converged much faster, each epoch took about 15s instead of 10s, because of the extra computations required by the BN layers. But overall the training time (wall time) to reach the best model was shortened by about 10%."
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
"### d.\n",
"*Exercise: Try replacing Batch Normalization with SELU, and make the necessary adjustements to ensure the network self-normalizes (i.e., standardize the input features, use LeCun normal initialization, make sure the DNN contains only a sequence of dense layers, etc.).*"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
"execution_count": 127,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 1.9386 - accuracy: 0.3045INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 20s 13ms/step - loss: 1.9385 - accuracy: 0.3046 - val_loss: 1.8175 - val_accuracy: 0.3510\n",
"Epoch 2/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.7241 - accuracy: 0.3869INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.7241 - accuracy: 0.3869 - val_loss: 1.7677 - val_accuracy: 0.3614\n",
"Epoch 3/100\n",
"1407/1407 [==============================] - ETA: 0s - loss: 1.6272 - accuracy: 0.4263INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.6272 - accuracy: 0.4263 - val_loss: 1.6878 - val_accuracy: 0.4054\n",
"Epoch 4/100\n",
"1406/1407 [============================>.] - ETA: 0s - loss: 1.5644 - accuracy: 0.4492INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 18s 13ms/step - loss: 1.5643 - accuracy: 0.4492 - val_loss: 1.6589 - val_accuracy: 0.4304\n",
"Epoch 5/100\n",
"1404/1407 [============================>.] - ETA: 0s - loss: 1.5080 - accuracy: 0.4712INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.5080 - accuracy: 0.4712 - val_loss: 1.5651 - val_accuracy: 0.4538\n",
"Epoch 6/100\n",
"1404/1407 [============================>.] - ETA: 0s - loss: 1.4611 - accuracy: 0.4873INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.4613 - accuracy: 0.4872 - val_loss: 1.5305 - val_accuracy: 0.4678\n",
"Epoch 7/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.4174 - accuracy: 0.5077 - val_loss: 1.5346 - val_accuracy: 0.4558\n",
"Epoch 8/100\n",
"1406/1407 [============================>.] - ETA: 0s - loss: 1.3781 - accuracy: 0.5175INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.3781 - accuracy: 0.5175 - val_loss: 1.4773 - val_accuracy: 0.4882\n",
"Epoch 9/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.3413 - accuracy: 0.5345 - val_loss: 1.5021 - val_accuracy: 0.4764\n",
"Epoch 10/100\n",
"1407/1407 [==============================] - 15s 10ms/step - loss: 1.3182 - accuracy: 0.5422 - val_loss: 1.5709 - val_accuracy: 0.4762\n",
"Epoch 11/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.2832 - accuracy: 0.5571 - val_loss: 1.5345 - val_accuracy: 0.4868\n",
"Epoch 12/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.2557 - accuracy: 0.5667 - val_loss: 1.5024 - val_accuracy: 0.4900\n",
"Epoch 13/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.2373 - accuracy: 0.5710 - val_loss: 1.5114 - val_accuracy: 0.5028\n",
"Epoch 14/100\n",
"1404/1407 [============================>.] - ETA: 0s - loss: 1.2071 - accuracy: 0.5846INFO:tensorflow:Assets written to: my_cifar10_selu_model/assets\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.2073 - accuracy: 0.5847 - val_loss: 1.4608 - val_accuracy: 0.5026\n",
"Epoch 15/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.1843 - accuracy: 0.5940 - val_loss: 1.4962 - val_accuracy: 0.5038\n",
"Epoch 16/100\n",
"1407/1407 [==============================] - 16s 12ms/step - loss: 1.1617 - accuracy: 0.6026 - val_loss: 1.5255 - val_accuracy: 0.5062\n",
"Epoch 17/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.1452 - accuracy: 0.6084 - val_loss: 1.5057 - val_accuracy: 0.5036\n",
"Epoch 18/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 1.1297 - accuracy: 0.6145 - val_loss: 1.5097 - val_accuracy: 0.5010\n",
"Epoch 19/100\n",
"1407/1407 [==============================] - 16s 12ms/step - loss: 1.1004 - accuracy: 0.6245 - val_loss: 1.5218 - val_accuracy: 0.5014\n",
"Epoch 20/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0971 - accuracy: 0.6304 - val_loss: 1.5253 - val_accuracy: 0.5090\n",
"Epoch 21/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.0670 - accuracy: 0.6345 - val_loss: 1.5006 - val_accuracy: 0.5034\n",
"Epoch 22/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.0544 - accuracy: 0.6407 - val_loss: 1.5244 - val_accuracy: 0.5010\n",
"Epoch 23/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0338 - accuracy: 0.6502 - val_loss: 1.5355 - val_accuracy: 0.5096\n",
"Epoch 24/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.0281 - accuracy: 0.6514 - val_loss: 1.5257 - val_accuracy: 0.5164\n",
"Epoch 25/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.4097 - accuracy: 0.6478 - val_loss: 1.8203 - val_accuracy: 0.3514\n",
"Epoch 26/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.3733 - accuracy: 0.5157 - val_loss: 1.5600 - val_accuracy: 0.4664\n",
"Epoch 27/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.2032 - accuracy: 0.5814 - val_loss: 1.5367 - val_accuracy: 0.4944\n",
"Epoch 28/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.1291 - accuracy: 0.6121 - val_loss: 1.5333 - val_accuracy: 0.4852\n",
"Epoch 29/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0734 - accuracy: 0.6317 - val_loss: 1.5475 - val_accuracy: 0.5032\n",
"Epoch 30/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0294 - accuracy: 0.6469 - val_loss: 1.5400 - val_accuracy: 0.5052\n",
"Epoch 31/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0081 - accuracy: 0.6605 - val_loss: 1.5617 - val_accuracy: 0.4856\n",
"Epoch 32/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0109 - accuracy: 0.6603 - val_loss: 1.5727 - val_accuracy: 0.5124\n",
"Epoch 33/100\n",
"1407/1407 [==============================] - 17s 12ms/step - loss: 0.9646 - accuracy: 0.6762 - val_loss: 1.5333 - val_accuracy: 0.5174\n",
"Epoch 34/100\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 0.9597 - accuracy: 0.6789 - val_loss: 1.5601 - val_accuracy: 0.5016\n",
"157/157 [==============================] - 0s 1ms/step - loss: 1.4608 - accuracy: 0.5026\n"
]
},
{
"data": {
"text/plain": [
"[1.4607702493667603, 0.5026000142097473]"
]
},
"execution_count": 127,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100,\n",
" kernel_initializer=\"lecun_normal\",\n",
" activation=\"selu\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"optimizer = tf.keras.optimizers.Nadam(learning_rate=7e-4)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"\n",
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(\n",
" patience=20, restore_best_weights=True)\n",
"model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\n",
" \"my_cifar10_selu_model\", save_best_only=True)\n",
"run_index = 1 # increment every time you train the model\n",
"run_logdir = Path() / \"my_cifar10_logs\" / f\"run_selu_{run_index:03d}\"\n",
2021-10-17 04:04:08 +02:00
"tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)\n",
"callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]\n",
"\n",
"X_means = X_train.mean(axis=0)\n",
"X_stds = X_train.std(axis=0)\n",
"X_train_scaled = (X_train - X_means) / X_stds\n",
"X_valid_scaled = (X_valid - X_means) / X_stds\n",
"X_test_scaled = (X_test - X_means) / X_stds\n",
"\n",
"model.fit(X_train_scaled, y_train, epochs=100,\n",
" validation_data=(X_valid_scaled, y_valid),\n",
" callbacks=callbacks)\n",
"\n",
"model.evaluate(X_valid_scaled, y_valid)"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
"This model reached the first model's validation loss in just 8 epochs. After 14 epochs, it reached its lowest validation loss, with about 50.3% accuracy, which is better than the original model (46.7%), but not quite as good as the model using batch normalization (50.7%). Each epoch took only 9 seconds. So it's the fastest model to train so far."
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
"### e.\n",
"*Exercise: Try regularizing the model with alpha dropout. Then, without retraining your model, see if you can achieve better accuracy using MC Dropout.*"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
"execution_count": 128,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.8953 - accuracy: 0.3240INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 18s 11ms/step - loss: 1.8950 - accuracy: 0.3239 - val_loss: 1.7556 - val_accuracy: 0.3812\n",
"Epoch 2/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 1.6618 - accuracy: 0.4129INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.6618 - accuracy: 0.4130 - val_loss: 1.6563 - val_accuracy: 0.4114\n",
"Epoch 3/100\n",
"1402/1407 [============================>.] - ETA: 0s - loss: 1.5772 - accuracy: 0.4431INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.5770 - accuracy: 0.4432 - val_loss: 1.6507 - val_accuracy: 0.4232\n",
"Epoch 4/100\n",
"1406/1407 [============================>.] - ETA: 0s - loss: 1.5081 - accuracy: 0.4673INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 15s 10ms/step - loss: 1.5081 - accuracy: 0.4672 - val_loss: 1.5892 - val_accuracy: 0.4566\n",
"Epoch 5/100\n",
"1403/1407 [============================>.] - ETA: 0s - loss: 1.4560 - accuracy: 0.4902INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.4561 - accuracy: 0.4902 - val_loss: 1.5382 - val_accuracy: 0.4696\n",
"Epoch 6/100\n",
"1401/1407 [============================>.] - ETA: 0s - loss: 1.4095 - accuracy: 0.5050INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 16s 11ms/step - loss: 1.4094 - accuracy: 0.5050 - val_loss: 1.5236 - val_accuracy: 0.4818\n",
"Epoch 7/100\n",
"1401/1407 [============================>.] - ETA: 0s - loss: 1.3634 - accuracy: 0.5234INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.3636 - accuracy: 0.5232 - val_loss: 1.5139 - val_accuracy: 0.4840\n",
"Epoch 8/100\n",
"1405/1407 [============================>.] - ETA: 0s - loss: 1.3297 - accuracy: 0.5377INFO:tensorflow:Assets written to: my_cifar10_alpha_dropout_model/assets\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.3296 - accuracy: 0.5378 - val_loss: 1.4780 - val_accuracy: 0.4982\n",
"Epoch 9/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.2907 - accuracy: 0.5485 - val_loss: 1.5151 - val_accuracy: 0.4854\n",
"Epoch 10/100\n",
"1407/1407 [==============================] - 13s 10ms/step - loss: 1.2559 - accuracy: 0.5646 - val_loss: 1.4980 - val_accuracy: 0.4976\n",
"Epoch 11/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.2221 - accuracy: 0.5767 - val_loss: 1.5199 - val_accuracy: 0.4990\n",
"Epoch 12/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.1960 - accuracy: 0.5870 - val_loss: 1.5167 - val_accuracy: 0.5030\n",
"Epoch 13/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 1.1684 - accuracy: 0.5955 - val_loss: 1.5815 - val_accuracy: 0.5014\n",
"Epoch 14/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.1463 - accuracy: 0.6025 - val_loss: 1.5427 - val_accuracy: 0.5112\n",
"Epoch 15/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 1.1125 - accuracy: 0.6169 - val_loss: 1.5868 - val_accuracy: 0.5212\n",
"Epoch 16/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 1.0854 - accuracy: 0.6243 - val_loss: 1.6234 - val_accuracy: 0.5090\n",
"Epoch 17/100\n",
"1407/1407 [==============================] - 15s 11ms/step - loss: 1.0668 - accuracy: 0.6328 - val_loss: 1.6162 - val_accuracy: 0.5072\n",
"Epoch 18/100\n",
"1407/1407 [==============================] - 15s 10ms/step - loss: 1.0440 - accuracy: 0.6442 - val_loss: 1.5748 - val_accuracy: 0.5162\n",
"Epoch 19/100\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 1.0272 - accuracy: 0.6477 - val_loss: 1.6518 - val_accuracy: 0.5200\n",
"Epoch 20/100\n",
"1407/1407 [==============================] - 13s 10ms/step - loss: 1.0007 - accuracy: 0.6594 - val_loss: 1.6224 - val_accuracy: 0.5186\n",
"Epoch 21/100\n",
"1407/1407 [==============================] - 15s 10ms/step - loss: 0.9824 - accuracy: 0.6639 - val_loss: 1.6972 - val_accuracy: 0.5136\n",
"Epoch 22/100\n",
"1407/1407 [==============================] - 12s 9ms/step - loss: 0.9660 - accuracy: 0.6714 - val_loss: 1.7210 - val_accuracy: 0.5278\n",
"Epoch 23/100\n",
"1407/1407 [==============================] - 13s 10ms/step - loss: 0.9472 - accuracy: 0.6780 - val_loss: 1.6436 - val_accuracy: 0.5006\n",
"Epoch 24/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 0.9314 - accuracy: 0.6819 - val_loss: 1.7059 - val_accuracy: 0.5160\n",
"Epoch 25/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 0.9172 - accuracy: 0.6888 - val_loss: 1.6926 - val_accuracy: 0.5200\n",
"Epoch 26/100\n",
"1407/1407 [==============================] - 14s 10ms/step - loss: 0.8990 - accuracy: 0.6947 - val_loss: 1.7705 - val_accuracy: 0.5148\n",
"Epoch 27/100\n",
"1407/1407 [==============================] - 13s 9ms/step - loss: 0.8758 - accuracy: 0.7028 - val_loss: 1.7023 - val_accuracy: 0.5198\n",
"Epoch 28/100\n",
"1407/1407 [==============================] - 12s 8ms/step - loss: 0.8622 - accuracy: 0.7090 - val_loss: 1.7567 - val_accuracy: 0.5184\n",
"157/157 [==============================] - 0s 1ms/step - loss: 1.4780 - accuracy: 0.4982\n"
]
},
{
"data": {
"text/plain": [
"[1.4779616594314575, 0.498199999332428]"
]
},
"execution_count": 128,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100,\n",
" kernel_initializer=\"lecun_normal\",\n",
" activation=\"selu\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.AlphaDropout(rate=0.1))\n",
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"optimizer = tf.keras.optimizers.Nadam(learning_rate=5e-4)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"\n",
"early_stopping_cb = tf.keras.callbacks.EarlyStopping(\n",
" patience=20, restore_best_weights=True)\n",
"model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\n",
" \"my_cifar10_alpha_dropout_model\", save_best_only=True)\n",
"run_index = 1 # increment every time you train the model\n",
"run_logdir = Path() / \"my_cifar10_logs\" / f\"run_alpha_dropout_{run_index:03d}\"\n",
2021-10-17 04:04:08 +02:00
"tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)\n",
"callbacks = [early_stopping_cb, model_checkpoint_cb, tensorboard_cb]\n",
"\n",
"X_means = X_train.mean(axis=0)\n",
"X_stds = X_train.std(axis=0)\n",
"X_train_scaled = (X_train - X_means) / X_stds\n",
"X_valid_scaled = (X_valid - X_means) / X_stds\n",
"X_test_scaled = (X_test - X_means) / X_stds\n",
"\n",
"model.fit(X_train_scaled, y_train, epochs=100,\n",
" validation_data=(X_valid_scaled, y_valid),\n",
" callbacks=callbacks)\n",
"\n",
"model.evaluate(X_valid_scaled, y_valid)"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
"The model reaches 48.1% accuracy on the validation set. That's worse than without dropout (50.3%). With an extensive hyperparameter search, it might be possible to do better (I tried dropout rates of 5%, 10%, 20% and 40%, and learning rates 1e-4, 3e-4, 5e-4, and 1e-3), but probably not much better in this case."
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
"Let's use MC Dropout now. We will need the `MCAlphaDropout` class we used earlier, so let's just copy it here for convenience:"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
"execution_count": 129,
"metadata": {},
2017-06-14 09:09:23 +02:00
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"class MCAlphaDropout(tf.keras.layers.AlphaDropout):\n",
" def call(self, inputs):\n",
" return super().call(inputs, training=True)"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's create a new model, identical to the one we just trained (with the same weights), but with `MCAlphaDropout` dropout layers instead of `AlphaDropout` layers:"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "code",
"execution_count": 130,
"metadata": {},
2017-06-14 09:09:23 +02:00
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"mc_model = tf.keras.Sequential([\n",
" (\n",
" MCAlphaDropout(layer.rate)\n",
" if isinstance(layer, tf.keras.layers.AlphaDropout)\n",
" else layer\n",
" )\n",
" for layer in model.layers\n",
"])"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
"Then let's add a couple utility functions. The first will run the model many times (10 by default) and it will return the mean predicted class probabilities. The second will use these mean probabilities to predict the most likely class for each instance:"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
"execution_count": 131,
"metadata": {},
2017-06-14 09:09:23 +02:00
"outputs": [],
"source": [
"def mc_dropout_predict_probas(mc_model, X, n_samples=10):\n",
" Y_probas = [mc_model.predict(X) for sample in range(n_samples)]\n",
" return np.mean(Y_probas, axis=0)\n",
"\n",
"def mc_dropout_predict_classes(mc_model, X, n_samples=10):\n",
" Y_probas = mc_dropout_predict_probas(mc_model, X, n_samples)\n",
" return Y_probas.argmax(axis=1)"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's make predictions for all the instances in the validation set, and compute the accuracy:"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "code",
"execution_count": 132,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.4984"
]
},
"execution_count": 132,
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-14 09:09:23 +02:00
"source": [
"tf.random.set_seed(42)\n",
"\n",
"y_pred = mc_dropout_predict_classes(mc_model, X_valid_scaled)\n",
"accuracy = (y_pred == y_valid[:, 0]).mean()\n",
"accuracy"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
"We get back to roughly the accuracy of the model without dropout in this case (about 50.3% accuracy).\n",
"\n",
"So the best model we got in this exercise is the Batch Normalization model."
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
"### f.\n",
"*Exercise: Retrain your model using 1cycle scheduling and see if it improves training speed and model accuracy.*"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
"execution_count": 133,
"metadata": {},
2017-06-14 09:09:23 +02:00
"outputs": [],
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100,\n",
" kernel_initializer=\"lecun_normal\",\n",
" activation=\"selu\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.AlphaDropout(rate=0.1))\n",
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
"\n",
"optimizer = tf.keras.optimizers.SGD()\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
"execution_count": 134,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"352/352 [==============================] - 3s 8ms/step - loss: nan - accuracy: 0.1706\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXwAAAEOCAYAAACKDawAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAu0klEQVR4nO3deXxU5bkH8N8TguwQE5IQdhDBIIoIolcqm0VxrVtdqN66VCrKtVbrVtuqvS612l4XXFFBiytaXBBQSwkoWhBEWYxBEFBZw04WAiTv/eOZ13NmMjOZCZMzkzm/7+eTz2xn5rxzkjznPc+7iTEGRESU/jKSXQAiIvIGAz4RkU8w4BMR+QQDPhGRTzDgExH5BAM+EZFPZCa7ANFkZWWZXr16JbsYaaO8vBytWrVKdjHSAo9lYiXzeJaUAJWVQHU1kJkJ9O+flGIkzOLFi7caY3LDvZbSAT8/Px+LFi1KdjHSRlFREYYPH57sYqQFHsvESubxHDYMKC4GSkuBrCygsYccEVkX6TWmdIjI14wBMnwSCX3yNYmIwqupAZo00fvpPvEAAz4R+ZoxDPhERL7gTukw4BMRpTGmdIiIfII1fCIin2AOn4jIJ9wpnXTHgE9EvsaUDhGRTzClQ0TkE+ylQ0TkE0zpEBH5BFM6REQ+wZQOEZFPMKVDROQTnB6ZiMgnmNIhIvIJNtoSEfkEc/hERD7BlA4RkU8wpUNE5BOs4RMR+QS7ZRIR+QRTOkREPlFTwxo+EZEv7NsHNG+e7FJ4gwGfiHytqooBn4jIFxjwiYh8wBhg716gRYtkl8QbDPhE5Fv79+stAz4RUZqrqtJbpnSIiNIcAz4RkU/s3au3TOkQEaU51vCJiHzCBvxmzZJbDq8w4BORbzHgNzARaSIiS0Rkutf7JiJyszl8BvyG8xsAxUnYLxFRkNAc/hVXJK8sXsj0cmci0hnAGQDuBXCjl/smIgrlTulUVQGZnkZE73n99R4GcAuANpE2EJGxAMYCQG5uLoqKijwpmB+UlZXxeCYIj2ViJet4LlqUA+AoLFu2CPv2lXm+f695FvBF5EwAW4wxi0VkeKTtjDHPAHgGAPr06WOGD4+4KcWpqKgIPJ6JwWOZWMk6nqWlejtkyCD06+f57j3nZQ5/CICzRWQtgFcBjBSRKR7un4goiG20ZT/8BDPG3G6M6WyM6Q7gYgD/NsZc6tX+iYhCsVsmEZFP+C3gJ6VN2hhTBKAoGfsmIrL8FvBZwyci32IOn4jIJ2wN/5BDklsOrzDgE5FvVVVpsBdJdkm8wYBPRL5VVeWf/D3AgE9EPrZ3LwM+EZEvVFX5p8EWYMAnIh9jSoeIyCcY8ImIfII5fCIin2AOn4jIJ5jSISLyCQZ8IiKfYA6fiMgn9u5lDp+IyBf27AHatUt2KbzDgE9EvrVrF9C2bbJL4R0GfCLypQMHgIoKBnwiorS3Z4/eMuATEaW53bv1lgGfiCjNMeATEfkEAz4RkU8w4BMR+YQN+OyHT0SU5ljDb2SM0R8iongx4DcyZ54JtGyZ7FIQUWO0ezcgArRqleySeCcz2QU4GDNmJLsERNRY7d4NtGkDZDTqam98Uvqr7tuXgexsYM0aYMcOYMqU8NsxrUNE8fLbPDpAigf8qqoM7NgBlJQAkyYBl10GfP+9vnbggLOdHSJNRBSr3bsZ8FNKdbUA0DPx6tX63MaNert+vbPdjh3O/TvvBF57zaMCElHK2bEDuPtuoLo6+nYM+CmmpqZ2wN+0SW/XrXO2cwf8CROA55/3qIBElHLefRe46y5g6dLo223f7q8++ECKB3x3Df/bb/W5zZv1du1aZzsb8Kuq9Jf4zTfelZGIUsuWLXpru12GU1GhJ4Sjj/amTKmiUQT87dudAB+thu9+rarKmzISUWopLdXbaG178+cD+/cDI0d6U6ZU0SgC/ooV+ssBnKBuG28BPSHU1Div1dQ4VwRuX3/NHj1E6S6WGv6cOUBmJvCTn3hTplSR0gHf5vCXLHGesymdrVuBrl31/tVXAxdc4DToAsDKlcGf9e67QGEh8MorDVhgIko6G/Cj1fAXLAAGDABat/amTKnCs4AvIs1FZKGIfCkiK0Tk7rreY2v4P/ygj7t2dWrx27YB3bsDTZro4w8+CO658803zlUBADz5pN5+8snBfhMiSmWxpHRKS4GOHb0pTyrxsoZfBWCkMaY/gGMAjBaRE6K9wQZ8a9Cg4Bp++/aavgGA8nKgqEiHSnfooIO0DjkEePppfc+sWbrdnDnAffcBe/cm8qsRUaqIJaWzYwdw6KHelCeVeBbwjSoLPGwa+ImaUXcH/Pz82jX8nJzgnPw77wB5ecCDDwJffqnPzZ6tDb7GaErnq6+AO+4Apk1L2FcjohQSWsOfMwcYNy54mx07gOxsb8uVCjydS0dEmgBYDKAXgMeNMQvCbDMWwFh9NPDH59u124OKii0oKzsM06d/hK1bh6C8/HsA3QAAmZk12LcvA61bl6FTp0W44YaOePjh3ti0qRSzZ28CcBT69fsexcVdAADPPrsZBQXFDfp9U01ZWRmKioqSXYy0wGOZWIk6npWVGaioGAoAKCnZiKKiEjz55GGYOrULBg78DJs2tcDxx29Defkw7Ny5BkVF6+r4xDRjjPH8B0AWgDkA+kXfbqBp0UInQT7nHGPeflvvz5yptw89ZCdINuaii/T2xBPNj04+2ZgTTjDm2Wf1tZISY+67z5jzzzcmK8uYfftMVI88YsykSdG3SaQZM4yprGy4z58zZ07DfbjP8FgmVqKO55o1Tkz4+c/1uauuch7n5hqzaZM+fvzxhOwy5QBYZCLE1KT00jHG7ARQBGB0Xdva6Y87dwb699f7//633ubkAC+/DDz0EPDii9ow++CDznsLCrTnjs3pdekC3H47MGYMsHNn3Q24f/878MwzcXyxg7B8OXD66cD48d7sjygd2f91wMnh79rlvLZzp3bjBpjDb1AikisiWYH7LQD8FMDX0d7Tps0BnH++3s/O1hx+Vpbm5QFttL3kEuCmm7SB9pprgBNPdN5vA/7mzToNaosW+vyoUbr9u+9G3ndlJfDdd9rf/803gc8/r9/3jpX945w3r2H3Q5TONmzQ23btnBz+zp16u3Wr9tyzvfkY8BtWAYA5IrIUwGcAPjTGTI/6hoLKH4N027baA6d/fyf45uTUscMCYN8+nW0zL895vk0bYPjw6AF/9Wq9MNywAbjiCr0yaEj2j9JdQyGi+BQHmuUGDXICvq3h28bcNWv01o+Ntl720llqjBlgjDnaGNPPGPPn2N6nt/aXY9M6QN0B3/az/fLL4IAP6GpZK1fqL7+8HPjVr4CBA7UXz9y5wFNP6XY1NfqH8/HHevKoj/vv155B0djLTPvHSUTx++orTf927Fg7pbNtm97agO/HGn7Kr3j1pz/p4KoxY/Tx2WcDjz6q99u3j/7eggK93bgRGDw4+DU7pHrBAr0CeP55TRcdc0zwgC2rogJYuDD+odjPPgv8/vd6/957I29nA77dF5duJIpfcTHQt69mBEJr+Ha6ZNbwU1hOjjaeNmumj92THWVlRX+vDfhA7Rp+v35A8+bAZ59p0D/qKG04HTtWXwvnyiu1T2883nxTb0WcQWLhuAM+Z/skil9NjRPw27TRGr4xta+a7TxbdcWPdJTyAT+UiI6avfHGutei7NRJG2cBIDc3+LWmTXUujYULNeXTv79eBk6Y4Myjba8gWrUC7rlHA/Frr8WX2ikp0VtjdL6fG24Iv5293AR0kjciis933+nVcWGhBvwDB7SWHzqqfu1avQKw07L4SaML+ABw6qnA3/5W93YtWwK33KL3beB3GzxYc/MbNgTPiy2ijairVukfRu/emoMfNEhPBnl5wCmnBE/BvH69pm/ctfjKSv3jOuoofTxtmtOl1M0YDfj2KsQu9kJEsbPdrPv3d1aysvNwuW3Z4s90DtBIA3487rwT+L//A37969qvnX66c9/dGAxot6527TTI27x99+7Af/6jl4gffgi89JLWIMaP14aiq6/W9JC1apUG8xEjnOfcE7wBerWQkaGzeHb
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2017-06-14 09:09:23 +02:00
"source": [
"batch_size = 128\n",
"rates, losses = find_learning_rate(model, X_train_scaled, y_train, epochs=1,\n",
" batch_size=batch_size)\n",
"plot_lr_vs_loss(rates, losses)"
2017-06-14 09:09:23 +02:00
]
},
{
"cell_type": "code",
"execution_count": 135,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential()\n",
"model.add(tf.keras.layers.Flatten(input_shape=[32, 32, 3]))\n",
"for _ in range(20):\n",
2021-10-17 04:04:08 +02:00
" model.add(tf.keras.layers.Dense(100,\n",
" kernel_initializer=\"lecun_normal\",\n",
" activation=\"selu\"))\n",
"\n",
2021-10-17 04:04:08 +02:00
"model.add(tf.keras.layers.AlphaDropout(rate=0.1))\n",
"model.add(tf.keras.layers.Dense(10, activation=\"softmax\"))\n",
"\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=2e-2)\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=optimizer,\n",
" metrics=[\"accuracy\"])"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "code",
"execution_count": 136,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 2.0559 - accuracy: 0.2839 - val_loss: 1.7917 - val_accuracy: 0.3768\n",
"Epoch 2/15\n",
"352/352 [==============================] - 3s 8ms/step - loss: 1.7596 - accuracy: 0.3797 - val_loss: 1.6566 - val_accuracy: 0.4258\n",
"Epoch 3/15\n",
"352/352 [==============================] - 3s 8ms/step - loss: 1.6199 - accuracy: 0.4247 - val_loss: 1.6395 - val_accuracy: 0.4260\n",
"Epoch 4/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.5451 - accuracy: 0.4524 - val_loss: 1.6202 - val_accuracy: 0.4408\n",
"Epoch 5/15\n",
"352/352 [==============================] - 3s 8ms/step - loss: 1.4952 - accuracy: 0.4691 - val_loss: 1.5981 - val_accuracy: 0.4488\n",
"Epoch 6/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.4541 - accuracy: 0.4842 - val_loss: 1.5720 - val_accuracy: 0.4490\n",
"Epoch 7/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.4171 - accuracy: 0.4967 - val_loss: 1.6035 - val_accuracy: 0.4470\n",
"Epoch 8/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.3497 - accuracy: 0.5194 - val_loss: 1.4918 - val_accuracy: 0.4864\n",
"Epoch 9/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.2788 - accuracy: 0.5459 - val_loss: 1.5597 - val_accuracy: 0.4672\n",
"Epoch 10/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.2070 - accuracy: 0.5707 - val_loss: 1.5845 - val_accuracy: 0.4864\n",
"Epoch 11/15\n",
"352/352 [==============================] - 3s 10ms/step - loss: 1.1433 - accuracy: 0.5926 - val_loss: 1.5293 - val_accuracy: 0.4998\n",
"Epoch 12/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 1.0745 - accuracy: 0.6182 - val_loss: 1.5118 - val_accuracy: 0.5072\n",
"Epoch 13/15\n",
"352/352 [==============================] - 3s 10ms/step - loss: 1.0030 - accuracy: 0.6413 - val_loss: 1.5388 - val_accuracy: 0.5204\n",
"Epoch 14/15\n",
"352/352 [==============================] - 3s 10ms/step - loss: 0.9388 - accuracy: 0.6654 - val_loss: 1.5547 - val_accuracy: 0.5210\n",
"Epoch 15/15\n",
"352/352 [==============================] - 3s 9ms/step - loss: 0.8989 - accuracy: 0.6805 - val_loss: 1.5835 - val_accuracy: 0.5242\n"
]
}
],
"source": [
"n_epochs = 15\n",
"n_iterations = math.ceil(len(X_train_scaled) / batch_size) * n_epochs\n",
"onecycle = OneCycleScheduler(n_iterations, max_lr=0.05)\n",
"history = model.fit(X_train_scaled, y_train, epochs=n_epochs, batch_size=batch_size,\n",
" validation_data=(X_valid_scaled, y_valid),\n",
" callbacks=[onecycle])"
]
2017-06-14 09:09:23 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
2017-06-14 09:09:23 +02:00
"source": [
"One cycle allowed us to train the model in just 15 epochs, each taking only 2 seconds (thanks to the larger batch size). This is several times faster than the fastest model we trained so far. Moreover, we improved the model's performance (from 50.7% to 52.0%)."
]
2016-09-27 23:31:21 +02:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
2016-09-27 23:31:21 +02:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-10-17 03:27:34 +02:00
"version": "3.8.12"
2016-09-27 23:31:21 +02:00
},
"nav_menu": {
"height": "360px",
"width": "416px"
},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
2020-04-06 09:13:12 +02:00
"nbformat_minor": 4
2016-09-27 23:31:21 +02:00
}