handson-ml/17_autoencoders_gans_and_di...

3303 lines
3.4 MiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Chapter 17 Representation Learning and Generative Learning with Autoencoders, GANs, and Diffusion Models**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"_This notebook contains all the sample code and solutions to the exercises in chapter 17._"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
2022-03-31 23:59:51 +02:00
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/17_autoencoders_gans_and_diffusion_models.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
" </td>\n",
" <td>\n",
2022-03-31 23:59:51 +02:00
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/17_autoencoders_gans_and_diffusion_models.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dFXIv9qNpKzt",
"tags": []
},
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "8IPbJEmZpKzu"
},
"source": [
"This project requires Python 3.7 or above:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "TFSU3FCOpKzu"
},
"outputs": [],
"source": [
"import sys\n",
"\n",
"assert sys.version_info >= (3, 7)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TAlKky09pKzv"
},
"source": [
"It also requires Scikit-Learn ≥ 1.0.1:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "YqCwW7cMpKzw"
},
"outputs": [],
"source": [
"import sklearn\n",
"\n",
"assert sklearn.__version__ >= \"1.0.1\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GJtVEqxfpKzw"
},
"source": [
"And TensorFlow ≥ 2.8:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "0Piq5se2pKzx"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"assert tf.__version__ >= \"2.8.0\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DDaDoLQTpKzx"
},
"source": [
"As we did in earlier chapters, let's define the default font sizes to make the figures prettier:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "8d4TH3NbpKzx"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"plt.rc('font', size=14)\n",
"plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n",
"plt.rc('xtick', labelsize=10)\n",
"plt.rc('ytick', labelsize=10)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RcoUIRsvpKzy"
},
"source": [
"And let's create the `images/generative` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "PQFH5Y9PpKzy"
},
"outputs": [],
"source": [
"from pathlib import Path\n",
"\n",
"IMAGES_PATH = Path() / \"images\" / \"generative\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
"\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YTsawKlapKzy"
},
"source": [
"This chapter can be very slow without a GPU, so let's make sure there's one, or else issue a warning:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "Ekxzo6pOpKzy"
},
"outputs": [],
"source": [
"if not tf.config.list_physical_devices('GPU'):\n",
" print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n",
" if \"google.colab\" in sys.modules:\n",
" print(\"Go to Runtime > Change runtime and select a GPU hardware \"\n",
" \"accelerator.\")\n",
" if \"kaggle_secrets\" in sys.modules:\n",
" print(\"Go to Settings > Accelerator and select GPU.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Performing PCA with an Undercomplete Linear Autoencoder"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's build the Autoencoder..."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"tf.random.set_seed(42) # extra code ensures reproducibility on CPU\n",
"\n",
"encoder = tf.keras.Sequential([tf.keras.layers.Dense(2)])\n",
"decoder = tf.keras.Sequential([tf.keras.layers.Dense(3)])\n",
"autoencoder = tf.keras.Sequential([encoder, decoder])\n",
"\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.5)\n",
"autoencoder.compile(loss=\"mse\", optimizer=optimizer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's generate the same 3D dataset as we used in Chapter 8:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# extra code builds the same 3D dataset as in Chapter 8\n",
"\n",
"import numpy as np\n",
"from scipy.spatial.transform import Rotation\n",
"\n",
"m = 60\n",
"X = np.zeros((m, 3)) # initialize 3D dataset\n",
"np.random.seed(42)\n",
"angles = (np.random.rand(m) ** 3 + 0.5) * 2 * np.pi # uneven distribution\n",
"X[:, 0], X[:, 1] = np.cos(angles), np.sin(angles) * 0.5 # oval\n",
"X += 0.28 * np.random.randn(m, 3) # add more noise\n",
"X = Rotation.from_rotvec([np.pi / 29, -np.pi / 20, np.pi / 4]).apply(X)\n",
"X_train = X + [0.2, 0, 0.2] # shift a bit"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"history = autoencoder.fit(X_train, X_train, epochs=500, verbose=False)\n",
"codings = encoder.predict(X_train)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAARYAAADICAYAAAAtDs6kAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAWDElEQVR4nO2df7AdZXnHP09yk5QCd1AjIagIzFCnzdCkwYKndmwihRGmNlVqixXQBipW6AwddAptsc4wBcqUmmaQCtqoTCOog9ioQUTGW3FyUROGiKlGECMiKQgNpEG9JN6nf+weWU7Ojz3nvLvv7t7vZ2Znz9ndd9/n3Xv2e9+fz2PujhBChGRebAOEEM1DwiKECI6ERQgRHAmLECI4EhYhRHAkLEKI4EzENqAsFi9e7MceeyzPPvsshx56aDQ7YudfBRti5y8bwuS/bdu2J939pV1Puvuc2E466SR3d//KV77iMYmdfxVsiJ2/bAiTP7DVe7xvagoJIYIjYRFCBEfCIoQIjoRFiAKYnoarr072c5E5MyokRFlMT8Opp8Jzz8HChXD33dBqxbaqXFRjESIwU1OJqPziF8l+aiq2ReUjYREiMKtWJTWV+fOT/apVsS0qHzWFhAhMq5U0f6amElGZa80giFhjMbM3mNlOM3vIzC7rcn6VmT1jZven2/vyphUiNq0WXH753BQViFRjMbP5wAeB04BHgW+a2SZ3/++OS+9x9z8YMa0QIhKxaiwnAw+5+8Pu/hxwK7CmhLRCiBKIJSwvA36U+f5oeqyTlpltN7M7zGzZkGmFEJGI1XlrXY51evW+D3ilu+8zszOBzwIn5EybZGL2TuCdAEuWLGFqaop9+/YxFXH8L3b+o9qwY8ck999/BCtWPM2yZXtLyz9kvqPaUBSxbSg0/16rE4vcgBZwZ+b75cDlA9LsAhaPkta1unksG7ZscT/kEPf585P9li3l5B8631FsKJLYNjRxdfM3gRPM7DgzWwicDWzKXmBmR5mZpZ9PJmm2PZUnrQhLrAlfmmhWX6I0hdz9gJldDNwJzAc2uPsOM3tXev5DwB8Df2lmB4CfAWenKtk1bYxyzBXaE77aU9TLmvAVK18xPtEmyLn7ZmBzx7EPZT5fD1yfN60ojlgTvjTRrL5o5q3IRasV58WOla8YD60VEkIER8IihAiOhEUIERwJixAiOBIWIURwNCpUEtPTybDp5OSk5mOIxiNhKYGsD9SJieWsXFnNIdS2+GnOSFya8HeQsJRAdmq6uzE1Vb0fzKgOoJvwElSJpjjiVh9LCWR9oC5Y4JVsCo2yLqf9ElxxRbKfq6EuQtKU9VGqsZRAdmr65OR2Wq2VsU06iFHW5XR7Cer437VKNGV9lISlJNpT06emwvkUCcko63Ka8hJUiaasj5KwiF8y7LqcprwEMZieho0bj2HRooOfWxPWR0lYxFg04SUom3bf1MzMcWzcWN8O2n6o81aIkmn3Tc3OWq07aPshYRGiZNp9U/PmzTa2b6rKAcveZmbfSrctZrY8c26XmT2QBjLbWq7lQoxHu29q7dpdjWwGQbUDlv0A+D1332NmZwA3Aadkzq929ydLM1qIgLRaMDPzCK3W8bFNKYTKBixz9y3uvif9ei/w8pJtrD3T03D11Zq4Json1qhQt6Bjp/S4FuB84I7Mdwe+ZGYO3OjuN4U3sd40ZWq4qCdVDliWXGi2mkRYfjdz+LXu/piZHQncZWbfdfevdknbuIBl2QBeQM9gXhs3HsPMzHHMzhozM7Ns2LCLmZlHgtgwLrHzlw0l5N8r4FCRGzmDjgG/CXwf+LU+93o/8J5BeTYhYFk2gNeiRe4LF/YO5tUv2Fedn4FsqE7+1DRg2THAZ4Bz3f17meOHmtnh7c/A6cC3S7M8Ip1rc/bv771YrT3ycOWVagaJ8qlywLL3AS8BbkgDIh5w91cDS4Db02MTwCfc/YsRilE62bU5ExPgnghLr7kQmhUrYlHlgGUXABd0SfcwsLzz+Fygc20OaJ2OqCZaK1QzOmshEhRRRTSlX4g5QNlzmlRjaQhyESl6/QZizGmSsDQATYYT/X4DMTz9qSnUAJriJ1WMTr/fQNbnclmrqVVjaQByESn6/QZiePqTsDQAuYgUg34DZc9pkrBkqHMHqCbDiSr9BiQsKXXoAC1S+OosqqJ6SFhSqh4jp0jhq4OoinqhUaGUGD3nw1DkyI9GlcZDDrUORjWWlKp3gBY58qNRpdFRba87EpYMVer86qRI4au6qFaZqjehYyFhqRFFCl/23urIzY9qe92RsIgXoKr9cKi2151cwmJmC4BngQU9Lrnd3d8czCoRDVXth6fKTehY5B0VWgisBc7t2O5Lz39u2IxzBCwzM1ufnv+Wma3Mm1aMTtVHx0Q9yFVjcfdngf/IHjOza4GVJI6sPzpMpjkDlp0BnJBupwD/BpySM60YkVYL1q2D226Ds846+D+x+l9EHobuY7HE2ex64CLgIne/YYR8fxmwLL1nO2BZVhzWADen3sDvNbMjzGwpcGyOtCLDMGIwPQ2XXJI0g+65B0488YWduup/EXkYaoKcmc0jCXX6buCCtqiY2SIz+7CZPWxm+8zsQTO7pM+tugUse1nOa/KkFSltMbjiimQ/aBJXv8lymkgn8pK7xpI2QT5GEqrjHHe/peM+/0MSiuNhknhAd5rZbnf/ZLfbdTnWGbCs1zXDBDtrXMCyLNngZZ0By9p0C1y2Zk1vGyYnJ5mYWI67MTHhTE5uZ2pq78BzwxD7byAbSsi/V8Ahf2FQsAXAp4EZ4M0502wA1vc4NzBgGXAj8NbM953A0jxpu21lByzbssX9qqsODiQWKv9+AckGXddpQ6etN97ofvrpyb7b/bqVaxhiB+qSDWHyp0/AsoE1FjNblIrKaamofCFHmgmSkKjX9rjklwHLgB+T1IL+rOOaTcDFaR/KKcAz7r7bzH6SI21Uxu2LyNMnkndYuNs8i+w/qU5b163r3cfSvp/6VcQg8jSFbgbeSNIMepGZndNxfpO7d9aH1wPPpGkPwvMFLNsMnAk8BPwU+PN+aXOUozTGmQuSV5SGmfHZTww6bb3tNs1jEePTV1jSEaAz0q/vSLcss8DhHWmuI6mtvN7dn+t1bx8csMxJRp5ypa0S40zzHqcmEsLWs85Kaiqaot4MYk0P6Css6cs9mfdmZrYOOJVEVJ4cz7T6Ms5LH6omMo6tJ56ouSpNIOb0gGBrhcxsPfB6YLW7/yTUfevKqC99XlEK+Z+oW3RFCUr9ibk8I4iwmNkrgb8iGTX6QRqwHeAedz+jZ0LRlUEvdhMmqmkGb/HEXHkdRFjc/Yd0n18iCqDuCwV37Jjkve+ttzDWgZgrr+U2oYbU3QfI/fcfUWthrBOxmrXyeVtD2v+Jrrwy2UO9fK6uWPG0VlA3HNVYakr7P1Ed+1uWLdsr50gNR8JSc+ra36KRp2YjYSmJ9ijI5OTk2FX/7IhK3ftbRDORsJRAtrkyMbGclStH/2/dremjZoWoGuq8LYFsc2X/fhvLj0mvps/ll5crKnUL0lU3e+uOaiwlkG2uTEz4WM2VKjR96tZh3M1eUSwSlhLITlSanNxOq7VyYJo894rV9IndYTzsrN3Y9s5FJCwl0R4FGcXjWq97xSJmrWmU2lI3e2dmyrB27qI+lgZTVL9C5wS9MkVuFL+7Me2dq6jG0iCyTQQoth8kVq1p1NpS7FreXEPC0hA6mwhvf3sz+xWq0MckBlO6sJjZi4FPksQH2gX8ibvv6bjmFSRuLY8i8VJ3k7v/a3ru/cBfAG2fL3+bepSb03Q2ESD+6FFRqPYRhh07JpmeLkagY9RYLgPudvdr0vColwF/03HNAeBSd7/PzA4HtpnZXf58tMMPuPs/l2hz5elsIpx3XrIV8Z+9SF8qVfDTUgUbimZ6Gi69dDkHDhTTVI4hLGuAVennjwNTdAiLu+8Gdqef/8/MvkMSlEzRDjvIvgTdmghFvPhF9d1UYX5MFWwog6kp2L9/HrOzxTSVYwjLklQ4SMN5HNnvYjM7Fvgt4OuZwxeb2XnAVpKazZ4eaRsfsOzSS5e
"text/plain": [
"<Figure size 288x216 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure(figsize=(4,3))\n",
"plt.plot(codings[:,0], codings[:, 1], \"b.\")\n",
"plt.xlabel(\"$z_1$\", fontsize=18)\n",
"plt.ylabel(\"$z_2$\", fontsize=18, rotation=0)\n",
"plt.grid(True)\n",
"save_fig(\"linear_autoencoder_pca_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Stacked Autoencoders"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"## Implementing a Stacked Autoencoder Using Keras"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's load the fashion MNIST dataset, scale it, and split it into a training set, a validation set, and a test set:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# extra code loads, scales, and splits the fashion MNIST dataset\n",
"fashion_mnist = tf.keras.datasets.fashion_mnist.load_data()\n",
"(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist\n",
"X_train_full = X_train_full.astype(np.float32) / 255\n",
"X_test = X_test.astype(np.float32) / 255\n",
"X_train, X_valid = X_train_full[:-5000], X_train_full[-5000:]\n",
"y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's build and train a stacked Autoencoder with 3 hidden layers and 1 output layer (i.e., 2 stacked Autoencoders)."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"1719/1719 [==============================] - 5s 2ms/step - loss: 0.0241 - val_loss: 0.0194\n",
"Epoch 2/20\n",
"1719/1719 [==============================] - 4s 3ms/step - loss: 0.0174 - val_loss: 0.0165\n",
"Epoch 3/20\n",
"1719/1719 [==============================] - 5s 3ms/step - loss: 0.0161 - val_loss: 0.0160\n",
"Epoch 4/20\n",
"1719/1719 [==============================] - 5s 3ms/step - loss: 0.0153 - val_loss: 0.0152\n",
"Epoch 5/20\n",
"1719/1719 [==============================] - 5s 3ms/step - loss: 0.0148 - val_loss: 0.0147\n",
"Epoch 6/20\n",
"1719/1719 [==============================] - 5s 3ms/step - loss: 0.0144 - val_loss: 0.0144\n",
"Epoch 7/20\n",
"1719/1719 [==============================] - 5s 3ms/step - loss: 0.0142 - val_loss: 0.0143\n",
"Epoch 8/20\n",
"1719/1719 [==============================] - 5s 3ms/step - loss: 0.0139 - val_loss: 0.0143\n",
"Epoch 9/20\n",
"1719/1719 [==============================] - 5s 3ms/step - loss: 0.0137 - val_loss: 0.0140\n",
"Epoch 10/20\n",
"1719/1719 [==============================] - 6s 3ms/step - loss: 0.0136 - val_loss: 0.0137\n",
"Epoch 11/20\n",
"1719/1719 [==============================] - 6s 3ms/step - loss: 0.0135 - val_loss: 0.0136\n",
"Epoch 12/20\n",
"1719/1719 [==============================] - 6s 3ms/step - loss: 0.0134 - val_loss: 0.0136\n",
"Epoch 13/20\n",
"1719/1719 [==============================] - 6s 4ms/step - loss: 0.0133 - val_loss: 0.0135\n",
"Epoch 14/20\n",
"1719/1719 [==============================] - 6s 4ms/step - loss: 0.0133 - val_loss: 0.0134\n",
"Epoch 15/20\n",
"1719/1719 [==============================] - 7s 4ms/step - loss: 0.0132 - val_loss: 0.0134\n",
"Epoch 16/20\n",
"1719/1719 [==============================] - 6s 4ms/step - loss: 0.0132 - val_loss: 0.0134\n",
"Epoch 17/20\n",
"1719/1719 [==============================] - 7s 4ms/step - loss: 0.0131 - val_loss: 0.0133\n",
"Epoch 18/20\n",
"1719/1719 [==============================] - 7s 4ms/step - loss: 0.0131 - val_loss: 0.0133\n",
"Epoch 19/20\n",
"1719/1719 [==============================] - 7s 4ms/step - loss: 0.0130 - val_loss: 0.0132\n",
"Epoch 20/20\n",
"1719/1719 [==============================] - 7s 4ms/step - loss: 0.0130 - val_loss: 0.0132\n"
]
}
],
"source": [
"tf.random.set_seed(42) # extra code ensures reproducibility on CPU\n",
"\n",
"stacked_encoder = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(),\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(30, activation=\"relu\"),\n",
"])\n",
"stacked_decoder = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(28 * 28),\n",
" tf.keras.layers.Reshape([28, 28])\n",
"])\n",
"stacked_ae = tf.keras.Sequential([stacked_encoder, stacked_decoder])\n",
"\n",
"stacked_ae.compile(loss=\"mse\", optimizer=\"nadam\") \n",
"history = stacked_ae.fit(X_train, X_train, epochs=20,\n",
" validation_data=(X_valid, X_valid))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Edq73sa5yLhC",
"tags": []
},
"source": [
"## Visualizing the Reconstructions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This function processes a few validation images through the autoencoder and displays the original images and their reconstructions:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAggAAADICAYAAACNixn+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA5vklEQVR4nO2deaxdZfX+XxARARlKKbTQgbZYZilTEZAUEJBBQByAoNFgUHEgGDTOMRoSI34VDCZEjdE4BRPFiQhBZJ7KVChDgVLoTCnzJAqIv39+d/nZa59n333vPef0FJ7PX6v37uHd77Dv7rPetdZ6//3vf4sxxhhjDFl/bTfAGGOMMYOHPxCMMcYYU8MfCMYYY4yp4Q8EY4wxxtTwB4IxxhhjavgDwRhjjDE1Nhjm9z2JgXzttdfCXn/9decbpcvtXm+sF/j/OE61d6yTY/TUU0+Ffeqpp4b9wgsvhL3vvvuGvfPOO4e9wQbVV8KiRYvCvu+++8K+4447wv7KV74S9hlnnDHaZo+WgR6jl19+OeyLL7447JNPPrkXtxsxF110Udj7779/2NOmTevmbQZ6jEwpRYzRuvPX2RhjjDF9Y71hEiX19Ytt9erVYX/mM58Je8MNNwz78MMPD3vjjTcOO//P55///GfY/J8T/xf04IMPhn3WWWeFffTRR4+06aPBX9WDz8CN0SuvvBL2T3/607D/9re/hb1y5cqO52622WZhT5w4MWyqYf/+978r5/DfL774YthPPPFE2JtssknYDz/8cNjHHnts2GeeeWbYe+65Z8f2jZKBG6PKRfF+ffzxx8OeMGFCL27XCqoaixcvDptKUpcZ6DEypRQrCMYYY4xpiz8QjDHGGFNjrbgYLrjggrDPO++8sF966aWO9nrrdVao2PbNN9+88jvKaG2uRSZPnhz2rFmzwv7+978f9pQpU4a9zjD0VHZj3/CZ6VYppZT/+7//C5sbk970pjeF/eY3vzls9ivv8Z///Cdsbuake+jVV1+t3JvyNa9LGT2f0+nelMjpTspuJ0rvH/nIR8I+6aSTOt6jDIg0+q1vfSvsBQsWhE032pZbbhk25esnn3wy7EceeSTsSZMmhb399tuHnfv7+eefD3vhwoVh07234447hv3000+HTUmdP990003D/uY3v1m5H12ILRmIMVJwXXC+s//6DdvEtbbRRhv16pYDPUamlGIXgzHGGGPa4g8EY4wxxtTom4vhqKOOCvu6664Le/z48WEr2U3J5ZSWKWvn3/Ec2pTX1M+fe+65jm368pe/XPn35z//+Y7HNbBWZLff/e53lX+ff/75YdO1Qqmeu9dVn9PmeLEvM+zzNmPcpjQ5z33b295W+R1l7unTp4fNWPB8uWFv2I4RjdFXv/rVyr8p1TMSga4f9j9dD1tttVXY99xzT9jMY7DtttvKtjz77LNhv+Utbwl7v/32C5vuobe+9a0dr0M3H9dUdnn9/Oc/D3ufffYJW82PMiDyNaM3OBZcR3TbcW7SjaPWRMPzV37Heyv3HFmzZk3YnCscoxkzZlTO4dppyUCMkWnELgZjjDHGtMMfCMYYY4ypMVyq5TFx3HHHhX311VeHzV3TdCtQjlZStpLTMvwdpT0leVM+JVtvvXXYlAJ//OMfV44bhYuh6zTJkEP88Ic/rPybsiJlYPbfuHHjOv6c8OfKpcPd053+3QkluarrNB1PWXfZsmXD3rufsO//9a9/VX5HFwDXBecj5y+jBDjfZ8+eHTZdFfPnzw+b7oJSStluu+3C3mOPPcJmn7O9tJULhG3dYostKvf75S9/GTZdDG2ij/pJfo/QlcC+ZRQJn5u2ciUoF15Tmne2i23iWLB9jz32WNiMsuB7jy5Go+lHCYEHHngg7Ouvvz7sj3/845XjGLHEpGiMVGnTXisIxhhjjKnhDwRjjDHG1PAHgjHGGGNq9HQPwve+972wDzvssLCV74N+MqJC3VQoY4bhPszsp3zlhG2lz5fFpNYmfDZmDlyyZEnYLC27yy67VM6/9dZbw2Z/Kh8pC/PQn8X+U/sR8v4FFYbFPSNqnwKvy+vwXNql6MJD55xzTthf//rXO96v19x8881h530AHAv2v9o3Q/jM7A9mzaSPMvcZfdq02X/Kr9m0R2iIHNrM8MtBJhe1ol+fz8R+4s95vELtKcoZQgnXggo5ZQgj78H9SKTNPDPt9x0wDHn58uVhcz8I7VWrVoU9b968sFkwje+PUqrvkG222Sbsc889t1Ubh7CCYIwxxpga/kAwxhhjTI2euhhY0EjJXZTq6GKgnEnZs002vXycCpnkz5WsrWR3yjuDCCVkhlrddNNNleMYzsSQM8qY7BuGsbGPOY5twheb4PnqWhwXHsN2MwtgKVX5lRLcT37yk7C/9KUvha1cXr3g2muvDTtLuqq4Dt1lhMcw5JFzX2UIzdfkcVzDDBlliKZySbCY2jXXXBM2518p1dC8Cy+8MOwzzjijY3vXFtnFQNjPdAmxQBbXkSpo1sZFk4/j/Of6ZMbaH/zgB2FzHRx88MFhL126NGxK4kbTFDZ45ZVXhk13AI+jvWjRorDpIqA9Z86csLN7iNmKjz322HYP0AErCMYYY4yp4Q8EY4wxxtToqYvhqaeeCpsyNSU1FYmgZDdKaE2RC+p3SgJl+/Iu8iEoOVMKXZs07WgeggVwTjvttMrv2E8rV64Mm3IZ5WGOF/uJY9S2rrzKGqfGrk1BJ7YpjyOPozuFWTC5e5gZP3sNM6TRPVRKtcgU5ynHXrnLCGVxri/2U3bhcbc9j6NbgWOvXHt093Dn9r777lu5H+9x4403hj1oLoYMIxRUgSaOF9cI+4zPP1ZXHd1LdFsdcsghYdP9qN5pbV0db3SaXAyMHnvmmWfCZpZURhPRlaCKGF522WVh04WRoZurbXvj5/KqxhhjjHnD4g8EY4wxxtToqYuBchnlfEomPEbtGufP1c7tTNta6kNQ/qM0RyjBDcrOXiUTUdJ917veFfbHPvaxyvk33HBD2JS1OC5qxzVh/7V1CXFcVcEfSqMqmoXXZWKaPJ/UdSnp99OtcNVVV4VNqTHDMW6zdihNcyyUK4ZrM7sYKIW3cWcRytd0X/GauVgToyPuu+++Ed2vn2T5n2PE/mSfsYDOTjvt1PG6Supteoep39GlRNcZ5eiTTjop7MWLF4fN6JKmiA3zP5rWB9f3iSeeOKLrcm796le/Cpt/g3LE1uTJk8Pee++9R3Q/YgXBGGOMMTX8gWCMMcaYGj11MTCZEGUSSmKUIWfMmBF2mxz8TXnflRytdhXzfEo3agdplnTWFkqSpKTFuhFnnnlm5bhZs2aFPXPmzLDpTqFkqupXqJ3slCezNMokPJSW1e587pwnHDuVeKuU6m5eSq5MKnLEEUd0vEcvuPrqq8Om+4NropRSHn/88bAffPDBsI8//viwV69eHTafs81OePZxdiFxXGlzXXAc2fZHH320489JjjRhvRNGb/z+978P+wMf+EDHa/WTXEeE/cZ5yn7iePPnTCjH+auS6GQ3EOc/57xKKscd9b/+9a/DZr2cfrra3gi0iQRR73L+HT3ggAPCpqson9umpkwbl6EVBGOMMcbU8AeCMcYYY2r01MXAhBBEuRuY650SvnIRNNVlUEl4COVNyt9M8MTd7pRxVC77tQl3STPpxo9+9KOw//SnP8nzKXMzGoDyqRo7ompfUCovpZSvfOUrYX/jG98Imzu8KZFzLChFt9nln6/Ftr/3ve/t+By9hjIg++aWW26pHHf77beHzYRKa9asCZtrp40rpm0SHvYz1x7dTlyrlMjV+mLO/9z3xxxzTNjveMc7wmYdg0FEvWO4po477riw6Qqgu4JzXI1XlpM5Lux/jhGv+6lPfSrs73//+2Ez0mT//fcPm66eUqrj2jaqzPwPjp+KMrr88svD5jxgzQUmEvvxj3/c6t5XXHFF2HRD77bbbp3b2uqqxhhjjHlD4Q8EY4wxxtToWy0GSrqUviiFE+6ipzxGqVLlfc+o8sEqocyBBx4Y9r333tuxHdwFPyh873vfC5vSMFmxYoU8n3IoUTIiZXG
"text/plain": [
"<Figure size 540x216 with 10 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import numpy as np\n",
"\n",
"def plot_reconstructions(model, images=X_valid, n_images=5):\n",
" reconstructions = np.clip(model.predict(images[:n_images]), 0, 1)\n",
" fig = plt.figure(figsize=(n_images * 1.5, 3))\n",
" for image_index in range(n_images):\n",
" plt.subplot(2, n_images, 1 + image_index)\n",
" plt.imshow(images[image_index], cmap=\"binary\")\n",
" plt.axis(\"off\")\n",
" plt.subplot(2, n_images, 1 + n_images + image_index)\n",
" plt.imshow(reconstructions[image_index], cmap=\"binary\")\n",
" plt.axis(\"off\")\n",
"\n",
"plot_reconstructions(stacked_ae)\n",
"save_fig(\"reconstruction_plot\") # extra code saves the high res figure\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The reconstructions look fuzzy, but remember that the images were compressed down to just 30 numbers, instead of 784."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Edq73sa5yLhC",
"tags": []
},
"source": [
"## Visualizing the Fashion MNIST Dataset"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "g69nLh2GyLhC",
"outputId": "47f84d4f-b41f-4ef9-d653-d73506cb4839"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/ageron/miniconda3/envs/homl3/lib/python3.9/site-packages/sklearn/manifold/_t_sne.py:982: FutureWarning: The PCA initialization in TSNE will change to have the standard deviation of PC1 equal to 1e-4 in 1.2. This will ensure better convergence.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from sklearn.manifold import TSNE\n",
"\n",
"X_valid_compressed = stacked_encoder.predict(X_valid)\n",
"tsne = TSNE(init=\"pca\", learning_rate=\"auto\", random_state=42)\n",
"X_valid_2D = tsne.fit_transform(X_valid_compressed)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "klppwl4gyLhC",
"outputId": "b66330ea-5d49-4153-82ce-39fb78614a64"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADnCAYAAAC9roUQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAADdpElEQVR4nOydd3wVVdrHv2fmtvRGIEBCCx2CAVFURCWxG8W6sbyKdUWxZ1V03fW6a8HVuOuurKjraizrxooau0kULIgIgdAh1BACCaTntpk57x/n3uQmhKqiq/n5yUfuvTNnzpw585znPOX3CCkl3ehGN7rRjUMD7afuQDe60Y1u/JrQLXS70Y1udOMQolvodqMb3ejGIUS30O1GN7rRjUOIbqHbjW50oxuHEN1Ctxvd6EY3DiG6hW43utGNbhxCdAvdbnSjG904hOgWut3oRje6cQjRLXS70Y1udOMQolvodqMb3ejGIUS30O1GN7rRjUOIbqHbjW50oxuHELafugPd6MaeUFySrgF/AE4F3gYeA1KAKuBC4I/AVuCy7KyKLWHnXQ/kAq9mZ1XMOtT97kY39gbRTe3YjZ8rikvSrwMeBSKBVsAA7EANkAxEBA+tAWYAXwDXAL8La+aB7KyKew5Vn7vRjX2hW+h24ydDcUl6InAJ0Aj0B44FXgPigKFAInBe2CkSEHtpUgIWoId9V5OdVdHzB+x2N7rxvdAtdLvxk6C4JN0JrAH6onwL4cLURAlOI/hbyPdgceB+iCYgKTurIvC9OtyNbvxA6HakdeOnwkggFSVcO2uvIU3VRsc5ujctd0+IAVYGhXw3uvGTo1voduOnwm848Pl3MEIXIB148CDP7UY3flB0C91uHHIUl6TnAnce4ssOPMTX60Y3ukS3TbcbhxTFJekDgAoO/YK/BRidnVXReIiv241udEB3nG43DjVG8NPssHoBdcUl6QL4DMjOzqro1ji6ccjRbV74H0RKaVlsSmlZUn5ujpafm/O/9gzXoqITDjUctEdJTAbWB5MvutGNQ4pu88L/AFJKy/TqyZlm8N9PAtOQFmPLv7ZMzUZ1r7QFO5L7zAGigGeqJ2du2Vt7PxWKS9L7A4tRAjDqezRloISnvq8D9wIJHJ6dVbH4e7TRjW4cMLpX+p8ZUkrLRPD/kSmlZc+nlJYFACOltKwspbQsDpgGgNBYMmqCplumtjO+x1FIOROVMrs5pbRsyk92A3vHxagQru8jcE3gCODyAzxvVxftNH2PfnSjGweFbqH7M0FKadlVKaVlXsBMKS3bDjQDU2m3u48B/tzxLEH5iMMx7Y7Ozc1JKS2rSSktO/tH7fSBowqlYR4ofMH/B4DfZmdVlKESK/wH0EYC8GHY5zKUQ68b3Tik6Ba6PwOklJYNBp4GnKhtc092j0kVwNUASAlSMnj9svaDxG4hrD2AV1JKy1w/UrcPBi8CGw/wHAs4CcW1EA28UFySPgklgNegtFXfnk9vgwDGh30eiXLqdaMbhxTdQvfngaHsWwP0oGyhbQL2+G8+lg6fVwnhrqG3nfMzQHZWhYXS3vdHSCIl0pRsverjv8dc9fHf01AmgTJgLrAIWBls7yH2T4MOv64GePe/993oxg+DbqH788Bi9v0snIQ7joTgmUvyrJao2K603BD+XT0582cVl5qdVfE1cPLejglbQ0SdJypN1zyFwBdPLb3sfmBU2KEXAF+hGMb2la3mQ9mUd6HMEn/KzqpYf8A30I1ufE90C92fB45h30Jj92clNH0vAhfgqpTSsujv0a8fDMUl6WnFJem9ALKzKuYCdWE/t4lZKbFCtyQEJDh8jO77cTQQub5+4NROzUqUY25f89gCTsrOqpibnVWRBLiysyoe+l431I1uHCS6he7PAwfiEDoQ2IA3fqS29xvFJemPA+uAquKS9M+ClI7hZg8/SjACmFKKNm1XAH5DAzBrvUkVwEsoYSuBguysinXACyhTgcXuZgYTGJGdVTEv9EV3UkQ3fkp0C92fECmlZSellJbVAa+iHEM/Bk5MKS1L+5Ha3ieC2u002pMTjgeW0zFszAmsAKqF4JbKVdnzA554DF8kqysmWKtqjl0PfA5clJ1VcSmKDrJ/dlbFFcHzb0QJbsHuO4aF2VkVa36s++tGNw4U3WnAPy3+C8QH/+1F2ScjgFVADorE+/tCA4pSSsvGhRIsDjF8dBKEUpLShVXk2OysigaAYtKfqlx/1EN1+EbEpn17/4brXvgm/MDsrIptnc7tgRLq4a1aKPvtWd//FrrRjR8O3ZruIUJKaZlIKS3rPN6dw7muQoVUncv3SyDojDGoBIvPfsA29wvZWRX1QF4wyo2AqYagU8DF2pDADeLx1EHz8zIGLc7pH+mfX1yS/tg+LlMFLANaUPHNX6AqTjwNPFZckj7hB7mZbnTjB0B3GvCPhJTSst6ouNoGlOb6JkqLXQRcC2xCBefHhZ3WjKoHdjCLYQCl6e1r93J+9eTMQ27nPeK+gm8OSy4/8twhRTj13Swpr2VnVfwm9KG4JL2ejuNiAsOD9tsuUVySHgGcj9Ks3wSeRwleF0oYj8jOqthSXJJ+FfB3lO332uysipf31u/ikvQoVHREMvD37KyKFftxu93oxh7RLXR/QKSUlo0F7kMF7J+GyoKC9vIzIfiBcuDwLpppQQmKkODdX+JuE/gGmMDeOQnur56c+Yf9bPOgUVyS7gKeAiYC71z7Sb5mSNtvZ2fnRdh1o6tTemZnVdQEz62no9C1gIzsrIoVxSXpV6OiPZ4Ld451cf1NQL/gxwZU9eAvg/8OjakEIrOzKvYYr1tckv4ekI0yXwRQTrsZ2VkVO/d2/93oxp7QbV74nkgpLTs5pbSsPKW07BuUTfZMVDxoQthhnYWgA8jcQ5MfAh8AJcBbXR0gTMNCShyeFiu4T7dQGt69QH3bgWpFDY/TNYHX9+e+fgC8C1yGqtpw66zs200QT66rH1Czh3XeCvv3B51+M4HVxSXpTwPPAFcAnxeXpJ+0l+u/SXsFYQl8h3om4YuYYN9282NozxR0oHYvNcUl6SuLS9I/LS5Jv2wf53ejGx3wq9J0gzbVf6BSQB8BPkFVoe0LLKmenFl/gO0lA1tRZcEPFAGUKSBcCGwFjqyenFkVbP9M4J3wk9K2ruf891+wbKahWQjPN2OPe/2LCSetAT6unpy5IKW0LC6+vnZtQ2xicnRrI6lVG59bOTSzGKV5z6menHlIhG5xSXojKoY2hEXAqSgy8c71yhpQz+K67KyK2uKS9MOBb+k4Nq+j7iHc1v1pdlbFboI3GDGxNnh9P3B6dlZFcZBLtxLoEzx0K5C2txCy4pL0V4Cz2d3+HoIP5bDbBlwabn6onDEvsjV+Td+6fp9UTLj4v9Yezu/Grwy/muiFlNKyGJQmOjr41QlhPweApqCHf9N+tNUHmAMM4+AELnSsgPstcFT15MzOL2Y5SlOzAQjLksPXlRuaqQJXNWTE0Ys/j3t9xu33h064ffY9aYZuS5KAzTSoj0m4Inln9ZTDl32d6HVEXDx5zal/K7328tsOss8Hgs9QWn8IL6Ds1V0JuDhgCkpInoraBXQut34uygY+JOy7ycUl6aOzsyqWdWpvCe0C3wHcAxRnZ1XIIL1kXrD9v+1HzO5U4DrgUbp+X5xAb6RMsRmyfM0LfdbbPUnn2TY8cXTV6KefaE6ZbwMhiz8elpd98uq/7uNa3fgV4Feh6aaUliWivNu993KYCbirJ2fev6cDgu38ETgHpS0d7KLVVSnxAV0J/JTSsiuAmboR8F3y1tM9k3dVW5qUEShh7AeuAU5B2SwbgGsCNvscu6GcVZtT+tO7Zit2U9lRd8UmmQ88U/CjL7bFJekOlH37aODF7KyKZ4Oa5nPAJaix6+z8kyheBRfKNk3Y96HEh859fz87q+KM4pL0ISjh2opyVIYL7PLsrIox3/N+aoGkvR4kJUKCw2eTfb/JpzWxXFSPflY9abWE1KO4I4qBh7KzKn6KEL5u/MT4tWi6OSjmrr0hAFTv45gPUFrY9yWR8bB7SFiXJDDVkzOfA57Lz815FmXLFKgFohT
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(X_valid_2D[:, 0], X_valid_2D[:, 1], c=y_valid, s=10, cmap=\"tab10\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Xs95C-AtyLhC"
},
"source": [
"Let's make this diagram a bit prettier (adapted from [this Scikit-Learn example](https://scikit-learn.org/stable/auto_examples/manifold/plot_lle_digits.html)):"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "JIkuJSpPyLhC",
"outputId": "239e4fe5-7756-40c1-cc2c-70c83eee6c0f"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAAIwCAYAAACIvd32AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOydd3hb5fXHP1fD285OlGBCghIggEDsMAxIaimQQBltA78WzCgQSMsKw7RQ0rLMCNCWQBgtMYUWFygUQpmSALOnQEDCUBISJ1G243hq3d8f515JVuzYTpz9fp7HTyLpjvdq3Hvueb/nezRd11EoFAqFQqFQKHYWLFt7AAqFQqFQKBQKxZZEBcAKhUKhUCgUip0KFQArFAqFQqFQKHYqVACsUCgUCoVCodipUAGwQqFQKBQKhWKnQgXACoVCoVAoFIqdChUAKxQKhUKhUCh2KlQArFAoFAqFQqHYqVABsEKhUCgUCoVip0IFwAqFQqFQKBSKnQoVACsUCoVCoVAodipUAKxQKBQKhUKh2KlQAbBCoVAoFAqFYqdCBcAKhUKhUCgUip0KFQArFAqFQqFQKHYqVACsUCgUCoVCodipUAGwQqFQKBQKhWKnQgXACoVCoVAoFIqdChUAKxQKhUKhUCh2KlQArFAoFAqFQqHYqVABsEKhUCgUCoVip0IFwAqFQqFQKBSKnQoVACsUCoVCoVAodipUAKxQKBQKhUKh2KlQAbBCoVAoFAqFYqdCBcAKhUKhUCgUip0K29YegEKhUPQEf8Cpb+0xdIfPG9G29hgUCoVC0T0qA6xQKBQKhUKh2KlQGWCFQrFd4fNGtvYQ1sMfcG7KuqXAXsB3Pm+koa/GpFAoFIquUQGwQqFQ9BH+gDMIVACfAD/pLqD1B5y7GsvmAyl/wHmEzxuZ08N9+YBjgH/7vJEvN2ngCoVCsZOhJBAKhULRd4wHrIAbuM580h9w7uMPOKv8AeeEnOXPBwYCZcbfZT3ZiT/gvAN4HbgB+MIfcP5404euUCgUOw8qAFYoFIq+w5r1bwmAP+DcA/gAuAmo9QecF2UtvxKIG/+3AMf4A86ezMxdmfV/DfjjpgxaoVAodjaUBEKhUCj6jqXAIKAZuNN4zosEqTbj7/+AB43XHgKuBkYaj3cFKg1dMMDfjHXW+ryRVNZ+rHRkUR8eg0KhUOzwqAywQqFQ9B1jgP2AkT5vZIHx3GdZr7cAP/gDzpP9AWeBzxuJAZ8DpsWbBtwO3GP8rQKWAd/6A85hWdt5IWudVkRKoVAoFIoeogJghUKh6CN83kjc543M83kj7VnPfQCcATwLfAz8AngaCPsDzhGAw1g0hmRyB2Vt0m787QZcBeAPOE8HhgHvApcDDp830rQZD0uhUCh2OJQEQqFQKDYzPm/kBeAFf8DZhjg+gGSLvwOKjMdJJKB9qYvNJP0B5z7AY8Y6MSAFnOQPOBcBU33eyOrNcwQKhUKxY6ECYIVCoegj/AHnfcDpQDtwvc8beTxnkXYyATBkgl+TUkT6cEXO84uAu4AjgYTxXB5wFCKbiAEjgJ9s4iEoFArFToEKgBUKxQ5DKpXCYhFl11tvvcVLL73EmDFjaG1t7bBcPB4nkUigaZnOxclkMv1/i8VCW1sbDoeDCy64oDdDuJiMtOxBf8D5kc8b+QbAH3D2A/4O/BYpYkvRUYbWBLzi80ae8gec7yKZXjtQg3gFn4Z4DOchWmJz3QLjuf16M1CFQqHYmVEBsEKh2CFIJpNYrVai0SifffYZDzzwANFolK+++orW1lasViuapqFpGna7HYvFgq7r6fV1XUfXdRKJBBaLhZaWFtasWcMPP/yAz+fD4/H0ZBi5dRW7AN9kWaFZkOD1n0jx2x3GcmuBQ3zeSCOAzxt52h9wPod4A3+GOEdkZ45bgZ8jbhL9kUC5xB9wvgmc4fNGlvb0fVMoFIqdEVUEp1AotmtisRiJRAKrVZzBLr/8ck488US+/fZbdF1n8eLFrFq1ilWrVrF69er0/9esWUNDQwNr166lsbGRxsZGGhoaaGhoYMWKFTQ3N5NKpZg1axaXXnopd9xxRzcjkeEY/+pAFHjPePxbMs0uioDVPm/kAWAUktV15gatPm8kAbiQRhmFyPk6+68Q+A1ilZZEfIePBhb4A869ev1GKhQKxU6EygArFIrtGovFgs0mp7IrrriCzz//nFGjRpFMJlm+fDnJZDIti7BYLB1kD0D6sZkN1nWdVEosd202G8OGDSMWi/HXv/6Vgw8+GK/Xu6HhPAacg8gbhiK63AiwHAmOC4x/VwD4vJFV/oBzCDDRH3C+7fNGlpkb8gec5wMz6Jj5NbEBjyKB9nw6JjPykMYYkzY0UIVCodiZURlghUKxXWMGv7/4xS944403WLRIekLE43EsFgt2ux2r1ZqWQORiSh9MNE1LL59IJGhsbCQvL4/GxsaeZIFHI8FpHiJLONl4fjrSungt4vJwP4A/4DwO0fc+Csz1B5wjjec14D46D34xtl2MZH33QGQVJkmkEYdCoVAoukAFwAqFYrvntttu46mnnqKtrY1hw4aRSCSw2+2kUql0gJtKpdJ/udleM+NrPmcWxMXjcVKpFMlkEofDwcsvv9zdUD5G9Lkgweuf/AHnOJ830gL8GjgGmJTlE3wZIokoNf69MasLXKYqb8PowL6IxjgFfAP8vofrKhQKxU6JCoAVCsV2z4IFCxg6dCjvv/8+8+bNo7m5GavVSjweJ5lMdgh+zb9kMrnea9nPJRIJCgoKSKVSrF69mnvuuacnQ/kD8D4SiIJkaX/vDzgnIFKFt4H3/QGnmdmdC7QZ/88Dzga+NtarpPMgOPu5FPAnnzey2OeNjAdsPm9kH1UEp1AoFBtGBcAKhWK75uOPP6a9vZ1p06bRr18/3G43a9asQdO0Dtlf8//Zj7NJJpPp5ywWC4lEIp0BjsVi+Hy+bsditDZ+kkxQmwAagGqkaK0E2BP4pz/gfAhxePg0axM2xDnCC9yA2KVl0wRcguiI25Bs75+z9q+jUCgUim5RRXAKhWK7Ztddd+WTTz7hrbfeYsKECTz00EMceuihacszIB0MZ5Pt+9sZVqs1HfwWFRWRn9+VHHc9ZgEnAscjVmc3As+Q8f0tAk5Czr8X0FG/C9LYwgns3sm2f4boic3Bf2XIKwDwB5x5iKxiN+BBnzcS7umgFQqFYmdCZYAVCsV2zbBhw/jxj3/M/PnzOf7443n//fcBaGlpwWq1pjO7uVKH7Ixw7h+I/tdms5FIJMjLy+vxeIws8A3AWcBJPm9kFXAeMAfJ2lqQIjazIq8ICY5TiJ53MSKVyM7+JpFs7yxgHJJNLgR+mrP7hxAHiEuAd/wB5/AeD1yhUCh2IlQArFAotnvuvvtuxo8fz5w5czjiiCMoLCwkkUiQTCax2WzryRs6c4PIRtf19DLJZJLi4uIej8UfcJ6C6ID/BnxrtEc+EPCQCXrTuzL+bUNcH64B9vV5Ix8BtwHrED/hG4GRgIPMeTsBzPMHnPv4A07zOS8SGGtIQK26wykUCkUnKAmEQqHY7rnzzjsZNGhQB5lDLBajoKCAWCyG3W4nkUik3R6yM72dySTM50wJxCGHHNKjcfgDzoeBE5CsrskUxBv4T2QCXpAAtgbJ6D4IhIB/A1f7A87rELmDFXGIOJ6OxW/tiOPEQcCHSKDdAsQRF4oCJAgOGeM6zhjX2z5v5JkeHYxCoVDswKgAWKFQbPcsXLiQCy64IP34uuuu45Zbbklnf2OxWIcMsNVqTWeCTX/g7OI4M1C22+3k5eVx6aWXAiKLsNvtGxrKr7t4vhg4CqgFfoEEp1U+byRdwOYPOL8DxhgPZxj/Fhj/HoEEvOOQ8/ZUJKA2X3cb/8aBHxDN8SyfN7LMH3D6gOeQzPDl/oAzDBxsSDUUCoVip0RJIBQKxXbPHXfcwb333ssxxxzD3Llz+dGPfsSAAQPQdR2bzYbdbic/P5+8vLwuG2JkZ4gBioqKiEajHH744ey7774A3QW/nWFusBl4FTgX2B8YmxP8FiESB5Mk4vSgG9tYggTQE4ADjTbKAzrZnx0Y4PN
"text/plain": [
"<Figure size 720x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code beautifies the previous diagram for the book\n",
"\n",
"import matplotlib as mpl\n",
"\n",
"plt.figure(figsize=(10, 8))\n",
"cmap = plt.cm.tab10\n",
"Z = X_valid_2D\n",
"Z = (Z - Z.min()) / (Z.max() - Z.min()) # normalize to the 0-1 range\n",
"plt.scatter(Z[:, 0], Z[:, 1], c=y_valid, s=10, cmap=cmap)\n",
"image_positions = np.array([[1., 1.]])\n",
"for index, position in enumerate(Z):\n",
" dist = ((position - image_positions) ** 2).sum(axis=1)\n",
" if dist.min() > 0.02: # if far enough from other images\n",
" image_positions = np.r_[image_positions, [position]]\n",
" imagebox = mpl.offsetbox.AnnotationBbox(\n",
" mpl.offsetbox.OffsetImage(X_valid[index], cmap=\"binary\"),\n",
" position, bboxprops={\"edgecolor\": cmap(y_valid[index]), \"lw\": 2})\n",
" plt.gca().add_artist(imagebox)\n",
"\n",
"plt.axis(\"off\")\n",
"save_fig(\"fashion_mnist_visualization_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1OJPbzoNyLhD"
},
"source": [
"## Tying weights"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XLMREMfUyLhD"
},
"source": [
"It is common to tie the weights of the encoder and the decoder, by simply using the transpose of the encoder's weights as the decoder weights. For this, we need to use a custom layer."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "jTru0q0CyLhD"
},
"outputs": [],
"source": [
"class DenseTranspose(tf.keras.layers.Layer):\n",
" def __init__(self, dense, activation=None, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.dense = dense\n",
" self.activation = tf.keras.activations.get(activation)\n",
"\n",
" def build(self, batch_input_shape):\n",
" self.biases = self.add_weight(name=\"bias\",\n",
" shape=self.dense.input_shape[-1],\n",
" initializer=\"zeros\")\n",
" super().build(batch_input_shape)\n",
"\n",
" def call(self, inputs):\n",
" Z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)\n",
" return self.activation(Z + self.biases)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"id": "f_mU3e6IyLhD",
"outputId": "e8045bff-7004-4940-8d00-295ad99c0f23"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 6s 3ms/step - loss: 0.0246 - val_loss: 0.0185\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 6s 3ms/step - loss: 0.0169 - val_loss: 0.0161\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 5s 3ms/step - loss: 0.0155 - val_loss: 0.0151\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 5s 3ms/step - loss: 0.0148 - val_loss: 0.0146\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.0142 - val_loss: 0.0141\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.0138 - val_loss: 0.0138\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 5s 3ms/step - loss: 0.0136 - val_loss: 0.0136\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 5s 3ms/step - loss: 0.0134 - val_loss: 0.0137\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 5s 3ms/step - loss: 0.0133 - val_loss: 0.0134\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 5s 3ms/step - loss: 0.0132 - val_loss: 0.0133\n"
]
}
],
"source": [
"tf.random.set_seed(42) # extra code ensures reproducibility on CPU\n",
"\n",
"dense_1 = tf.keras.layers.Dense(100, activation=\"relu\")\n",
"dense_2 = tf.keras.layers.Dense(30, activation=\"relu\")\n",
"\n",
"tied_encoder = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(),\n",
" dense_1,\n",
" dense_2\n",
"])\n",
"\n",
"tied_decoder = tf.keras.Sequential([\n",
" DenseTranspose(dense_2, activation=\"relu\"),\n",
" DenseTranspose(dense_1),\n",
" tf.keras.layers.Reshape([28, 28])\n",
"])\n",
"\n",
"tied_ae = tf.keras.Sequential([tied_encoder, tied_decoder])\n",
"\n",
"# extra code compiles and fits the model\n",
"tied_ae.compile(loss=\"mse\", optimizer=\"nadam\")\n",
"history = tied_ae.fit(X_train, X_train, epochs=10,\n",
" validation_data=(X_valid, X_valid))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"id": "WS2bm471yLhD",
"outputId": "8427f9cc-630f-4cc9-a1f5-7361c5eab6dc",
"scrolled": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAACvCAYAAACcuYvQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAChFklEQVR4nO19SY9j2ZXex3lmkIwpI+csVSmrpKqu0mCjJXVrITfcgO2Vl/51BrzywgvDgGHBNtwW4DaEtqZWqeYqZVZlZoxkcJ4HL8Lfje+dvI8xkZGRLh6AYAT5+N4dzj3Dd849NzKbzbCiFa1oRSta0etG0VfdgBWtaEUrWtGKLkMrBbaiFa1oRSt6LWmlwFa0ohWtaEWvJa0U2IpWtKIVrei1pJUCW9GKVrSiFb2WtFJgK1rRila0oteS4md8v5Ace03Vj0QiL30/nU5f+sx33UWfpfcJu99ZbTsHXa6hCxrb/8/pMmO7cJ6dzWaYzWbo9Xro9/v4+uuv8e/+3b/Dxx9/jNFohNFohHg8jrt372J9fR35fB5bW1tIpVKIRqOIRqMYDoc4OjpCt9tFvV7H3t4e+v0+dnd3Ua1W8eDBA/yzf/bPcOvWLbz//vv44Q9/iHg8HuDfy64JD71Snp1MJphOp9jd3cU//uM/YjAY4Hvf+x4eP36MaPR6bOpWq4Xf/e532N3dxcbGBu7fv490Oo1yuYx8Pn+VW6/kwfLopbE9S4Et5qmRCGazWWABcuG322189tlnqNfriEajiEQiiMViKJVKyGaziMVi7nP9/WQywWw2c4thOp06ATMej9Hr9TCdTrG1tYVbt24hlUqhUqkgk8mc2bYVfXtpOp1iNBphb28PjUYD3W4XtVoNw+EQ3W4X3W4Xx8fHeP78OXq9nlN0s9kMw+EQnU4H4/EYk8kE8fjp8hqPx2i1WhgMBmi32xgOhxiPx4hEIkgkEhgOh3jy5AlqtRoODw/xhz/8AalUCuvr68hkMiiXy7h9+zZSqRRyuRzS6fSrGqIrUyQSQTQaRTabxc7ODkajEQqFwrWsQRojkUgEhUIBo9EI5XIZxWIRqVQKiURi6W1Y0eLoWhQY8LJ3w8X+4sUL/Lf/9t/w1VdfOeWVTCbx5ptvYnt72/1Ppo9Go07ITCYTjMdjJwwODw9xfHyMXq+HarWK0WiEH/zgB/jRj37kGNQqMF/bVvTtpNlshul0in6/jy+++AJ//vOfcXh4iE8//RSdTgetVgudTgeTycTxXzKZRDKZBAD0+31EIhG0223U63UACBhYyrP9fh/T6RSRSASZTAaDwQCffvopYrEY2u02Wq0Wcrkc3nnnHWxubuLNN9/EX/7lX6JYLOLWrVuvtQKLRqOYzWbI5/O4d+8eptMp8vn8tSmwyWSCSCSCcrmMRCKBUqnk/l7JgteLlq7AZrMZBoMBRqORW7iTyQSdTgfdbhe7u7s4OjrC0dERYrGYU1gHBwcAgFgs5qyieDzumF+FwWAwwHg8RrVaRbPZRK/XQ61Ww2QywdHREfb29tDtdlEsFjEcDp0ijEajyGQySKfTTkHeNKLwm0wmzrtk20m+RXeVCisWPrvsvcN+S8GdSqUQj8cDRsp102QyQa/Xw3g8xng8xmg0cl5WvV5Hq9Vynj0NJQABZIA0nU4DHhm/s+NAL2A6nSIWizm+JsownU4dX3c6HaTTadTrdRwcHKDX6wEABoMBEokEMpmMu4d6fK8D0fvkOFz3s+PxONLpNBKJBGKx2Ep5vYYUOUMYXRmX7ff7+OSTT7C7u4v9/X189NFHaLfb6Pf7GI1G6HQ6ePLkCdrttlMidO9pZfKzRCLhlNl4PHYLfTgcYjKZOAFDoTSbzVAul7G+vo50Oo1bt24hn88jm82iWCwim83igw8+wOPHj5FMJpHL5S4jBJaGec9mM2f11+t1fPjhh6hWq8hkMsjlcgGhpwqYSs/eiy/+T4/Xfk94Vu9D4czPfW11A/L/BAHnAoBrG5VBKpXCd77zHWxsbGB9fR2PHj1ynozQ0mNgx8fH+P3vf4/Dw0O0Wi3U63UMBgMcHByg2Ww6/qKQjcVimEwmGAwG7r3b7SISiaBYLCKTyTheisVibgypHCeTifPQqAyj0SiSySRKpRISiQS63S46nQ4AIJVKIRaLIZVKoVgsIpFIOB7e2NjAD37wA1QqFWxsbGB7e/u8RtgrjdMonw0GA8xmM6RSKd/8L5zIv2oQcr4WpMBWMbDl0fXHwCaTCfb39x0k83d/93c4Pj52Fu90OsVwOHQCkEy0u7sLICgYU6kUUqkUgBPhyEUwHA4Dv1U6Pj7GV1995aCCdDqNYrGIzc1N937v3j3MZjNks9mljsV5yMbjhsMh2u02jo6O8NFHH+H58+fI5/OoVCrOE1N4NRKJYDKZOMWhCouKhwIkGo06y5NjCZzEJ7nIqbj4P3Aaf7TtBk6TDRgTGo1GAU+F8Fg+n3eGRiQSwb17965FgFnq9/t48uQJvv76a1SrVezv77v47Hg8RiKRQDabRTKZRDabRSaTwWQyQbfbxWg0QqvVQrfbDRgHVEj0LjRGy3fybyKRQDwed0ZVMplEOp1GNpt1Bl6/30etVsNXX30FAO7e9+/fR6VSwXg8RjqdxtbW1rWP32WJY0UD6jrRDxpu6XTaGSYr7+v1pKUpsL29PXzzzTdoNBr4/e9/jydPnmB/f98pGxWchAWBcDiMLr8yvMIwPlJlwGvG47ELzPd6PfzhD3/AYDBAsVjEo0ePUCgUsLa2hkql8kqYWp/JzLdGo+FezWYTs9nMQVgqOEk+5RXmgSnMxe9UcdFL0O993p3tAxUi4w30bPv9PgaDAaLRKKrVKuLxOAqFgmvzddB0OsXh4SFqtRqOjo5c3HQ2mzkPimMQjUYdxByLxdy4MeBP5cHEocFg4ARjPB5340ZIkMqPY8i4LOMvXA/JZNLxbCqVQjqdRiqVCvD7aDTCV199hVqthtlshlKpFLjfqyYap9ZwsoYRgAC6QgPJrr/LQNfz4Ft6xPF43EHYNDai0ajLItW2rehm0VIU2Gw2w5/+9Cf8h//wH3B8fIwvvvgCe3t7AZiLFhCv198CQUF+Viq8j7hArEIYj8eo1+uoVquIRCJ4+vQp0uk0dnZ28POf/xy3bt3C97//faytrb3ymMJ0OkWj0cDz58+xt7fnXs1m00FMPk/W9tmnvHh/nwK0v1XypXOfJVgoDCKRCPr9voMRv/jiC+zv7yOdTmM0Gl1scK5Ak8kEn3zyCX7zm9+g2+2iWq2i3+8jFothbW0tAMva31G4FYtFBykOh0MMh0McHByg0Wg4uE/5h5CV9jORSCCXy6FSqQSuY4YeAORyOQBwECQFL+PJv/rVrwCceLalUgmFQgE7Ozs3QuCOx2M0m00HE/KlGZgcYxoE5GMaPj4DK8zYtQabxiQtEcodjUZuHiKRiIvXM2s5mUyiWCyiWCyuvLQbSAuV0Mqcx8fH2N3ddVZuvV53+D0XvhUSYXu4ALjsQyUqw7MEKi1bhXMo7BkUB4CDgwPE43EHCd0EGgwGLg7GRIJkMonBYAAATpgRSvUpML5bAUBL19dXjpn9jO+++dB76/Xcz8T4EWNChA/7/f61jvd0OnWJGoPBwHlGmlBB2NOOGxNQmFyUSCSQSqXcNaPRyAlCeq98JtcGkQS+EomEM67Uu9YYJf+3kG6r1cJ4PEaj0UC73XZjfBOIfRqNRgGjidCyxry1f6rA7P30/Sw+90Hd/B0NgH6/j3Q67Z7PtgEn8D0h+RVdntR40fVjk5wuYyAsVIHVajX87//9v7G/v4/f/e53+PLLL9Hr9TCZTJDP55FIJALMYhtshaZ2zCcsaQ3r9dbqInzFxaQJD4QzyMyffPIJ9vb2cOvWLUwmk1duxU4mE3z99df4h3/4B2ctMuuMi4r9JdRFCoNO7GcK74T9VuksJtPnWIUHwCUkqJBXSG2ZpPPNONdsNnPjRgUGwCkI5VMqYv4/nU5d7InCTmN7NCgUdsxkMshms1hbW3NxtUQi4drFNvIZmpij94rH4055Eo798MMPsb6+jnK5jHK5vPTxDCO
"text/plain": [
"<Figure size 540x216 with 10 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code plots reconstructions\n",
"plot_reconstructions(tied_ae)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ucUS-1GwyLhE"
},
"source": [
"## Extra Material Training one Autoencoder at a Time"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"id": "CIqJ922pyLhE"
},
"outputs": [],
"source": [
"def train_autoencoder(n_neurons, X_train, X_valid, n_epochs=10,\n",
" output_activation=None):\n",
" n_inputs = X_train.shape[-1]\n",
" encoder = tf.keras.layers.Dense(n_neurons, activation=\"relu\")\n",
" decoder = tf.keras.layers.Dense(n_inputs, activation=output_activation)\n",
" autoencoder = tf.keras.Sequential([encoder, decoder])\n",
" autoencoder.compile(loss=\"mse\", optimizer=\"nadam\")\n",
" autoencoder.fit(X_train, X_train, epochs=n_epochs,\n",
" validation_data=(X_valid, X_valid))\n",
" return encoder, decoder, encoder(X_train), encoder(X_valid)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"id": "-3TSUkv4yLhE",
"outputId": "bca2fb96-7ddf-4457-b1df-30cdce8d803e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 8s 4ms/step - loss: 0.0202 - val_loss: 0.0153\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 8s 5ms/step - loss: 0.0120 - val_loss: 0.0115\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 7s 4ms/step - loss: 0.0106 - val_loss: 0.0111\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 8s 5ms/step - loss: 0.0099 - val_loss: 0.0097\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 8s 5ms/step - loss: 0.0096 - val_loss: 0.0093\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 8s 5ms/step - loss: 0.0093 - val_loss: 0.0091\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 8s 5ms/step - loss: 0.0092 - val_loss: 0.0090\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 8s 4ms/step - loss: 0.0090 - val_loss: 0.0098\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 7s 4ms/step - loss: 0.0090 - val_loss: 0.0089\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 8s 4ms/step - loss: 0.0089 - val_loss: 0.0089\n",
"Epoch 1/10\n",
"1719/1719 [==============================] - 3s 763us/step - loss: 0.1377 - val_loss: 0.0621\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 1s 694us/step - loss: 0.0585 - val_loss: 0.0554\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 1s 710us/step - loss: 0.0542 - val_loss: 0.0537\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 1s 712us/step - loss: 0.0533 - val_loss: 0.0531\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 1s 739us/step - loss: 0.0522 - val_loss: 0.0519\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 1s 738us/step - loss: 0.0516 - val_loss: 0.0517\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 1s 699us/step - loss: 0.0515 - val_loss: 0.0518\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 1s 754us/step - loss: 0.0515 - val_loss: 0.0515\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 1s 787us/step - loss: 0.0514 - val_loss: 0.0515\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 1s 769us/step - loss: 0.0514 - val_loss: 0.0516\n"
]
}
],
"source": [
"tf.random.set_seed(42)\n",
"\n",
"X_train_flat = tf.keras.layers.Flatten()(X_train)\n",
"X_valid_flat = tf.keras.layers.Flatten()(X_valid)\n",
"enc1, dec1, X_train_enc1, X_valid_enc1 = train_autoencoder(\n",
" 100, X_train_flat, X_valid_flat)\n",
"enc2, dec2, _, _ = train_autoencoder(\n",
" 30, X_train_enc1, X_valid_enc1, output_activation=\"relu\")"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"id": "FpjUrD77yLhE"
},
"outputs": [],
"source": [
"stacked_ae_1_by_1 = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(),\n",
" enc1, enc2, dec2, dec1,\n",
" tf.keras.layers.Reshape([28, 28])\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"id": "KbEe4tQqyLhE",
"outputId": "52eb1adf-f4c4-4029-f68b-1ef420a5daed"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAACvCAYAAACcuYvQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACjy0lEQVR4nO19V49kV5Lel97bsm3Z5JLb5Ay55BgJOzO7+zBaaAFJT3rUrxOgJz3oQRAgaCAJWg2gFQarcbscDk2T7CbblMtK7zNv6qH0nfpu9Mksl1nVLWYAhapKc++5ceKE+SJOnMhsNsOa1rSmNa1pTa8bRW96AGta05rWtKY1XYbWBmxNa1rTmtb0WtLagK1pTWta05peS1obsDWtaU1rWtNrSWsDtqY1rWlNa3otaW3A1rSmNa1pTa8lxc94fyk19lqqH4lEXno/CIKXXvN97qL30uvMu95ZYzsHXW6gS+Lt/+d0Gd4uXWZnsxlmsxn6/T4GgwG++eYb/Lt/9+/wxz/+EePxGOPxGPF4HHfv3sXGxgby+Ty2t7eRSqUQjUYRjUYxGo1wdHSEXq+HRqOBvb09DAYDvHjxArVaDW+88Qb+2T/7Z9jd3cWHH36IH/7wh4jH4yH5veya8NCNyux0OkUQBHjx4gX+4R/+AcPhEN/73vfw8OFDRKPX41O322389re/xYsXL7C5uYn79+8jnU6jUqkgn89f5dJrfbA6eom3Zxmw5dw1EsFsNgstQC78TqeDzz//HI1GA9FoFJFIBLFYDOVyGdlsFrFYzL2u359Op5jNZm4xBEHgFMxkMkG/30cQBNje3sbu7i5SqRSq1SoymcyZY1vTd5eCIMB4PMbe3h6azSZ6vR6Oj48xGo3Q6/XQ6/VQr9fx7Nkz9Pt9Z+hmsxlGoxG63S4mkwmm0yni8dPlNZlM0G63MRwO0el0MBqNMJlMEIlEkEgkMBqN8PjxYxwfH+Pw8BC///3vkUqlsLGxgUwmg0qlgtu3byOVSiGXyyGdTt8Ui65MkUgE0WgU2WwWt27dwng8RqFQuJY1SGckEomgUChgPB6jUqmgWCwilUohkUisfAxrWh5diwEDXo5uuNifP3+O//bf/hu++uorZ7ySySTefvtt7OzsuP8p9NFo1CmZ6XSKyWTilMHh4SHq9Tr6/T5qtRrG4zF+8IMf4Ec/+pETUGvAfGNb03eTZrMZgiDAYDDAo0eP8PXXX+Pw8BCfffYZut0u2u02ut0uptOpk79kMolkMgkAGAwGiEQi6HQ6aDQaABBysFRmB4MBgiBAJBJBJpPBcDjEZ599hlgshk6ng3a7jVwuh/feew9bW1t4++238ed//ucoFovY3d19rQ1YNBrFbDZDPp/HvXv3EAQB8vn8tRmw6XSKSCSCSqWCRCKBcrns/l7rgteLVm7AZrMZhsMhxuOxW7jT6RTdbhe9Xg8vXrzA0dERjo6OEIvFnME6ODgAAMRiMecVxeNxJ/yqDIbDISaTCWq1GlqtFvr9Po6PjzGdTnF0dIS9vT30ej0Ui0WMRiNnCKPRKDKZDNLptDOQrxpR+U2nUxddcuwk36K7SocVC59d9trzvkvFnUqlEI/HQ07KddN0OkW/38dkMsFkMsF4PHZRVqPRQLvddpE9HSUAIWSAFARBKCLje5YPjAKCIEAsFnNyTZQhCAIn191uF+l0Go1GAwcHB+j3+wCA4XCIRCKBTCbjrqER3+tAjD7Jh+u+dzweRzqdRiKRQCwWWxuv15AiZyijK+Oyg8EAn376KV68eIH9/X188skn6HQ6GAwGGI/H6Ha7ePz4MTqdjjMiDO/pZfK1RCLhjNlkMnELfTQaYTqdOgVDpTSbzVCpVLCxsYF0Oo3d3V3k83lks1kUi0Vks1l89NFHePjwIZLJJHK53GWUwMow79ls5rz+RqOBjz/+GLVaDZlMBrlcLqT01ADT6Nlr8Yf/M+K17xOe1etQOfN131gdQ/6fIuBcAHBjozFIpVL4kz/5E2xubmJjYwNvvvmmi2SEVp4Dq9fr+N3vfofDw0O02200Gg0Mh0McHByg1Wo5+aKSjcVimE6nGA6H7nev10MkEkGxWEQmk3GyFIvFHA9pHKfTqYvQaAyj0SiSySTK5TISiQR6vR663S4AIJVKIRaLIZVKoVgsIpFIOBne3NzED37wA1SrVWxubmJnZ+e8TtiN5mlUzobDIWazGVKplG/+l06UX3UIOV9LMmDrHNjq6PpzYNPpFPv7+w6S+du//VvU63Xn8QZBgNFo5BQghejFixcAwooxlUohlUoBOFGOXASj0Sj0XaV6vY6vvvrKQQXpdBrFYhFbW1vu97179zCbzZDNZlfKi/OQzceNRiN0Oh0cHR3hk08+wbNnz5DP51GtVl0kpvBqJBLBdDp1hkMNFg0PFUg0GnWeJ3kJnOQnuchpuPg/cJp/tOMGTosNmBMaj8ehSIXwWD6fd45GJBLBvXv3rkWBWRoMBnj8+DG++eYb1Go17O/vu/zsZDJBIpFANptFMplENptFJpPBdDpFr9fDeDxGu91Gr9cLOQc0SIwuNEfL35TfRCKBeDzunKpkMol0Oo1sNuscvMFggOPjY3z11VcA4K59//59VKtVTCYTpNNpbG9vXzv/LkvkFR2o60Q/6Lil02nnmKyjr9eTVmbA9vb28O2336LZbOJ3v/sdHj9+jP39fWdsVHESFgTmw2EM+VXgFYbxkRoDfmYymbjEfL/fx+9//3sMh0MUi0W8+eabKBQKKJVKqFarNyLUek9WvjWbTffTarUwm80chKWKk+QzXvMiMIW5+J4aLkYJ+r4vurPPQIPIfAMj28FggOFwiGg0ilqthng8jkKh4MZ8HRQEAQ4PD3F8fIyjoyOXN53NZi6CIg+i0aiDmGOxmOMbE/40HiwcGg6HTjHG43HHN0KCNH7kIfOyzL9wPSSTSSezqVQK6XQaqVQqJO/j8RhfffUVjo+PMZvNUC6XQ9e7aaJzah0n6xgBCKErdJDs+rsMdL0IvmVEHI/HHYRNZyMajboqUh3bml4tWokBm81m+MMf/oD/8B/+A+r1Oh49eoS9vb0QzEUPiJ/X7wJhRX5WKbyPuECsQZhMJmg0GqjVaohEInjy5AnS6TRu3bqFv/qrv8Lu7i6+//3vo1Qq3XhOIQgCNJtNPHv2DHt7e+6n1Wo5iMkXydpn9hkvXt9nAO13lXzl3GcpFiqDSCSCwWDgYMRHjx5hf38f6XQa4/H4Ysy5Ak2nU3z66af49a9/jV6vh1qthsFggFgshlKpFIJl7feo3IrFooMUR6MRRqMRDg4O0Gw2Hdyn8kPISp8zkUggl8uhWq2GPscKPQDI5XIA4CBIKl7mk3/5y18COIlsy+UyCoUCbt269Uoo3Mlkglar5WBC/mgFJnlMh4ByTMfH52DNc3atw6Y5SUuEcsfjsZuHSCTi8vWsWk4mkygWiygWi+so7RWkpWpoFc56vY4XL144L7fRaDj8ngvfKol5e7gAuOpDJRrDsxQqPVuFc6jsmRQHgIODA8TjcQcJvQo0HA5dHoyFBMlkEsPhEACcMiOU6jNg/G0VAD1d37OSZ/Y1/vbNh15bP8/9TMwfMSdE+HAwGFwrv4MgcIUaw+HQRUZaUEHY0/KNBSgsLkokEkilUu4z4/HYKUJGr7wn1waRBP4kEgnnXGl0rTlK/m8h3Xa7jclkgmaziU6n43j8KhCfaTweh5wmQsua89bnUwNmr6e/z5JzH9TN79EBGAwGSKfT7v4cG3AC3xOSX9PlSZ0XXT+2yOkyDsJSDdjx8TH+9//+39jf38dvf/tbfPnll+j3+5hOp8jn80gkEiFhsQO2SlMfzKcs6Q3r563XRfiKi0kLHghnUJg//fRT7O3tYXd3F9Pp9Ma92Ol0im+++QZ///d/77xFVp1xUfF5CXWR5kEn9jWFd+Z9V+ksIdP7WIMHwBUkqJJXSG2VpPPNPNdsNnN8owED4AyEyikNMf8PgsDlnqjsNLdHh0Jhx0wmg2w2i1Kp5PJqiUTCjYtj5D20MEevFY/HnfEkHPvxxx9jY2MDlUoFlUpl5fycRwoVjsd
"text/plain": [
"<Figure size 540x216 with 10 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_reconstructions(stacked_ae_1_by_1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If needed, we can then continue training the full stacked autoencoder for a few epochs:"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"id": "vRgEnA-WyLhF",
"outputId": "3d367c4e-82d4-48d9-8757-b106edf00650"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"1719/1719 [==============================] - 8s 4ms/step - loss: 0.0173 - val_loss: 0.0161\n",
"Epoch 2/5\n",
"1719/1719 [==============================] - 7s 4ms/step - loss: 0.0151 - val_loss: 0.0144\n",
"Epoch 3/5\n",
"1719/1719 [==============================] - 7s 4ms/step - loss: 0.0142 - val_loss: 0.0141\n",
"Epoch 4/5\n",
"1719/1719 [==============================] - 5s 3ms/step - loss: 0.0137 - val_loss: 0.0136\n",
"Epoch 5/5\n",
"1719/1719 [==============================] - 6s 3ms/step - loss: 0.0133 - val_loss: 0.0133\n"
]
}
],
"source": [
"stacked_ae_1_by_1.compile(loss=\"mse\", optimizer=\"nadam\")\n",
"history = stacked_ae_1_by_1.fit(X_train, X_train, epochs=5,\n",
" validation_data=(X_valid, X_valid))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"id": "Ln1j0u4eyLhF",
"outputId": "fef0981b-6b8d-4e9d-e837-9e3c20358e9e"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAACvCAYAAACcuYvQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACgLElEQVR4nO392Y9kV3YdDq8b8zzlVFkTixSpanaTIlst/6D5oS1YgO0nP/qvM+AnP/jBMGBYsA3LAixDkHuSKDZnVpE15RQZ8zx8D/Wtk+vuPDcyMjMiq8qMDQQiM+LGvWfYZw9r77NPMJ/PsaENbWhDG9rQm0axV92ADW1oQxva0IauQhsFtqENbWhDG3ojaaPANrShDW1oQ28kbRTYhja0oQ1t6I2kjQLb0IY2tKENvZG0UWAb2tCGNrShN5ISF3y/khx7TdUPguDc97PZ7Nxnvusu+yy9T9T9LmrbEnS1hq5obP8fp6uM7cp5dj6fYz6fo9/vYzAY4LvvvsO///f/Hr/97W8xHo8xHo+RSCRw9+5dbG1toVAoYHd3F+l0GrFYDLFYDKPRCMfHx+j1emg0Gnjx4gUGgwGeP3+Ok5MTvPXWW/jn//yf49atW/joo4/w+7//+0gkEiH+veqa8NAr5dnpdIrZbIbnz5/jH/7hHzAcDvHjH/8YDx8+RCx2MzZ1u93Gr371Kzx//hzb29u4f/8+MpkMqtUqCoXCdW69kQfro3Nje5ECW81TgwDz+Ty0ALnwO50OvvjiCzQaDcRiMQRBgHg8jkqlglwuh3g87j7X30+nU8znc7cYZrOZEzCTyQT9fh+z2Qy7u7u4desW0uk0arUastnshW3b0A+XZrMZxuMxXrx4gWaziV6vh3q9jtFohF6vh16vh9PTUzx9+hT9ft8puvl8jtFohG63i8lkgul0ikTibHlNJhO0220Mh0N0Oh2MRiNMJhMEQYBkMonRaIRHjx6hXq/j6OgIv/nNb5BOp7G1tYVsNotqtYrbt28jnU4jn88jk8m8qiG6NgVBgFgshlwuh/39fYzHYxSLxRtZgzRGgiBAsVjEeDxGtVpFqVRCOp1GMplcexs2tDq6EQUGnPduuNifPXuG//7f/zu++eYbp7xSqRTeffdd7O3tuf/J9LFYzAmZ6XSKyWTihMHR0RFOT0/R7/dxcnKC8XiMn/70p/jZz37mGNQqMF/bNvTDpPl8jtlshsFggK+++grffvstjo6O8Pnnn6Pb7aLdbqPb7WI6nTr+S6VSSKVSAIDBYIAgCNDpdNBoNAAgZGApzw4GA8xmMwRBgGw2i+FwiM8//xzxeBydTgftdhv5fB7vv/8+dnZ28O677+IP//APUSqVcOvWrTdagcViMczncxQKBdy7dw+z2QyFQuHGFNh0OkUQBKhWq0gmk6hUKu7vjSx4s2jtCmw+n2M4HGI8HruFO51O0e120ev18Pz5cxwfH+P4+BjxeNwprMPDQwBAPB53VlEikXDMr8JgOBxiMpng5OQErVYL/X4f9Xod0+kUx8fHePHiBXq9HkqlEkajkVOEsVgM2WwWmUzGKcjXjSj8ptOp8y7ZdpJv0V2nwoqFz65676jfUnCn02kkEomQkXLTNJ1O0e/3MZlMMJlMMB6PnZfVaDTQbredZ09DCUAIGSDNZrOQR8bv7DjQC5jNZojH446viTLMZjPH191uF5lMBo1GA4eHh+j3+wCA4XCIZDKJbDbr7qEe35tA9D45Djf97EQigUwmg2QyiXg8vlFebyAFFwija+Oyg8EAn332GZ4/f46DgwN8+umn6HQ6GAwGGI/H6Ha7ePToETqdjlMidO9pZfKzZDLplNlkMnELfTQaYTqdOgFDoTSfz1GtVrG1tYVMJoNbt26hUCggl8uhVCohl8vh448/xsOHD5FKpZDP568iBNaGec/nc2f1NxoNfPLJJzg5OUE2m0U+nw8JPVXAVHr2Xnzxf3q89nvCs3ofCmd+7murG5D/vyDgXABwbaMySKfT+J3f+R1sb29ja2sLb7/9tvNkhNYeAzs9PcWvf/1rHB0dod1uo9FoYDgc4vDwEK1Wy/EXhWw8Hsd0OsVwOHTvvV4PQRCgVCohm806XorH424MqRyn06nz0KgMY7EYUqkUKpUKkskker0eut0uACCdTiMejyOdTqNUKiGZTDoe3t7exk9/+lPUajVsb29jb29vWSPslcZplM+GwyHm8znS6bRv/ldO5F81CDlfK1JgmxjY+ujmY2DT6RQHBwcOkvnrv/5rnJ6eOot3NpthNBo5AUgmev78OYCwYEyn00in0wBeCkcugtFoFPqt0unpKb755hsHFWQyGZRKJezs7Lj3e/fuYT6fI5fLrXUsliEbjxuNRuh0Ojg+Psann36Kp0+folAooFarOU9M4dUgCDCdTp3iUIVFxUMBEovFnOXJsQRexie5yKm4+D9wFn+07QbOkg0YExqPxyFPhfBYoVBwhkYQBLh3796NCDBLg8EAjx49wnfffYeTkxMcHBy4+OxkMkEymUQul0MqlUIul0M2m8V0OkWv18N4PEa73Uav1wsZB1RI9C40Rst38m8ymUQikXBGVSqVQiaTQS6XcwbeYDBAvV7HN998AwDu3vfv30etVsNkMkEmk8Hu7u6Nj99ViWNFA+om0Q8abplMxhkmG+/rzaS1KbAXL17g+++/R7PZxK9//Ws8evQIBwcHTtmo4CQsCETDYXT5leEVhvGRKgNeM5lMXGC+3+/jN7/5DYbDIUqlEt5++20Ui0WUy2XUarVXwtT6TGa+NZtN92q1WpjP5w7CUsFJ8imvKA9MYS5+p4qLXoJ+7/PubB+oEBlvoGc7GAwwHA4Ri8VwcnKCRCKBYrHo2nwTNJvNcHR0hHq9juPjYxc3nc/nzoPiGMRiMQcxx+NxN24M+FN5MHFoOBw6wZhIJNy4ERKk8uMYMi7L+AvXQyqVcjybTqeRyWSQTqdD/D4ej/HNN9+gXq9jPp+jUqmE7veqicapNZysYQQghK7QQLLr7yrQ9SL4lh5xIpFwEDaNjVgs5rJItW0ber1oLQpsPp/jn/7pn/Af/+N/xOnpKb766iu8ePEiBHPRAuL1+lsgLMgvSoX3EReIVQiTyQSNRgMnJycIggCPHz9GJpPB/v4+/vzP/xy3bt3CT37yE5TL5VceU5jNZmg2m3j69ClevHjhXq1Wy0FMPk/W9tmnvHh/nwK0v1XypXNfJFgoDIIgwGAwcDDiV199hYODA2QyGYzH48sNzjVoOp3is88+wy9+8Qv0ej2cnJxgMBggHo+jXC6HYFn7Owq3UqnkIMXRaITRaITDw0M0m00H9yn/ELLSfiaTSeTzedRqtdB1zNADgHw+DwAOgqTgZTz5b/7mbwC89GwrlQqKxSL29/dfC4E7mUzQarUcTMiXZmByjGkQkI9p+PgMrChj1xpsGpO0RCh3PB67eQiCwMXrmbWcSqVQKpVQKpU2XtprSCuV0Mqcp6eneP78ubNyG42Gw++58K2QiNrDBcBlHypRGV4kUGnZKpxDYc+gOAAcHh4ikUg4SOh1oOFw6OJgTCRIpVIYDocA4IQZoVSfAuO7FQC0dH195ZjZz/jumw+9t17P/UyMHzEmRPhwMBjc6HjPZjOXqDEcDp1npAkVhD3tuDEBhclFyWQS6XTaXTMej50gpPfKZ3JtEEngK5lMOuNKvWuNUfJ/C+m2221MJhM0m010Oh03xq8DsU/j8ThkNBFa1pi39k8VmL2fvl/E5z6om7+jATAYDJDJZNzz2TbgJXxPSH5DVyc1XnT92CSnqxgIK1Vg9Xod/+f//B8cHBzgV7/6Fb7++mv0+31Mp1MUCgUkk8kQs9gGW6GpHfMJS1rDer21ughfcTFpwgPhDDLzZ599hhcvXuDWrVuYTqev3IqdTqf47rvv8Pd///fOWmTWGRcV+0uoixQFndjPFN6J+q3SRUymz7EKD4BLSFAhr5DaOknnm3Gu+Xzuxo0KDIBTEMqnVMT8fzabudgThZ3G9mhQKOyYzWaRy+VQLpddXC2ZTLp2sY18hibm6L0SiYRTnoRjP/nkE2xtbaFaraJ
"text/plain": [
"<Figure size 540x216 with 10 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_reconstructions(stacked_ae_1_by_1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HGixWVtuyLhF"
},
"source": [
"## Convolutional Autoencoders"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xVK4vETzyLhF"
},
"source": [
"Let's build a stacked Autoencoder with 3 hidden layers and 1 output layer (i.e., 2 stacked Autoencoders)."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"id": "qom8ZHdgyLhF",
"outputId": "45b2544f-0812-4fc7-a4e5-d8df9f193f5f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 64s 37ms/step - loss: 0.0335 - val_loss: 0.0235\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 62s 36ms/step - loss: 0.0209 - val_loss: 0.0194\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 64s 37ms/step - loss: 0.0179 - val_loss: 0.0178\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 61s 36ms/step - loss: 0.0162 - val_loss: 0.0155\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 62s 36ms/step - loss: 0.0150 - val_loss: 0.0144\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 57s 33ms/step - loss: 0.0141 - val_loss: 0.0137\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 61s 36ms/step - loss: 0.0134 - val_loss: 0.0132\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 63s 37ms/step - loss: 0.0129 - val_loss: 0.0128\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 63s 36ms/step - loss: 0.0125 - val_loss: 0.0126\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 59s 34ms/step - loss: 0.0121 - val_loss: 0.0121\n"
]
}
],
"source": [
"tf.random.set_seed(42) # extra code ensures reproducibility on CPU\n",
"\n",
"conv_encoder = tf.keras.Sequential([\n",
" tf.keras.layers.Reshape([28, 28, 1]),\n",
" tf.keras.layers.Conv2D(16, 3, padding=\"same\", activation=\"relu\"),\n",
" tf.keras.layers.MaxPool2D(pool_size=2), # output: 14 × 14 x 16\n",
" tf.keras.layers.Conv2D(32, 3, padding=\"same\", activation=\"relu\"),\n",
" tf.keras.layers.MaxPool2D(pool_size=2), # output: 7 × 7 x 32\n",
" tf.keras.layers.Conv2D(64, 3, padding=\"same\", activation=\"relu\"),\n",
" tf.keras.layers.MaxPool2D(pool_size=2), # output: 3 × 3 x 64\n",
" tf.keras.layers.Conv2D(30, 3, padding=\"same\", activation=\"relu\"),\n",
" tf.keras.layers.GlobalAvgPool2D() # output: 30\n",
"])\n",
"conv_decoder = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(3 * 3 * 16),\n",
" tf.keras.layers.Reshape((3, 3, 16)),\n",
" tf.keras.layers.Conv2DTranspose(32, 3, strides=2, activation=\"relu\"),\n",
" tf.keras.layers.Conv2DTranspose(16, 3, strides=2, padding=\"same\",\n",
" activation=\"relu\"),\n",
" tf.keras.layers.Conv2DTranspose(1, 3, strides=2, padding=\"same\"),\n",
" tf.keras.layers.Reshape([28, 28])\n",
"])\n",
"conv_ae = tf.keras.Sequential([conv_encoder, conv_decoder])\n",
"\n",
"# extra code compiles and fits the model\n",
"conv_ae.compile(loss=\"mse\", optimizer=\"nadam\")\n",
"history = conv_ae.fit(X_train, X_train, epochs=10,\n",
" validation_data=(X_valid, X_valid))"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"id": "6p0Yud0RyLhG",
"outputId": "7112e556-b714-4645-af12-3bb3bd7cd6ce"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAACvCAYAAACcuYvQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACczklEQVR4nO292Y+c2ZEdfnLf19q5s9UtqtXd063FnpE0owd54AFsP/nRf50BP/nBD4YBw4JteCzAYwhjbTOtXtktspusYi257/vvgb9z63xRN7OKZGZVcZgBJLIq88tviRs3lhNx44ZmsxnWtKY1rWlNa3rdKHzVN7CmNa1pTWta08vQ2oCtaU1rWtOaXktaG7A1rWlNa1rTa0lrA7amNa1pTWt6LWltwNa0pjWtaU2vJa0N2JrWtKY1rem1pOg53y+lxl5L9UOh0Jnvp9Ppmc98x73otfQ888533r1dgF7uRpfE23/i9DK8XbrMzmYzzGYz9Ho99Pt9fPPNN/gP/+E/4NNPP8VoNMJoNEI0GsWtW7ewsbGBbDaL7e1tJBIJhMNhhMNhDIdDnJycoNvtol6v49mzZ+j3+zg4OEClUsHdu3fxL/7Fv8Du7i4+/PBD/PCHP0Q0Gg3I78vOCQ9dqcxOJhNMp1McHBzgH/7hHzAYDPD9738fDx48QDh8OT51q9XC7373OxwcHGBzcxN37txBMplEqVRCNpt9lVOv9cHq6AxvzzNgy7lqKITZbBaYgJz47XYbX3zxBer1OsLhMEKhECKRCIrFItLpNCKRiPtcfz+ZTDCbzdxkmE6nTsGMx2P0ej1Mp1Nsb29jd3cXiUQC5XIZqVTq3Htb05tL0+kUo9EIz549Q6PRQLfbRbVaxXA4RLfbRbfbRa1Ww9OnT9Hr9Zyhm81mGA6H6HQ6GI/HmEwmiEZPp9d4PEar1cJgMEC73cZwOMR4PEYoFEIsFsNwOMSjR49QrVZxfHyMP/zhD0gkEtjY2EAqlUKpVMKNGzeQSCSQyWSQTCavikWvTKFQCOFwGOl0Gnt7exiNRsjlcpcyB+mMhEIh5HI5jEYjlEol5PN5JBIJxGKxld/DmpZHl2LAgLPRDSf7/v4+/sf/+B/4+uuvnfGKx+N4++23sbOz4/6n0IfDYadkJpMJxuOxUwbHx8eo1Wro9XqoVCoYjUb4wQ9+gB/96EdOQK0B893bmt5Mms1mmE6n6Pf7ePjwIf70pz/h+PgYn3/+OTqdDlqtFjqdDiaTiZO/eDyOeDwOAOj3+wiFQmi326jX6wAQcLBUZvv9PqbTKUKhEFKpFAaDAT7//HNEIhG02220Wi1kMhm8++672Nrawttvv42/+Iu/QD6fx+7u7mttwMLhMGazGbLZLG7fvo3pdIpsNntpBmwymSAUCqFUKiEWi6FYLLq/17rg9aKVG7DZbIbBYIDRaOQm7mQyQafTQbfbxcHBAU5OTnBycoJIJOIM1tHREQAgEok4rygajTrhV2UwGAwwHo9RqVTQbDbR6/VQrVYxmUxwcnKCZ8+eodvtIp/PYzgcOkMYDoeRSqWQTCadgbxuROU3mUxcdMl7J/km3at0WLHw2cuee95vqbgTiQSi0WjASblsmkwm6PV6GI/HGI/HGI1GLsqq1+totVousqejBCCADJCm02kgIuN3lg+MAqbTKSKRiJNrogzT6dTJdafTQTKZRL1ex9HREXq9HgBgMBggFoshlUq5c2jE9zoQo0/y4bKvHY1GkUwmEYvFEIlE1sbrNaTQOcrolXHZfr+Pzz77DAcHBzg8PMQnn3yCdruNfr+P0WiETqeDR48eod1uOyPC8J5eJj+LxWLOmI3HYzfRh8MhJpOJUzBUSrPZDKVSCRsbG0gmk9jd3UU2m0U6nUY+n0c6ncZHH32EBw8eIB6PI5PJvIwSWBnmPZvNnNdfr9fx8ccfo1KpIJVKIZPJBJSeGmAaPXsuvvg/I177PeFZPQ+VMz/33atjyP+vCDgWANy90RgkEgl85zvfwebmJjY2NnD//n0XyQitPAdWq9Xw+9//HsfHx2i1WqjX6xgMBjg6OkKz2XTyRSUbiUQwmUwwGAzce7fbRSgUQj6fRyqVcrIUiUQcD2kcJ5OJi9BoDMPhMOLxOIrFImKxGLrdLjqdDgAgkUggEokgkUggn88jFos5Gd7c3MQPfvADlMtlbG5uYmdn56JO2JXmaVTOBoMBZrMZEomEb/yXTpRfdQg5XksyYOsc2Oro8nNgk8kEh4eHDpL527/9W9RqNefxTqdTDIdDpwApRAcHBwCCijGRSCCRSAB4rhw5CYbDYeC3SrVaDV9//bWDCpLJJPL5PLa2ttz77du3MZvNkE6nV8qLi5DNxw2HQ7TbbZycnOCTTz7B06dPkc1mUS6XXSSm8GooFMJkMnGGQw0WDQ8VSDgcdp4neQk8z09yktNw8X/gNP9o7xs4LTZgTmg0GgUiFcJj2WzWORqhUAi3b9++FAVmqd/v49GjR/jmm29QqVRweHjo8rPj8RixWAzpdBrxeBzpdBqpVAqTyQTdbhej0QitVgvdbjfgHNAgMbrQHC3fKb+xWAzRaNQ5VfF4HMlkEul02jl4/X4f1WoVX3/9NQC4c9+5cwflchnj8RjJZBLb29uXzr+XJfKKDtRloh903JLJpHNM1tHX60krM2DPnj3Dt99+i0ajgd///vd49OgRDg8PnbFRxUlYEJgPhzHkV4FXGMZHagx4zHg8don5Xq+HP/zhDxgMBsjn87h//z5yuRwKhQLK5fKVCLVek5VvjUbDvZrNJmazmYOwVHGSfMZrXgSmMBe/U8PFKEG/90V39hloEJlvYGTb7/cxGAwQDodRqVQQjUaRy+XcPV8GTadTHB8fo1qt4uTkxOVNZ7OZi6DIg3A47CDmSCTi+MaEP40HC4cGg4FTjNFo1PGNkCCNH3nIvCzzL5wP8XjcyWwikUAymUQikQjI+2g0wtdff41qtYrZbIZisRg431UTnVPrOFnHCEAAXaGDZOffy0DXi+BbRsTRaNRB2HQ2wuGwqyLVe1vT9aKVGLDZbIY//vGP+E//6T+hVqvh4cOHePbsWQDmogfE4/W3QFCRn1cK7yNOEGsQxuMx6vU6KpUKQqEQHj9+jGQyib29Pfz85z/H7u4u3nvvPRQKhSvPKUynUzQaDTx9+hTPnj1zr2az6SAmXyRrn9lnvHh+nwG0v1XylXOfp1ioDEKhEPr9voMRHz58iMPDQySTSYxGoxdjzivQZDLBZ599ht/85jfodruoVCro9/uIRCIoFAoBWNb+jsotn887SHE4HGI4HOLo6AiNRsPBfSo/hKz0OWOxGDKZDMrlcuA4VugBQCaTAQAHQVLxMp/8q1/9CsDzyLZYLCKXy2Fvb+9aKNzxeIxms+lgQr60ApM8pkNAOabj43Ow5jm71mHTnKQlQrmj0ciNQygUcvl6Vi3H43Hk83nk8/l1lHYNaakaWoWzVqvh4ODAebn1et3h95z4VknMW8MFwFUfKtEYnqdQ6dkqnENlz6Q4ABwdHSEajTpI6DrQYDBweTAWEsTjcQwGAwBwyoxQqs+A8d0qAHq6vmclz+xnfPeNh55bj+d6JuaPmBMifNjv9y+V39Pp1BVqDAYDFxlpQQVhT8s3FqCwuCgWiyGRSLhjRqORU4SMXnlNzg0iCXzFYjHnXGl0rTlK/m8h3VarhfF4jEajgXa77Xh8HYjPNBqNAk4ToWXNeevzqQGz59P38+TcB3Xzd3QA+v0+ksmkuz7vDXgO3xOSX9PLkzovOn9skdPLOAhLNWDVahX/9//+XxweHuJ3v/sdvvrqK/R6PUwmE2SzWcRisYCw2Bu2SlMfzKcs6Q3r8dbrInzFyaQFD4QzKMyfffYZnj17ht3dXUwmkyv3YieTCb755hv8/d//vfMWWXXGScXnJdRFmged2M8U3pn3W6XzhEyvYw0eAFeQoEpeIbVVko4381yz2czxjQYMgDMQKqc0xPx/Op263BOVneb26FAo7JhKpZBOp1EoFFxeLRaLufviPfIaWpij54pGo854Eo79+OOPsbGxgVKphFKptHJ+ziOFCke
"text/plain": [
"<Figure size 540x216 with 10 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code shows the reconstructions\n",
"plot_reconstructions(conv_ae)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FKuHGp1DyLhG"
},
"source": [
"# Extra Material Recurrent Autoencoders"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's treat each Fashion MNIST image as a sequence of 28 vectors, each with 28 dimensions:"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {
"id": "ZGOuVy0kyLhG"
},
"outputs": [],
"source": [
"tf.random.set_seed(42)\n",
"\n",
"recurrent_encoder = tf.keras.Sequential([\n",
" tf.keras.layers.LSTM(100, return_sequences=True),\n",
" tf.keras.layers.LSTM(30)\n",
"])\n",
"recurrent_decoder = tf.keras.Sequential([\n",
" tf.keras.layers.RepeatVector(28),\n",
" tf.keras.layers.LSTM(100, return_sequences=True),\n",
" tf.keras.layers.Dense(28)\n",
"])\n",
"recurrent_ae = tf.keras.Sequential([recurrent_encoder, recurrent_decoder])\n",
"recurrent_ae.compile(loss=\"mse\", optimizer=\"nadam\")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"id": "4grialnWyLhG",
"outputId": "e6bb8d79-9a57-4fd3-e336-b81687de6781"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 224s 128ms/step - loss: 0.0293 - val_loss: 0.0208\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 242s 141ms/step - loss: 0.0190 - val_loss: 0.0171\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 253s 147ms/step - loss: 0.0162 - val_loss: 0.0172\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 286s 166ms/step - loss: 0.0146 - val_loss: 0.0137\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 224s 130ms/step - loss: 0.0134 - val_loss: 0.0130\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 211s 123ms/step - loss: 0.0126 - val_loss: 0.0121\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 239s 139ms/step - loss: 0.0119 - val_loss: 0.0120\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 273s 159ms/step - loss: 0.0113 - val_loss: 0.0112\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 256s 149ms/step - loss: 0.0109 - val_loss: 0.0111\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 285s 166ms/step - loss: 0.0106 - val_loss: 0.0104\n"
]
}
],
"source": [
"history = recurrent_ae.fit(X_train, X_train, epochs=10,\n",
" validation_data=(X_valid, X_valid))"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"id": "ZLOnTctFyLhG",
"outputId": "e5b59e40-09bc-4d82-bb6d-58ae7d7cc376"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAACvCAYAAACcuYvQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACZiklEQVR4nO19R49kV5beF97bdJWVVcUim+xqdpND9vRoMH7RGmgASSst9esEaKWFFoIAQQNJ0GgAjdAatR0Ovakiy6UL760Wqe/mFyfvi4jMjMjKEuMAgciMePHeNece851zzw1Np1NsaEMb2tCGNvS6UfhVN2BDG9rQhja0oavQRoFtaEMb2tCGXkvaKLANbWhDG9rQa0kbBbahDW1oQxt6LWmjwDa0oQ1taEOvJW0U2IY2tKENbei1pOiC71eSY6+p+qFQ6ML3k8nkwme+6y77LL1P0P0WtW0JulpDVzS2/5/TVcZ25Tw7nU4xnU7R7XbR6/Xw7bff4t/+23+LTz75BMPhEMPhENFoFPfu3cPW1hay2Sx2d3eRSCQQDocRDocxGAxwcnKCTqeDWq2Gly9fotfr4cWLFzg9PcUbb7yBf/pP/ynu3LmDDz74AL//+7+PaDQ6w79XXRMeeqU8Ox6PMZlM8OLFC/zud79Dv9/Hj3/8Yzx69Ajh8M3Y1M1mE7/+9a/x4sULbG9v48GDB0gmkyiVSshms9e59UYerI8ujO0iBbaap4ZCmE6nMwuQC7/VauHzzz9HrVZDOBxGKBRCJBJBsVhEOp1GJBJxn+vvx+MxptOpWwyTycQJmNFohG63i8lkgt3dXdy5cweJRALlchmpVGph2zb0/aXJZILhcIiXL1+iXq+j0+mgUqlgMBig0+mg0+mgWq3i2bNn6Ha7TtFNp1MMBgO0222MRiOMx2NEo+fLazQaodlsot/vo9VqYTAYYDQaIRQKIRaLYTAY4PHjx6hUKjg+PsZvf/tbJBIJbG1tIZVKoVQq4e7du0gkEshkMkgmk69qiK5NoVAI4XAY6XQa+/v7GA6HyOVyN7IGaYyEQiHkcjkMh0OUSiXk83kkEgnEYrG1t2FDq6MbUWDARe+Gi/358+f4r//1v+Lrr792yisej+Ptt9/G3t6e+59MHw6HnZAZj8cYjUZOGBwfH6NaraLb7eL09BTD4RA//elP8bOf/cwxqFVgvrZt6PtJ0+kUk8kEvV4PX375Jb755hscHx/js88+Q7vdRrPZRLvdxng8dvwXj8cRj8cBAL1eD6FQCK1WC7VaDQBmDCzl2V6vh8lkglAohFQqhX6/j88++wyRSAStVgvNZhOZTAbvvvsudnZ28Pbbb+OP/uiPkM/ncefOnddagYXDYUynU2SzWdy/fx+TyQTZbPbGFNh4PEYoFEKpVEIsFkOxWHR/b2TB60VrV2DT6RT9fh/D4dAt3PF4jHa7jU6ngxcvXuDk5AQnJyeIRCJOYR0dHQEAIpGIs4qi0ahjfhUG/X4fo9EIp6enaDQa6Ha7qFQqGI/HODk5wcuXL9HpdJDP5zEYDJwiDIfDSKVSSCaTTkHeNqLwG4/Hzrtk20m+RXedCisWPrvqvYN+S8GdSCQQjUZnjJSbpvF4jG63i9FohNFohOFw6LysWq2GZrPpPHsaSgBmkAHSZDKZ8cj4nR0HegGTyQSRSMTxNVGGyWTi+LrdbiOZTKJWq+Ho6AjdbhcA0O/3EYvFkEql3D3U43sdiN4nx+Gmnx2NRpFMJhGLxRCJRDbK6zWk0AJhdG1cttfr4dNPP8WLFy9weHiIjz/+GK1WC71eD8PhEO12G48fP0ar1XJKhO49rUx+FovFnDIbjUZuoQ8GA4zHYydgKJSm0ylKpRK2traQTCZx584dZLNZpNNp5PN5pNNpfPjhh3j06BHi8TgymcxVhMDaMO/pdOqs/lqtho8++ginp6dIpVLIZDIzQk8VMJWevRdf/J8er/2e8Kzeh8KZn/va6gbk/wkCzgUA1zYqg0QigR/84AfY3t7G1tYW3nzzTefJCK09BlatVvGb3/wGx8fHaDabqNVq6Pf7ODo6QqPRcPxFIRuJRDAej9Hv9917p9NBKBRCPp9HKpVyvBSJRNwYUjmOx2PnoVEZhsNhxONxFItFxGIxdDodtNttAEAikUAkEkEikUA+n0csFnM8vL29jZ/+9Kcol8vY3t7G3t7eskbYK43TKJ/1+31Mp1MkEgnf/K+cyL9qEHK+VqTANjGw9dHNx8DG4zEODw8dJPM3f/M3qFarzuKdTCYYDAZOAJKJXrx4AWBWMCYSCSQSCQBnwpGLYDAYzPxWqVqt4uuvv3ZQQTKZRD6fx87Ojnu/f/8+ptMp0un0WsdiGbLxuMFggFarhZOTE3z88cd49uwZstksyuWy88QUXg2FQhiPx05xqMKi4qEACYfDzvLkWAJn8Ukuciou/g+cxx9tu4HzZAPGhIbD4YynQngsm806QyMUCuH+/fs3IsAs9Xo9PH78GN9++y1OT09xeHjo4rOj0QixWAzpdBrxeBzpdBqpVArj8RidTgfD4RDNZhOdTmfGOKBConehMVq+k39jsRii0agzquLxOJLJJNLptDPwer0eKpUKvv76awBw937w4AHK5TJGoxGSySR2d3dvfPyuShwrGlA3iX7QcEsmk84w2XhfryetTYG9fPkS3333Her1On7zm9/g8ePHODw8dMpGBSdhQSAYDqPLrwyvMIyPVBnwmtFo5ALz3W4Xv/3tb9Hv95HP5/Hmm28il8uhUCigXC6/EqbWZzLzrV6vu1ej0cB0OnUQlgpOkk95BXlgCnPxO1Vc9BL0e593Z/tAhch4Az3bXq+Hfr+PcDiM09NTRKNR5HI51+aboMlkguPjY1QqFZycnLi46XQ6dR4UxyAcDjuIORKJuHFjwJ/Kg4lD/X7fCcZoNOrGjZAglR/HkHFZxl+4HuLxuOPZRCKBZDKJRCIxw+/D4RBff/01KpUKptMpisXizP1eNdE4tYaTNYwAzKArNJDs+rsKdD0PvqVHHI1GHYRNYyMcDrssUm3bhm4XrUWBTadT/OM//iP+/b//96hWq/jyyy/x8uXLGZiLFhCv198Cs4J8USq8j7hArEIYjUao1Wo4PT1FKBTCkydPkEwmsb+/j7/4i7/AnTt38JOf/ASFQuGVxxQmkwnq9TqePXuGly9fulej0XAQk8+TtX32KS/e36cA7W+VfOnciwQLhUEoFEKv13Mw4pdffonDw0Mkk0kMh8PLDc41aDwe49NPP8Uvf/lLdDodnJ6eotfrIRKJoFAozMCy9ncUbvl83kGKg8EAg8EAR0dHqNfrDu5T/iFkpf2MxWLIZDIol8sz1zFDDwAymQwAOAiSgpfx5L/9278FcObZFotF5HI57O/v3wqBOxqN0Gg0HEzIl2ZgcoxpEJCPafj4DKwgY9cabBqTtEQodzgcunkIhUIuXs+s5Xg8jnw+j3w+v/HSbiGtVEIrc1arVbx48cJZubVazeH3XPhWSATt4QLgsg+VqAwXCVRatgrnUNgzKA4AR0dHiEajDhK6DdTv910cjIkE8Xgc/X4fAJwwI5TqU2B8twKAlq6vrxwz+xnfffOh99bruZ+J8SPGhAgf9nq9Gx3vyWTiEjX6/b7zjDShgrCnHTcmoDC5KBaLIZFIuGuGw6EThPRe+UyuDSIJfMViMWdcqXetMUr+byHdZrOJ0WiEer2OVqvlxvg2EPs0HA5njCZCyxrz1v6pArP30/dFfO6Duvk7GgC9Xg/JZNI9n20DzuB7QvIbujqp8aLrxyY5XcVAWKkCq1Qq+F//63/h8PAQv/71r/HVV1+h2+1iPB4jm80iFovNMIttsBWa2jGfsKQ1rNdbq4vwFReTJjwQziAzf/rpp3j58iXu3LmD8Xj8yq3Y8XiMb7/9Fn//93/vrEVmnXFRsb+EukhB0In9TOGdoN8qLWIyfY5VeABcQoIKeYXU1kk634xzTadTN25UYACcglA+pSLm/5PJxMWeKOw0tkeDQmHHVCqFdDqNQqHg4mqxWMy1i23kMzQxR+8VjUad8iQc+9FHH2FrawulUgmlUmn
"text/plain": [
"<Figure size 540x216 with 10 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_reconstructions(recurrent_ae)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JafadovRyLhH"
},
"source": [
"# Denoising Autoencoders"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NdMv5zGbyLhH"
},
"source": [
"Using dropout:"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"id": "wTKWjK3EyLhH",
"outputId": "3be1a3bc-7f30-4793-955e-07bf60fc3b6e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 16s 9ms/step - loss: 0.0290 - val_loss: 0.0220\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 15s 9ms/step - loss: 0.0223 - val_loss: 0.0200\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 15s 9ms/step - loss: 0.0209 - val_loss: 0.0191\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 16s 9ms/step - loss: 0.0201 - val_loss: 0.0185\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 14s 8ms/step - loss: 0.0196 - val_loss: 0.0180\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 14s 8ms/step - loss: 0.0193 - val_loss: 0.0178\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 15s 9ms/step - loss: 0.0190 - val_loss: 0.0174\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 15s 9ms/step - loss: 0.0187 - val_loss: 0.0171\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 17s 10ms/step - loss: 0.0185 - val_loss: 0.0169\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 16s 9ms/step - loss: 0.0183 - val_loss: 0.0166\n"
]
}
],
"source": [
"tf.random.set_seed(42) # extra code ensures reproducibility on CPU\n",
"\n",
"dropout_encoder = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(),\n",
" tf.keras.layers.Dropout(0.5),\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(30, activation=\"relu\")\n",
"])\n",
"dropout_decoder = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(28 * 28),\n",
" tf.keras.layers.Reshape([28, 28])\n",
"])\n",
"dropout_ae = tf.keras.Sequential([dropout_encoder, dropout_decoder])\n",
"\n",
"# extra code compiles and fits the model\n",
"dropout_ae.compile(loss=\"mse\", optimizer=\"nadam\")\n",
"history = dropout_ae.fit(X_train, X_train, epochs=10,\n",
" validation_data=(X_valid, X_valid))"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"id": "Pucj7MI3yLhI",
"outputId": "d9a4289a-fde7-4e3b-c968-72421ddbfe6d"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAggAAADICAYAAACNixn+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA5xUlEQVR4nO2debRe4/XHn5hJQhJCyBwJkUYGSZCIKTXElBqqhtZSlC5KzTWs4qftKilFURZqKk1Dl5ZQc4gkDSIIIiMSmQeSGGKmv7/u9jnPOfu85733fe99U9/PXzv3nuE5z3DuyXc/e+9m//3vf4MQQgghBFmnqRsghBBCiNpDHwhCCCGESKEPBCGEEEKk0AeCEEIIIVLoA0EIIYQQKfSBIIQQQogU65X4fVViIL/55huz11nnO/uN0qxC11GcavVYK8do5cqVZv/4xz82++OPPzZ7woQJmefeeeediX/PmTPH7OnTp5v94IMPNrSZlWKtHKNaZN68eWZ36dKlkpfWGNU+mWP0nf3rLIQQQgifZiUSJTXqF9vSpUvNbteundnHHHOM2X//+9/NHj16tNnrrZcUQ374wx+afdNNN5nN/wXNnj3b7CeffLK+za4v+qqufdaaMTr44IPNXrRoUeYxr776qtlHH3202VTxRo0a5V53zZo1Zr/33ntmN2/e3OwXXnjB7BNPPNHsO+64I/8B6s9aM0a1whdffGH2Bhts0Bi31BjVPlIQhBBCCFEMfSAIIYQQIkWTuBhuuOEGs6+99lqz33nnnczjW7dunflztn316tVlt8O7bseOHc1+/fXXzZ4/f77ZnTp1Kvt+EVWV3dg3zZr5tzrllFPMvvXWW0ve7Morr8y8x0UXXWT2b37zG7MvvfRSs3/3u98lrvX555+bTdlz5MiRZv/yl780+/rrr89s05lnnmk2N+LdfvvtzlMUpiak0csvv9zsyy67zOwDDjjA7Mceeyzz3OOOO87se+65J/OYc845x+xrrrkm8bsTTjjB7BkzZpi9ySabmN2jRw+zV61aZfb9999vdr9+/cyeOnVqZjvqSU2MkcfXX39t9rrrrluNW5RNE7SppsdIhBDkYhBCCCFEUfSBIIQQQogUTeJiaNGihdmUhD122GEHsymXMyIhpnfv3mZPmzbN7D59+pj95ZdflnXdrbbayuxly5aVanYp/qdkt913391sxtjvsssuZr/44ouJc3bddVezOQ953JAhQzKP8Xj++efN3n///RO/o/w9efLkktcKNTJGJ598stmbbrqp2X/84x/N/tWvfmX2H/7wh8zrHHHEEWY/8MADmccMHjw48W/2Z//+/c3eeeedzb7lllvMPu200zKvy0gitoORRCGE8MYbb2Sen0NNjBHdo926dcs8hs+24447lvVzzn2+w2Lovvnqq6/MHjhwoHtOFk899ZTZ2267beJ33vPlUBNjJHKRi0EIIYQQxdAHghBCCCFSlEq13CBGjBhh9pgxY8wu4lYglP89mblXr16Jf9NN4Ml2RaCrYqONNjK7Z8+eieNmzpxZ1nWbCkr2IYTQpk0bsx955BGzhw8fbvbGG29s9r/+9a/M67Zq1cpsjjvdBUy6E0IyqQ535HvX5dh7u/bpVojnSsuWLTPPqTUYVRBCCLfddpvZjGLgcRtuuGHmtS6++GKz6SJgdAITKOVFGAwbNsxsujeOP/54s7kr3kuO5Lk31ibongwhKbszZfHbb79t9ve///3Ma3GeMmqKKenzokD4O9rPPvts5v08uLYPO+wws+vh9hFVYtasWWZPnDjR7JNOOilx3Ny5c83eeuutzebfsCIlD6QgCCGEECKFPhCEEEIIkUIfCEIIIYRIUdU9CFdddVXmzxk2w7At+kSIt5+gb9++mcfE0MfHc1577TWzd9ppJ7NfeeUVsxkiSbxQssaGoUwsWOWVbo33Tnh+4scff7ysdjz88MOZP+deBm98Q0juKdh3333NZtY3jz333NPs5557zj1u6NChZnvzoBZgVskQklkqWRjJ23dAmK2S+wNY1pn7FOoD/ZoMeSTcp3D33XebHYeiPvHEEw1qS2PBfg0hWSCL661I2WT233bbbVfyeL4z82jbtq3ZfI99+OGHZnPvFItrEb5zRcP55JNPzF6wYIHZDJ2nfeSRR5q9/fbbm33GGWeYHe9B+L//+z+zGZ7Pv1vevgMiBUEIIYQQKfSBIIQQQogUVXUxUA5hQRdKG5Tq8twEddBdUFQa5v28c+hWYLsZVsJwkSVLlhS6d1NBaZN9P2fOnELnH3744WZTmv7ss8/M9twKhHL5M888k/gdw+b4O2Zx89h7773NpluBoWRjx45NnENpj/OAhbdYkKupiF0HV1xxReZxDHn0YNZSFn0iXIN0Z8T3ZmgjQyw9qdIrFFU07Pjmm282+9RTT3WPawpiF0MRV8K7775rdufOnc323ArMMlnE9RBCMqMjQ7QJXQmei5FrxwvPFEmKhA2GkCx0xtDSCy+8MPP4X//612az4N1BBx3k3oNuPA+FOQohhBCiXugDQQghhBApqlqsiTt7KVNtsMEGZlPaYMZEZncjXkYxZonLO98r4uRFMRBK0XFWQBajKUiTFDCJpUq6TZjhkvKT15eMCmBWr6IUjT4oxR577GH2+PHj3eOKjPHChQvN7tChw1pZZOaSSy4x+7e//a3ZXkGnc88912y6EfKg5M9IGmZ99CIXmB1z9erVietyTnXt2tVsuigimmSM3n///cS/N99885Ln8H3Yvn37cm5XL5YvX272lltumXnMddddZ/ZZZ51lNl0MrVu3TpzDdVSQtXIdVZKlS5eaPWnSJLPpyiV0h9IlUR9efvllswcMGGB2FAGnYk1CCCGEKIY+EIQQQgiRoqouBsLkSJtttpnZ3OW+/vrrm01Zm+4D/nyXXXZx78fnos3EO5SZWd/+008/NdvbZf3Tn/408e+77rrLbYtDVWW3Dz74wGz2dx6777672RwXFlwilIMJ3Q177bWX2ePGjSt0b+4Qnzx5sntOHXRV0E0S78xlBEbBAjSNJo0yKoCJc0JIytleIqJy8dwKcaGoa665xmy6KD766COzKVXSXrFihdmU1z2XVQjJol1eQa6IJpGvKd+HkHTJce2xD4i3dsjixYu/bRzeYXTF5v1um222KXmPIjz//POJfw8ePLjcS3znXQyVgq46uiHiKB9GOzAKgsjFIIQQQoh6oQ8EIYQQQqSoaqKkl156yWxKcJTEvFoHlLEoSQ4ZMsRsSt90EYSQTO5DiWzXXXc1mzIfpWy6Qwh3VbPedi1S1K1ASZLy7qOPPpp5POV8njthwoTM49esWWN2PEZeTnhPwvRcSpxb3LEbw7En3q7/asNER3R/UNYPIYQDDzww83xvBzphnQWvFgbPjd0bdCtQkuTOarooRowYYTaTiXluhdNPPz3xb7ouapk4KuC9994zm4nWOP/pYvNcDEymRBcZ7dgtzHGlm9aDNXLOP/98s1mfo1evXmbXw6UgIvLq0NThJSuiO4t//+LEc8RzKxDW7nHbVPIIIYQQQnzn0AeCEEIIIVJUNYqBOfVPOeUUsynPt2vXzmwmk6AczTbSdcCfxzvtKSdT3qErgffgz5kYZNWqVSGLOD95nPe/AFXd2dutWzezmZ+9WlCGjHc9F4ElSVnqlHgJkepTurmgu6iqY/Tll1+azd3uceTGoYcemnlRug9YvyGuE1AHXQRMlJTnYuCOfNZG+PnPf2421wiP5y58upMoo8flnblW6fLKoaZ3yDMpDt0SdAWwLkMRGBESQjLpEvuW86t79+5m043TsmVLs72d73yGEEIYOHCg2Ux6l0NNj1FjULROQ7V5+umnzd5nn334K0UxCCGEEKIY+kAQQgghRIqquhjuu+8+sy+44AKzuQOa0mheApUsBg0aZDYjJmKYf5qJkmgXTJxTSaoqu7Vt29Zsytc33nhj4jjuIucOWcrUjAwoUj/Bk++HDx+eOO7xxx/PPJ9svPHGZtP1wygNyuKMBpgxY0bJ65dgrZFGvUgMlpC98sorM89lcqTmzZsnflduVAfnB0ul013
"text/plain": [
"<Figure size 540x216 with 10 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 179\n",
"tf.random.set_seed(42)\n",
"dropout = tf.keras.layers.Dropout(0.5)\n",
"plot_reconstructions(dropout_ae, dropout(X_valid, training=True))\n",
"save_fig(\"dropout_denoising_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you want, you can try replacing the `Dropout` layer with `tf.keras.layers.GaussianNoise(0.2)`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fd8xb5g-yLhI"
},
"source": [
"# Sparse Autoencoder"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kDseIwd4yLhJ"
},
"source": [
"Let's use the sigmoid activation function in the coding layer. Let's also add $\\ell_1$ regularization to it: to do this, we add an `ActivityRegularization` layer after the coding layer. Alternatively, we could add `activity_regularizer=tf.keras.regularizers.l1(1e-4)` to the coding layer itself."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"id": "ORrSHhNxyLhJ",
"outputId": "664ca147-e8f2-4ccf-960e-a86e60c2c60d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 21s 12ms/step - loss: 0.0294 - val_loss: 0.0204\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 21s 12ms/step - loss: 0.0184 - val_loss: 0.0176\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 21s 12ms/step - loss: 0.0159 - val_loss: 0.0152\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 22s 13ms/step - loss: 0.0145 - val_loss: 0.0139\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 21s 12ms/step - loss: 0.0134 - val_loss: 0.0131\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 21s 12ms/step - loss: 0.0128 - val_loss: 0.0124\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 20s 12ms/step - loss: 0.0122 - val_loss: 0.0120\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 18s 11ms/step - loss: 0.0118 - val_loss: 0.0119\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 18s 10ms/step - loss: 0.0114 - val_loss: 0.0114\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 19s 11ms/step - loss: 0.0111 - val_loss: 0.0111\n"
]
}
],
"source": [
"tf.random.set_seed(42) # extra code ensures reproducibility on CPU\n",
"\n",
"sparse_l1_encoder = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(),\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(300, activation=\"sigmoid\"),\n",
" tf.keras.layers.ActivityRegularization(l1=1e-4)\n",
"])\n",
"sparse_l1_decoder = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(28 * 28),\n",
" tf.keras.layers.Reshape([28, 28])\n",
"])\n",
"sparse_l1_ae = tf.keras.Sequential([sparse_l1_encoder, sparse_l1_decoder])\n",
"\n",
"# extra code compiles and fits the model\n",
"sparse_l1_ae.compile(loss=\"mse\", optimizer=\"nadam\")\n",
"history = sparse_l1_ae.fit(X_train, X_train, epochs=10,\n",
" validation_data=(X_valid, X_valid))"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"id": "2vU9twBAyLhK",
"outputId": "feb78185-b67e-4775-c61e-3d4a8d082c75"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAACvCAYAAACcuYvQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACiqklEQVR4nO292Y9kyXUe/uW+r7X33sMZNoec0QxF+geRlPRACxZg+8mP/usM+MkPfjAMGCZsw7IAyxBkbtJo9qV7prtrzazc9+X30P6ivns6btaWWd0t5gESWZV5896IExFn+c6JE5H5fI41rWlNa1rTmt40ir7qBqxpTWta05rWdBVaK7A1rWlNa1rTG0lrBbamNa1pTWt6I2mtwNa0pjWtaU1vJK0V2JrWtKY1remNpLUCW9Oa1rSmNb2RFD/n+6Xk2GuqfiQSeen72Wz20me+6y77LL1P2P3Oa9sF6GoNXRJv/4nTVXi79Dk7n88xn8/R7/cxGAzw7bff4t//+3+PTz75BOPxGOPxGPF4HHfu3MHGxgby+Ty2t7eRSqUQjUYRjUYxGo1wcnKCXq+HRqOBg4MDDAYD7O/vo1ar4f79+/jn//yfY3d3Fx988AH++I//GPF4PDB/r7omPPRK5+x0OsVsNsP+/j7+/u//HsPhED/84Q/x6NEjRKM3Y1O322389re/xf7+PjY3N3Hv3j2k02lUKhXk8/nr3HotD1ZHL/H2PAW2nKdGIpjP54EFyIXf6XTw+eefo9FoIBqNIhKJIBaLoVwuI5vNIhaLuc/199PpFPP53C2G2WzmBMxkMkG/38dsNsP29jZ2d3eRSqVQrVaRyWTObdua/nBpNpthPB7j4OAAzWYTvV4P9Xodo9EIvV4PvV4Pp6enePbsGfr9vlN08/kco9EI3W4Xk8kE0+kU8fjZ8ppMJmi32xgOh+h0OhiNRphMJohEIkgkEhiNRnj8+DHq9TqOj4/x+9//HqlUChsbG8hkMqhUKrh16xZSqRRyuRzS6fSrYtG1KRKJIBqNIpvNYm9vD+PxGIVC4UbWII2RSCSCQqGA8XiMSqWCYrGIVCqFRCKx8jasaXl0IwoMeNm74WJ//vw5/vt//+/4+uuvnfJKJpN4++23sbOz4/7npI9Go07ITKdTTCYTJwyOj49xenqKfr+PWq2G8XiMH//4x/jJT37iJqhVYL62rekPk+bzOWazGQaDAb788kt88803OD4+xmeffYZut4t2u41ut4vpdOrmXzKZRDKZBAAMBgNEIhF0Oh00Gg0ACBhYOmcHgwFmsxkikQgymQyGwyE+++wzxGIxdDodtNtt5HI5vPvuu9ja2sLbb7+NP/mTP0GxWMTu7u4brcCi0Sjm8zny+Tzu3r2L2WyGfD5/YwpsOp0iEomgUqkgkUigXC67v9ey4M2ilSuw+XyO4XCI8XjsFu50OkW320Wv18P+/j5OTk5wcnKCWCzmFNbR0REAIBaLOasoHo+7ya/CYDgcYjKZoFarodVqod/vo16vYzqd4uTkBAcHB+j1eigWixiNRk4RRqNRZDIZpNNppyBfN6Lwm06nzrtk20m+RXedCisWPrvqvcN+S8GdSqUQj8cDRspN03Q6Rb/fx2QywWQywXg8dl5Wo9FAu912nj0NJQABZIA0m80CHhm/s3ygFzCbzRCLxdy8Jsowm83cvO52u0in02g0Gjg6OkK/3wcADIdDJBIJZDIZdw/1+N4EovdJPtz0s+PxONLpNBKJBGKx2Fp5vYEUOUcYXRuXHQwG+PTTT7G/v4/Dw0N8/PHH6HQ6GAwGGI/H6Ha7ePz4MTqdjlMidO9pZfKzRCLhlNlkMnELfTQaYTqdOgFDoTSfz1GpVLCxsYF0Oo3d3V3k83lks1kUi0Vks1l8+OGHePToEZLJJHK53FWEwMow7/l87qz+RqOBjz76CLVaDZlMBrlcLiD0VAFT6dl78cX/6fHa7wnP6n0onPm5r62OIf9PEHAsALi2URmkUil873vfw+bmJjY2NvDw4UPnyQitPAZ2enqK3/3udzg+Pka73Uaj0cBwOMTR0RFarZabXxSysVgM0+kUw+HQvfd6PUQiERSLRWQyGTeXYrGY4yGV43Q6dR4alWE0GkUymUS5XEYikUCv10O32wUApFIpxGIxpFIpFItFJBIJN4c3Nzfx4x//GNVqFZubm9jZ2bmoEfZK4zQ6z4bDIebzOVKplG/8l06cv2oQcryWpMDWMbDV0c3HwKbTKQ4PDx0k81d/9Vc4PT11Fu9sNsNoNHICkJNof38fQFAwplIppFIpAC+EIxfBaDQK/Fbp9PQUX3/9tYMK0uk0isUitra23Pvdu3cxn8+RzWZXyouLkI3HjUYjdDodnJyc4OOPP8azZ8+Qz+dRrVadJ6bwaiQSwXQ6dYpDFRYVDwVINBp1lid5CbyIT3KRU3Hxf+As/mjbDZwlGzAmNB6PA54K4bF8Pu8MjUgkgrt3796IALM0GAzw+PFjfPvtt6jVajg8PHTx2clkgkQigWw2i2QyiWw2i0wmg+l0il6vh/F4jHa7jV6vFzAOqJDoXWiMlu+cv4lEAvF43BlVyWQS6XQa2WzWGXiDwQD1eh1ff/01ALh737t3D9VqFZPJBOl0Gtvb2zfOv6sSeUUD6ibRDxpu6XTaGSZr7+vNpJUpsIODA3z33XdoNpv43e9+h8ePH+Pw8NApGxWchAWBcDiMLr9OeIVhfKTKgNdMJhMXmO/3+/j973+P4XCIYrGIhw8folAooFQqoVqtvpJJrc9k5luz2XSvVquF+XzuICwVnCSf8grzwBTm4nequOgl6Pc+7872gQqR8QZ6toPBAMPhENFoFLVaDfF4HIVCwbX5Jmg2m+H4+Bj1eh0nJycubjqfz50HRR5Eo1EHMcdiMcc3BvypPJg4NBwOnWCMx+OOb4QEqfzIQ8ZlGX/hekgmk27OplIppNNppFKpwHwfj8f4+uuvUa/XMZ/PUS6XA/d71UTj1BpO1jACEEBXaCDZ9XcV6HoRfEuPOB6POwibxkY0GnVZpNq2Nb1etBIFNp/P8Y//+I/4j//xP+L09BRffvklDg4OAjAXLSBer78FgoL8vFR4H3GBWIUwmUzQaDRQq9UQiUTw5MkTpNNp7O3t4c///M+xu7uLH/3oRyiVSq88pjCbzdBsNvHs2TMcHBy4V6vVchCTz5O1ffYpL97fpwDtb5V86dznCRYKg0gkgsFg4GDEL7/8EoeHh0in0xiPx5djzjVoOp3i008/xa9//Wv0ej3UajUMBgPEYjGUSqUALGt/R+FWLBYdpDgajTAajXB0dIRms+ngPp0/hKy0n4lEArlcDtVqNXAdM/QAIJfLAYCDICl4GU/+67/+awAvPNtyuYxCoYC9vb3XQuBOJhO0Wi0HE/KlGZjkMQ0CzmMaPj4DK8zYtQabxiQtEcodj8duHCKRiIvXM2s5mUyiWCyiWCyuvbTXkJYqoXVynp6eYn9/31m5jUbD4fdc+FZIhO3hAuCyD5WoDM8TqLRsFc6hsGdQHACOjo4Qj8cdJPQ60HA4dHEwJhIkk0kMh0MAcMKMUKpPgfHdCgBaur6+kmf2M777xkPvrddzPxPjR4wJET4cDAY3yu/ZbOYSNYbDofOMNKGCsKflGxNQmFyUSCSQSqXcNePx2AlCeq98JtcGkQS+EomEM67Uu9YYJf+3kG673cZkMkGz2USn03E8fh2IfRqPxwGjidCyxry1f6rA7P30/bx57oO6+TsaAIPBAOl02j2fbQNewPeE5Nd0dVLjRdePTXK6ioGwVAVWr9fxf/7P/8Hh4SF++9vf4quvvkK/38d0OkU+n0cikQhMFttgKzS1Yz5hSWtYr7dWF+ErLiZNeCCcwcn86aef4uDgALu7u5hOp6/cip1Op/j222/xd3/3d85aZNYZFxX7S6iLFAad2M8U3gn7rdJ5k0yfYxUeAJeQoEJeIbVVko4341zz+dzxjQoMgFMQOk+piPn/bDZzsScKO43t0aBQ2DGTySCbzaJUKrm4WiKRcO1iG/kMTczRe8Xjcac8Ccd+9NFH2NjYQKVSQaVSWTk/w0i
"text/plain": [
"<Figure size 540x216 with 10 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code shows the reconstructions\n",
"plot_reconstructions(sparse_l1_ae)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "J3vxPDGUyLhK"
},
"source": [
"Let's plot the KL Divergence loss, versus the MAE and MSE:"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"id": "D0fOy7aMyLhK",
"outputId": "35a609c2-6d23-462f-a519-6d93b07d0ec4"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAEQCAYAAAD2/KAsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAABgxklEQVR4nO2dd1gVRxeH30EERCzYe5dYsCeWJBpr1Gg0GlusWGOPxq6JscSSaIxGYwd7r7FFo0bU2NHPir33FkUB6fP9MUhAQa/Avcu9zPs8++DdnZ05d4T97cycOUdIKdFoNBqNxtLYGW2ARqPRaJInWoA0Go1GYwhagDQajUZjCFqANBqNRmMIWoA0Go1GYwhagDQajUZjCPZGG5BYpE+fXhYqVMhoM5IcAQEBpE6d2mgzkhy6X2LHlvolPBxOngRXV8iXL+H12VLfJCZHjx59JKXMHJ97bUaAsmbNio+Pj9FmJDm8vb2pWrWq0WYkOXS/xI4t9cuECXD8OOzcCaVKJbw+W+qbxEQIcT2+9+opOI1GY3OEhcHUqVCtWuKIj8Y82MwISKPRaF6ybh3cvAnTphltieZN6BGQRqOxOSZPhoIFoV49oy3RvAktQBqNxqY4fBj274fevSFFCqOt0bwJLUAajcammDIF0qaF9u2NtkTzNpLNGtCzZ8948OABoaGhRptiUdKlS8fZs2eNNiPJkZB+SZ06Nbly5cLOTr+/JTWuX4cVK+CbbyBNGqOt0byNZCFAz5494/79++TMmZNUqVIhhDDaJIvx/Plz0ui/xNeIb79ERERw+/ZtHj16RJYsWcxgmSYh/PorCAF9+xpticYUksUr3IMHD8iZMyfOzs7JSnw0iY+dnR1Zs2bFz8/PaFM0r/D4McyZA61aQa5cRlujMYVkIUChoaGkSpXKaDM0NkLKlCkJCwsz2gzNK8yYAYGB0L+/0ZZoTCVZCBCgRz6aREP/LiU9XryA335Tbtfu7kZbozGVZCNAGo3GdlmwAB4+hIEDjbbE9gkJD+Hu87uJUpcWIA2rV6+O8VY/f/58XFxcDLRIozGd8HCYOBEqVIDKlY22xrY5fu845eeUp9GKRkTIiATXpwUoCePh4UH9+vVjnNu0aRPOzs4MGzYMgBEjRuCeyHMOzZs358qVK4lap0ZjLtatg8uX1ehHz46ah5DwEEZ6j+SDOR9wz/8eQz4egp1IuHwkCzdsW2HRokV06tSJn3/+mW+++cZs7aRKlcrsThshISE4ODiYtQ2N7SMl/PwzFC4MDRsabY1tcv3pdRoub8iJ+ydoWaIlv9X5jYzOGROlbj0CshKmTJlCp06dmDt3boLFZ+HCheTNmxdnZ2fq16/P/fv3Y1yPPgV34cIFhBCcOnUqRpnZs2eTKVOmqI29vr6+1KtXjzRp0pAlSxa++uor7t27F1X+5Wjup59+IleuXOSK9JM9dOgQZcuWxcnJiTJlyrBlyxaEEHh7e0fda2rdU6ZMIWfOnLi6utK+fXsCAwOjykgp+eWXXyhcuDCOjo4UKVKEIUOGRF2/ffs2LVq0wNXVFVdXV+rVq8fFixcT1M8a87N7Nxw5ojzfdNgd85AldRbSOaVjXfN1LGm8JNHEB7QAWQXff/89Q4YMYe3atbRp0yZBdR06dAgPDw+6dOnC8ePH+fzzzxk+fHic5d3c3Hj//fdZsmRJjPNLliyhefPmpEyZkrt371KlShXc3d05fPgwO3bswN/fnwYNGhAR8d888e7duzl58iRbt25l586d+Pv7U79+fYoUKcLRo0f5+eefGTBgQIx2TK177969nD59mh07drBixQrWrVvHlClToq4PHTqU0aNHM2TIEM6cOcOCBQvInTs3AIGBgVSrVg0nJyd2797NgQMHyJ49OzVr1owhYpqkx/jxkCULtG1rtCW2xcn7J2mysgkBIQGkSpkK73befFHki0RvJ1lOwfXpoxJVWZLSpVWE3ndl+/btbN68mU2bNlEvEUL7TpkyhRo1akStIbm5uXHkyBE8PT3jvKd169ZMmjSJcePGIYTg5s2b7N27l/HjxwMwY8YMSpUqxU8//RR1z8KFC8mQIQM+Pj6UL18eACcnJ7y8vHB0dARg1qxZhIeH4+npSapUqShevDjDhg2jVatWUfWYWnfatGmZMWMG9vb2FC1alKZNm7Jz506GDBmCv78/v/76K5MnT6ZDhw6ASmBYs2ZNAJYvX46Uknnz5kU5Y8yaNYssWbKwadMmmjVrlrBO15iFw4dh2zb46SdwcjLaGtsgNDyU8f+MZ/Se0bimcuXC4wuUyV7GbFsP9AgoiePu7k7BggUZOXIkT58+TXB9Z8+epVKlSjHOvfr5Vb766ivu3LnD3r17AVi6dCkFChSIuu/o0aPs2bMHFxeXqOPl6OLy5csxvstL8QE4d+4c7u7uMdabKlSoEKNtU+suVqwY9vb/vU/lyJGDBw8eAGoKLzg4mBo1asT6/Y4ePcrVq1dJkyZNVBvp0qXjyZMnMdrQJC1+/BEyZIBu3Yy2xDY4ef8kFeZWYLj3cJoUa4Jvd1/KZC9j1jaT5QgoPiMRo8iePTsbNmygevXq1KxZk+3bt+Pq6hrv+qSU73xPlixZqFmzJkuWLKFKlSosWbIkxiglIiKCevXqMXHixNfuzZo1a9S/U6dO/Zotb3uzMrXulClTxrgmhIiaonvbd46IiKB06dIsX778tWsZMmR4470aYzh+HDZuhFGjdNDRxKLvtr7cfn6btc3W0qhoI4u0mSwFyNrImTMn3t7eVK9enRo1arB9+3YyZozfQmCxYsU4ePBgjHOvfo6N1q1b06tXL7p06cKpU6dYs2ZN1LWyZcuycuVK8ubN+5oQvImiRYuycOFCXrx4ETUKOnz4cIwy8a07OsWKFcPR0ZGdO3dSuHDh166XLVuWZcuWkSlTJtKnTx+vNjSWZcwYlXKhVy+jLbFuTt0/RebUmcnmko15DefhnNKZTM6ZLNa+noKzErJnz463tzchISFUr16dR48eRV0LCgri+PHjMY4LFy7EWk/v3r3ZsWMH48aN4+LFi8yZM4d169a9tf1GjRoRGhpKx44dKV++fIwHeY8ePfDz86N58+YcOnSIK1eusGPHDrp06cLz58/jrLNVq1akSJGCzp074+vry44dOxg7dizwX7ib+NYdnTRp0vDNN98wZMgQ5s2bx+XLl/Hx8WHGjBlRdmTNmpWGDRuye/durl69yp49e+jXr5/2hEuC+PrCmjVKfPT7QvwIDQ9lzJ4xlJtdjqE7hwKQJ10ei4oPaAGyKrJmzcquXbsAqFatWtQax+XLlylTpkyMo2XLlrHWUbFiRTw9PZkxYwYlS5Zk7dq1jBgx4q1tOzs706hRI06cOEHr1q1jXMuRIwf79u3Dzs6OOnXqULx4cXr06IGjo2OMNZ9XcXFxYePGjZw5c4YyZcowYMCAKFucIleV41v3q4wbN45BgwYxevRoihYtSps2bbh161bUd9uzZw8FChSgadOmFClShHbt2vHkyZMETXdqzMOYMeDsrJyJNO/O6QenqehZke92fUfjoo35udbPxhkjpbSJw83NTcaFr69vnNdsnWfPnhltwjuxfv16KYSQDx8+NGs7Ce0XW/2d2rVrl9EmvJELF6S0s5NywADLt53U+8YU1p1dJ1OOSikz/5xZrj6zOlHqBHxkPJ/beg1IYygLFiygQIEC5M6dm9OnT9OnTx8+//xzMmWy7FSAxjoYNw4cHODbb422xLqIkBHYCTs+yv0R7Uu3Z0yNMRafbosNPQWnMZT79+/Tpk0b3nvvPXr06EHdunVZvHix0WZpkiDXrsGiRdClC2TLZrQ11kFYRBhj946l2oJqhEeEkzl1ZmZ9PitJiA9oLziNwQwcOJCBOoa+xgTGjQM7O3glWIYmDs48OIPHHx743PGhWfFmBIYGksYxafmsawHSaDRJnitXwMsLvv5ap9t+G2ERYUzYN4ERu0eQzjEdq5quokmxJkabFStagDQaTZJn9Giwt4ehQ422JOkTGh7K/BPzafheQ37/7Hcyp85stElxogVIo9EkaS5cgIUL4ZtvIEc
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 1710\n",
"p = 0.1\n",
"q = np.linspace(0.001, 0.999, 500)\n",
"kl_div = p * np.log(p / q) + (1 - p) * np.log((1 - p) / (1 - q))\n",
"mse = (p - q) ** 2\n",
"mae = np.abs(p - q)\n",
"plt.plot([p, p], [0, 0.3], \"k:\")\n",
"plt.text(0.05, 0.32, \"Target\\nsparsity\", fontsize=14)\n",
"plt.plot(q, kl_div, \"b-\", label=\"KL divergence\")\n",
"plt.plot(q, mae, \"g--\", label=r\"MAE ($\\ell_1$)\")\n",
"plt.plot(q, mse, \"r--\", linewidth=1, label=r\"MSE ($\\ell_2$)\")\n",
"plt.legend(loc=\"upper left\", fontsize=14)\n",
"plt.xlabel(\"Actual sparsity\")\n",
"plt.ylabel(\"Cost\", rotation=0)\n",
"plt.axis([0, 1, 0, 0.95])\n",
"plt.grid(True)\n",
"save_fig(\"sparsity_loss_plot\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's define a custom regularizer for KL-Divergence regularization:"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"id": "N-MoRrQGyLhK"
},
"outputs": [],
"source": [
"kl_divergence = tf.keras.losses.kullback_leibler_divergence\n",
"\n",
"class KLDivergenceRegularizer(tf.keras.regularizers.Regularizer):\n",
" def __init__(self, weight, target):\n",
" self.weight = weight\n",
" self.target = target\n",
"\n",
" def __call__(self, inputs):\n",
" mean_activities = tf.reduce_mean(inputs, axis=0)\n",
" return self.weight * (\n",
" kl_divergence(self.target, mean_activities) +\n",
" kl_divergence(1. - self.target, 1. - mean_activities))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's use this regularizer to push the model to have about 10% sparsity in the coding layer:"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {
"id": "HIoQ-1D4yLhK",
"outputId": "8874ead0-468a-47f8-9e35-4a855bace6b4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 20s 11ms/step - loss: 0.0273 - val_loss: 0.0186\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 21s 12ms/step - loss: 0.0167 - val_loss: 0.0148\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 20s 11ms/step - loss: 0.0141 - val_loss: 0.0142\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 18s 11ms/step - loss: 0.0126 - val_loss: 0.0121\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 17s 10ms/step - loss: 0.0116 - val_loss: 0.0116\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 17s 10ms/step - loss: 0.0110 - val_loss: 0.0106\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 18s 10ms/step - loss: 0.0105 - val_loss: 0.0103\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 19s 11ms/step - loss: 0.0102 - val_loss: 0.0108\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 20s 12ms/step - loss: 0.0100 - val_loss: 0.0104\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 20s 12ms/step - loss: 0.0098 - val_loss: 0.0099\n"
]
}
],
"source": [
"tf.random.set_seed(42) # extra code ensures reproducibility on CPU\n",
"\n",
"kld_reg = KLDivergenceRegularizer(weight=5e-3, target=0.1)\n",
"sparse_kl_encoder = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(),\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(300, activation=\"sigmoid\",\n",
" activity_regularizer=kld_reg)\n",
"])\n",
"sparse_kl_decoder = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(28 * 28),\n",
" tf.keras.layers.Reshape([28, 28])\n",
"])\n",
"sparse_kl_ae = tf.keras.Sequential([sparse_kl_encoder, sparse_kl_decoder])\n",
"\n",
"# extra code compiles and fits the model\n",
"sparse_kl_ae.compile(loss=\"mse\", optimizer=\"nadam\")\n",
"history = sparse_kl_ae.fit(X_train, X_train, epochs=10,\n",
" validation_data=(X_valid, X_valid))"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {
"id": "o0jIPg2ayLhL",
"outputId": "35512003-0640-4b9a-835b-6c7af95c930f"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAACvCAYAAACcuYvQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACjhUlEQVR4nO19WW9d2ZXed+d55ChRUkmyynLZVV3lIUHb7u4Hp5EGkjzlMb8uQJ7ykIcgQBAjCdIxkE6Mjodql2tUlaUqSRzvPM95YL7N7yzue0mRl6TUdRdAkLzDOXuvs/YavrX22qHZbIYVrWhFK1rRit40Ct/0AFa0ohWtaEUrugitDNiKVrSiFa3ojaSVAVvRila0ohW9kbQyYCta0YpWtKI3klYGbEUrWtGKVvRG0sqArWhFK1rRit5Iip7x/lJq7LVUPxQKnXp/Op2ees33uVe9l15n3vXOGts56GIDXRJv/5HTRXi7dJmdzWaYzWbo9Xro9/v4+uuv8e/+3b/DJ598gtFohNFohGg0ijt37mBtbQ3ZbBabm5tIJBIIh8MIh8MYDoc4OjpCt9tFvV7H3t4e+v0+dnd3UalU8NZbb+Gf/bN/hu3tbbz//vv40Y9+hGg0GpDfi64JD92ozE4mE0ynU+zu7uIf/uEfMBgM8P3vfx+PHz9GOHw9PnWr1cLvfvc77O7uYn19Hffu3UMymUSpVEI2m73MpVf64OroFG/PMmDLuWsohNlsFliAXPjtdhuff/456vU6wuEwQqEQIpEIisUi0uk0IpGIe12/P5lMMJvN3GKYTqdOwYzHY/R6PUynU2xubmJ7exuJRALlchmpVOrMsa3o20vT6RSj0Qh7e3toNBrodruoVqsYDofodrvodruo1Wp48eIFer2eM3Sz2QzD4RCdTgfj8RiTyQTR6MnyGo/HaLVaGAwGaLfbGA6HGI/HCIVCiMViGA6HePr0KarVKg4PD/Hhhx8ikUhgbW0NqVQKpVIJt2/fRiKRQCaTQTKZvCkWXZpCoRDC4TDS6TRu3bqF0WiEXC53LWuQzkgoFEIul8NoNEKpVEI+n0cikUAsFrvyMaxoeXQtBgw4Hd1wsb98+RL/7b/9N3z11VfOeMXjcTx69AhbW1vufwp9OBx2SmYymWA8HjtlcHh4iFqthl6vh0qlgtFohB/+8If48Y9/7ATUGjDf2Fb07aTZbIbpdIp+v48nT57gT3/6Ew4PD/HZZ5+h0+mg1Wqh0+lgMpk4+YvH44jH4wCAfr+PUCiEdruNer0OAAEHS2W23+9jOp0iFAohlUphMBjgs88+QyQSQbvdRqvVQiaTwTvvvIONjQ08evQIf/7nf458Po/t7e032oCFw2HMZjNks1ncvXsX0+kU2Wz22gzYZDJBKBRCqVRCLBZDsVh0f690wZtFV27AZrMZBoMBRqORW7iTyQSdTgfdbhe7u7s4OjrC0dERIpGIM1gHBwcAgEgk4ryiaDTqhF+VwWAwwHg8RqVSQbPZRK/XQ7VaxWQywdHREfb29tDtdpHP5zEcDp0hDIfDSKVSSCaTzkC+bkTlN5lMXHTJsZN8i+4yHVYsfHbRa8/7LhV3IpFANBoNOCnXTZPJBL1eD+PxGOPxGKPRyEVZ9XodrVbLRfZ0lAAEkAHSdDoNRGR8z/KBUcB0OkUkEnFyTZRhOp06ue50Okgmk6jX6zg4OECv1wMADAYDxGIxpFIpdw2N+N4EYvRJPlz3vaPRKJLJJGKxGCKRyMp4vYEUOkMZXRqX7ff7+PTTT7G7u4v9/X18/PHHaLfb6Pf7GI1G6HQ6ePr0KdrttjMiDO/pZfK1WCzmjNl4PHYLfTgcYjKZOAVDpTSbzVAqlbC2toZkMont7W1ks1mk02nk83mk02l88MEHePz4MeLxODKZzEWUwJVh3rPZzHn99XodH330ESqVClKpFDKZTEDpqQGm0bPX4g//Z8Rr3yc8q9ehcubrvrE6hvx/RcBnAcCNjcYgkUjgO9/5DtbX17G2toYHDx64SEboynNgtVoNv//973F4eIhWq4V6vY7BYICDgwM0m00nX1SykUgEk8kEg8HA/e52uwiFQsjn80ilUk6WIpGI4yGN42QycREajWE4HEY8HkexWEQsFkO320Wn0wEAJBIJRCIRJBIJ5PN5xGIxJ8Pr6+v44Q9/iHK5jPX1dWxtbZ3XCbvRPI3K2WAwwGw2QyKR8D3/pRPlVx1CPq8lGbBVDuzq6PpzYJPJBPv7+w6S+du//VvUajXn8U6nUwyHQ6cAKUS7u7sAgooxkUggkUgAOFaOXATD4TDwXaVarYavvvrKQQXJZBL5fB4bGxvu9927dzGbzZBOp6+UF+chm48bDodot9s4OjrCxx9/jBcvXiCbzaJcLrtITOHVUCiEyWTiDIcaLBoeKpBwOOw8T/ISOM5PcpHTcPF/4CT/aMcNnBQbMCc0Go0CkQrhsWw26xyNUCiEu3fvXosCs9Tv9/H06VN8/fXXqFQq2N/fd/nZ8XiMWCyGdDqNeDyOdDqNVCqFyWSCbreL0WiEVquFbrcbcA5okBhdaI6Wvym/sVgM0WjUOVXxeBzJZBLpdNo5eP1+H9VqFV999RUAuGvfu3cP5XIZ4/EYyWQSm5ub186/ixJ5RQfqOtEPOm7JZNI5Jqvo682kKzNge3t7+Oabb9BoNPD73/8eT58+xf7+vjM2qjgJCwLz4TCG/CrwCsP4SI0BPzMej11ivtfr4cMPP8RgMEA+n8eDBw+Qy+VQKBRQLpdvRKj1nqx8azQa7qfZbGI2mzkISxUnyWe85kVgCnPxPTVcjBL0fV90Z+dAg8h8AyPbfr+PwWCAcDiMSqWCaDSKXC7nxnwdNJ1OcXh4iGq1iqOjI5c3nc1mLoIiD8LhsIOYI5GI4xsT/jQeLBwaDAZOMUajUcc3QoI0fuQh87LMv3A9xONxJ7OJRALJZBKJRCIg76PRCF999RWq1SpmsxmKxWLgejdNdE6t42QdIwABdIUOkl1/F4GuF8G3jIij0aiDsOlshMNhV0WqY1vR60VXYsBmsxn++Mc/4j/8h/+AWq2GJ0+eYG9vLwBz0QPi5/W7QFCRn1UK7yMuEGsQxuMx6vU6KpUKQqEQnj17hmQyiVu3buGv/uqvsL29jR/84AcoFAo3nlOYTqdoNBp48eIF9vb23E+z2XQQky+StXP2GS9e32cA7XeVfOXcZykWKoNQKIR+v+9gxCdPnmB/fx/JZBKj0ejVmHMJmkwm+PTTT/Gb3/wG3W4XlUoF/X4fkUgEhUIhAMva71G55fN5BykOh0MMh0McHByg0Wg4uE/lh5CVzjMWiyGTyaBcLgc+xwo9AMhkMgDgIEgqXuaTf/WrXwE4jmyLxSJyuRxu3br1Wijc8XiMZrPpYEL+aAUmeUyHgHJMx8fnYM1zdq3DpjlJS4RyR6ORew6hUMjl61m1HI/Hkc/nkc/nV1Haa0hL1dAqnLVaDbu7u87LrdfrDr/nwrdKYt4eLgCu+lCJxvAshUrPVuEcKnsmxQHg4OAA0WjUQUKvAw0GA5cHYyFBPB7HYDAAAKfMCKX6DBh/WwVAT9c3V/LMvsbfvueh19bPcz8T80fMCRE+7Pf718rv6XTqCjUGg4GLjLSggrCn5RsLUFhcFIvFkEgk3GdGo5FThIxeeU+uDSIJ/InFYs650uhac5T830K6rVYL4/EYjUYD7Xbb8fh1IM5pNBoFnCZCy5rz1vmpAbPX099nybkP6ub36AD0+30kk0l3f44NOIbvCcmv6OKkzouuH1vkdBEHYakGrFqt4n//7/+N/f19/O53v8OXX36JXq+HyWSCbDaLWCwWEBY7YKs0dWI+ZUlvWD9vvS7CV1xMWvBAOIPC/Omnn2Jvbw/b29uYTCY37sVOJhN8/fXX+Pu//3vnLbLqjIuK8yXURZoHndjXFN6Z912ls4RM72MNHgBXkKBKXiG1qyR93sxzzWYzxzcaMADOQKic0hDz/+l06nJPVHaa26NDobBjKpVCOp1GoVBwebVYLObGxTHyHlqYo9eKRqPOeBKO/eijj7C2toZSqYRSqXTl/JxHChWORiM
"text/plain": [
"<Figure size 540x216 with 10 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code shows the reconstructions\n",
"plot_reconstructions(sparse_kl_ae)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R1WIZRNByLhL"
},
"source": [
"# Variational Autoencoder"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {
"id": "64ji8G6eyLhL"
},
"outputs": [],
"source": [
"class Sampling(tf.keras.layers.Layer):\n",
" def call(self, inputs):\n",
" mean, log_var = inputs\n",
" return tf.random.normal(tf.shape(log_var)) * tf.exp(log_var / 2) + mean "
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42) # extra code ensures reproducibility on CPU\n",
"\n",
"codings_size = 10\n",
"\n",
"inputs = tf.keras.layers.Input(shape=[28, 28])\n",
"Z = tf.keras.layers.Flatten()(inputs)\n",
"Z = tf.keras.layers.Dense(150, activation=\"relu\")(Z)\n",
"Z = tf.keras.layers.Dense(100, activation=\"relu\")(Z)\n",
"codings_mean = tf.keras.layers.Dense(codings_size)(Z) # μ\n",
"codings_log_var = tf.keras.layers.Dense(codings_size)(Z) # γ\n",
"codings = Sampling()([codings_mean, codings_log_var])\n",
"variational_encoder = tf.keras.Model(\n",
" inputs=[inputs], outputs=[codings_mean, codings_log_var, codings])"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"decoder_inputs = tf.keras.layers.Input(shape=[codings_size])\n",
"x = tf.keras.layers.Dense(100, activation=\"relu\")(decoder_inputs)\n",
"x = tf.keras.layers.Dense(150, activation=\"relu\")(x)\n",
"x = tf.keras.layers.Dense(28 * 28)(x)\n",
"outputs = tf.keras.layers.Reshape([28, 28])(x)\n",
"variational_decoder = tf.keras.Model(inputs=[decoder_inputs], outputs=[outputs])"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"_, _, codings = variational_encoder(inputs)\n",
"reconstructions = variational_decoder(codings)\n",
"variational_ae = tf.keras.Model(inputs=[inputs], outputs=[reconstructions])"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"latent_loss = -0.5 * tf.reduce_sum(\n",
" 1 + codings_log_var - tf.exp(codings_log_var) - tf.square(codings_mean),\n",
" axis=-1)\n",
"variational_ae.add_loss(tf.reduce_mean(latent_loss) / 784.)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {
"id": "YXUKW_UPyLhL",
"outputId": "6db5b948-1b3d-419b-e810-08512fa65369"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/25\n",
"430/430 [==============================] - 9s 19ms/step - loss: 0.0513 - val_loss: 0.0398\n",
"Epoch 2/25\n",
"430/430 [==============================] - 8s 19ms/step - loss: 0.0376 - val_loss: 0.0369\n",
"Epoch 3/25\n",
"430/430 [==============================] - 7s 17ms/step - loss: 0.0356 - val_loss: 0.0351\n",
"Epoch 4/25\n",
"430/430 [==============================] - 8s 18ms/step - loss: 0.0345 - val_loss: 0.0343\n",
"Epoch 5/25\n",
"430/430 [==============================] - 8s 19ms/step - loss: 0.0338 - val_loss: 0.0339\n",
"Epoch 6/25\n",
"430/430 [==============================] - 8s 18ms/step - loss: 0.0333 - val_loss: 0.0333\n",
"Epoch 7/25\n",
"430/430 [==============================] - 7s 17ms/step - loss: 0.0329 - val_loss: 0.0336\n",
"Epoch 8/25\n",
"430/430 [==============================] - 7s 16ms/step - loss: 0.0327 - val_loss: 0.0327\n",
"Epoch 9/25\n",
"430/430 [==============================] - 7s 17ms/step - loss: 0.0324 - val_loss: 0.0326\n",
"Epoch 10/25\n",
"430/430 [==============================] - 8s 19ms/step - loss: 0.0322 - val_loss: 0.0325\n",
"Epoch 11/25\n",
"430/430 [==============================] - 8s 19ms/step - loss: 0.0320 - val_loss: 0.0322\n",
"Epoch 12/25\n",
"430/430 [==============================] - 9s 20ms/step - loss: 0.0318 - val_loss: 0.0322\n",
"Epoch 13/25\n",
"430/430 [==============================] - 8s 19ms/step - loss: 0.0317 - val_loss: 0.0319\n",
"Epoch 14/25\n",
"430/430 [==============================] - 7s 16ms/step - loss: 0.0316 - val_loss: 0.0321\n",
"Epoch 15/25\n",
"430/430 [==============================] - 6s 15ms/step - loss: 0.0315 - val_loss: 0.0318\n",
"Epoch 16/25\n",
"430/430 [==============================] - 6s 15ms/step - loss: 0.0314 - val_loss: 0.0317\n",
"Epoch 17/25\n",
"430/430 [==============================] - 7s 17ms/step - loss: 0.0313 - val_loss: 0.0316\n",
"Epoch 18/25\n",
"430/430 [==============================] - 7s 15ms/step - loss: 0.0312 - val_loss: 0.0315\n",
"Epoch 19/25\n",
"430/430 [==============================] - 6s 15ms/step - loss: 0.0312 - val_loss: 0.0315\n",
"Epoch 20/25\n",
"430/430 [==============================] - 7s 17ms/step - loss: 0.0311 - val_loss: 0.0318\n",
"Epoch 21/25\n",
"430/430 [==============================] - 9s 20ms/step - loss: 0.0311 - val_loss: 0.0314\n",
"Epoch 22/25\n",
"430/430 [==============================] - 8s 19ms/step - loss: 0.0310 - val_loss: 0.0312\n",
"Epoch 23/25\n",
"430/430 [==============================] - 7s 16ms/step - loss: 0.0310 - val_loss: 0.0312\n",
"Epoch 24/25\n",
"430/430 [==============================] - 8s 17ms/step - loss: 0.0309 - val_loss: 0.0311\n",
"Epoch 25/25\n",
"430/430 [==============================] - 7s 17ms/step - loss: 0.0309 - val_loss: 0.0311\n"
]
}
],
"source": [
"variational_ae.compile(loss=\"mse\", optimizer=\"nadam\")\n",
"history = variational_ae.fit(X_train, X_train, epochs=25, batch_size=128,\n",
" validation_data=(X_valid, X_valid))"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {
"id": "FDs525puyLhL",
"outputId": "049794c0-4407-49b7-ba65-138f31d5e5dc",
"scrolled": true
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAACvCAYAAACcuYvQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACcV0lEQVR4nO29V49k2ZUdvMJ7m758Nbum2GT3dNPow5Cc4QM10ACSnvSoXydAT3rQgyBAECEJGhHQCNSIdpptq7vLp48M7833kLNOrth5bqSpiMys6dhAIDIjbtx7zD7brL3PPqHJZIIlLWlJS1rSkt42Cl93A5a0pCUtaUlLugwtFdiSlrSkJS3praSlAlvSkpa0pCW9lbRUYEta0pKWtKS3kpYKbElLWtKSlvRW0lKBLWlJS1rSkt5Kip7x/Vxy7DVVPxQKnfp+PB6f+sx33UWfpfcJut9ZbTsHXa6hcxrbf+J0mbGdO89OJhNMJhN0Oh10u108f/4c//7f/3t8+umnGAwGGAwGiEajuHPnDlZWVpDNZrG+vo5EIoFwOIxwOIx+v4+DgwO0221Uq1Xs7Oyg2+1ie3sbh4eHuH//Pv75P//n2NzcxIcffogf/vCHiEajU/x72TXhoWvl2dFohPF4jO3tbfzxj39Er9fD9773PTx+/Bjh8NXY1I1GA7/73e+wvb2N1dVV3Lt3D8lkEqVSCdls9k1uvZQHi6NTY3uWApvPU0MhTCaTqQXIhd9sNvHFF1+gWq0iHA4jFAohEomgWCwinU4jEom4z/X3o9EIk8nELYbxeOwEzHA4RKfTwXg8xvr6OjY3N5FIJFAul5FKpc5s25K+vTQejzEYDLCzs4NarYZ2u41KpYJ+v492u412u42joyO8evUKnU7HKbrJZIJ+v49Wq4XhcIjRaIRo9GR5DYdDNBoN9Ho9NJtN9Pt9DIdDhEIhxGIx9Pt9PH36FJVKBfv7+/jDH/6ARCKBlZUVpFIplEol3Lp1C4lEAplMBslk8rqG6I0pFAohHA4jnU5ja2sLg8EAuVzuStYgjZFQKIRcLofBYIBSqYR8Po9EIoFYLLbwNixpfnQlCgw47d1wsb9+/Rr//b//d3z99ddOecXjcbz77rvY2Nhw/5Ppw+GwEzKj0QjD4dAJg/39fRwdHaHT6eDw8BCDwQA/+MEP8KMf/cgxqFVgvrYt6dtJk8kE4/EY3W4XT548wTfffIP9/X18/vnnaLVaaDQaaLVaGI1Gjv/i8Tji8TgAoNvtIhQKodlsolqtAsCUgaU82+12MR6PEQqFkEql0Ov18PnnnyMSiaDZbKLRaCCTyeC9997D2toa3n33XfzFX/wF8vk8Njc332oFFg6HMZlMkM1mcffuXYzHY2Sz2StTYKPRCKFQCKVSCbFYDMVi0f29lAVvFy1cgU0mE/R6PQwGA7dwR6MRWq0W2u02tre3cXBwgIODA0QiEaew9vb2AACRSMRZRdFo1DG/CoNer4fhcIjDw0PU63V0Oh1UKhWMRiMcHBxgZ2cH7XYb+Xwe/X7fKcJwOIxUKoVkMukU5E0jCr/RaOS8S7ad5Ft0b1JhxcJnl7130G8puBOJBKLR6JSRctU0Go3Q6XQwHA4xHA4xGAycl1WtVtFoNJxnT0MJwBQyQBqPx1MeGb+z40AvYDweIxKJOL4myjAejx1ft1otJJNJVKtV7O3todPpAAB6vR5isRhSqZS7h3p8bwPR++Q4XPWzo9EokskkYrEYIpHIUnm9hRQ6Qxi9MS7b7Xbx2WefYXt7G7u7u/jkk0/QbDbR7XYxGAzQarXw9OlTNJtNp0To3tPK5GexWMwps+Fw6BZ6v9/HaDRyAoZCaTKZoFQqYWVlBclkEpubm8hms0in08jn80in0/joo4/w+PFjxONxZDKZywiBhWHek8nEWf3VahUff/wxDg8PkUqlkMlkpoSeKmAqPXsvvvg/PV77PeFZvQ+FMz/3tdUNyD8KAs4FANc2KoNEIoHvfOc7WF1dxcrKCh4+fOg8GaGFx8COjo7w+9//Hvv7+2g0GqhWq+j1etjb20O9Xnf8RSEbiUQwGo3Q6/Xce7vdRigUQj6fRyqVcrwUiUTcGFI5jkYj56FRGYbDYcTjcRSLRcRiMbTbbbRaLQBAIpFAJBJBIpFAPp9HLBZzPLy6uoof/OAHKJfLWF1dxcbGxnmNsGuN0yif9Xo9TCYTJBIJ3/zPnci/ahByvuakwJYxsMXR1cfARqMRdnd3HSTzt3/7tzg6OnIW73g8Rr/fdwKQTLS9vQ1gWjAmEgkkEgkAx8KRi6Df70/9Vuno6Ahff/21gwqSySTy+TzW1tbc+927dzGZTJBOpxc6FuchG4/r9/toNps4ODjAJ598glevXiGbzaJcLjtPTOHVUCiE0WjkFIcqLCoeCpBwOOwsT44lcByf5CKn4uL/wEn80bYbOEk2YExoMBhMeSqEx7LZrDM0QqEQ7t69eyUCzFK328XTp0/x/PlzHB4eYnd318Vnh8MhYrEY0uk04vE40uk0UqkURqMR2u02BoMBGo0G2u32lHFAhUTvQmO0fCf/xmIxRKNRZ1TF43Ekk0mk02ln4HW7XVQqFXz99dcA4O597949lMtlDIdDJJNJrK+vX/n4XZY4VjSgrhL9oOGWTCadYbL0vt5OWpgC29nZwYsXL1Cr1fD73/8eT58+xe7urlM2KjgJCwLBcBhdfmV4hWF8pMqA1wyHQxeY73Q6+MMf/oBer4d8Po+HDx8il8uhUCigXC5fC1PrM5n5VqvV3Kter2MymTgISwUnyae8gjwwhbn4nSouegn6vc+7s32gQmS8gZ5tt9tFr9dDOBzG4eEhotEocrmca/NV0Hg8xv7+PiqVCg4ODlzcdDKZOA+KYxAOhx3EHIlE3Lgx4E/lwcShXq/nBGM0GnXjRkiQyo9jyLgs4y9cD/F43PFsIpFAMplEIpGY4vfBYICvv/4alUoFk8kExWJx6n7XTTROreFkDSMAU+gKDSS7/i4DXc+Cb+kRR6NRB2HT2AiHwy6LVNu2pJtFC1Fgk8kEf/rTn/Af/+N/xNHREZ48eYKdnZ0pmIsWEK/X3wLTgvysVHgfcYFYhTAcDlGtVnF4eIhQKIRnz54hmUxia2sLP//5z7G5uYnvf//7KBQK1x5TGI/HqNVqePXqFXZ2dtyrXq87iMnnydo++5QX7+9TgPa3Sr507rMEC4VBKBRCt9t1MOKTJ0+wu7uLZDKJwWBwscF5AxqNRvjss8/wm9/8Bu12G4eHh+h2u4hEIigUClOwrP0dhVs+n3eQYr/fR7/fx97eHmq1moP7lH8IWWk/Y7EYMpkMyuXy1HXM0AOATCYDAA6CpOBlPPlXv/oVgGPPtlgsIpfLYWtr60YI3OFwiHq97mBCvjQDk2NMg4B8TMPHZ2AFGbvWYNOYpCVCuYPBwM1DKBRy8XpmLcfjceTzeeTz+aWXdgNprhJamfPo6Ajb29vOyq1Wqw6/58K3QiJoDxcAl32oRGV4lkClZatwDoU9g+IAsLe3h2g06iChm0C9Xs/FwZhIEI/H0ev1AMAJM0KpPgXGdysAaOn6+soxs5/x3Tcfem+9nvuZGD9iTIjwYbfbvdLxHo/HLlGj1+s5z0gTKgh72nFjAgqTi2KxGBKJhLtmMBg4QUjvlc/k2iCSwFcsFnPGlXrXGqPk/xbSbTQaGA6HqNVqaDabboxvArFPg8FgymgitKwxb+2fKjB7P30/i899UDd/RwOg2+0imUy657NtwDF8T0h+SZcnNV50/dgkp8sYCHNVYJVKBf/n//wf7O7u4ne/+x2++uordDodjEYjZLNZxGKxKWaxDbZCUzvmE5a0hvV6a3URvuJi0oQHwhlk5s8++ww7OzvY3NzEaDS6dit2NBrh+fPn+Pu//3tnLTLrjIuK/SXURQqCTuxnCu8E/VbpLCbT51iFB8AlJKiQV0htkaTzzTjXZDJx40YFBsApCOVTKmL+Px6PXeyJwk5jezQoFHZMpVJIp9MoFAourhaLxVy72EY+QxNz9F7RaNQpT8KxH3/8MVZWVlAqlVAqlRY+nkGkUOFgMEC
"text/plain": [
"<Figure size 540x216 with 10 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_reconstructions(variational_ae)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fJfl2bXbyLhM"
},
"source": [
"## Generate Fashion Images"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's generate a few random codings and decode them:"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42) # extra code ensures reproducibility on CPU\n",
"\n",
"codings = tf.random.normal(shape=[3 * 7, codings_size])\n",
"images = variational_decoder(codings).numpy()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's plot these images:"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"id": "gzwLpDIDyLhM"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZUAAACxCAYAAAAMESCvAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACwIklEQVR4nO39WW9rWZLfDQepgeJMzdIZMvPk2FO5B3S3DdiAP6c/h298a6Dhi3a70VXdVa6syjyZeQbNpDgPIkU+F+f5Lf13aFHaW+c0XuB9GABBidzcew2xIv4xrFi5xWJhK1rRila0ohV9Csr//7oBK1rRila0ov//oZVSWdGKVrSiFX0yWimVFa1oRSta0SejlVJZ0YpWtKIVfTJaKZUVrWhFK1rRJ6OVUlnRila0ohV9Mlp/6MtffvnlwXzjXC639P9Op2ODwcDev39vNzc3Zma2ublpW1tbtrW1Zevr67a+vm7z+dxubm5sPB7bzc2NjUYju729tcViYQcHB/b8+XMrlUq2vn6/qZoOzbP1s88//zx370eOvv/++1Q51blcznK5nOXzH/TwfD636XRq4/HY/tt/+2/2/v17m8/ntra2Fvq2trZmhUIh9HFtbc3y+bzN53NrNBr2X/7Lf7GXL1/aq1evbDab2Xw+t8ViYYvFwm5vb0N//Dgrfffdd4/20czs7Oxs4e8XG79l/aXPfEc/9be0fTab2e3tbegD3+dyudBH/yz/rtcsFgs7Pj5+tJ+9Xu/eXD6UMs9Y6PN4X1tbi15P+/L5vOVyOZvNZkv745+vv4/NabVafbSPg8Eg3NCPe+ye8NV0Og2v8Xhsw+HQ3r9/b6enp/bmzRubTqdhfrn3/v6+bW9v21dffWWVSsUqlUpYv4VCIfDAsv7G2lMqlR7tY7fbTQyo5wdPrJfZbGaj0cgmk4m12+3Ah51Ox9rtto3HY5vP57a+vm6LxcLm87lVq1UrlUpWqVRsY2PDtra2rFgshs9Zx74/Opee0szj+fn50vUYm0v938/7YrGwfD5va2tr4bt8Ph/6eHt7a/P5PPAtMgk5y0vbsOzZ0OHh4dI+PqhUHiPfAATKaDSy8/Nza7VadnJyYpPJxG5vb219fd02NzetWCyGv+fzuY1Go6BUYO75fG7j8djW1tbs4ODASqWSbWxsJDoZ62xMYH4q8gvm9vbWxuOxXV5e2sXFheXzeVtfX7eNjY0wyYVCwRaLhU2n0yBM6N9gMLDpdGr5fD4IbxXEaYXVU/qxTLDxeex7ZViEqjKkv4cnVSgPzZ3+/ZhSfYweGzMvkJf1Pfa7WFtZvP5z/9uP6VOMYs+iPbPZzCaTiQ0GAxsOh3Z1dWWdTsd+/PFHOz8/t3fv3gW+U2o2m1av1wMIOjw8tEqlYrVazebzeViPyr+PtelT9DMmCJE7nU7H+v2+XV1d2Wg0sm63a51Ox1qtVgAAKJXpdGqVSsWKxaLt7+9bsVi0RqNhlUolzGOhULCNjY0ApGjDp54/30dPft0wDigKXZOel/V6MwvKxF8Dv8QUTJb+PlmpxCb39vbW+v2+nZyc2L/927/ZTz/9ZGdnZzYej200GoWGlstl29raskqlYjc3N9bpdALCYAKn06l9+eWXdnl5aX/9139tR0dHVqvVAtr35DV+2kF47NplAmF9fT20/be//a399NNPoV8gATOzYrFoZneIEYSwv79vf/EXf2HHx8ehz3xvZsGiUaHn2/mxizXWb5An8+uFTT6fD+2N8QD3ZY5oO33jmSxQ/e2yPn5MPz1PxOZzPp+HRemRN9fQd114IESeM51ObTabJRb3MmvC892nUpyMM8+7ubmxyWRi3W7Xzs7O7Orqyn7961/bycmJ/eM//qM1m027vLwMbd3Y2AgCF7TO+vvrv/5r29/ft+PjY9ve3rZSqWTlctk2NjZsc3PzXr8Zr8csjTT9igk/XS+j0ciurq7s3bt3dnFxYW/evLFms2nff/+9XV9fW7PZtLW1tWBdAVoLhYJtbW3Zd999Z7u7u/b555/b7u6uHRwc2GeffWa1Wi0BgunjsnalpRi48gBI+ctbMfpaX18PIJbrZ7NZUITww2g0CgpY50jnClDvlfZDINLTk5VKbHHe3NzYcDi0VqtlvV7PhsOhDYdDG4/HobG4RYbDoY1GI5vNZtbv94P5qsK30+kEVFUul61SqQQTzwuzpzKu9uUhgcbAg3Jub2+t2+3a5eWlTadTW1tbC0pvc3PTbm5uEgywubkZFsNoNLL5fG69Xs96vZ4NBoOEcNJnqyCOuVKy0kPWiS6YGHLV73w7lfmZP+2TKiGu02fGUJW6Yz4Vaf9ZSIrytB8eufk++t+ura1F10Vs3rwS/hRoXl0cgJJ+vx8Uyps3b+zs7Mzevn1rrVbLCoWCNRoNW1tbs8lkEn5PX6rVqlWrVbu9vbVer2dv3ryxyWSSEEJbW1tBcOmYfgyaf4hHdZwAN8iNi4sLe//+vZ2cnNgvv/wSZBBuaRWotBFB3Gw2bTKZ2GKxsF6vZzc3N1av121zczOsa+ac+ym/QGnn8bGx0TGO/S4GxLyby6+rjY2NIIO1vX5Mb25u7rmplTcea3smpRJD9SrgZ7OZjcdj63a7NhgMbDwe23g8DgwLkmNi8XGOx+Nwf0zU2Wxmg8EgmLPD4fCeAokhZG1rGvKCzt9f+62LFfdVu90OTFYsFgNiYPIUTfAZ91Clu7Gxcc+Pryg6S5+eQvSNZ2j8RMdAv9M2aawphuL9fSEdV0+fStgu4xEdV+86UCTsQQfzgkvXzIKQUeGj18aE4jIwk4Ye+g2CMpfLBaB3eXlpZ2dndnJyYhcXFzYcDm1zc9NyuQ8unl6vFxQGMcHt7W2rVqsB5eLixUJZX1+3RqNxzwr1Y/6xFBNkWFWspdlsZtfX13ZxcWEnJyd2dnZmg8EgyB6EpVrfKiR7vV6QQ9Pp1NbX1+2zzz6zer0exhO3kRfan6LPyhseXC2z5PU6nk/7ULj6PTzqgaquYQ/a08yFp49yf6nmns1m1u12rdvtWr/fDzESVSoajM7n80GB3N7eBoWDn3Y6ndpgMLDr62trtVpWqVTC98s05lNM7di1CD4UgDIRwb/pdGqtVsvevXsXYkZ8DmKAEDSMAUrz+vraut2uDYfDkLygwhXLyJujy9qdlWKLQl10+nyuN7szrf3C8sjem+oqcFmgusBjDP2xbhP+j7kNUOJecft2+T6p0uTes9ks8TwFF4r8cK8tQ6BZ++eVPgphMBjYYDCwX375xS4uLkL8pNVq2fr6uhWLRdvc3LTZbGY3Nze2vr4e+oyLEyukWCwGa6bdbodx6/V6ViqVrFqtWqFQSI1ks/ZTUTsAbnNzM4wvbq8//OEPweXe6/Xs9vY2WCibm5sBqUPMIevSzKzVaoVkhkajYdPp1I6Pj61UKtnW1lZY/wAh7h8D3Fn7qRavt/Y1KcbPv77gK+SkT7RR4M7fAH4FUfC9mQW3YSx5JUYfFag3Sy4yrBAGXq0TFYxoShVCCGVFivz25ubGbm5uwn3883n/VMys91XTWYXufD63yWRi/X7fzO6Ej2ZV8HuEic8U0nFS/7BnmtjnT6XHlK5XYB7R6PzxvUf3/lleoOt3eq9/T2vF3/MhQBJT4LSV//VzM0v0Q+/nBaLn06x9iymU2N9kUna7Xev1esHan06nZnZnzSix/gATZD6p8JxOpzYcDq3f79vGxoaNx2Pb2tqKrsNPBXxinzGms9kstKfT6QSX13Q6tcViEYAKbusYAud/QOF4PLZerxdA8s3NTch2i83hx6zLZf3z6yu23mLrNHbvGL/q/wp2ze5kwM3NTWKs09KDSiW2SHyj1FohZZH4CEqBDCezO2bW3/oO6f1RKFg9ICja54PZmlaXhmJ9UuWhzMv3Zh8E4Gg0sl6vZ+vr64HplEG5ThWDkrp99BneFRVbCD5gmIUemlftt7+OhapuvdgC9ZkoBK+XWYWqNGPPfUq/+O1Diw7lH3MHaJv0fzMLrhAVMCqglUdyuQ/uJaxZRX/6rJhCzdJXnseaabfbIQjfbDat2+3aZDIJc0E7C4WCFQq
"text/plain": [
"<Figure size 504x216 with 21 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cells generates and saves Figure 17-12\n",
"\n",
"def plot_multiple_images(images, n_cols=None):\n",
" n_cols = n_cols or len(images)\n",
" n_rows = (len(images) - 1) // n_cols + 1\n",
" if images.shape[-1] == 1:\n",
" images = images.squeeze(axis=-1)\n",
" plt.figure(figsize=(n_cols, n_rows))\n",
" for index, image in enumerate(images):\n",
" plt.subplot(n_rows, n_cols, index + 1)\n",
" plt.imshow(image, cmap=\"binary\")\n",
" plt.axis(\"off\")\n",
"\n",
"plot_multiple_images(images, 7)\n",
"save_fig(\"vae_generated_images_plot\", tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-JLZ_HWqyLhM"
},
"source": [
"Now let's perform semantic interpolation between 2 images:"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {
"id": "GsVXES9byLhM",
"outputId": "eca5be95-9eee-47d1-e9ce-50756e24d235"
},
"outputs": [],
"source": [
"tf.random.set_seed(42) # extra code ensures reproducibility on CPU\n",
"\n",
"codings = np.zeros([7, codings_size])\n",
"codings[:, 3] = np.linspace(-0.8, 0.8, 7) # axis 3 looks best in this case\n",
"images = variational_decoder(codings).numpy()"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZUAAAA+CAYAAADj94aiAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA49ElEQVR4nO19aW/jyBXtEbWRolbvdncmk5lBAiT5Bfnz+Zwg+RIgXzLIzKTd6enxqo3aJb4PjVM+vC7Kouy8BzzoAoZlWSLrVt3l3KWKpTRNcaADHehABzrQW1Dw/3oABzrQgQ50oP9/6OBUDnSgAx3oQG9GB6dyoAMd6EAHejM6OJUDHehABzrQm9HBqRzoQAc60IHejA5O5UAHOtCBDvRmVNn2z4eHh536jdmWXCqVAADr9Rqj0Qij0Qh//vOf8fnzZ3z48AHL5RLr9Rq9Xg9RFKFcLqNUKiEIAqxWK6xWK5ydnaHb7eK7777D+fk5rq6uEIYhKpUKtP2Z99psNt6xAMDR0VHppbGPRiP3BV6T10jT9Nl90jTFer3GbDbDeDxGkiT45z//iX6/j9vbW6xWKyyXS/R6PdTrdVSrVWw2GywWC8znc6xWK5yenqLZbOLi4gLdbhcnJydotVqo1WqoVCrunnnzq9RsNl/kEQCm0+mztfS1k6dpijRNsdlssFqtMJ1OMRwOkSQJbm9vMZ1OMRqNMJ1OMZlMEMcxqtUqKpUKlsslxuMxNpsN0jTF0dERoijC8fExWq0W2u02Op2Om5dSqZT54Xz7+Iyi6EU+l8vl1rXk+xzfer3Ger3GYrHAdDrFbDbD7e0txuMxPn/+jMFggIeHB8RxjFqthiAIsFwuMRqNEEURwjDE1dUV4jh2a9pqtRCGIarVaoZHS/oeX1cqlb15tLrB98jjfD7HbDbDfD7H/f09xuMxrq+v8fDwgNvbWzSbTVSrVcfjeDxGHMeIoghfffUVWq0Wzs7OHI+1Wg3lctnJq+VR11VpFx5ns1lKvngNq+ckldXZbOZ47Pf7GI/H+M9//oP7+3vc3Nw4O1Iul7FarZAkCRqNBqIowjfffINOp4OLiwvEcYxms+l4pJ2y8w4AQfAcl4dh+CKPqo/b1pFEHufzuZPVx8dHjEYj/Pjjj7i/v8ft7S3q9fozHqMoQhRF+O6779DtdnF1dfVsHS0fr+Fxq1PZRmpw7aRsNhvM53NMJhMMBgM8Pj7i8fERq9XKGWQ6iUqlgiiKEAQBgiBAq9VCvV53Rng+n6NSqbj/+wTVLsZrePK9R+Ukb8vlErPZDJPJBOPxGOPx2BmgyWTijBPHq2MPgsDxO51OUa/XkSQJqtWqu38QBM8E+X9BND7K62q1cjwuFgskSeJ+JpMJkiTBzc2N41d547rXajVnhNfrNcIwdPeg0QrDEOVy2X0uz/i+NVE56eiXyyXm8zmSJMF0OkW/38dwOMTHjx9xd3eHm5sb1Go1lEolp6jT6RStVssZYvJDnViv16hWq07Bd3UuryFdy81mk3GYKq80RoPBANfX17i9vcXnz5/depRKJccjDU8URVitVqjX604PoihyPNK5WH7eijerl/yb6zifz7FcLh3QmU6nuLu7w3A4xIcPH3Bzc5PhkTI4mUzQbDYRxzHiOMZisUCtVsNisXByW61W3fc4P2r73pqsLV2v1xkeaXOm0ylub2/R7/fx008/4f7+Hr/88otbDwL1JEkQxzEajQYajQZmsxmq1SpmsxkWiwUajQZqtRrCMEQQBG4dX0OFrpA3mfRkjESWyyWGwyFubm7www8/4Pr6GtfX186pjEYj53GbzSbevXuHVquFOI5RqVSw2WwwGAxQq9XQaDTcYjK6KZfLzsBzXOv1+tWTYXkFvigoDRCVdD6fYzQa4eHhAYPBAP/9739xe3uL77//Ho+Pj06g6RTDMESv18PR0RG63S7K5TLm8znq9boz3uv1OoOQoihyTojI2qLuovzkGTXO5WazwWw2c2h8NpthMBhgNBohSRL0+330+3384x//wMePH/HDDz+g3++7CCwIAtRqNVxcXODk5AR/+tOfcHp6itVqhSiK0Gg0HDpst9sIwxDtdhv1ev2ZMXotSMgDCAQ1i8XCyeFkMsFwOMR4PMZ///tf3N3d4a9//Ss+fvyIH3/8EbPZDKvVKjOu09NTHB8fYzab4erqCsvlEq1WC91uF81mE/V6HXEcOx6J6N/SIFke6dg3mw2m0ykWi4WTwyRJMjze3t7iL3/5C66vr/HTTz89Q8Xz+RxHR0c4Pj7GcrnE1dUVZrOZ45ERGdcvjmOUy2U3DjuuvDXZxleeg+I6kseHhwfHI+X106dPuLm5wd///nd8+vQJ19fXzrmnaeocbbfbRafTwWw2w+XlJcbjMdrtNrrdLtrtNhqNBo6Pj1Gv19FoNBwA4pjeYi2tTKhOksf7+3vMZrOM3fn48SNub2/xt7/9DTc3N/j06ZNzKgp+CYAGgwEuLy/x+PiIbreLo6MjHB0duUib62izBnbuX6KdnIoaM7vgQRC495n6mUwmLv212WxQrVZxcnLiEM39/b1DFK1Wy6UQoigC8AWBUHAYtpdKJedwbHhPZMZxvWaxlVdel85Q/6YzmM/njseLiwu0220cHR3h06dPGI/HmM1miOMYl5eXOD8/R6/Xy6S6OH90xrz+YrFw/1ch49h8Iek22jYfXDuiWvJr31sulyiVSjg/P0e5XEa9XseHDx8wGAyc8zw9PcX79+9xcXGB4+Njp4gAMvOnKQuu11ugpJd45D0JamiMKK9JkmCxWODk5AQAUK/XcXt7iyRJMJ/PUa1W0el08O7dO5yeniKOYwBwCJCRGiMbRuNhGLr3OZ63dDKqD8vlEkmSOCM0nU4xGAzQ7/eRJAkeHx+RJInTvXq97taQmYEoinB2dobj42MEQYDJZIJ+v4/NZpMBkQBc1EKZ9o2tCOXNh0YoBK4EO8rjeDzGzc0NxuMxGo0Gzs7OUKlUMJ1OnUwTwB4dHaHT6WCz2TggTJBEALLZbJzTZOr2LcmXeSGA5VoxC/Lw8IDHx0c31sFggCiK0Ov1UC6XsVwunW4BcE6z1WphvV5jMBjg559/xmQycXrdaDSwWq0QxzGCIHjGY9H121mLFbkDyKASNcDMT/NnvV6jXC67GgPDLIbijUYDp6eniKIItVrNLaYqyGKxcBOWpqlTVgBep/La1JGm8TgG8s6f5XLpFgWAE1A6lfV6jXq97qKxy8tLh+BrtRqq1WomnCYfang5z1RiG6281hhZ50x+rFPh3+v1GqVSCUdHR6hWq4iiCLPZzKXyms0m3r9/j2+//RZXV1doNBouT2/nT+WFNRmL5t/qCCGNxOhUmJdmpMLUHo1Ip9NxxrVUKqHf72M0GqHZbOLq6grv3r3D2dkZ6vU6ALjIk05YU5g01Mx1c0xcg7ciXcfpdIokSTAej10kNhwOMRqNMBwOsVgs0Gq1UK1W3TqxRhiGIY6Pj3FxcYFer5epszBFUi6XsV6v3drS6GpN8H+RJlKDSyBA53J/f5/hcTabIYoiVCoVNBoN3N7eurWv1WpoNps4PT1Fu91GmqaYTqd4fHx0wJVrx3u22+1MtPO/SoEpyBsOhy7VPB6PcXd351K0j4+PDswFQYAwDDEej12qLAgCNBoNdLtdxHGM9Xrt5olrl6Yp4jh2Nq3T6Tid3Vf/XnQq1pBxMhVh0+gQzdGh0HGsViuH0kqlkkvxzOdzlxahYeG1x+OxS39Vq1XnnFgQJopX1Mv8/msmxIZ65JdGiREWUe5isUAQBA6prddrNBoN/PrXv8bZ2ZmLVC4uLlx+mk6FghMEAZIkwWazccieRpdOFEBG2F9DGr5r6rBcLjsnQlTL9NZyuUS9Xnf51zAM3Tpy3aMowvn5uUuPsHZko7ooip6l+ayMvJY/Xx6eBpdG5+bmxjVcMO1XLpfduMMwRKPRQL1ed/8PwxAnJydoNpsuT815Gw6HWC6XGAwGqNfrrrDdbDbdeqrRfQ2fyiPllGmd6XSK+/t7PDw84Pr62jkVgjKmKalbrB3QIdbrdXQ6HURRhHq97tLORLe
"text/plain": [
"<Figure size 504x72 with 7 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 1713\n",
"plot_multiple_images(images)\n",
"save_fig(\"semantic_interpolation_plot\", tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9hA5xhT_yLhN"
},
"source": [
"# Generative Adversarial Networks"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {
"id": "N31KxgksyLhN"
},
"outputs": [],
"source": [
"tf.random.set_seed(42) # extra code ensures reproducibility on CPU\n",
"\n",
"codings_size = 30\n",
"\n",
"Dense = tf.keras.layers.Dense\n",
"generator = tf.keras.Sequential([\n",
" Dense(100, activation=\"relu\", kernel_initializer=\"he_normal\"),\n",
" Dense(150, activation=\"relu\", kernel_initializer=\"he_normal\"),\n",
" Dense(28 * 28, activation=\"sigmoid\"),\n",
" tf.keras.layers.Reshape([28, 28])\n",
"])\n",
"discriminator = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(),\n",
" Dense(150, activation=\"relu\", kernel_initializer=\"he_normal\"),\n",
" Dense(100, activation=\"relu\", kernel_initializer=\"he_normal\"),\n",
" Dense(1, activation=\"sigmoid\")\n",
"])\n",
"gan = tf.keras.Sequential([generator, discriminator])"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {
"id": "Qm4Uk3MTyLhN"
},
"outputs": [],
"source": [
"discriminator.compile(loss=\"binary_crossentropy\", optimizer=\"rmsprop\")\n",
"discriminator.trainable = False\n",
"gan.compile(loss=\"binary_crossentropy\", optimizer=\"rmsprop\")"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"id": "UQoPjJthyLhN"
},
"outputs": [],
"source": [
"batch_size = 32\n",
"dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(1000)\n",
"dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {
"id": "cxrUUhrjyLhN",
"outputId": "2813bb7d-84c4-4ed2-8d03-c1fbd138c576"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/50\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAADnCAYAAACTx2bHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9d3SU1fY+/kyfycxkZtJ774SQ0EKHgPQWUIqKgIB4VVDsoNfCtSBFUbmCKGClht6lQwiQnpCE9N7LZGaSTC/n90fuez4JogImfNdvLfdas8TMzDvvfvcp++z97GezCCH4R/6Rf+Qf+Uf+kX/kz4X9//oG/pF/5B/5R/6Rf+T/D/LPhvmP/CP/yD/yj/wj9yH/bJj/yD/yj/wj/8g/ch/yz4b5j/wj/8g/8o/8I/ch/2yY/8g/8o/8I//IP3Ifwv2L9+8LQmu1WmGz2cDj8UAIgcViAYfDAZv9f/uxzWaD2WyG0WjEpUuX0NLSgilTpiA1NRWrV6/GW2+9hblz56KjowMWiwWEEEgkEsjl8m7XANDtun8hrPv4TI/DhAkhOHbsGBobG7FgwQLw+XzodDqIRCLw+fye/rle05EQAhar8/L3evY1NTXQ6XSQSqUwGo1oaWmBj48PXFxcAHSOC71eDy6XCz6f/yB2u1seuR2tVissFgvOnDmDlpYWzJ8/H3l5ediwYQOeffZZTJo0qfPGWCxwOJye+Mle09Fms4HFYoHFYoFBxTN2BYD6+nro9Xo4ODhAr9ejrq4OPj4+cHZ27vzR/32HEAJCCNhsNr1W1+vch/w/m487duxATU0NXn/9dbS3t+P69euIiYlBcHDw/93cg+nyR/JIdLRarQDQbezl5uaivb0dMTExIIRArVZDJpPBzs7u7/7c3fJIdOw67pj5eOjQITQ2NmLJkiUoLCzEl19+iQULFmDixIm/23P+ptxTR9ZflJX86Zutra1obGyEh4cHJBIJnUTMRLLZbEhPTweLxUL//v1RUlKCrKwslJaWwmAwICQkBAUFBfj1118xc+ZMDB8+HAaDATabDYQQBAUFoW/fvigqKoJOp8OAAQMedNP524a12WywWq3gcrl/OKFUKhVYLBbkcjlKS0uRk5MDoHNzEQqFMJlMaGtrA5/PB5fb6aMwC62fnx/69OkDo9EIQgiEQuEjX4RMJhOMRiNEIhG9P+D/Nkyr1Yrk5GSwWCwMHjwYlZWVyMvLg4+PD4RCIXJyctDe3g6VSgUnJyfIZDL6fQBwd3eHr68vvT6zeT7ARvO3dTQYDHRz5/F4v/8yIairqwMAeHh40LHKZrPpvRYXF+PgwYMYN24cYmJiOm+MxQKbzUZgYCD69u2LqqoqGAwGBAQEdHuWj0JHs9kMg8Hwh3a02Wy4ffs2WCwW+vbti4qKCuTl5cHf3x92dnYoKCiAWq1GQ0MDtSOzQXK5XDg5OcHDwwP29vYQCATgcrl083xUOra2tqKhoQHe3t6QSqW//zIhaGpqAgC4uLigrKwMt2/fhkajgcVigaOjI1QqFfLy8hAYGAh3d3f6XTabjYCAAERGRiI/Px8dHR2Ijo5+UCe31zcTm82G1NRUAMCgQYNQXl6OnJwceHh4QCgUoqmpCVqtFiqVCiKRCAKBoNv3vb29ERQUBKFQCC6X+zAOX4/o+GcOl9VqxfXr18FisTBs2DCUlZUhKyuLrpNisRhlZWU4fvw4xowZg+joaLpZslgs+Pr6IiwsDEqlEmazGT4+Pg+q571vjNng/uD1p3Ljxg3y3nvvkfz8fGKz2YjJZCImk4mYzWZiMpmIWq0mM2bMIPHx8aSjo4N8/vnnBAAJCwsjgwYNIlKplPB4PPK/h9vtxeVyyfz580liYiKZPn06iYqKIjk5OUSlUv3VbXWVv9LvL3U0GAxEpVIRs9l8z/etVitJTk4mKSkpxGq1ki+//JKw2Wxy7NgxUl9fT4YPH07c3d1/px+bzSZSqZS8+eabxGQykdraWlJRUUEsFsuD6NcjOra0tJCcnBzS1tb2u/dsNhvRarVk6tSpZMaMGUSv15MtW7YQsVhM9u7dS/Ly8kh4eDiRy+WEx+PRF5vNJkKhkHh5eZHFixeTM2fOkMzMTFJSUkI0Gg3R6/WPVMfa2lqSmJhIlErlPd83m81k3759ZP/+/cRisZAvv/ySsFgssmfPHnLnzh0SEBBAhELhPccqALJq1Spis9nIjz/+SD755JN7Psve1lGpVJK8vDzS3t5+z/f1ej154YUXyEsvvUSMRiO146FDh0hZWRkZPHgwcXNzI1wul77YbDYRCATE1dWVzJ8/nyQkJJC8vDzS1NREDAbDg47Xv61jYmIiWb16Nblz584937dYLOTEiRPk1KlT3ey4f/9+UlhYSDw9PQmLxfrDNefVV18lZrOZrF69msyePZs0Nzc/iH49oiMhnfPuj0Sv15O5c+eS+fPnE4PBQNec48ePk/r6ehIbG0vteLeuLBaLLFiwgGRlZZGGhgbS0dHxp7/VWzparVZiNpv/8Lc7OjrI8OHDyciRI4ler6d2/Omnn0hmZibx8vIiXC73D+fjsmXLSFVVFdm9ezfZtm0b6ejo6BEdH8gFvlu8vLwwevRoODk5QaPRYMOGDfDz88OyZctgMplgs9kwb9481NXV4b333kNWVha4XC6MRiNYLBbMZjMsFgu9HpfLxbBhwwAAycnJyM/Px88//4yBAwdi8uTJIIRAq9VCKpU+qGf70MLlcmFnZwc2mw2NRoNNmzbB19cXy5YtA9DpzXh6eqK6uhpvvPEGysvLERoaivLycphMJtTW1qKtrY1eTyAQ4JVXXgGLxcL333+Pjo4O1NXVQSwWU70etYjFYri7u0MoFMJsNuPOnTuQSCQIDAwEi8UCj8fDkiVLoFKpcOzYMdTU1CAqKgocDod67larlZ6QhUIhvL29AQC1tbWQSCQICQmBRCIBj8eDVquF1WqFUCh8ZDrKZDIEBgZCLBZDo9Fg48aNdKwCnaeL6OhoVFdX480330RWVhYEAgESEhJw5coVKJVKmEwmej0+n49XXnkFALBlyxakp6fjs88+Q2RkJPr06fNIdWPEzs4Obm5uEAgEMBqNSE1NhUwmQ9++fQF0juXHHnsMjY2N+Oqrr3Dnzh0EBASAw+FAq9VCq9XCYDCAxWJBIpFAIpHAx8cHVqsVBQUFkMlkiIyMhLOzM50Tj1oCAwMxffp0uLm5/aEdIyMjqR0zMzPB5/ORkZEBpVIJrVZLIx8AwOPx8Morr8Bms2Hr1q1ISkrCe++9h+DgYIwaNQoSieSR6wj8X3hYq9Vi3759cHV1xbRp0wB0hmHHjRuHmpoavPXWW1Cr1Zg4cSLUajUyMjLQ1NSE9vZ2WCwWsFgsCIVCvPjiizQ0LZPJ4O7uDolEAj6f/0jW0bul62nwXnbkcDgYNGgQamtr8eabb9I16cCBAzh37hzUanW3vYPH42HFihUghGDbtm1obW1FYWEhAgICIBaLeywV9lAjnvwvTymTyRAWFgaxWIyOjg788MMPOH36NGw2GywWC2w2G0aMGIGwsDD88MMPuHnzJrhcLqxWK8xmM11kgc4FSCwWo2/fvujTpw/EYjEaGhpw8eJFhIaGYvLkyTAYDGhra4PJZKIx/N4WDodDQ4harRa//vorLly4QN9nsViQSqXQ6/X49ttvkZWVBT8/PzQ2NiI3NxcqlQparRZA52Zpb2+PWbNmYdasWbC3twchBK2treDxeJBKpf9PBq9QKISDgwN4PB4sFgvKy8tpeBLoXGgnTJiA2NhYXL16Fc3NzQgLCwOXy4VOp+sWXhWLxXByckJgYCD8/PzA4/Egk8ng5eUFJycnmu80Go2PVEexWEw3E61Wi59//hlnzpyh77PZbHh7e0MgEGDbtm24ceMG+Hw+EhMTcfjwYbS1tdE8Lo/Hg0QiwezZszFr1izI5XJUVlbiwIEDsLe3x4ABA2jO/lHK3XbMz89HRUUFgP8Lf0VHR8Pb2xsJCQkoKSmBr68v2Gw2dDodzGYzzXfKZDJ4enoiNDQUQUFBEAgEUCgUCAgIgEKhoA5B183nUYi
"text/plain": [
"<Figure size 576x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/50\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAADnCAYAAACTx2bHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9d3SU1fY+/kzNTOokmSSTQhqBJKRNQgJBQhUQkBYuTbGAUuSqFAuC2NArioqi3HsVFBuK0qRKFwgttHRCekib9DLJZHrZvz/iez6JoFIC9/tbK89aWWJmMnOe99S9z97P5hERetCDHvSgBz3owV+D/79uQA960IMe9KAH/39Az4bZgx70oAc96MEtoGfD7EEPetCDHvTgFtCzYfagBz3oQQ96cAvo2TB70IMe9KAHPbgF/N2GSXf6c/r0afr000/phRdeoH/84x8kEAho2bJl7PWGhgYaO3YsjR8/nlauXEmpqakEgCwWC124cIGkUmmX99/hz63gbz+HOkKJb/ipq6ujjIwMCg8PJx6PRzwej1544QX2utFopBMnTtD58+eprq6O2tvbSaPRUGxsLHv//ysc/+pHp9P9aZtbWlron//8J73//vvU1NREBoOB9Ho9PfDAAySVSv9/wdFkMlFjYyMplcqbctRqtfTNN9/Qv//9b/r3v/9NaWlp1NTU9Kfv/3+RY15eHu3du5d69+5NIpGI3N3dadmyZVRbW0sGg4FUKhV5e3szPh9//DE1NzdTXFwciUQikkqltHz58v8nORIRWSwWKi8vp/Pnz9P48eOpT58+JJPJaMWKFWS1WomIqLa2lvr3709Dhgyh1atX06VLl8hgMNCsWbOob9++xOPxaOnSpV3mu81m+9P5f7/7kWuLXq+nsWPHkpeXFwkEAnrhhRdYG1taWmjx4sW0Zs0aKikpodbWVgJAVquVLl269P/UunqzHw4Gg4EWL15MQ4cOJScnJ1q+fDnj2NbWRv/617/oX//6F3333XeUn5/P/vby5cv3lKPwFsnfNlxcXODn5wdHR0coFAoAgJOTEw4dOgSz2YyWlhbU1dXB19cXERERsLe3h1qtRnp6OjIzMwEAPB7vXjXvL2Gz2dDe3o7s7Gx4e3ujd+/eAACz2Yzc3FxoNBo0NTVBo9GgtbUVRASFQoH+/fvDzs4Oe/bsYe/Py8tDQEAAwsPDUVpaCpVKhfj4eLi5uSE7Oxu1tbXYv38/EhIS4OXldV85t7a24ty5c/D390dkZGSX14gIFosFZrMZOp0Ow4YNQ2BgIABAKpUyju3t7bh+/TqcnJwgFovB5/NhtVqRlJQEqVSKU6dO3Tc+N8OfcSQi1NbW4tKlS+DxeDCZTACAXr16QalUduFoMBiQnp4ONzc3JCQkoK2tDVevXkX//v3h5OSECxcuoKamBqmpqQgLC4Orq+t950lEfzl2nJyc4O3tjdGjR6Ourg4ikQgikQiHDh2Cg4MDNBoNDAYDXFxcEBwcDI1Gg4sXLyIhIYH1e3h4+H1icyNMJhNqa2vh6OgINze3Lq/V1dXhypUr0Ov10Gq1cHBwQHBwMBISEiCVSvHrr7/CZrNBrVajra0NMpkMoaGhcHFxgdVqhYODA2QyGZydnVFXV4fjx48jOjoacrn8vs5HIoLJZIJAIIBQKLzhNSKCwWCATqfDwIEDoVAoGJ9Tp05Bq9VCrVajtLQUEokEDg4OEIlEsNlsyMvLQ35+PogIVqsVJpMJIpHof7LGWiwW8Pl88Pl/bq/xeDxERUXB0dERvr6+kEql2LdvH4gIOp0O+fn58Pb2xuDBg+Hk5MQ+FwBkMhmkUum9aTzXEX/yc8dobW2l2tpaMplM7HcbN24kT09PsrOzYzv5lClTyGazUXV1NV25coUSEhLI2dmZJBIJvfzyy3/5HdyJ6y/wd/xuytFgMFBWVhaNGDGCPv30U/b75uZmWrhwISUlJREAcnBwILlcTsHBwTRlyhRqamqitWvX3nBamTJlClmtVlqzZg3Fx8fTtWvX6Nq1azRs2DAKDg4mqVRKe/bs6fKsbgN3xJGIKD09neRyOTu9dYbJZKKWlha6fv065ebmkk6nY8/7k08+uSlHm81GZrOZDAYDaTQaSklJIYlEwk7td4Fu52ixWGjXrl3E5/NJLpdTUFAQKZVKeuyxx0in092U4+TJk8lisdC3335Lc+fOpZycHDp16hQ5OjpSUFAQjRs3jlJTU+87R5vNRhaLhVlSf4TVamV9ZzKZSKvVUlVVFb311ls3cIyNjaVvvvmGnnrqKRo2bBjl5eWRyWQitVpNOp3uTrndNceGhgb6+uuvb3i+VquV9u/fT56enpSQkEATJ06kZ599lj7++GMyGAy0YcMGcnR0JKFQSACIx+NRcnIy2Ww2MhqN1NzcTG+88QbNmDGDwsPDKTQ0lMLCwujXX38lrVZ7K2tMt3E0m81UV1dHGo3mhtdsNhuZTCaqqKig7Oxs1jabzUZbt26luLg4kslkN8xHIiKj0Uivv/46PfzwwyQUCmnhwoXU0NBARqPxdrndNUer1UptbW1kMBj+9MOtViuZzeYuz/6zzz4jiURCfD6fceT6kXs+Go2GLl++TIMHD6Z169b96Xy4G448+mvhgls1vwEAVqsVPB4PfD4fJpMJZrMZUqkUKpUKn3zyCWQyGXx8fGAymWC1WgEAWq0W1dXVaG9vh8ViQUxMDBwdHWGz2RAVFYWkpKTbacIfcSvHpxs4qtVqtLS0IDMzEwEBAYiMjIRQKERDQwOSk5Nx/fp11NbWIjk5GUlJSRCJRDAajSgtLYVMJoOnpycKCwuh1+sRGRkJk8mE6upqXLlyBWVlZVi9ejV69eoFrVaLEydOYNeuXZg+fTqCgoJQVVWFiIgIzJs3r2sjf7f4BALBH09mt83RZrPh2rVraGxsRGFhISIjI/HAAw8A6LDI1q1bB19fX8yePRsWiwU2mw3Ozs6orq7G+vXr4ejoCA8PD+zZswf19fWYPXs2iAg1NTWw2WwQCoVYsmQJ7OzssHfvXkRGRmLIkCEAOsZIbW0t7OzsIJfLb6Hpd8aRiKBSqdDQ0IDLly/fwHHt2rXIzMzE4cOHMWPGDCQlJaGgoAAGgwESiQQymQzu7u7Ytm0bamtr8fTTT4PP56Ompgb5+fmoqanBjBkz4OHhAbPZjLa2NjQ1NcHd3R12dnaorq6+aT92J8c/8u1sLbS2tuLDDz+EyWSCm5sbRo0aBaVSidzcXJSXl+PAgQOorKxEWVkZamtrIRaL8fzzz0Or1eLKlStQqVTQ6/XYvXs3IiIiQETMKuDz+ezzAwMD7wtHvV6PsrIyuLq6Mo8V149Xr17F8ePH4e/vDx8fH0ycOBESiQT5+flwcHCAu7s7WltbYbFY4OXlBb1ej+rqamaReHp6sjFdWVmJwsJCjBw5EnK5HKdPn0bfvn3vC0ebzQaj0QihUAiRSASgw0t19epVSKVSBAYGoqWlBVqtFr169UJdXR3Wr18PT09PBAUFQa1Ww2g0gsfjsXUVAAQCAQYPHgw7OzuUl5cjMjISiYmJzCN0B7hjjtw6xufzIRAIAHR4cFJSUuDq6or4+Hi2OQkEAlRVVeGTTz6Bk5MTvLy8YLFY2Gvt7e2oqakBEUEoFOL555+Hg4MDzpw5g+DgYMTGxt4Jt7/m+Gc7Kd2BhWkymchisRARMSvDYDDQxYsXyc3NjZYtW0ZarZYsFgs7GezevZsAkFAoJJlMRhkZGXd2Hrg5bvskZLPZqLa2lurq6shms5FOp6P6+noyGo1UVVVFCoWCABCfz6fVq1dTaWkplZSU0L59+0gmk9GSJUuopaWFjhw5Qtu3b6fGxkb66aefCACJRCJydnamZcuW0bfffkutra30+eefU0BAACUlJdGIESNIJpPRtGnTqL29ncxmc5d26fX6Lr+7U44Wi4UOHjxIhw4dYv3FQaVSUWBgIE2dOpV0Oh3pdDrSarWk0+koNTWVXF1dadmyZdTe3k6vvPIKTZs2jaqqqujnn39mz0Umk1F6ejpZLBbS6XRdLGeTyUTZ2dlUUlJyT/vRarVSZmYmZWV
"text/plain": [
"<Figure size 576x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/50\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAADnCAYAAACTx2bHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9d3SU1fYG/MxkaiZl0ie9h4QUkhhC6EWqdBUu6BWBq6hXkCKKXL3Xe/UnonSw0RFEkE6QZqGFlh4IJSGd9D4p08v+/sh9z01I6AG+71t51pq1dJi88z6zz3vO2fvs/WweEaELXehCF7rQhS7cG/xnfQNd6EIXutCFLvx/AV0LZhe60IUudKELD4CuBbMLXehCF7rQhQdA14LZhS50oQtd6MIDoGvB7EIXutCFLnThAXC/BZM6emVmZlJgYCD95z//6fDf7/ZavXo18Xg8OnLkCJWWlpKrqyvxeDzi8Xj0zTffPNS1HvD1IOjU77wXx3Xr1pFaraaoqCj2Hvc6cuTI/yc5lpWVkZeXF1lYWLTjaGFhQdbW1iSVSsnCwuKZcExNTSWpVErz589/LI4+Pj40adIkMplMHX7+/PnztHjxYrp58+ZT51hUVER///vf6ZdffnlkjtXV1TRhwgQKCQlp8zyaTCYymUxERLRx40by8PCg33///f+zdnRzcyMLCwuSSqX0/ffft/v8unXrSCqV0q+//vrUOQIgailbeKjX+vXrydnZmU6ePEkAyGg00r59+9jzCIAaGxvp5s2b9Omnn9KwYcM6nHMe4rsfi6PJZCKz2fzYdmw9rxqNRjp06BDt3buXCgoKqKamhvR6/UN/z/04Ch6QfBvY2Nhg2LBhCA4ObvsL/bdEhcfjtXvfZDLB19cXEyZMQGFhIaqqqqDVatlnrl69il9//RV9+vQBEeHixYsQCoWQyWQICQmBvb09+Pxn6xBrtVqcOXMGdnZ26NWrFwDAZDLh3Llz4PF4GDBgAPz9/e/J8ciRI1Aqle2ufenSJYhEIgwaNAhisfhpUWqHh+VYXV0NjUbDbM9xbGhoABHBaDTCZDKBiJCcnAypVIr+/fs/VY4WFhZtxiTH0d7eHrGxsR3+zZ0c1Wo1SkpKcPjwYfD5fAgEAvTv3x9EhISEBGg0GnTv3h3W1tZPixaDpaUlevToAXd3d/aeTqdDcnIybGxsEBERwd5v/YxyHIuKilBfX4+KigqoVCoAQF5eHs6fPw9bW1sYjUbcvn0bSqUSgwcPhoODw9Ml+N/7FQqFbeyo0+lw6dIl2NraIioqqsO/u9tYNZlMuHLlCo4cOYIBAwYwO6pUKowbNw6urq5Pi1ob3Dl3arVaJCQkwM7ODjExMR3+ja+vL0aNGoXKykpcvHgRkZGR4PP54PF4ICIYDAZoNBrodDrweDwEBgbC2toaCoXiaVBqBx6P1+HzeOeck5CQAB6Ph/79+7ez453z6uHDh5GamgqTyYS6ujpIpVJYWlpCIBDA0tISAwcO7Jw5h4ju9Xoo6HQ6MhgMbd4zmUxkMBiosbGRNBoNmUwmGjduXIerupWVFV24cIHOnTtHEomEXF1dqV+/fnT69GnSarUPezt0H24PzZHzGCdOnMjeU6lUFBUVRTExMaRWqxnnsWPHPvSuxt3dncrKyv5/y1EikVBQUBBVVFQ8NY6pqakkl8tp4cKFbTi6u7u34dgR7jZW+Xw+yeVySk9Pp+TkZJJKpTR37lzmiT0iHsuO3K6dQ2VlJQ0YMIDmzJnT5nMGg6HNM2oymWj8+PHE5/OJz+cTj8cjABQYGEhDhw6lxYsX07vvvktCoZBmzZpFFRUVj/osPhbH9PR0cnZ2pkWLFrH3ysvLKSoqimbOnHnPL73XWJXJZJSWlkYpKSkklUpp3rx5j8rtsTl2hLKyMgoICKApU6bc83Mmk4nmzZtHI0eOpKqqKjp06BCJRCJasWIFNTQ00PXr1+nUqVP02Wef0fHjx9uMlUfAU5lzYmJiKDY29oHnHJFIRCKRqM17QqGQfHx8qLy8vFM4PpKH2REaGhrw1VdfwcfHB7NmzQIRwWw2g8/ng4ig1Wpx9uxZ/P7777h69Sr7O24HMHr0aPTr1w8eHh4AgJUrV+LSpUv4448/UF9fD41GA71eD4FAAKlU2lm3fV+cOXMGBw4cwN/+9jfI5XL2wwGAUqlEdXU19Ho9qqur8f7770MgEICIkJmZ+dDfpVQq8fHHH6NPnz544403OpvKXaHVaqFWqzv0jmpqalBRUQGtVov6+np89tlnMJlM0Gg0HXLkdo9ms5m9Z2NjA3t7ewwdOhSenp5YvXo1QkJCMG3atCfKCwA8PT3x1VdfoXv37uw9+q+HwdkRAAwGA9auXQsej4c5c+YgNzcXycnJuH37NgBAJBIBAMxmM1588UX07dsXO3fuBAAsXboUSqUSc+fOZTt7AAgPD8ff/va3J84RQIfRlzu9FYPBgH379oHH4+Hll1/GxYsXcejQIWRmZsJsNjMPjogQFBSE6OhoXL9+HSKRCF9//TWamprw5ZdfYsaMGez35PP5sLCweOL8PDw8sGTJEoSEhLD3iAgNDQ1Qq9XsPaPR2IZjQkIC9u/f3+FYHThwIHr06IGtW7dCIpFg6dKlaGhowNy5c/HGG28gKCgIWVlZkMlkCAgIeOIcO4KNjQ0++eQTuLi4tHmfG7s8Hg+XLl3CsWPHYDQa4e3tjeXLl4PP52PBggWIjIwEEcHZ2Rk2NjawsrKCs7Nzm7FhMpnA4/GeWQSv9YIEAPHx8Thz5gwqKirg6uqK+vp6nD59GsePH+/QjhMnTkSvXr3w3XffoaqqCp6enmhqaoJSqYSFhQXUajX++c9/olevXo89r3bagqlSqbB7924899xzePXVV2E2m0FEkEqlMBqNaGhowOXLl7Fu3ToAYA+mUCiEg4MDBg0ahClTprAJ5/XXX4dWq8WBAwfQ0NDAQnyWlpZPdcHMyMjAunXr0L9/fxb2MZlMUKlUqK6uRmlpKSwsLNDY2Ijvv//+ga4pFAoBtExgQqEQIpEIGo0GKpUKW7duRX19/VNfMGtra1nIwsLCAmazGSqVChUVFSgoKGAD7+eff0ZjY2OHYWWgxa4CgQBGo5E9BNbW1vDy8sLw4cPh6uqKv/3tb+jVq9dTWTCdnJwwa9asdvcolUohEAigUqnA4/Gg0Whw6NAh8Pl8TJ8+Hbdu3cLx48dRXV0N4H82MxqNiIuLw5QpUzBs2DDw+Xz88ccf+PHHH7FkyRKYTCYAgFgsxtixY5/agnknWnNUq9Vs05qQkAALCwuMGDECKSkp+P7772EymSAUCtssfK6urggICMBvv/0GOzs7TJs2Db/88gs2bdqEAQMGwNPTk13/aSyYjo6O7X5LbnwZjUZmR51Oh3PnzoHP52PEiBFISkrCN9980+E1u3fvjiFDhmDhwoWwtrbG77//ju3bt+Orr77CgAEDoFAokJ+fDycnp2e2YMpkMrz22mtsgTMYDDAajRAKhWyuzM7Oxu7du/HCCy/Aw8MDO3fuRFhYGN5//324uLiAiGBnZwcLCwvmkLQG59g8TZjNZsaDx+NBJBKBx+NBpVLhzJkz2Lp1K/h8PsRiMerq6nDx4sW72rFXr17461//il27dqGuro4dGTQ0NEAgEECr1WLz5s2oqal5tgum0WiEXq+HWCyGk5MT4uPjkZSUhH79+gFo2ZV//PHHMJvN+Oyzz1BeXg4iQmBgIGQyGW7cuIG4uDisWrUKO3fuxPDhwzF37lwQEdasWYNBgwYhISEBu3btwp49e1BRUYGBAwdi1apVj0X6UZCTkwMA7IHs168fDAYDLCws8O9//xt6vR7z589Hc3MzNBoNjEZjh9cRCAR49dVXAQA///wzXnrpJcydOxcffPABEhISnhqf1rhx4wZ+/fVXTJs2DTY2NrCwsMDFixcxePBgaDQaAMB//vMf2NjY4OTJk0hNTcWFCxdgMBhARJBIJCAi6HQ69iD4+/uDiJCfn48+ffpg9uzZUKlUKCwshLu7+zM7PwHQ4VgVi8VYvHgxiAijRo1CTU0N6urqoNVqIRK
"text/plain": [
"<Figure size 576x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<<44 more epochs>>\n",
"Epoch 48/50\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAADnCAYAAACTx2bHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9d5Td5XUujj+n996n9ybNjLpQpViIEggWHRzfxDi+iU0SO44d59qO78p1wnWSZTsumMRxBxswmGJMRwUQ6mWkkabXM3N67/2c3x/67s0ZwBjQzIi7frPXYgHSlPN+3vfz7vbs5xFUKhWs2Iqt2Iqt2Iqt2Lub8FJ/gBVbsRVbsRVbsf8XbMVhrtiKrdiKrdiKvQdbcZgrtmIrtmIrtmLvwVYc5oqt2Iqt2Iqt2HuwFYe5Yiu2Yiu2Yiv2Hkz8B/7+XSG0iUQCkUgEJpMJEokETqcTcrkcdrsd8Xgc4XAY9957L4RCIb73ve8hmUzC6/UiGAyiXC5j8+bN0Ol0MBgMePjhh7F3714IBAI0NDTgC1/4AmQyGUQiEYaHh5FOp7F+/XpIpdL3sz7Be/iaDwwTnp+fx/z8PP72b/8Ws7Oz0Ol00Ol0MBqNaGxshEQiwauvvgqVSoVVq1ZhcnISkUgEX/3qV+FwOBAMBtHS0oLVq1d/0I8ALOIaC4UCYrEYfvWrX6G2tha33HLLwh9SqSCXyyGZTCIWiyGVSiGbzcLr9cJoNGLVqlUQiUQQCoWoVCoQCoVQqVQLvp/+/H3aku7jgh9SqaBUKuHkyZM4cOAAjhw5ArfbjXg8DpPJhFWrVqGmpgZ2ux233HILzGbzgu/9l3/5F0xNTeGb3/wmDAbD+/nVy7LGcrkMQsaHw2H4fD7eo/HxcUilUjgcDuj1eqhUKmQyGYhEIl5LpVKBQCCAQPBePu7bbNHW+OMf/xhHjx7FY489hquuugq//vWv3/wlVZ+N/judTqNcLkOpVCKXyyEcDuPw4cM4f/48vF4vampq8L/+1/9CJpNBJBJBuVyGSCRCbW0tRCLRJVnju1mpVEIwGMQLL7yAWCyGaDSKHTt2oL6+Hnq9HlKpFHK5HP/6r/+KJ554Al/60pfQ3d2N3t7eD/L+vdUWdY2VSgWZTAYCgQBSqRTlchnlchmlUgkAIBaL2dcYjUYoFApIpVI+h88//zxOnjyJrVu3wmazobOzE0KhEAKBAIVCAZVKBXK5fFHW+Icc5tssFosBALRaLUQiEeRyOTKZDNLpNEQiEcLhMM6fP49kMslOUyAQ4MCBA0in0wiFQkgkEqhUKkgmk1AqlVCpVBgYGIDL5YJIJEK5XMaBAwcgFoshFAphMpmg0+n44X7Al3XRbWZmBgMDA4jFYsjlcshmsxAIBCiXywAAkUiEeDyOfD6PmZkZhEIhpNNphMNhKBQKZLNZpFIpxGIxqFQqiMXvezsuykqlEhKJBGQyGRQKBQQCAWQyGVpaWmAymRZ8bSaTQT6fRyQSgd/vx/z8PBwOBwc0Go2Gfyb9XIVCgb6+PqjVaqjVagD40Ozdu1mhUEChUEA+n2fnkkqlIBaLEQwGYbVaIZPJ3rYWgUCArq4uvrA+zFYsFpHJZBCNRvkCEolE/E+pVEIul4NQKOQAiIz+ezn3slwuI51OQyKRQCqVoqGhge+dlpYWjI+Pc7DjcrlQKBSgUqlQqVRQLpfR0tICuVyOV199FdlsFtlsFpOTk5ibm4PH40E8HsczzzzDa62trYVer8eHceyuUqnA5/MhEAggm81CKBTCYDCgVCrx/SsSiWA0GlEul9HU1IRisYh0Ov2hWw85RgpiJiYmoFAo+E6qVCpIJBIQi8WQSCRIp9MQCoWQyWSoVCrI5/OIx+OQSCQYHR2F3++HWCxGpVJBoVCAw+GASqVCsVjks3wx9r5u6HK5jJmZGQDA6tWrIZPJIJFIMDExgUQigZqaGoyOjuKrX/0qcrkccrkcUqkUyuUyvvjFL6JQKLBjFQgEvHnlchkSiQRCoRBSqRQSiQRHjx5lx/O9730PfX19EAqFKJfLHFlcysu3Uqlg7969+NnPfgav14tSqYRyuYxQKIR8Ps8HgdYwMjLCGz02NoZMJgOZTAaZTAapVIrW1lZotdplXUM2m8XExASsVisaGhogFouh0Whwww03vG2toVAIkUgEY2NjGBwcxNGjR/EXf/EXcDgc+NWvfgWxWIyNGzciEokgFothfHwcNpsNX/7yl9HW1obOzs7/J5xlqVRCOp3mgI/OqM/nQzQaRalUQlNTE3Q63TtmHrfeeutyf+R3NcrqAfA7Uy6Xkc1mEQqFMDMzA5vNBqVSCalUCqlUCplMhnw+j1wuB4PBAKlUilKpBIFAwA6lUqm838zroqxQKMDj8UCn08FiseDKK6/EFVdcgc985jMYHx/HE088gVKphHw+j1/+8pdIJBJobm5GPp9HNpvFP/3TP6GpqQmf+tSnkEgkYDQauYLldDqRy+Xw+OOPo6amBi0tLbjrrrvQ19eHUqnE99WHxYrFIgYGBuD1ehEOh2E2m1FXV4dMJgO3241vfetbEAgEaGtrQ3t7O3bv3o1MJgOfz4dyuczZ14fB8vk88vk8otEoDh06hK9//etobm5GbW0tWlpaeK1dXV3YuHEjPB4PMpkM5HI5stksAoEATCYTVCoVHn/8cQDA9ddfj0wmg3g8jjvuuAMdHR0ol8sQi8UfJNNcYO/LYQoEAtTW1vL/k0P87W9/i6GhISgUCng8HkSjURSLRXZuEokEWq0WmUwGpVIJe/bs4YcxPz+PgwcPcmZCkUVdXR2i0SgCgQAefPBBvPzyy6hUKqivr8fll1+OZDKJfD6Pyy67bNkdzcTEBJ5++mmcOnWKSzdisRhmsxnFYhHZbBaZTAZSqRQf/ehHAQB+vx86nQ4qlQoTExPsqOrr65HJZGC1Wpd9HVKpFPX19VAqlfxnxWIRMzMziMfj8Pv9XHpNp9OIRqM4ffo0gsEg/H4/Dh8+jLq6OvT19XE5JZFIwOPxIBKJoFgs4umnn8ZVV12FpqYmpNNplEoljhblcvklvYwoyxKLxbyHwIUSUHNzM3bv3g2BQACRSIR8Po/m5mbcdNNNaGpqgs1mQyaTQSKRwOzsLAcdUqn0bVHsRZSiL3p9wJuOkjKwSqUCqVQKu90OkUgEjUYDkUiElpYWAIBEIuGyWDQa5c+v1WphNBp5v+hn0Z4upQPNZDIYHh5GU1MTLBYLIpEIQqEQXnjhBUxPT+PkyZNckTIajaitrcX69euRyWSQTCZhNpths9nw9a9/HUNDQ3jppZdQKBSQy+VQqVQgFothMpmwfv16XHvttZiamsKZM2fwxBNPsIPp7u5Gc3Mzent7oVarlzVgqDaBQAC5XA6tVgupVAqxWIxUKoXBwUH4fD7U19dDKBRCLpfzHZPP56FWqznTulSfnYwyf5fLhUQiwVnyTTfdhNraWhiNRni9XhSLRTQ3N0MoFGJsbAzt7e0wGAwQi8XweDxwuVwALpxZu93O9zG1hubn5/l86vV6tLa2XtTnft8Ok3o2pVIJmUwGoVAIBw8exGuvvYZCoQAACyIYuhyVSiWnydu3b8fWrVtRLBZx+vRpDA8PI5fLoVAoQCaTQa1Wo7GxESKRCIFAAK+88gpHIv39/TCZTIhGo8jlcujr61t2R+PxePD0008jHA5zxEYbUigUuBynVqtx7bXXolwuY2RkBLW1tdBqtbjvvvsQCoXQ0dGBQqEAkUiEdevWLXu5WSKRwGq18v+XSiVks1lMTU3B7/djbm4ObrcbwWAQlUoF0WiUM3+hUIjBwUFEo1HU19dzH4LK7qlUCplMBgcPHkRtbS1nbPl8HhKJBDKZDBqNhsuBy13mK5fLKBQK3KOjfis5SIfDAYvFgomJCWQyGZTLZaxduxZ33XUXf7/f70cwGMT58+eRzWZhNBr5Eqs2uhyWw2HSGapUKigWi7weer7kRCUSCfcpyZnb7fYFpTsq8eVyOZRKJb7UaB0U5FZfwku1j/l8Hk6nE0qlEul0Gj6fD06nE48++ijjIhQKBeRyOVpbW+FwOND
"text/plain": [
"<Figure size 576x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 49/50\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAADnCAYAAACTx2bHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9d3jcZ5Uv/pne+2g0KqPeZdmy5d7idBJCSTZAgNCSwJIQLlzYXcpddrl3dwO7XFgCLB2SDSUBHBKHJE514thxtyVLVu/SVE3vfeb3h37nZOSEYCeSzX0enefxk9iWpe873/d9zzmf8zmfIygWi1i1VVu1VVu1VVu1Nzfh5X6AVVu1VVu1VVu1/xds1WGu2qqt2qqt2qpdgK06zFVbtVVbtVVbtQuwVYe5aqu2aqu2aqt2AbbqMFdt1VZt1VZt1S7AxH/h79+QQhsIBPD888+jubkZGzZsQLFYRDQaxe9+9ztYrVbcdNNNAIBsNotnn30WyWQSdXV1KC8vR3V1NX8fgUCAJ554Ah/84Adx/fXXY/v27bj99ttRVlaGeDwOmUwGhUKBQqGAYrEIkUi05DkSiQR27NgBoVCIw4cPQ6FQnP+oggv4DN5wjYVCAYlEAhKJBDKZDNlsFoVCASKRCAKB4HXPks/nEYlEIBAIIBQKUSwWUSwWkclk8Nxzz+Ezn/kMOjs70dbWhn/+539GbW3tBTzayq6xWCwim81CJBK9bj3098FgEOFwGGNjY4hEIggEAhAKhRCJRGhra4PBYIBQKMT09DSefvppVFVVwWKxwGw2w2AwoKurC3K5HDKZ7M8/XLGIyclJCAQCNDQ0QCB43ZLe8hovxvL5PBKJBKamphAMBuH1evHrX/8aTqcTn/3sZzE8PIxvfvObUCgUMBqN2LdvH9auXQuJRPJ2fzRwidaYTCZx+vRp3HDDDbBYLOjo6MANN9yAlpYWrFu3DlKpFADwox/9CM899xx+9KMfIZVK4b3vfS9uvvlm/N3f/R1MJtObvs83sUuyxmKxiFQqhbm5OXi9Xng8HlRVVUGhUMDn80GpVKK6upr3vUQigVgshlar5XNLe/AN9uJfskuyRmBxnel0Gn6/H263G7FYDAKBAF1dXXj22WfxoQ99CD09Pejq6rqoO6dQKGBoaAgA0NnZednOIwCk02nY7XYMDQ3h1KlTqK2t5ffkcrlw8uRJ7NixA2vXrsWaNWugVqshFF5YHhiJRAAAWq32jf76Ddf4lxzmG5pEIoHVal3ygyQSCWpqapBKpbBv3z40NjbCYrFgcHAQgUAAs7OzaGhoQCKRgFKphEwmg8lkgkgkglwuh8/nw+DgIJ555hlotVqkUilIpVLIZDJs2rQJVquVL/hYLIZEIoFgMIhCoQC5XP5WNvabmkAggFgs5g+f/isQCJb8rGKxiImJCQQCAUxNTfEBzGazEIvF6O7uRi6XQz6fRzAYxOzsLJ577jlYrVZIpVLU1taira1tyc/2eDw4ffo0VCoVRCIRNm/eDLVa/YZO7e2uUSgU8npKL4tsNotEIoF4PI5UKgWv14toNIpsNot0Oo1UKgWj0YhsNgufzwePx4NcLodEIoFIJAKTyYRsNovp6Wno9XqYzWbI5XKIRCIUCgX+2WRqtXpZ13Yxls/nkcvlMDs7i1AohOnpaUSjUUQiEUSjUcTjcfT19cFut0MoFEIsFqNQKODo0aNwuVwAAIvFgqqqKphMJsjl8su2ljezfD6PV155BWfOnEEul0Mul0M8HsfQ0BBCoRCUSiVyuRycTieGh4fh8/nw/PPP814AAJlMdsEX0uWyfD6PYrEIhUIBvV4Pap2LRqPweDwQCoXw+/2wWq0wGAzQ6/Uc5AJ4O87yklk+n+d353Q6MTc3h3Q6DaFQiFgshtHRURSLRYRCIczPz+PEiRNwOBx8d9tstj/7vQUCAXQ63SVczRtboVBAoVCAUCiETCaDTqeDVCqFQCCA3+9HLBaDRCLhfZzL5fjrL8TEYvFFv+O35DBVKhU2b94MkUjEF6xCocCVV17JGePnP/95vOtd78JDDz2E8fFx5PN5bN26FTfeeCMaGhpQXl6Obdu2QSqVwmw2Y3BwEK+++ioeeugh/jnktP74xz/ine98J/L5PMLhMMbHxzE1NQW73Y5MJgOTyfRWlvGmJhAIIJPJ+AP9c1lYNpvFH//4R/T29mLfvn2QyWTQarUIhULQ6XR4+OGHkUgkkEgkMDIygpGREbz88suQSqUoKyvDXXfdhX/6p39a8j1Pnz6Nj3/846irq0NdXR3+7d/+DQ0NDcvuMIHFTUNGF41QKEQ8Hsfs7Czy+TwymQyGhoYQiURQVlYGt9uN6elp5HI56PV6HDhwAAKBACaTiR1rbW0tYrEYjh8/jsbGRqxbtw4VFRVQKpVIpVIQCoVQKBR8SVmt1mVf24VaOp1GLBbDk08+ibm5OczNzaFQKCCfz8NutyMYDOLnP/85crkcZyMAcN9996FYLCKZTOLqq6/GBz/4QezYsQMVFRWXbS1vZul0Gl/5ylfQ29sLAIjFYlhYWODfB4NBOBwOPPbYYxxEfPazn+V/n8/nodPpVmQfXqyd79xKLZ1Oo1AowGw2w2w2o1AoYGBgAC6XC6dOnUI4HEYwGMQVV1yB9evXQ6fTQSwW8112fkD8537O5bR0Oo25uTn8+Mc/xujoKAYGBiASiSCVStHa2opwOAwAmJiYwOzsLCKRCPR6PfR6Pd75znfiox/96J/93gKB4E0d6qWybDaLXC4HkUgEg8GA+vp6TkZGR0eRyWSg0+mQSqXgdruRTCahVCoZBXwzKxaLUCqVF/1Mb8lh0qaiX4lEAoFAAD/84Q+RSqXwla98BT6fDz/60Y/g9XoZUp2ensZTTz2Fd77znchkMnj++ecxNTWFSCTCmWI6nWYHnMvlkMlk8Mgjj2BiYgJ33HEH1Go1GhoaYDKZ0NbWhurqaqhUquWCxV63zvONXmIqlUI0GkUoFEJfXx+Gh4eRy+VQLBZRKBQ4E/7e974Hr9f7uu9XLBYRi8Xw3HPPIRgM4s4778SaNWsALMIg9913H7RaLTQaDSwWyyWJ6gUCAXK5HBYWFpBIJJDNZqFWq6HT6VBeXg6tVov6+npEIhE4nU44nU5ks1l25lqtliFZm83GGzoWi6G/vx8KhQJSqZSh7r+GTKxYLOLgwYM4evQojh07hnw+j66uLkxPT+PcuXNYu3YtlEolnnvuOWSzWWSzWWzevBnr1q1Df38/Q7eBQAAvvvgiamtrYTAYlgRbfy1W6hCKxSISiQScTidkMhkkEgkOHDiAeDyOfD4PjUYDhUIBr9eLTCYDAJiZmcH+/fuxefNmlJeXX5bnB/A6RISCOoFAgEKhgMHBQQgEAnR0dODw4cPYt28fOjs7odfrYbVaUVFRAYVCgba2NlitVmSzWUQiET6/xWIRRqOR9yf9PEKKKMu5HOseHR3F7Owsent7OXMMBoNIp9NQqVQAgKmpKaRSKQCAVCqFQqHAunXrUF1djYqKClRWVsJut8NkMr1RieeyGQWpmUwGmUwGg4ODyGaz0Ov1OHnyJJ588klUV1dDJpNhbGwMIpEIJpMJVqsVCoUCwWAQmUyGkUGhUAij0fiG5YO3+v7eksMsXWA2m0U4HMb8/DwefPBBdHZ24lvf+ha++93v4k9/+hMfNgBwu91wu93o7u6GSCTCf//3fyMWi/ECyOmJRCIolUokk0mk02m8+OKLmJqawi233AKr1cq/AGDjxo1vZwkXbJlMBrlcDul0mjOShYUFeDwejI2NYW5uDvl8Hvl8Hul0GlKpFKlUCnv37n1d5EpQQDabxZkzZ3D69Gns2rULjY2NEAgEqKysxF133XVJ1lV6GGnD+nw+Xq9YLIZSqYTRaEQmk4HNZsPg4CAikQj8fj9kMhlqamogk8kgk8nQ2NiImpoaiEQixGIxaDQaRKNR+Hw+dHR0wGg0Ip1Oc2BRCglfaqNnOHv2LPbt24e5uTkYjUZceeWVmJycxOzsLG666SbU1tbi4MGDDEs2NzfjiiuuQDQahUwmQyaTQSwWQ19fH5x
"text/plain": [
"<Figure size 576x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 50/50\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAADnCAYAAACTx2bHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9eXTdZ3Uujj9nnuf5HB3Nsyxb8hAPiTMnJGSAhIS0gfZSKO0tNFwKtLSrUGgvLe1laAuL0pYFFEoghISQkHkiieMpniRLsmbpSOdIZ57n+feH7945MgEc25K535/2Wl6JbVk67+fzvu/e+9nPfragXq9j0zZt0zZt0zZt0369CS/3B9i0Tdu0Tdu0Tft/wTYd5qZt2qZt2qZt2nnYpsPctE3btE3btE07D9t0mJu2aZu2aZu2aedhmw5z0zZt0zZt0zbtPEz8G/7+klBoR0dHceedd8LlcmFoaAj33nsv2traEAgEoFQq0dTUhFwuh3g8jvvuuw8ikQg//OEPYbfbYbFY3vqD/V92r0Ag+HU/+tf+JX2rt72gX2OPPfYY/vM//xO9vb2w2+0YGhpCMBjEK6+8gvvvvx833njjpfxxwDqtsV6vo1arYXl5GSsrK/j617+OQCAAlUoFkUgEsViMgYEBGAwGxONxrKys4Pjx4+jq6kJrayv27NkDq9WKlpYWqNVqqNVqSKVSiESiC1jixr3HYDCIn/zkJ0in08jlclAoFBAKhcjlcshkMvD7/Whra4PT6cSuXbtgsVjQ1tb2m/bh+di6rbGRCX8+n7Ner6/5ulqthlKpBJFIBIlEciEfgX/8eXzNr1xjrVaDQCBY89nq9Tqvr1KpoFqtIhgMIhgM4uWXX0ZPTw927tyJSCSCTCaDUCgEiUQCo9EIgUCAer2OsbExKBQK3H777VAqlZDJZBCJRBf6TtflPSYSCczNzeHEiROYnp5GJBKB0Wjk+7Jer0OlUkEmk8FkMqFYLCISicDhcMBoNKJSqSAajeLFF1+EXC6HTqfD8PDwL92v1WoVtVoNYrH4161/Q85jrVYDAAiFQtRqNVSrVeRyOcRiMXzzm9+ERqPBnj17YLfbYTKZYLFYLnZ/NtpbrvE3OczztkKhgFdeeQWFQgEAIJVKIRaLUSqVEA6HsWPHDlSrVaysrCAcDkOtVuP06dNQq9UoFot8McvlcohEIhw6dAg6nQ5arRbbt2+H1WqFUPhmQkyb/bfJ6vU6NBoN3G43YrEY4vE4crkckskkfD4fvF4vlpeXYbPZIJVKAZzfBbbRViqVkM1m4ff7MTMzA4/HwxeOUChEtVpFuVyGVCqFSqVCJpNBNBpFKpVCLBaDUqlEOp2GQqFAOBxGpVJhRymRSCAWX7Jtd8msXq8jlUohmUyiXq8jnU5jdXUVSqUSQqEQkUgE+Xwe6XSa/763txd6vR7lchlisXjN/vxtssazQg6GnM+5AQw5y3PPlkgk4vWdZ7B6yY0+V71eh1AoRL1eRzabBQCIxWKEQiHEYjEEg0GkUilIpVJUq1WkUimk02mk02l4PB7IZDKUSiXUajXUajX+XtPT09Dr9TCbzWsCPIFAcFnebbVaRaVSQSAQgN/vx8mTJzE9PY3l5WVkMhlkMhm8/vrrvPdMJhPvVwC8vlqthkgkgmQyiVqtxs+iUqnAaDRi69ata4IEoVD4W3EvnfsZBAIBisUiMpkMwuEwMpkMlpaWoNPpIJPJAPxysHep7ZLdXLFYDB/84Afh9/sBAEajEVqtFolEAm1tbfjXf/1XPPPMM/jWt76FwcFBpFIpfPvb34ZWq8W+ffvw3HPP4eTJk7jqqqsgFovxv//3/0a5XEa9Xsd3vvMd3HzzzexkyH4bXuq55na7ce211+LrX/86RkZGUCgUIBAIIJVK4XA4IBaL8Y53vANGo5EjQ+DNS+hyX7q1Wg3ZbBZzc3N49tlncejQIZw5cwalUokDoEQigWAwiOPHj/O/oYssn88jFAqho6MDxWIRoVAILpcL9Xoder0eCoUCKpXqsq+z0Sh6XVhYQDgcRr1ex8rKCl599VXodDoIBAKcOXMGAKBUKlGv1yGVSrFjxw7o9XrkcjnI5XLI5fLLvJJfbY3OhiJ1qVQKuVzO54gieoFAwBctZXSNkTt93QWiBRe9hmq1CuCsQwkEAhAKhVCr1XjjjTcwOjqK1dVVyGQybNu2DeVyGR6Ph4M5yrC6urpQKBRQrVaxfft21Go1PPzww2hvb8fg4CBaW1t5v4rFYkgkkg2/b0qlElKpFF566SWMjY3h8ccfRzqdRjabhVQqRb1exzPPPAOZTAaFQoHu7m5oNBosLCxAp9Ohu7sbe/bsQU9PD9544w2USiU4HA74/X4sLi7C5/NBJpPh61//OqxWK6Mpvy12boAkFAqRTCaxurqKkZERAGf9jsvlwo4dO/gcXwQ68BvtohxmsVhENpvFE088gdHRUeTzeUilUs5CUqkUJBIJqtUqjh49itnZWWQyGbz00ktwOBzYtWsXUqkUnn/+eahUKrzjHe+ARCJBPp+HRCKBRqOBTqfDQw89hEOHDuFTn/oUpFIpEokE9Ho91Gr1pXoOF2zVahWHDx9GpVJBZ2cnjh49ioceegjJZBJ6vR7RaJQzspmZGYjFYng8HlgsFuzYsQNWqxUOh2NNBH+51xMKhbCysoL5+XkAgMPhQLVahVgshk6nw8DAAIxGI6amphCPx5HJZDhSHx4eRnd3NxwOB2QyGWdui4uLaGlpgcFgYKdzOQOeer2OUqmEdDqNUqmEYrGIlZUVBAIBTE1NQSQSYf/+/chkMsjlcgiFQqjX61AoFLDZbLBYLCgUCgiFQrBarSiXyyiXyxypb7Qz+U1GjmZkZAT5fB5qtRpzc3M4ffo03vOe96CjowPlchmzs7N47LHH0NHRAZfLhXK5DJFIBKPRCKvViqamJojF4g1dX2NGm81mEY/HUSqV2PHTn8vlcrS1tcHtdkMkEkGlUqFQKMDj8UClUkGj0eC6667jvSqVSiGTydDS0gKRSASNRgONRgOVSoV6vY5cLodwOAyZTAaj0Qi5XA6pVLph+1YgEKBarWJiYgJzc3PIZrMQCARQq9W8ns7OThQKBaTTaSQSCYacM5kMjh07BofDAYfDAbPZjHA4jJ///OdIJBKIx+MAziY2oVCIne5vkzWiHKVSCaVSCQaDAb29vfj0pz8Nr9eL06dPo1KpIJvNIpvNQigUwmw2r1uWfFEOky6cl19+GQcOHECxWIRIJIJUKmVvbzabIRaLMT8/j0AggEKhgPHxcQQCAVx77bWYn5/H8ePHcdttt+GKK67A1NQUCoUCX85OpxMvv/wyhEIh/sf/+B9Qq9UIBoOQSqWX3WFWq1UUi0WMjY2hWCzCYDBgamoKL730Eux2O7RaLbLZLMrlMorFIlZXV1EsFnH69GnYbDbIZDJ0dnZCpVJBpVJdTI3vkhhdqolEApFIBMFgEAKBAEajEfV6HWKxGAaDAd3d3RgaGoJCoYDH40EsFgNwFhYbHh7G0NAQO6JKpYJ8Po9CoQC9Xg+pVAqj0bghwcGvgg7r9TrK5TLy+TwSiQTy+Tzy+TzC4TACgQB8Ph+cTif6+/sZjjaZTOwwOzo60NzcjEqlgmQyiWw2y/sdOAs7/7bAWmRUh5ydnUU6nYbT6cShQ4fwk5/8BP39/TCZTKhUKjhz5gy+/e1v47rrrsPQ0BCSySQkEgna2tpQqVRgsVh+U33rklvjeywUCojFYshkMqhUKmv2kUwmg91uZ4eXSqWQyWQQiUQglUqh1WqxdetWFAoFBINBGAwGaDQaWCwWSKVSWCwWRksqlQrDoVTXFIvFv4Ryreeaq9UqO/yVlRWUSiXIZDLIZDLo9XpYrVbs3r0b0WgUPp8PJ06cQCQSQWtrK7LZLBYWFhAKhZDL5aBSqRCNRnHkyBHkcjmGY6VSKcLhMLRaLex2+2UPZBut0WHSPaLRaGAymdDS0oKRkRHMzc0xKpZIJPiOOrfWfansohymRCKBWq3Gn/7pn+KWW27BAw88gHQ6zbUtrVaLz33uc+jo6IBKpcKDDz6
"text/plain": [
"<Figure size 576x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"def train_gan(gan, dataset, batch_size, codings_size, n_epochs):\n",
" generator, discriminator = gan.layers\n",
" for epoch in range(n_epochs):\n",
" print(f\"Epoch {epoch + 1}/{n_epochs}\") # extra code\n",
" for X_batch in dataset:\n",
" # phase 1 - training the discriminator\n",
" noise = tf.random.normal(shape=[batch_size, codings_size])\n",
" generated_images = generator(noise)\n",
" X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)\n",
" y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)\n",
" discriminator.train_on_batch(X_fake_and_real, y1)\n",
" # phase 2 - training the generator\n",
" noise = tf.random.normal(shape=[batch_size, codings_size])\n",
" y2 = tf.constant([[1.]] * batch_size)\n",
" gan.train_on_batch(noise, y2)\n",
" # extra code — plot images during training\n",
" plot_multiple_images(generated_images.numpy(), 8)\n",
" plt.show()\n",
"\n",
"train_gan(gan, dataset, batch_size, codings_size, n_epochs=50)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42) # extra code ensures reproducibility on CPU\n",
"\n",
"codings = tf.random.normal(shape=[batch_size, codings_size])\n",
"generated_images = generator.predict(codings)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
"id": "NrLzMxweyLhO",
"outputId": "b21da95e-da35-45b6-bbee-fd6e8b915382"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAADnCAYAAACTx2bHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9d5TdV3U2/Nzee5/eZzTSaDTqsiVjW8bGhtimgwOhJCEhELKSlbwmb74XkgAJECCUmARCe0MzwYB7kY2FUbFVR9Joertze++9fn/Mt7fvGBtkaeZCvjV7LS9sNLpzz+93ztl7P/vZzxbU63Vs2qZt2qZt2qZt2q834W/7C2zapm3apm3apv1PsE2HuWmbtmmbtmmbdgW26TA3bdM2bdM2bdOuwDYd5qZt2qZt2qZt2hXYpsPctE3btE3btE27AhP/hj9fVwrtysoKpqenkUgkEAwG8elPfxp2ux1vfOMb8brXvQ579+5dz18HAIIr+JkrXmOtVkM2m0W5XEY+n8f09DSWlpZw/PhxlEoldHR0wG63o62tDR0dHVCr1QCAdDoNl8sFrVYLtVoNgUCAQqGA+fl5tLa2oq+vD3a7HSqVClKpFALBi1+7Xq+v+e+NXuMrfkC9jlgshp///OdIpVJIpVIYGhqCwWBAoVBANptFKBSC2WyGTqdDMpmESqXCvn37IJPJIJFI1nwWgN+0rkZryhoBoFwuIxQK4dixY3j44YfR1dUFg8EAuVyOZDKJ5eVlmM1mmEwmvOtd74LdbodQuC5x57qtsVaroVqtIp1OI5/PIxaLIZFIIJ1OQyKRQCqVwmKxIJVKIRaLoaWlBRqNBlqtFuVyGYlEAiqVitccj8dx5swZ6PV62Gw2fh5KpRJqtRpWqxXhcBjJZBLt7e2QyWQbvsbf+CH1OsrlMp588kl8//vfx6FDh9DZ2QmHw4FEIoHp6WnE43EUCgX82Z/9GVpaWn7n3uOVWLVaxfT0NM6dOweXy4V8Po/t27dDKBQil8vBbrfDbDZjeHgYSqVyvX5tU9f4W7KXXeNvcpjragqFAlarFZcvX8bs7CwKhQJv3n379jXzq7xqK5fLyOVyeOGFF5DP51GpVDA/Pw+3241gMAgA0Gq1EIlEAIBMJgOpVIpqtYpsNotgMAi73Q6TyQS5XI5yubzmzyKRCKRSKYaGhiCXyyGTySAQCF6NU1k3q1QqAAChUIh6vY5arYZEIoFUKgWFQoFgMIjl5WWIxWJoNBoEg0Hkcjmk02lotVpoNBr09/dDoVAAWL28Xur4fxvr+k1WrVZRKpWQz+eRyWSQSCTgdrsRiUQgFouRy+Xg9XqRzWaRy+WQz+dRrVbX66K9ZqtWq6hWqyiXyygWi5icnEQ+n0ehUEAkEkEqlYJMJoNUKkU4HEY+n0c2m0UqlYJcLmeHmUwmYbVaodPpEI/HkU6nUS6Xkc1mEY1G4Xa7Ua/X4XA4YLFYIJVKAQAqlep34lnQnq3VatBqteju7gYAxGIxxONxxONxLC4uQigUQi6X439qa12lUkEymYTP58Ps7CxyuRwAwOv1olKpIB6PIxQKwWAwoK2tDVKpFGJxU6/8/99ZU5+exWKByWTCP/7jP+KRRx7hi9jpdOK6667Drbfe+jt5kdbrdWSzWXg8Hvz93/89EokEpFIpfD4fotEozGYz1Go1KpUKPB4P5HI50uk0CoUCUqkUSqUSKpUKtm3bhu7ubthsNs64/H4/fD4fJicnUa1W8Zd/+Zew2+2wWCwQCoVNfx71eh35fB4CgQBisZgv4MnJSRSLRVgsFpw/fx5PPfUUpqamIJVK8fzzz6NYLEIikUAoFEKhUOCrX/0qLBYLf2atVuP1/C6+YwAoFArIZDKIRqPwer1YXFzExMQEyuUyarUaisUiUqkULBYL7HY73v3ud695l79tK5VKyGaznFX+13/9FwqFAtRqNfx+P2KxGFQqFcRiMcRiMYRCIUQiEUqlEmq1Gu/hTCaDkZERdHV1IRQKoVKpQCKRIJVKwefz4ciRIwiFQjh48CCGhoYgEonQ3t4Oh8Px234EAF7MsCuVCtrb23H77bdjYmICs7OzOHfuHGKxGPx+P0ZGRjA0NIRarfY/zmnWajXk83ksLy/j+eefx/3334+BgQHo9XqsrKzwvSqRSKBSqTA6OgqVSgWNRvPb/ur/o62pDrNaraJYLOL9738/brnlFtTrdUxMTOA///M/kUwm4XK5cPToUchkMrz1rW/9nYiGyuUySqUSlpeXsby8jGq1CoFAwLCWXq+HVquFSqWC0WiE0WiEyWRCIBBAPp+HRCJBPp9HOBxGf38/tmzZAoVCgVKpBJfLxZ+vUqkgEokwNzeHQqEAg8EAsVjMGWszjDJBsViMer2OSqXC38/n8yEWi6FWq2FxcRG5XA61Wg1SqRR79uxBqVRCJpOBRCKBTCbD5OQkEokE+vv7Ybfb0dHRgVqtxo74d8EKhQLi8TikUimEQiHm5+fh9/tx4sQJKJVK/Nmf/RmmpqYQCARw6dIlKBQK2Gw2mEwmGAwGnD59GtFoFAcOHIBYLOZ9IRKJmp5p1et1dujRaBTRaJSRApFIBJlMBpVKBbPZzLCsQqGASqWC1+tFPp+HVqtFoVBALpeDQqGAXq+HUCjkYJEci8PhgMFgQLVaRT6fRy6X43/kcvlvZe21Wg2FQgHlchmFQgHBYBAPPvggBgcHsW/fPpTLZWi1WszOziKZTCKbzUKpVMJqteLIkSOw2+249dZb+T3+rgV29KxrtRoqlQpmZmbgdrvx1FNPIZ1OY3BwEMViET6fD5FIhN+jUChEqVRih3rzzTdDrVZDrVZDqVT+zpzF/ynW1KdFl+qtt97KkOOjjz6K+++/H7lcDn6/H88++yxUKhXuuusuCIXC3zrEQ07e7/czFCWRSCCRSGA0GiEQCKDRaCCXy6HX62G1WtHR0QFgFZY1Go3IZrMQCoVoa2tDZ2cn6vU60uk0KpUKXzgKhQIymQxerxcymYyhvmZmmXTxELRH2SVBeoFAgCHkYrHIDrO9vR3FYhGhUAhKpRIymQwrKyuIRqOoVquo1+uw2WwQiUT8z2/rMqJMolqtIpfLIRgMQqlUQiKRYHFxEYuLizh27Bhuu+023HnnndBqtZibm8PS0hIEAgHsdjv0ej3UajUWFhaQz+exdetWyGQyiEQiVKtViMViyOXypl66FODkcjlEo1GEw2HO6kUiEaRSKe9RhUIBsVgMrVYLvV6PdDqNer3ONS6hUAiZTAaFQsHfP5fLoVqtolarwWw2o16vI5PJoFQq8eWcy+U4+GiW0X4lXkE+n0cikcDCwgJ+9KMf4e1vfzvuvPNOFItFCIVCaDQarrGq1Wro9XqMj4/DYrHgxhtvhEAgWPP9f9sOhd5rsVhEMpnkssGlS5cwNTWFBx98EL29vRgbG8Ps7CzC4TB8Ph9qtRp/91KphKNHj2JlZQW9vb2wWCyo1+scjNP//i4FCL+rtuG7oTHKPXbsGH74wx/iL/7iLzA2NgYAOHToEH75y19icXERZ8+exXvf+17I5XKcOXMGNpsNbW1tUCgUTc20Xmr1eh2pVIojuWq1ColEgoWFBfj9fgwMDECpVGJwcBBarRY6nQ56vR4AkEqlIJVKMTIyAoFAAKfTiWKxCIFAgO7ubqTTacRiMUSjUZRKJWzduhVGo5F/hi68Zlg+n0ckEsGXvvQlJuwoFAoIhUIkk0mo1Wq88Y1vxGOPPYapqSmIRCJUKhXMzc2hWq0CAPbv34+RkRE8+eSTCAQCWFlZweLiIvx+P3bt2sWO87cVDBE55fLly4jH4/B6vTAYDFCpVBgfH0c8Hoder4dYLGaHIBAIMDIyAo1Gg8HBQcTjcWSzWYyMjECn02Fubg4mkwlWqxU/+clPEAqFcNddd8FiscBqtW5o0EOoADk8u92OX/ziF5iZmQEAzhQ9Hg98Ph96enpgNBqh1+v5ItZqtYwqCIVC6PV6BINBZLNZRCIRKBQKvOY1r+H65dzcHJLJJAwGA7LZLCMl2WwWW7ZsaUpARHXaxcVFpFIpRCIRdHV1wWg04lv
"text/plain": [
"<Figure size 576x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 1715\n",
"plot_multiple_images(generated_images, 8)\n",
"save_fig(\"gan_generated_images_plot\", tight_layout=False)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RZN7cSsKyLhO"
},
"source": [
"# Deep Convolutional GAN"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"id": "NgT77mdVyLhO"
},
"outputs": [],
"source": [
"tf.random.set_seed(42) # extra code ensures reproducibility on CPU\n",
"\n",
"codings_size = 100\n",
"\n",
"generator = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(7 * 7 * 128),\n",
" tf.keras.layers.Reshape([7, 7, 128]),\n",
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Conv2DTranspose(64, kernel_size=5, strides=2,\n",
" padding=\"same\", activation=\"relu\"),\n",
" tf.keras.layers.BatchNormalization(),\n",
" tf.keras.layers.Conv2DTranspose(1, kernel_size=5, strides=2,\n",
" padding=\"same\", activation=\"tanh\"),\n",
"])\n",
"discriminator = tf.keras.Sequential([\n",
" tf.keras.layers.Conv2D(64, kernel_size=5, strides=2, padding=\"same\",\n",
" activation=tf.keras.layers.LeakyReLU(0.2)),\n",
" tf.keras.layers.Dropout(0.4),\n",
" tf.keras.layers.Conv2D(128, kernel_size=5, strides=2, padding=\"same\",\n",
" activation=tf.keras.layers.LeakyReLU(0.2)),\n",
" tf.keras.layers.Dropout(0.4),\n",
" tf.keras.layers.Flatten(),\n",
" tf.keras.layers.Dense(1, activation=\"sigmoid\")\n",
"])\n",
"gan = tf.keras.Sequential([generator, discriminator])"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {
"id": "0-Sj_zc9yLhO"
},
"outputs": [],
"source": [
"# extra code compiles the discrimator and the gan, as earlier\n",
"discriminator.compile(loss=\"binary_crossentropy\", optimizer=\"rmsprop\")\n",
"discriminator.trainable = False\n",
"gan.compile(loss=\"binary_crossentropy\", optimizer=\"rmsprop\")"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {
"id": "4Ss7zwLGyLhO"
},
"outputs": [],
"source": [
"X_train_dcgan = X_train.reshape(-1, 28, 28, 1) * 2. - 1. # reshape and rescale"
]
},
{
"cell_type": "code",
"execution_count": 59,
"metadata": {
"id": "xKWHUiUqyLhP",
"outputId": "576b04b2-5604-405a-a672-9693db6ef26a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/50\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAADnCAYAAACTx2bHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9aYyk13Ue/NS+72vvPfvO4ZDDRaQkStQCS4aTOIhjOLYDA0ZsAwEEOzCcBEGALIARG0iC/Ar8WYqT2LENJ4Zsy5YoURaphaREDmchZ5/et9r3ff9+tJ8zt2qqu6u6q2fopA7QmOnu6qr3fe+9Z3nOc87RdDodjGUsYxnLWMYylt1F+6QvYCxjGctYxjKWvw0yNphjGctYxjKWsQwgY4M5lrGMZSxjGcsAMjaYYxnLWMYylrEMIGODOZaxjGUsYxnLAKLf7ZftdlsotBqNBhqNZqg373Q6qFaryOVy+OCDDxCJRLC+vo6NjQ3k83m0Wi1YrVYEg0G43W44HA6cPHkSbrcbU1NTsNlscDqdQ322yvrVDPZHu9KEd2IRD/sshpVOp4NOp4N2uy333+l0uv4PAFqtds8L6fzNi0d9zbyGg74v76v3fSmDrCP3qnpNrVYLnU4HzWYTpVIJ+Xxe3r/ZbKJWqyGbzSKfz6NUKsFut0Ov10On08lzttvt0Ol0SCaTaDab6HQ6ePDgAdbX13Hv3j2YzWZ84hOfwIkTJ3Dx4kU4nU6YTCaYTKah9uyw98iXq3/WarXQbreh1Wr5enkenU4HrVYLDx48QDQaxbvvvovbt2/j6tWrKBaLaDab0Gg0MBqNsNvtOHnyJObn5/FP/sk/wczMDIxGI/R6PfR6/b70wN9c68D3SNnpT3b6Oc+N8n5ot9td191qtVCv12EwGPreT+9+7L2Pfv/nl8FgGInOUa9B1QX8KhaLqFarSCQSKBQKyOVyWF1dRaFQgNVqxcTEBJ599lm4XC7Y7XYYjUZotdq+99rv+377S5F96xz1WfGr1WohFotBq9XCYrFgfX0dm5ub+IM/+ANUKhU8++yzCIVCmJiYwPz8PJxOJzqdDkqlEuLxOD788EMsLi5idXUVOp0O58+fx8WLF/Hiiy/CarXCaDTKeVS/9pC+L9jVYB5E+DByuRzi8Tju3r2LRCKBaDSKQqGAer0OnU6HWq2GXC6HWq2GfD4Pl8uFdruNYDAoB/6wjdOopfeay+UyEokEvF4vHA7HUO+lLu5BnsMon2Gj0UC9XofZbIZOpzvw+x3G+vY4Tmi1WqhUKqhUKmg0Gmi1WsjlclhZWUGpVEKtVkMgEIDZbIbVapXnXiqVAACJRAIAYDAYUCwWUalU0Gw2Ua1WEY/HYTabYTabMTs7C7fbDbfbDZ1ON9Dz2c/9q4aQ70FHgD+v1WrQaDTQ6XRy78vLy4hGo4hGo8hms6jVanLOWq0WtFotqtUqUqkUTCYTkskk7HY7QqHQvq+R/x/2Pntfz3vsZ7hUY1KtVlGtVhGLxZDP55HJZDA5OQmz2YxkMolKpYJ8Pg+9Xg+j0YhLly7B5XLBZDKh3W6j1Wr13duqo9przEZVntfv3tT3r9frqNfrSKVSYjBoMJPJJMrlMprNJhwOB6rVKmw2m9xTp9PZdT/udg/71cP9/o77Tb0/rVYLrVYrDqter4fH44HVaoXX64XNZoNer5e9W6/XxeGNxWLI5XIAAL1eL59Jh4gOJHBwXbOrwVQ3yE43v5s31mw2sbKygnv37uFP/uRPkMvlUCwW4ff7YbPZ4PV6Ua/XkUgk5OBWq1XMzc1hbm4OFotlqI34Uagp5aFllNLpdBCPx/Gd73wHL774Is6ePTvQ+6hGckCvb8/3GpVUKhWkUimEQiFYrdaRvjfw6L4b5m/UL9Wb5OFaWVlBPp9Hu93G2toavv3tb8sBPnv2LDweD8LhsByyZrOJer2OaDQKo9EIv9+PeDyOVColRvPatWtYXFzEtWvX8Oqrr+Lo0aM4efLkyByKftJ7nzqdDq1WC7VaDfV6HY1GA9lsFnq9HjabDVtbW4hGo/jmN7+JaDSKWCyGdDotUTewvXebzSbK5TKKxSLW19fx8ssvo9VqIRwOdxmJ/SA+g8gg3r+qbFUUoVqtolarIZFIIBKJ4Bvf+Abu3LmD69ev4wtf+ALC4TDefPNNxONxrKysQK/Xw2q14nd+53dw/vx5TExMoF6vo1qtIhgMiuPUe32910FH4yDSq2eJLPX+rlgsIpPJ4M6dO8hms4hEIigWi8hms0ilUqjVavD7/QCAVCoFs9kMi8WCVqsFnU4Hs9n8yGfvtK79Pn+/96a+V6vV6npfGnK9Xi+Op9VqxYkTJwAAp0+fluf79ttvY3NzE7lcTp5TLpdDpVKB2+2GXq+X9bBarTCZTNBqtY840PuVPSPMvd58p99XKhUUi0Xcvn0bKysrcDqdCAaDcDgc2NraQqlUQjabhclkgs/nQ71eF09DlWEXa5QRabvdlv/320j9RKPRdEEfGo0GPp8PL774IsLh8FCfr8Kwg3z24xKz2Qyfzwej0fikL0VkJweDB9RoNMLtduP06dOiXPV6Pdxutxxgg8EAAPI7nU6HcrmMRqMhHm673Ybf74fX68XExASazSYajQYqlQri8Ti++93v4urVq/D5fAiFQrh48SKmp6cRCARGer/1eh3NZrMLatPr9bBYLAC2zwGVf6VSwTvvvINr167h/v37KJVK4tSFw2EUi0U5fzqdDgaDATqdDkajET/60Y+wurqKSCSCmZkZHD16FDabDSaTqa9B6ZVhz+Mgr+W6ttttlMtlpFIp3LlzB3fu3MHy8jLK5TLK5TKi0SgSiQQajQbef/99mEwmbG1toVarwel0wuFwwOFw4N69e4hEIlhbW4PBYIDZbMbP/MzP4NSpU3uedTor+42+euFX/qvqwVarhWaziUqlgg8//BB3797F6uoqarUatFqtID7NZhPNZhOpVAqdTgd6vR52ux0WiwVerxcWiwVutxtOpxMejwcOhwNGoxFGo1HuRUX1eqNo7rNB73Wn16mpAr4nnUtes8vlQqFQQLVaRTqdRq1WQ6lUwvLyMiKRiETP4XAYU1NTsNvtaDQaqNVqiEQiyGazXc+3Wq1Cp9PBZDLtem17yaFAsp1OB+VyGZlMRg6bw+FAOBzG7OwsCoUCisUiyuUyDAYD3G43Go0GOp2OLB7fR91QvYpQlV4vaRR5td7Irvf//aTfZzscjoEjy37vt9PPn1REzUN2mLJfx4fPX90zrVYLer0eTqcToVAIWq1WYCy73S7ePOEbwpoajUYgMKfTKUrM5XKJ195oNJBOpxGJRCT9wH189OhRmM1m2Gw2+P3+kTg73JfMwaqQk1arhcFgQKVSQavVQrFYFPj5gw8+wPe+9z2BmB0OB2w2G2w2G+r1unjlNLo6nQ5arRZ37tzB2toaCoUCLly4AJPJJBCZyWTa1VjsdIYOImoestlsIp/PY3NzE1evXsUbb7yBa9euyR4wm81y//fv30ez2YROp4PFYoHH44HX64Xb7cba2hpqtRq+853viFF5+eWXMTs7C4vFsmv02OsgDys0/P1gQ3Wtq9UqisUiFhcXceXKFSQSCXQ6HXi9XlkD1bDWajVUq1V5v+npaTgcDgSDQYRCIczMzKDZbApPRKvVChrSC/erEe9BjKUakfM1PKdGo1FSGHTGyuUyCoUC8vk8CoUCUqkUYrEYkskkGo0GdDodbDYbjhw5gnA4jEwmg2QyiYWFBRSLxa7PZQpQ5SjsR3Y1mL0PTn0Q/SJBKpBKpYIPPvgA9+/fRz6fh1arhdfrFY80GAwKMYL5EX4GPaBsNtvlNWu1WqTTaVF8zWZTFCEfMnFt1Vs5iPTCeh81+She06hkv/Az9yG97WaziUwmI4rNYDBAq9UK0cXj8cBkMokTYLfbMT8/j0ajIZFmq9WC1+sFsB3Z8V+DwYBWq4VoNAqNRoMjR46Ikpqfn4fBYMDdu3cxOzuLubm5LmdwP6IaH+7vWq2GRqMBjUYj1/x7v/d7uH79OhYWFgBsG8eVlRVkMhk4nU7J3VWrVRQ
"text/plain": [
"<Figure size 576x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/50\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAADnCAYAAACTx2bHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9d4ylaVYe/tyc871161au6pwnzzDMzrKzCVgwwsAiFoMsjGSMMJJlC8sGyWDhvwwCWTIGS7sEe1ekAbyzsOzuzOzM9qSentA5VFV35bo55/j9/qjfc+q9d6q7K1fvco/U6u6qG77ve9/3hOc85xydpmnoS1/60pe+9KUv9xf9QV9AX/rSl770pS/fDdI3mH3pS1/60pe+bEL6BrMvfelLX/rSl01I32D2pS996Utf+rIJ6RvMvvSlL33pS182IcYH/P67nUKr28RrHniPmqZBp7v/RzWbTaRSKXzxi1/Ef/tv/03eMzQ0BLPZDKvVirGxMQwPD+PXfu3XMDY2Br1++/5Kp9MBAOj1+k3dYy8b+kH3cz9ptVpoNBool8totVqwWq1ot9uo1WpotVrQNA2RSAQWi2XT36NpGjRNQ6PRgF6vh8FgkPdu9h63cg/q81C/e3l5GV/84heRSCSQTqdhNBrRbrdx8+ZNNBoNAMCnPvUpPPXUU/jsZz+LUCgk97mTZ4pd2qsPuTzwHjudjgasr0mn00Gz2ZT/82dWqxVGoxFG471VWKvVwle+8hX8zu/8DgDAaDTiyJEjCIfDOHr0KJ5++mmcOHECdrt9R2dxq/eIDdZR0zQ5Q+12G61WC4VCAeVyGYuLi1heXsbFixfRbDbRaDQQjUZRLpeRzWYRDAYRiURw7tw5BAIB+UyDwYBarYZGo4FarQadTgeXy4UjR47g3Llz0Ov10Ol0sNvtMJlMsNls6NUT/L/6fPbiPG4krVYLlUoFCwsLePvtt/GP//iPOHLkCAKBAIaGhhAIBDA1NYVms4lOp4OpqSnYbDY5i9s5j9xj97rHBxnMXZfNGJ+HTTZzvdVqFTdu3EA6nYbVaoXNZoPNZsPU1BQajQZisRiazSYAoNFooNFowGq17vWlA8BHDgF/tt11aDabKJfL0Ol0MJvNMBqNKBaLmJ6eRjabRbVaxfPPP49AIACXy7Wpz9TpdNyocpD5870SPpdWq4VWq4VMJoN4PI6FhQXk83kUi0WYTCYAgNVqhdPphN/vR6PRwOzsLI4ePYpGo4HJycnvuj39sAv3A9ClxGTfPkgh0rDSsTObzQCAZDKJWq2Ger2OqakpHD16dMPzcRDS6XRQLBZRLBaRz+cRi8WQz+cxNzeHdDqNRCIhZ4O6o1qtitOwuLiIdDoNh8MBYH1ft1otlMtlAGv72Gq1YmRkBC6XCxaLBZ1OR57BRs90P86iKly7TCaDbDaLa9euIRaLwWQywWAwwGAwoNVqIZ/P486dOzCbzTCbzSgUCmi1WjCbzTCZTHJ2tyr3u899NZidTgftdhtGo/F7TsGk02l85StfweLiIoLBICYnJxEMBvHEE09gdnYWly9fxtDQEADIodhKBNYr233f/Q7GZqTT6aBcLiMej2NoaAhOpxM6nQ6JRAJ/+Zd/iStXrmB1dRW/93u/hzNnzsjvNyN6vb5rb+zVHlEVZLvdRqVSQblcxq1bt3D9+nW88soraLVaMBqNXQpmeHgYjz32GM6fP493330XOp0Op06dwsTExPfcfj4o2eg5ttttAIDZbN509KBpGprNJkqlElKpFLxeLwwGA5aXl8WITk5O4umnn96T+9iq6HQ6QTjm5uZw69YtXL58GdFoFDMzMwAAh8MBr9cLt9uN4eFhAJBoGwDeeOMNFAoFHD58GK1WC8ViUQxsqVRCu91GvV5HOp2Gx+PBsWPH4Pf70W630el0up7pQe1nRtrNZhPT09O4c+cOvvKVr8Bms8Hn88Fms8FsNqNarYpzOzo6isHBQZhMJrhcLvh8Prhcrm0ZzAfd974aTJ1O1wW1fbcLI8Vbt25henoa169fh8ViwfHjx1EoFLC0tCSeTyQSgclkQqVSwcWLF7GysoJz587B7XbD6/XuyzNRvfbtCp0eHjBGhJFIBD/8wz+MVquFWq2G1dVVeDweTExMwGg0bhry2gVo877SbDbRbDaRy+VQLpeRSqUwPT2NlZUVBAIBVCoVjIyMoNlsot1uw+/3w+FwYGxsDF6vF3q9HqFQCLVaDeVyWRSa3++H3+//nnQGD0q4F/hMVfh1o2fMyKTdbiOXy+Gtt97C6uoqJiYmkM1mUSwWUavV0Ol0YDQacfHiRTidTvzkT/4kBgYGdhOW3ZbU63Xcvn0b0WgUiUQCHo8HNptNHGur1SoG0Gw2w2AwYGxsDCaTCWazWVIlU1NTaLVaSKfTqFarqNVqsNvtcjYdDoc8W71eD5PJJM/4IKLtTqeDSqWCZDKJ69evY2VlBclkEuVyGaVSCTqdTu5xenoa7XYbDocD9XodmUwGAEQP63Q6OJ1OPPbYY3j++edhs9lgMBh27Vr33WB+LymTer2OQqGAGzdu4MaNG1hYWMDExARGRkZw6dIlZDIZ2ZShUAhmsxm1Wg23bt1CNpuF3+9HJBKB3W4Xo/IwPx9N09BqtdBut7siVb1ej0AggKeffhrXrl3D3bt3kUwm4fV6Ua/X5TWbkb2+/3q9jmq1img0ikwmg4WFBVy4cAHT09N4/vnnodfrEQ6HUa/X0Wq1EAwG4XK5EAqFYLfb0el04PV6Be7LZrNYXFxEu92Gy+X6rnEKv5tSI3yeVHw0jIwi1dxmp9NBq9VCKpXCu+++i2QyicHBQUkVtNtt6HQ6WCwWzM7Oot1u4+Mf/7hEJNRRB7GG9Xod8/PzSKVSyGQyojMIJ1utVpTLZdRqNRgMBpjNZomc6bgCwOTkJJrNJpxOJ9LpNAqFAqxWK1qtFkqlEiwWS9f5VXPB+2k0qU+azSay2SwWFhbw5ptv4vr161hcXITP54PJZJJ70+v1WFpaQj6flzNYrVZhs9kAAFeuXEGlUpGI++zZszAajbtqMHUPeDgPB7i/fdmT5DRzeBcuXMC1a9dw/fp1WXCn0wmv14vFxUVUq1WMjY3B5/NhdHRUDjWw5hHlcjmMjY3h+PHjeOyxxxCJRBAIBLbq6W75HrerLEulEi5dugSz2Qy73Y6xsTExEvQQX3/9dVy+fBn1eh0ulwsf+9jHMDExgaNHj+5ECe0aeev8+fOYmZnBSy+9hE6nA4/HI9HmoUOHoNPpsLS0BLfbjUAggGg0ilKphHq9DqvVimAwiEOHDiESiSCfz6Ner6NWq+HUqVN45plnRNn6fD6YzWZYLJbNOkL7Rvoh7KVGGfskW7pHGkWiGUajUYzizZs3sby8jK9+9avI5XLI5/MCnx89ehTtdhtXrlwBsGZwZ2ZmUCgUEAgE4HQ6EQwGkUwmUSwWceTIEXi9XoRCIQwMDGBoaAjPPPOMoEJ7eY+UdruNa9eu4ed//uclSjp69CjcbjdWV1cFzWIer9lsQq/Xw+FwSF6P62kymcQYkWNQr9dRKpVw48YNDA8P4/jx43jiiScwOjqKxx9/XKLXvbxHVWq1GkqlEl588UXcvHkTd+7ckdwy15i5WMLJzWZT9oHNZuuKtmnwq9Uq5ufnMTY2hmPHjuGXf/mXcfbs2c3e1wPvcd9JP9/tQlijUCggkUhgZWUF+XwetVoNTqcTer0epVJJ2KLMv6hKs9VqAQByuRysViscDgeCwSA0TYPH4xFPd69kJ2SfWCyGYDCIUCi0Yb6Ryuj27duoVCqYnZ2FxWLB8PCwsPEOQprNJmq1GhYXFzEzM4PV1VUYjUa5JrfbDWDNmNhsNlitVlgsFgBr61WtVqFpGiqVCtrtNgwGA4xGozhP+XweuVwO7XZbFA/Zl1RyD5PQCD1s16WKwpKWnzGqjMfjWFxcRDQaRT6fRz6fh8VigcVigcvlgl6vR7PZlGdPJ8/tdsPpdMLpdKJUKqFSqSAajSKdTqNUKqFYLKJUKsHlciGTyWBoaAg2m23T5LXtiKZpSKfTiMfjSCQSsv+azSb
"text/plain": [
"<Figure size 576x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/50\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAADnCAYAAACTx2bHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9aXDd53Ue/tx93+/FxQ6QBMFNpCSSMhWJliVZzmLHniiexnGzjT/UnbRpkzTth37odNzOeNw0nbRpMmmTaePantiNXTu1LVtxZMuSbNkSJXNfAWJfLnD3fb/3/wH/5+DcK0oigQtIaXBmOBIJ4OL3/t73PctznnOOod1uY0/2ZE/2ZE/2ZE/eWozv9APsyZ7syZ7syZ78XZA9g7kne7Ine7Ine3IXsmcw92RP9mRP9mRP7kL2DOae7Mme7Mme7MldyJ7B3JM92ZM92ZM9uQsxv83X/65TaA138T1bWqNmFy8sLGB5eRnf+c53AABmsxnVahW1Wg3ZbBatVgsAcPjwYYyPj+PAgQMwmUyIxWKw2+3wer0IBAJwOBwIBoMwGu/Jj9mxNXZLq9VCpVKRv1utVtTrdfzgBz+Aw+HAgQMHkE6nUSwWcezYMTidzi3/rna7DYNBlrYja2y327KmcrmM5557DrVaDUePHkV/fz+Gh4ff9OcajQamp6cxOzuLL3zhCzhz5gw+/OEPIxKJwOl0wmQy3evjvO0aG41G22AwbOWz0Wq10Gw2sbKyglwuh+XlZVy/fh2vv/46nE4njEYj4vE4isUikskkXC4X7HY7jEYjrFYrAoEAzGYzLBYLPvCBD2D//v3Yt28fbDYbbDZbz9aIe9jHdruNUqmEUqmEVCqF8+fP49atW0gkEjCbzThy5AjsdjscDgceeughRKNRmEwmtNttNJtNfPvb38YPf/hD/NVf/RWazSY+8pGP4KmnnsJHP/pRffbuVd72B0ulUrvdbss5arVaMBgMMBgMMJvN8p4BoNlsIpvNolaroVwuo1AoyJpbrRZMJhOMRiNMJhOKxSLq9Tra7TZqtRry+TysViusViuSyaR87f7778ejjz4qv9tqtcpn9GqN2IZebbVaqNVqePbZZ3Hx4kX85//8n1EqlVCv13HkyBGMj4/jk5/8JEwmExKJBFwuF5xOJw4ePAiXywWbzQa73S7neot7eccfejuDuSddUq1WAQAmkwm1Wg2VSgUvv/wy5ufncePGDRgMBlgsFjQaDTSbTRSLRTGYAJBKpbC8vAyDwYBCoYC+vj6Mjo7C7XbD7Xa/U8t6W2m328hkMnjppZcAbKx/ZGQEJpMJt2/fFgNCBTU2Ngaz2Qyr1SrK2mQyiWK4298JYDvK6w1SrVZRrVZx+/Zt1Go11Ot11Ot11Go1FAoF1Ot1XLp0CclkEuVyGS6XCxaLBVarFZVKBfF4XAzszZs3sbq6ilwuh9nZWbz00ks4cOAAgsEggsEgHA4HfD5fz579Xt5dt6yurmJ9fR3T09PI5/PIZrOYn59HKpVCPp+HwWBAsVhEpVJBtVqFwWCQM2yxWOR3m0wmLC8vy7/5/X4MDQ1tRzFtWVqtFsrlMjKZDBYXF5FKpVAul9FoNFCv13Hr1i0YDAYYjUaUy2WEQiFUq1W0Wi00Gg2cP38eMzMzcDgccLlcGBsbQ7vdxpUrV+D3++FwOBAIBO7akHQ5eW8qfJftdlvuBH+Wf2+1WqhWq6hUKpidnUW1WkW9Xpf9icViaDQacjZ5PpvNphjMXC4Hu90Oi8UiZz2Xy2F8fBwAYDQa5c9u752WdruNQqGASqWCQqGAbDaLfD6PqakprKyswGQywW63w2q1wm63w2w2Ix6PA4A4d3TOXS6X6NF2u41QKAS73Y5GowGbzYZAILCtde8ZzHuQZrOJdDoNk8kEj8eDRCKBWCyGf/fv/h1u3LgBAB0eotFoFGOplV2j0YDBYIDT6cSpU6fw/ve/H/39/YhGo+/owX0raTQauH37Nn7zN38T7XYbHo8HP//zP4/BwUF8//vfx9LSEq5evYrHHnsMDz30EMbGxmAwGERJVSoVOJ1OmM3mLUVIvZB2u41sNotYLIb/+l//KzKZDMrlMoANB+Chhx5Co9HAc889hwMHDuDBBx/E2NgYAoEAgsEglpeX8dJLL2FlZQXxeBwXL15Eo9HAwMAApqam8JWvfAVPPPEEJiYmcOrUKQwPD+P+++8H0Bujv533du7cObzwwgt4+eWXkc/n4XK5UK1WUS6XUavV0Gq1YLfb0Ww2Ua/XRZkUi0WYTCZEIhE0m000m03YbDYsLCwgEolgYmICfX19sFgsu7qvjBITiQTm5ubwwx/+EPl8XiKRbDaLv/mbvxHnoL+/HzabDUtLS7JeALBYLDh9+jSOHDmCD3zgA5iensbnPvc5PPjggxgeHsZDDz0Em832tqgPDZXZ/PYqVX8PI95arQaDwSAOZqPRQDKZRDKZxLPPPotKpSI/12g08Nprr6FUKsHj8cDtdsPr9XagXnQALRYLLBYL+vv7UavVcOXKFQwMDMja7xHN2hFpNBqYn59HPB7HzZs3cf36dczPz2NtbQ3lclkiRbPZDL/fD4vFgosXL6JeryOdTsPlcsHhcEi03NfXh1QqhXg8jjNnzqCvrw+5XA79/f14z3veA5vNBrPZLAjKneTN7uuewbwHaTQamJ2dRa1Wg9FoxNzcHBYXF1GpVCSaIJzg8XhgNpsFruIlaLVa8Pv9aLfbiMViKBQKconfDYeX0mq1UK/XsbS0hLW1Nayvr2NhYQEDAwOYnJzEQw89hOnpabzyyiu4ffs2ms0mJiYmcP/99+OBBx7ApUuXMD8/j/e///340Y9+hG9+85v4hV/4BRw6dEigsbcSet29kvX1dSwtLeHy5ctYW1uD3W5HJBJBu90WbzOXy6FcLsPhcCCVSuHcuXM4d+6cODflchmrq6sgnEanwOPxwGq1wmazyc++9NJLcDqd+MEPfoCBgQFEIhEcP34cbrdbjNFu7nc8HsfMzAxarZYoDHrchOrcbrekEfjuA4GAQIVOpxNut1vWOzw8jEgk8o44eYTuWq0WisUibt++jYmJCUxMTODWrVuwWq14+umnkU6nsbq6irm5OWQyGTE6rVYLJ0+exKFDhzAxMQEA+JM/+RNEo1EcOHAAy8vLWFhYwNTUFEZHR3H27FlYrdY3Pbdb2UuegXa73RHFN5tNVCoV5HI55HI5OBwOgUyDwSBcLhd8Ph8qlYrc00qlArfbDYvFgmaziVKphEQiId8zMjICo9GIRqOBQCCAQqEgZ/GdlNnZWSQSCVy9elV0oN/vR71eh8fjQa1WQ6lUQrFYRKFQQLlcRqVSwdzcnBjRcDgMj8cDi8WCSqWCbDaLUqmEfD6Pn/zkJ7Db7Uin07DZbHj++efh8XjgcrkwPDwsEavH40EgEEB/fz+cTic8Hs8dn3fPYN6DNJtNLC8vo1AooFqt4ubNm5ibm0O9XofD4UA4HBZYgTkgl8sFYMPYVqtVtNttjI6Ool6vy4FeW1sTGOzdIO12G/V6HcViEfPz85iamsLi4iIymQwGBwdx6tQpfOQjH8Hv//7v4+bNm1hbW4PP58O+ffskt/Xtb38bdrsdjz76KC5fvozPfe5zYiwjkchdRSO9NJqZTAZTU1N47bXXsL6+jqGhIVitVomMDAYD1tbWUCwW4XK5UCgUsLq6ikQigVKpBKvVKpBZIBCA1+vFwMCAKFG/349QKIRkMolcLoepqSk0m00AwKFDh7B//36MjIyIx8/fqde6k5LNZrGysoJWqyUwHvM8fE63241KpSK5NSIJBoMB5XIZPp8P0WgUDodDlLfX693R534zocGs1+soFAqIxWK47777MD4+jqWlJTidTjz88MNYW1vDzMwMVldXBeWg83rkyBG8973vRTQaxeLiIv7kT/4ETzzxBM6ePYsf/OAH8r6OHz+OBx54AG63G3a7/Y57dS/7R/iVP8P/1wazWq1KrtJms8l5CYfD6OvrQyQSkbUnEgmsr68jEonAbrfLzzYaDaTTaYGjHQ6HRKWFQgEOh0Nype+EtFotLC8
"text/plain": [
"<Figure size 576x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"<<44 more epochs>>\n",
"Epoch 48/50\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAADnCAYAAACTx2bHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9eZDl51Uejj933/etb+/r7KtmtMuSbMkGY9kKBpMiJDgsMQEqBP5IgHIqCSQhSaWgIJVUBQoTMBgKG9vBG8ZYtrWOpJmRZp/pfd/uvu/L74/+PmfeezWSRtN9u9v+3VPV1TN9b9/+fD7v+57lOc85R9NsNtGVrnSlK13pSlfeWbR7fQFd6UpXutKVrnw/SNdgdqUrXelKV7pyF9I1mF3pSle60pWu3IV0DWZXutKVrnSlK3chXYPZla50pStd6cpdiP5dXr9nCm2z2USlUsGf/umfIpPJYGxsDJlMBrFYDMViEY1GAzabDQDQaDRQLpdRq9UwPDwMs9kMnU6HsbExHDt2DAaDAVrtPdl2zd1c6r188D6SXb/H9fV1zM7OYnl5GZlMBsViEV6vF8eOHcPAwAACgcBO/jlgB++xXq8jm83ir//6r/HGG2/g85//PADAYDDgk5/8JIaGhtDX14ehoSFMTExsfXCziWw2i+XlZbz00ktYXFzE5uYmrFYrSqUSbt68iUwmg1wuh5/6qZ/CmTNn8PGPf1z2927f470I2fIazdZl1Ot1NBoN6PV6+dkOyI7eY7PZRLlcRqlUQjKZxHe+8x1cuXIFDz30EPr7+3H//ffDYDBAp9O97e83m00kk0kkEglcuHABWq0WBoMBX//615HNZvG7v/u7CAQCMJlMWF5eRiwWQygUgtVqhcfj6fg9vpOkUim8+OKLePnll/H3f//3CIfDcLvdOHv2LIxGIwDgyJEjGB4eRm9vL4xG473q0XZ513tsNpvN7e6bZrOJarWKRCKB119/Hd/+9rdx+vRp9PT0IBwOI5FI4PLlyxgbG8PAwAAmJiZgNpvRaDSg0+nedt3vUu548e9mMO9Zms0mGo0GkskkIpEIGo0G0uk0EokEKpUKGo0GrFarvL9UKqFWq0Gr1cJut8PhcKBYLEKr1e7kge3KNqRWqyGZTGJychKvvfYa0uk0qtUqPB4PdDodZmdnUa/XUSqV0NPTA4PBsNeXLBKJRLCysoJoNIpkMonLly9jYWEB1WoVNpsNTqcTXq8XBoMBly9fxvLyMmZmZgBsOXT5fB6xWAy3bt1CNBpFKpUSpeRyuWCxWOB2u1GpVLC6uornnnsOvb29OH78OPR6/XYPb0eEZ9NgMMBgMKC3t1eu8/r161hYWMATTzwBl8u1x1f6VqnVaiiXy7h58yY0Gg0cDgeMRiNsNhsqlQqKxSIqlYroDhoKGslGo4FsNot8Po/19XUUCgU4HI6W1zY2NnD58mUMDQ1hbGwMxWIRhUIBKysrsNlscLvdu66bGo0GqtUqXnvtNWxsbGBqagqRSARmsxk2mw0Oh0POXaFQwK1bt7C8vAyn0wm9fkvdj4yM4NixYx29znt5Ls1mE7VaDaVSCaVSCdlsFrVaDbVaDYlEArFYDFNTU9jc3MT8/DwKhQLW1tYQCARQr9exvr4u33t7ezE2NgadTreja9RRg1mr1TA1NYXp6WlUKhVkMhmkUil5j9lshkajgdFoRD6fR7VaxaFDh+Dz+TA+Po7e3t6d9nC7sg0pFAq4dOkSvvzlL+NP/uRP4HQ64fF48JGPfASxWAwXLlzAxMQEhoeH8SM/8iPwer17fckily5dwl/+5V/ilVdewdramqAcABAOhzE8PIyJiQkUCgX84R/+IYrFImq1Gur1ukQyQKviBQC/348f/uEfhs1mg8ViQTQaxSuvvII//MM/xKOPPorf/d3fhdPphMVi2ZsbfweZnp7Gt771Lbjdbvh8Pjz77LPixP7Zn/0ZPvvZz+Lb3/42Tp48ucdX+lYplUqIxWL4sz/7MwQCAXz0ox+FzWZDf38/8vk84vE4crkc6vU6LBYL9Ho9tFotKpUK6vU6qtUqpqamsLCwgOvXr8NiseAjH/kIisUiEokElpaWcPnyZfzRH/0Rzp49i09+8pOIx+OIRqOYnZ2Fx+PBoUOHxAjtlpTLZaTTafz6r/86bt26BbvdDr/fj3A4jFAoBJ/PB5vNhnw+j83NTXzzm9/E1NQU8vm87OVf+qVfwn/7b//tno1ap/RxrVZDPp8Xx/b69etoNBoYGBjA1NQUJicn8cYbbwiqYDKZ4PP5EAqFMDo6igsXLmBlZQVf/epX8dGPfhSf+tSnYLVad9Rx79hql8tlZLNZrK+vIxKJwOfzwel0wmQyoVwuo9FowGQyQafTwWw2w2KxoF6vo1arIR6PI5VKwW63Y3h4GENDQ3A4HJ26VACd3Qh7Ldu9t1qthi996UtYWFjAtWvX0Gw28Y//8T9GOp2GwWDAo48+2gI/ajQabG5uolgsore3d0+fa6FQwM2bN/HGG2/gypUrsFgsGBgYwPz8PAAI/F8sFnHp0iXo9Xo8/PDDSCaTiEajYjgzmQwAQK/Xw2QywWAwYHBwEH6/H4cPH4bBYIBer4ff70cul0M+n0exWMSf/Mmf4KGHHsKRI0fg8/l2XcG+k5TLZWQyGZw9exajo6MSMQPAxz72MXFa95MwAlxbW8Py8jI2Njag1+vFuSFypdFoMDs7K7De5uamOOWVSgXZbBbVahWNRgNmsxl2ux2JRAK5XA6JRAKPPPIIDh06JHrp/PnzEsU99NBDcLlcOwVv3pWk02nMz8/j3LlzuHLlCuLxOCwWC/x+P7RaLeLxOMxmM/L5PLLZrKQearUaDAYD/H4/6vU6isUidDqdGJz3eg+dOstMy6XTaej1erhcLiSTScTjcczMzECr1eKJJ57A9773PcTjcZRKJXmfzWaDwWDA66+/jnQ6jQ996EMIBoO4cuUK+vr6BD3aifXq2OktlUpIp9OIxWJIpVLo6emB0WiE1WqVjWs0GmE0GmGxWGRjVioV5PN5JBIJ9PX1YXZ2Fn6/v+MG8wdZ2vNT70VqtRoKhQKee+45gekeeeQRPPPMM5ibm0Oj0cDx48fh8Xhgt9uxurqKaDSKeDyOcrmMcDi8pwazWCzi5s2bmJ6exurqKsbGxuD3+7G2tgYAsNvt0Gg0qFarmJmZgcvlwokTJ7CxsQGTyYR8Po9yuQyTyYRmswm9Xg+r1Qq73Y6TJ0/C7/ejr69P7tHr9aJUKmF1dRXpdBrf+ta34HA44Pf7W2CxvZZmsymRVn9/Pw4cONDy+iOPPIIHHngAZrN5j67wzqKmejY3N5HL5VAqlVCtVsVoMqLY2NhAoVBANpvF9evXEYvFxGgkk0kEg0H4fD4cPXoUVqsVmUxGULDjx49Dq9Vibm4OADAzM4ORkRG43W4cO3YMdrt91wxmtVpFMpnEjRs38I1vfAPf/e534XA4ZF/lcjmkUimYTCYUi0WkUiloNBrodDrUajUYjUbY7XYAgE6ng8FgQLVa3Q43ZMeFEHsmk4HBYIDFYpEoeWVlBYcPH8YDDzyAc+fOoVQqoVwuo9lswm63Q6/Xo9FoYHJyEs1mEz/5kz+JYrGImZkZOdtut3t/G8yXX34ZL7zwAuLxOLRaLfR6vRxSFR7R6XSwWCzIZrOSczCbzQiHw4jFYvjbv/1bDA0NoaenpyPXWavVZHP9oMq9bJRGo4FKpYJ/+Id/wOuvv47JyUk0Gg186EMfwsmTJ3HkyBFRqG63G/F4HC+99BK0Wi20Wq149oVCQSKyvZBSqYTZ2VkEg0F88pOfxODgIAwGA5xOJ9LpNIrFIvR6PQwGAxqNBjQaDfr6+hAOh3Hs2DG43W5x6PhMuI8TiQTy+TwuXbqEYrGIfD4Pp9MJg8Egz8Xr9co+/vmf//l9Ac1WKhVsbm5idHQUv/qrv4pgMPiW92SzWeRyOXF094toNBpotVocPXoUo6Oj6OvrQzabxeLiItbX15FIJKDT6eBwONDX1we3241arQaXy4VyuQy73Q6dTge9Xg+bzQaz2YxMJoNarYZisYh
"text/plain": [
"<Figure size 576x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 49/50\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAADnCAYAAACTx2bHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9eXCd53Ue/tx931fsILhvWi1Zki1KlhxFzuLYcRMnsZMuaaZN2iRN0kzSNJ35NZlJO2mbadJOYyeTsbPYjR07seRVomVJlizJEmlKpEiQIABix933fft+f6DP4XuvKBEkLkDSuWcGQxC4uPf7vvd9z/Kc55yj0zQNAxnIQAYykIEM5J1Ff6MvYCADGchABjKQW0EGBnMgAxnIQAYykE3IwGAOZCADGchABrIJGRjMgQxkIAMZyEA2IQODOZCBDGQgAxnIJsR4ld9fN4U2m82iUCggl8shl8vh9ddfRzKZxPr6Om677TYMDw/jzjvvRLVaxaVLl9BsNtHpdOByueDz+XDPPfdAr9+yPddt4jWbukdN09But1EsFmEwGGC326HX66HTbXxEvV7H8vIyzp07h9deew3tdhuapsFqtWLfvn340Ic+BLPZDKPxao/8mqVv93gTy1XvUft/dG+uxy0oO7KOxWIR6XQar7zyCk6fPo0TJ07AaDTCaDTCZrOh0Wggn88jGo0iGo3iYx/7GMbGxhAKhfrxbG/YXm21Wuh0Omi329DpdNDr9TAajf3QMb3St3us1Wpot9swGAwwGAwwmUxbvLSei9A0aJom63oN67tt57HT6eCll17C2bNnUa1W0Wg0kMvl0Gw20Wg0sHfvXgSDQYyPjyOXy+Hs2bNYXl5GLpfDxz72MUxOTmJqagoGg2Gra3vFC++79uaBu3DhApaXl1GtVlGpVBCLxVCpVAAAKysryOVyqNfr6HQ6KBQKaDQaaDab0Ol0CIfDOHr0KCwWCwwGQ78v8bpFr9fDZDJ1GUpg4zCWSiVcunQJy8vLyGQyADYWv1qtwmAwYH5+HsPDw/D7/Tfq8r/v5RY2ltsuZ86cwaVLl1Aul1EqlTA7O4vFxUWkUinZ0xaLBa1WC+VyGQBQrVbx3HPPIRgMwu/3Y3h4GJOTk3C5XDCbzTf4jq5NVOWp0+luib3Ca6RhYwngO127pmniGLRaLfl5p9ORNabDoL7XdjyPa31PXrvdboff78f6+joAwGq1wmQywWQyYWlpCUtLS1hYWBD9arFYEAwGYTQa0Ww2kclkYLVaYbfbYTQa+3pvfTeY+Xwep0+fxt/93d/h+eefl6jK6/XC4/HA6/XihRdeQCaTgcFggNfrxfj4OIrFIqrVKlZXVzE1NYXHHnsMXq8Xdru935d4XcJD5nQ6AUA2r6ZpKJfLWFtbw9e//nXE43HEYjFRPm+88QZ27doFm82GRx99dGAwt0luBQV4o0TTNHzqU5/C//k//weNRgMGgwE2m01QE2DDoBgMhi6FCwBPPPEEdDodHA4HfuzHfgz/8l/+Sxw5cgSBQOBG3tI1i16vf0vEwXvdhijzHWUzho+/1+v16HQ6Xf+/ksFX9VGtVkOtVkM+n5fXNhoNmEwmhMNhGI1GmEymbXUcrud9O50Oms0mQqEQDAYDkskkAMDr9cr7/fVf/zXOnTuHcrmMsbExPPTQQ9i3bx8mJiZgsVhQrVZx/vx5BAIBjIyMwOl09jUy77vBXF1dxec//3msrq7C6/XCYDCg0+mgWCzC4/EgEomg0+kgl8thfn4eZrMZgUAAzWZTPNtms4m1tTXodLqbxmBSWq0WisUiVlZWkMlkUCgUMD8/j2KxCKPRiKGhIYyMjOC1115DJpPBgw8+CIPBgBdffBEjIyMYGhqC3++HyWS6XjhkIAPZtJTLZcRiMZTLZVgsFlG+VLxqioDfdzodiWiobAwGA86fP4+//Mu/xLFjxzA5OYl9+/bB4XDA6XS+BXW5FeRGGczNSrvdRr1eRyKRQLvdRrvdht1uh9lsFn25srICi8UCh8OBRCKBarWKer2OZrMpOokOktlshtPplPcl4rVv377tSBVds3Q6HdTrdayurmJxcVHux2AwwOfzIRAI4Ad/8Adx3333odVqweFwYHh4GD6fD06nE5VKBdVqFdVqFa1WC9VqFXv27OkyuFuVvj+lZDKJZ599Fn6/H263GwBk8QDA5XJJZLm4uAiDwQCPx4N0Oo1OpwOdTodOp4N0Og2v19vvy9uSdDodVCoVpNNpzMzMYGVlBbFYDG+88Qb0ej3uuOMOuFwuuFwuvPDCC8jlcrj77ruRTCbx1a9+FYuLi0gmk3A6ndDpdGIw6f3eagpnIDe/VCoVLC4uolKpwGw2o16vA7gcUXIfAnhbpclIdGVlBfF4HAaDAel0Gh6PR5y/bcrPb6vQYN6sRrPZbKJarSIej4sRdLlcsNvtqFQqyOVyOHXqFJxOJwKBAObn55HP59FsNsWgEsoMBoOwWCywWCwol8soFAo4dOgQRkZGMDk5KXvhRkqn00Gj0UAymcSlS5ckVQdAHIW7774bZrNZHD9VqtWqvEej0UCpVMLIyAg8Hs/NaTAZUheLRQQCATgcDtTrdeh0OgSDQeRyOTz77LPYu3cvbDYbrFYrACAej6NWq8FgMKDRaKDT6cDn891U0WW1WkU6ncYf//EfI5VKIZ/Pw2w2w2QyYXR0FNFoFB/60Ifw0ksv4YknnsDS0hIMBgOOHTuGRCKBxcVFxONxPP3003jkkUfgdDqRzWbFO3Q4HLBYLPD5fDd84w7k+0MajQbOnz+PP/qjP5L9aLPZoNfr4Xa7BXql8rFarUKWaDab0DRNcl6apgmUt7q6imq1Cr1eD5fLBY/HgzvvvBN79uy5wXd8bUIjUSwWodPpxMHfbtns+X755ZexsLCAarWKZrOJer0Op9Mp0WKn0xEdOz8/j927d8PpdCIej6PT6cBgMCCXy6FSqWB4eBgOh6MrCDlz5gxWVlYQDocRjUYxPj6+TXe8Oel0OqjValheXsaFCxdw++23w2w2I51Oo91uY2ZmBh6PR+7dZDLB4XBI+iAQCMBiscBut8s+1jQN9XodVqu1L3q17waT0AG9ViafdTodKpUKyuUy/H4/PB4PPB4PjEYjarWasGR1Oh0MBgOsVutN5bFmMhmsra1hbm4OuVwOwIbXwwUym81wuVxoNptYWlqC1WpFIBCAx+NBqVQCsMGkLRaLSCaTKJfLKBaLcDqdMBgMaLVaMJvN8Hg8NxXRaSC3pjSbTZw/fx7nzp3D/Pw8qtWqsEINBoPAemo+TUU7uCd5HtvttqAgxWIR7XYbCwsLcDgccLlc2LVrV1eK4VaRmxHZYeSbTqexvr4OnU6HdruNWq3WBSNzbbheFosFNpsNFosFAGCxWIRISafcbrfDZDLBaDTCYrGgUqkgkUjAarXecIPJeyuXy8hms2i1WrBarXKd5XJZ9q6maYJq0N60222YTCZxhPjc+PubymAy2dxoNCTJTNiAi5bNZpFOp5HNZuF2u/GjP/qjaDabWFlZEQ/KYrHA6XTC7XZLBHozyHe+8x2cOXNGIstIJCJEpWQyCYPBgMXFRczPz+PChQv4+Mc/jjvuuAOtVguJRALz8/MIBoPweDw4c+aMeOfRaBQ+nw/z8/NotVoYHh4eGMzrkFtRWW+XaJqGTCaDX/7lX8ba2hpKpZI4dg6HAzqdThiULLMALiNEVDSdTgetVku+SBwplUrQ6XSiaL1eLw4ePIg77rjjpnJyNyuMLLeyh5j37cfZbbVaqNVqSCQSiMViCIVCaLfbQtiy2+3CBSkUCvD7/QgEAiiXy1hdXUWpVBLuiM1mg91uh9PphNlslvcwm824/fbbUS6XMTs7i3a7jdtuu23L174VoQNQKBSwvLyM6elpBAIBRKNRIaQxR0sylMViEYNINBO4vB6VSkXQzH6QnPq2uzudDvL5PKrVKux2O7xeL3w+HwqFAgDAbDYjHA4jGAzKIhKjLxQKaLVassg0niot+kYLcX9N06DX6+FwOFCpVNBsNmGz2VCv1zEzM4NEIiEQdDgcxsrKCtbX1wV
"text/plain": [
"<Figure size 576x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 50/50\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAADnCAYAAACTx2bHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9V5Dc13Um/nXOOU6ewQAYZDCIBCmSlkRJlkhZtpLLa7145X2w92WrXFu1Lw4bqvbB3tra3dqHLXttl11re122LNvaolQKFDMJRgAkMmYGE3p6Oucc/w/z/87cbg5IDKZn0LT7VKEATPzd3733hO985xxNt9vFSEYykpGMZCQj+WjR3u8HGMlIRjKSkYzkkyAjgzmSkYxkJCMZyV3IyGCOZCQjGclIRnIXMjKYIxnJSEYykpHchYwM5khGMpKRjGQkdyH6j/pkp9PpAoBWe292tdPp4Pr169jY2MBLL72EU6dO4amnnoLb7YbJZEK320W320Wr1UKz2USz2USn00Gr1UIul4PX60UoFIJGo7mn3w/gbr7xk04TPrA11ut1rK+v48///M/x+7//+zh9+jRCoRC8Xi/K5TLW1tbwxS9+EY8++iiefvppuN3uQfxa4ADXyPN47do1vPPOO/jsZz+LqakpaDQadDodNBqNrQfSaGCxWO75buwgA19ju93e+sEaDUqlEqrVKgqFArRaLYLBIAwGA4xG49YP7nZRr9fle2/cuIH19XV0u124XC48/vjjMBgMe13vgexjo9HAysoK/ut//a949NFH8ZWvfAV2ux16vR6lUgmtVgutVgvXr1/HysoKVldXYbPZ8KUvfQmhUAhjY2P7rnNYndBut9HpdNDpdKDRaGAymdBut+XjAGAwGNDtdtFut5FMJlEoFDA+Pg6LxdKzf7FYDNVqFQaDAcCW/nW73bDb7dDpdFsPt8O6+CxqxQT3mc+gfk6n0+3rPkajUcRiMfzoRz+C0WjEr/7qr8JkMkGv18vffJ5Op4N4PI5EIoH/+3//L9bW1rCysoLTp0/jySefxNNPP42pqSlZ/y5kxzV+pMEEdn7BdyPtdhuNRgPr6+tYWVnB8vIy6vU6yuUyHA4HjEYj9Ho92u026vW6GE+dTieH48iRI/D5fNDpdINUTCO5B+l0Osjn83jllVewuLiIbreLcrmMbDaLSqWCVquFdruNcrmMXC4nyvqTJK1WC/V6HYlEAouLi7h8+TI0Gg38fj80Gg263S4ajYYot0996lPweDxwuVz3+9E/JJ1OB5VKBbVaDYVCARsbG0gmk8hms9Dr9ZienobBYIBer4fNZoNWq0WhUJA7mk6nkc/nUSgUYLFYUK1W4fP5EAwGEQwGYbFY9mJU9k06nQ4SiYQo3StXrsBkMsFms0Gn06HRaPTopng8jmw2C5fLhbW1Neh0Ovh8vkE4Bx8r1HnAloHi+9RoNPK7u90uNBqNnD+LxYJutyvPx4+rDhwNCr9Oo9Gg1WoBgBjTfmm326jVatDr9WJcGo0GotEo9Ho9jEYjrFYr9Ho9rFbrwN5BsVhEtVpFIpGAyWSCz+fD0tISlpaWcPv2bZhMJly5cgVarRadTqfHFnAfeU4rlQp0Oh28Xi9MJhNqtRoikQja7TYmJiZgMBjuxXD2yEcazL0cmHq9jkKhgFdeeQWXL1/Giy++iFqthnq9Ll6Cw+FAo9FAoVCAyWSC0WiE0+mExWJBKBTCL/3SL2FhYUG+fmQ07580m02srKzg937v98Qgbm5uIh6Po1QqwWq1YmxsDNFoFA6HA7Va7X4/8q6k2+2iWq0ilUrhpZdewksvvYR/+Id/kHPHs9dqtVCtVqHVavHf/tt/w4MPPogHHnhgqM6mGo1Eo1FcvHgRr7/+Ot5//30kEgkYjUacOXNG1jY/Pw+73Y6lpSWk02msrKwgEAjA6XTi5s2bcmcffvhhPP300/j85z+P6elp6HS6oTKaRAAuXLiAS5cu4d1338Xrr7+OP/uzP4PRaIRWq4XZbEa73Ua1WhXj6Xa74ff7odPpUK1WMT4+DqfTKdHbfj6vath4hnje2u12j8HUarXwer3wer09753/9ng8H/od/PnVahXAzgZTo9GgXq8jHo/D4XDAZrPBaDQim83ixz/+Mex2O0KhEKanp+FwOAZqMNfX1xGJRPD888/D5/PhqaeewnPPPYeXX34Z0WgUZrNZjGEul0Or1ZKot1arIZfLQavVwmAwYH5+Hg6HAwsLC7BarchkMnjjjTfgcDjw1a9+FS6XC3a7fU/P+7ER5r0KvXV6NlarVTbeaDRCp9NBr9fL52w2G8xms8BE3JRCoQCPxyNe0zBd0H/qwiilWq0KSnDkyBGYTCa4XC7U63W02220Wi243W4cOXIEfr8fHo8H9Xod6XQaAGAymWC1WofSqDCiqlQqiMfjyGQy2NzcRL1eh9/vRyaTQa1WQ7vdFo/eZDLBbrejXq+jWCyiUCjAbDbDbDbf72UBgNwzk8kEt9uNo0eP4urVq9BqtbDb7TAajeh0OigUCiiXyzCbzeK8Unnr9Xro9XpMTU2JI5HNZnHz5k088MAD8Pv9cDgc93upPcJ137x5E7du3UK73RYYjwaJZ1an08l+uVwu2Gw2VKtVFItFZLPZHrhz0EJDSaSCf/pFq9V+SOftVv/x69WoTI1mKUajET6fD61WC+VyGc1mE+12G+FwGGazWfSzyWTa7XJ3lFarhUajgfPnz+PmzZuyR+l0Gh6PB2fPnsXs7Kw4CR6PB+FwGEajEQaDQQw/9U+r1YJOp4PBYIDP55PUHtdz4cIFjI2N4eTJk3uKNPfVYNKDAwCz2SzwAx9YjRrpubRaLTnknU4HpVIJDoejB0MfGc39Fx64SqWCfD6P5eVlxGIxzMzMIBwOY3Z2FslkEvV6HVqtFoFAAKdPn5bL32w2kc/nAUCUNOGU+71/3W5XcuaVSgWZTAa5XA5ra2vIZrPIZrPodrtwu93i3dZqNWi1WhiNRrjdbng8HrmMhUIBnU5H4Kz7vT4A8qwulwtarRYul0tQHaY9aAidTmdPbpbKy2g0YmxsDJVKBdlsFqVSSd5RpVIRmHNYhO99Y2MD0WgUAESJms1myeUCEAOq0+ngcDhgt9vR6XRQr9clz7lfwqiPxvBOd+JOhnS3wt/B39npdHqcV41GA71eD6fTiVwuh2q1Kl/n8XhgNBrFgbgTpLtb4d25fv06Ll++jE9/+tPQ6XTIZrNwOBw4dOiQpHmazSYMBgNMJhM8Hg/MZjMsFgsMBgMsFgtKpRLK5TJSqRQAwOfzoVQqCRJWrVZx69YtNJtNHDp0CBqNZrgMJr1XwnW1Wq3HcyXkQA+12WzCarVKroTwSS6Xw7Vr12C1WuXAqwdsGBTTPzXpdrsoFotoNBool8uoVquS2/D7/fjiF78IvV4vnhwvIS8UUQVeOpvNJnlNQj1UyBaL5cDXl8vlUCwWcenSJXluKo9GoyHQjtFoRLfbhdfrRbFYxNraGjqdDkwmEx5//HGcOnUK4XAYjUYDL7/8MhwOBzweDw4dOgS32w2bzXZfz6dWq4Xb7Uaj0YDRaMT8/DweeOABFItF8by9Xi/sdjvsdrsgPzabDePj43C5XDCbzcjlcjAYDDh27BhqtRoSiQTeeOMNZDIZPPvss0NnNLVaLebm5lAqlXDlyhXRE4zmjEYjNBqNKH69Xo/Dhw9jdnYWX/rSl+B2u4WUuF/SaDQERvw4PTYoVI0BSLvdlrynmkOl0eQd1Wg0aDQaWF5exvT0NObm5mCxWAa212tra7h48SKmp6cRCASg1WpRrVaRTqfFsWaevNVqSZRrMplgMpkQCATkGYlmms1mtFotJJNJVCoVlMtlWU+5XMbGxgbeeecdHD16FJOTk/e0loEaTHUTeClp4FTWFeEePrBOp5McZqfTkQi0Xq8jk8n0wCg7JclHMhghvNFsNsXDJhJA2MpmswmrD9iO1tTLTwXFvSRZhl4r9/Kghc5AOp1GIpEAAMlVkVQAbKEhLpcL4XAYTqcTtVoNOp0OrVYLBoMBExMTwqSsVquIx+MoFAooFApwuVz
"text/plain": [
"<Figure size 576x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code generates the dataset and trains the GAN, just like earlier\n",
"batch_size = 32\n",
"dataset = tf.data.Dataset.from_tensor_slices(X_train_dcgan)\n",
"dataset = dataset.shuffle(1000)\n",
"dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)\n",
"train_gan(gan, dataset, batch_size, codings_size)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {
"id": "yY8_xytVyLhP",
"outputId": "7f3ebd45-3bc5-42ce-d0a6-5734d360cef6"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAADnCAYAAACTx2bHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy9eXDc53kf/tn7vnexi5sAD4CXSEmkKFGWK9VX6jiKHcfN4bHrNEmnk6YzadrJP+lM0mQmk2kyv6lb52qS8UycOpabOIflK04kyzoskaJI8QRxAwtgd7H3fe/+/kA/D99dURJILADH2WeGIwpcAN/3+77vc3yez/M8mna7jb70pS996Utf+vLOot3vB+hLX/rSl7705Z+C9A1mX/rSl770pS/bkL7B7Etf+tKXvvRlG9I3mH3pS1/60pe+bEP6BrMvfelLX/rSl22I/l3+/Z86hVazjc9sa431eh2tVgsmk+ltP9NqtZBKpRCNRmE0GlEul/HFL34RxWIRAPDDP/zDePTRR+F0OqHT6ba1gG1Iz9b4fSy7usZWq4VWq4VGo4Fms4lyuQytVgudTgez2QydTgetViufTafTyOVySCQS0Ol0GB0dhc1mg91uv99HAPZwH1utFmq1GpaXl3Hr1i0Ui0VUq1Wsr6+j2WxCo9Hg6NGjOHDgAI4cOQKbzQaj0diLX33Pa2y32yCTn3tw129qt1Gv11Gv15HNZlGpVJDJZOD1ehEKhWAwGDr2sNFoIJvNIpvN4tVXX0U+n4fBYMCBAwdw9OhReL1emM3md/ydvVrj3YTP2Gq1YDQaodFooNF0/ujNzU0888wziMVi2NjYQKVSQbvdhsFggNVqhd/vx9DQEAYGBvC+970PHo/nXtfydrKrZzUSiSAWi+HZZ59FKpWCyWSC2+3GwMAATpw4AZ/Ph0ajgc3NTVy5cgUbGxtIJpOwWCzw+/143/veh1AohKGhoY59v0e56xrfzWD25f+JRqN52xffaDSQy+Xw+uuvI5PJIJVKYWxsDEajEW63G3a7HQaDAfF4HK+99hrOnTsHl8t1vxvZlx1Io9FAo9FALBZDq9WCRqNBJpNBoVBAuVxGs9lErVaTz/HCtVot+RnVahWNRgMAYDQakc1mYTQaYTAYkMvl0Gq1MDIyAo/Hg+Hh4f1a6l2lXq+jXC5jdXUV4XAYa2trsFgs0Ol0CIVC0Gq10Ov18Hg8MBgMKBaLaLVasNls0Ov10Ol0b1HceyHv9DtpXMrlMiqVCpLJJDKZDG7fvg2DwQCLxQKDwQCTyYRDhw6hVqshGo0inU4jn89jcXERjUYDPp8PyWQSsVgM7XYbNpsNbrd7X+4pDaRWq33L2pvNJt544w2srKzg2rVrKBQKKJVKqFaraLVaMBgMqNVqMJvNmJ+fx9raGsbGxjA8PIyhoaF92b/tSqvVwsrKCm7cuIHFxUUUi0VYLBZkMhkkEgkUi0XY7XbU63VkMhmsrKwgm82iWCzCZrOhVqvh0qVLGBgYwODgIKanp+H1ensWoPQN5jZFr3/7V1WpVDAzM4Nf/MVfRKFQQKPRwMc+9jFMT09jcnJSLt5f//Vf49KlS/jsZz+LI0eOwGKx7OEK+gIA5XIZ2WwW//iP/4hqtQqdTodr165hYWEByWQStVoNWq1WIkhGOKVSCcBWlOPxeOByuXDixAnY7XaJ2MrlMq5du4ZqtYpPfvKTePjhh/GTP/mT+7ziLVHXEYlE8NWvfhWxWAyRSASnT5/GyMgITp06JZFyrVZDvV5HLBaDTqfD4OAgLBYLbDZbL9GRbT3zuxmsZrOJUqmEdDqNbDaLubk5LC4u4s///M+xubmJzc1NGI1GeL1e/Kf/9J+QTqfx/PPPI5lMIp/Pw+l0wuPx4OGHH0apVEI+n8f4+Dh8Ph9OnjzZq+j6noXvWTVw7XYb1WoVv/M7v4PLly9jeXkZDocDgUAAtVpNUDCXy4VWq4WrV69ibW0N7XYbDz74ID7+8Y/v2f7dq7RaLVSrVTz33HP4v//3/yIajQIAPB4PqtUqSqUSNBoN2u02KpUKgK13Y7fbYbFYBOF5/vnn4XK5EAgE8J//83/GmTNnYLPZeuIo9A3mfUi9Xkc+n8fy8jI2NjYEhh0cHITNZoPX68XBgwfh8XhgNBqxsbGBr3/965iZmUEikcB3vvMdrKysYHBwEAMDAxgeHt43z/2fixB2zWazyOVymJycRLvdhl6vh8vlwtTUFEqlkkQq1WoVlUpFHKBGowGTyQS73Y5isYh6vY6JiQnYbDbYbDaJciYmJlAoFAAA6XQac3NzGBgYgMvl2re1U8FUq1UUCgVUKhXY7XY4HA4cO3YM09PT8Pv98Pv9Aj8zwvH5fGi1Wkgmk3C5XAIP7kXUtd370G630Wg0oNfrYbPZYDAYYDab4ff7UavVkM/nEQwG5Y9er4fX64XJZEKtVoPD4YDb7cbY2BhcLhfsdrtA89VqFQA63steSDcE2263JR0Qj8clbRAMBuHxeBAIBJDJZFCv1wUxiMViqNVqMJlMKJVKKBaL+H5uVFMqlbC+vo5oNIpMJoOHH34YRqMRqVQKqVQKxWIRR44cgdvtlntZrVah1+vlHrfbbSQSCTQaDayvr2NjYwPRaBQTExM9cRT6BvMepV6vo1gsIh6P4+bNm7h58ybi8Tiq1SpCoRBCoRDGx8cRCARgtVqh1WqRTCbx/PPPi1f05ptvIpFI4ODBg6hWq/B6vTAajQJ59aX3wigwn8+jVCohFApBr9fDYDDA5/OhUqmg0WigVqsJrAoAmUxGIlGLxQKfz4dwOIx4PI5gMAibzQa/3w9gS6mNjo4il8thbm4OhUIB4XAYFotl3wwmDXmlUkG5XBangM89OjqKAwcOwOl0QqPRyHtqNpsCxdZqNSQSCRiNRvn63XJquyHb+R00JgaDARqNBkajEUajES6XC7lcDjabDYFAAIODg/B4PGg2m3A6nTCbzWi327Db7XC73XJnLRYLms0mqtWqIA50rnbjfrbb7XddZ6vVQr1eRzKZxNraGmq1GnQ6HbxeLzweDzwej0RoVqsV9Xod6XRacqCVSkW4FN+vUi6XEQ6HBXpl/vz69esol8vQaDQYHh7G8PAwcrkcKpUK8vm8fL/b7RbnKZFIIJFIIBqNYnNzE2NjY32DudeSz+fxpS99CdFoFGtra0KWsNvt8Pv9OHfuHBYXF3Hx4kU8/vjjsFgskoyfnp6G2WyGyWSCzWZDMpnEzMyMXNTJyUkEg0G8//3v70O1uyDxeBzLy8sSyafTaQBbEKvD4YDX60W9Xpe8JQ3D0NCQKF+DwQCDwYBjx46hXq9Dr9eLoqtWqyiXyxgbG5N/q1armJ+fh9vthsfjgcVi2bN8WLvdRqFQQCqVwsLCAgYHB+H1emG1WmEwGPDwww/D5XLB7/cLuYUQaLvdRrPZRKPRgFarhc1mw5EjRyTCIjmKBmq/kRGSs7gGv9+PVColeerBwUEMDQ0hGAyK00AYt1KpQKPRwGw2o9FooF6vw2AwCBTM3GC1WoXP54PT6ez587/T+1P3cXZ2FpcuXcLs7Cy8Xi8sFgtisRjy+TzS6TTcbrfok3Q6jbW1NRiNRlgsFszNzcFkMqHZbL5jemk/ZXFxEb/3e7+Hq1evIp/Pdzh1dIJisZiQnDQajcDpDodD7u6RI0dQr9dx69YtPPvss5ifn8exY8d6Aq1/f76570PhRt26dUtIBWTONptN2O12jI2NCaGC+SJ6QoSJzGYzarUaqtUqisWiKBydTodCoYBDhw7B5/NhYGBgv5f8AyU0BCaTCTqdDuVyWTx7KhWj0YhGoyGRlhqJOJ1OaLXatxgWfo4IAeHZ4eFhJJNJbGxsIJFIwOFwYGxs7B1Z1r0SRpW5XA75fB6NRkPW2mq1JA9rt9thtVpFYbdaLfmcTqcTwhP/Tli72WwCgLyPvYQq7yYajUaiv3a7DavVCrPZLM+pRoaxWAzJZBKFQgG1Wg0AxJFV18X11Go1aDQaQSD2Umiol5aWkEgksLS0hJWVFYTDYTgcDvkcn1V95mazCZPJBLPZLMS0tbU13Lh
"text/plain": [
"<Figure size 576x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 1716\n",
"tf.random.set_seed(42)\n",
"noise = tf.random.normal(shape=[batch_size, codings_size])\n",
"generated_images = generator.predict(noise)\n",
"plot_multiple_images(generated_images, 8)\n",
"save_fig(\"dcgan_generated_images_plot\", tight_layout=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GUCR2FrdyLhP"
},
"source": [
"# Diffusion Models"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dyvKvczoil-7"
},
"source": [
"Starting with an image from the dataset, at each time step $t$, the diffusion process adds Gaussian noise with mean 0 and variance $\\beta_t$. The model is then trained to reverse that process. More specifically, given a noisy image produced by the forward process, and given the time $t$, the model is trained to predict the total noise that was added to the original image, scaled to variance 1.\n",
"\n",
"The [DDPM paper](https://arxiv.org/abs/2006.11239) increased $\\beta_t$ from $\\beta_1$ = 0.0001 to $\\beta_T = $0.02 ($T$ is the max step), but the [Improved DDPM paper](https://arxiv.org/pdf/2102.09672.pdf) suggested using the following $\\cos^2(\\ldots)$ schedule instead, which gradually decreases $\\bar{\\alpha_t} = \\prod_{i=0}^{t} \\alpha_i$ from 1 to 0, where $\\alpha_t = 1 - \\beta_t$:"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {
"id": "yxPNTP3QpHAw"
},
"outputs": [],
"source": [
"def variance_schedule(T, s=0.008, max_beta=0.999):\n",
" t = np.arange(T + 1)\n",
" f = np.cos((t / T + s) / (1 + s) * np.pi / 2) ** 2\n",
" alpha = np.clip(f[1:] / f[:-1], 1 - max_beta, 1)\n",
" alpha = np.append(1, alpha).astype(np.float32) # add α₀ = 1\n",
" beta = 1 - alpha\n",
" alpha_cumprod = np.cumprod(alpha)\n",
" return alpha, alpha_cumprod, beta # αₜ , α̅ₜ , βₜ for t = 0 to T\n",
"\n",
"np.random.seed(42) # extra code for reproducibility\n",
"T = 4000\n",
"alpha, alpha_cumprod, beta = variance_schedule(T)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P4ijMDaVkyXi"
},
"source": [
"In the DDPM paper, the authors used $T = 1,000$, while in the Improved DDPM, they bumped this up to $T = 4,000$, so we use this value. The variable `alpha` is a vector containing $\\alpha_0, \\alpha_1, ..., \\alpha_T$. The variable `alpha_cumprod` is a vector containing $\\bar{\\alpha_0}, \\bar{\\alpha_1}, ..., \\bar{\\alpha_T}$."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9z6rfMntnSMi"
},
"source": [
"Let's plot `alpha_cumprod`:"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAADICAYAAACu5jXaAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAqSklEQVR4nO3dd3wVZdbA8d8hCSUEQkkoAkqAgITeLSBgA1TErthdhUXFfZXyKrIrKoJl7WVBXzsLsiCusqiAhaCoCERKKCJZQUGlKUJQIJCc94/nhsSQnrmZe5Pz/Xzmk3vvM/eZcyflZGaeOY+oKsYYY0x5q+J3AMYYYyonS0DGGGN8YQnIGGOMLywBGWOM8YUlIGOMMb6wBGSMMcYXRSYgEXlZRHaKyNoC2kVEnhaRNBFZIyJdvQ/TGGNMRVOcI6BXgYGFtA8CEgPLcGBK2cMyxhhT0RWZgFT1E+CXQlYZAryuzlKgjog09ipAY4wxFZMX14CaAFtzPd8WeM0YY4wpUKQHfUg+r+Vb30dEhuNO0yFSt1tUVHNUIStLUAVV97UwVaookZFKZGRW4GvO86pVs4iKUkSCV14oKyuLKlXCb+xGOMYdjjGDxV3ewjHucIy56p49rN21a7eqxnvVpxcJaBvQLNfzpsCP+a2oqi8ALwC0adNGN2786ph1srIgPR327IFffnFfs5ddu+Cnn+DHH3OWn36Cw4f/2EeTJtCiBbRsCYmJkJQE7dtDQgJERJTtwyYnJ9OvX7+ydeKDcIw7HGMGi7u8hWPc4RgzY8Ygjz32nZddepGA5gIjRWQm0AvYq6o/lbazKlUgNtYtzZsXvX5WFuzeDZs3w3//65Zvv3VfFy6EV1/NWbd6dZeM2rVzS+fO0K0bxMWVNlpjjDGlVWQCEpE3gH5AnIhsAyYAUQCqOhV4DzgHSAN+B24IVrD5qVIFGjRwS69ex7bv3w/r18PatbBunfv68ccwbVrOOs2bQ/fu0KOH+9q1K9SpU16fwBhjwsANN8Bjj3naZZEJSFWHFtGuwK2eReSxmBjo2dMtue3ZAytXwooVOcubb+a0t2sHffq4pXdvOP748o3bGGNCSrt2nnfpxSm4sFS3Lpx+uluy/fwzpKTAsmXw2WcwfTpMnerajj/eJaMGDRrTrJm7vmSMMZXGunWed1lpE1B+6teHs892C0BmJqxZA0uWwKefwkcfwfbtbXjiCTeg4eyz4ayzXBKrW9ff2I0xJqhefNHzLi0BFSIiArp0ccttt4EqTJv2Jfv29eKDD2DGDHj+eXcdqnt3GDAABg92AxvCbISlMQbYt28fO3fu5HDeobUei42NZcOGDUHdRklFRUXRoEEDateunf8KQZg92xJQCYjA8ccfoF8/GDnSDf9etgw++MAtkybBxInQqJFLROefD2ecATVq+B25MaYo+/btY8eOHTRp0oQaNWogkt8tjt5IT0+nVq1aQeu/pFSVAwcO8MMPPwAUnIQ8Zv+nl0FUFJx6Ktx7r7tmtHOnG13Xpw/MnOmSUP36LhG99JK7xmSMCU07d+6kSZMmREdHBzX5hCIRITo6miZNmrBz5878VwrCEZAlIA/Vrw9XXw2zZrl7kxYuhBtvdNeRbrrJHRkNGgQvv+xusjXGhI7Dhw9To5KfrqhRo0bQTz/mZgkoSKpWdQMUnnnG3SSbkgKjR8PGjS4pNWwI55zjbpTds8fvaI0xQKU78smr0M8/cqTn27MEVA5E3M2tDz3kKjQsXw6jRsGGDe7eroYN4aKL4O23ISPD72iNMSYfiYmed2kJqJyJuBFzDz/sSgYtW+b+sfj8c7jwQjjuODfibvnyoJxyNcaY0lmxwvMuLQH5SMSV/3n8cdi2Dd59F848E/7v/1zlhqQkmDzZtRljjK9yF9b0iCWgEBEZ6a4JzZwJ27e7JBQfD+PHwwknuJF0777rbo41xpj8zJgxgw4dOhAdHU1iYiKzZs3yrnMbBVc51KnjRs198gmkpcGdd7pTdeed5yow3HefHRUZY/5o3rx53HjjjYwdO5a1a9dy+eWXM2zYMDJD+L9WS0AhrmVLdxpu61aYM8edlrvvvpyjonnz7KjIGAOPPvoot912G9deey0tWrRgyJAhpKenezfxnR0BVV5RUW6k3Pz5biTduHFuoMLgwdC6NTzxBPz6q99RGmP88Pvvv7NkyRLOPffco6/Nnz+fTp06hfTQcktAYSghAR54AL7/HmbPdjPAjhoFTZvCLbe44d3GmMpjzZo1ZGVl0aVLFw4cOMC0adOYPHkyY8eO9W4jd97pXV8BVgsujEVFwSWXuGXlSnfT68svw5Qp7ibYv/zFDWywwqjGlEF+U2dfdpn7b+/3390vWV7XX++W3bvdL2gekdntW7fCNdf8sTE5ucQhrlq1ipYtW5KWlkb37t1RVQYMGMBll10GwJYtW1i9ejVDhgwpcd9HnXBC6d9bAPvTVEF06eKSz9atrijq+vU5p+emTIFDh+xbbUxFtXLlSrp27Urr1q1ZunQpzz77LEuXLmXUqFEALFiwgLVr15ZtI59+6kGkf2RHQBVMfDzcfTeMHQtvveXuMbrlFqhT5yTuuANuvdXVrDPGFFNhRyTR0YW3x8Xl234kPd09aNasVEc8ea1atYqLLrqImJgYevbsSc+ePdmyZQvJycksXryYcePGUb9+fWbPns3ixYuJjY0t+UamTStznHnZv8UVVFQUXH45LF0KixfDiSemM2GCm9n1L39x9emMMeEvMzOT1NRU2rZt+4fX16xZQ58+fejbty8dO3Zk4cKFrFq1qnTJJ0gsAVVwInDaafDgg6msXetOXU+dCq1awdCh8NVXfkdojCmLjRs3cuDAASZNmkRKSgrffPMN48ePZ/ny5YwMFBDdsmULzZs3L9uGbBi2KYt27eCVV9zRz+jRrrJCt27uGuoXX/gdnTGmNFauXEnDhg2pW7cu/fr1o3fv3qSmppKcnEyLFi3Ytm0bjRo1KvtwbL8SkIgMFJGNIpImInfl0x4rIv8RkdUisk5EbvA8UuOZJk3gkUfcgIXJk939RKec4urQLV7sd3TGmJJYtWoVPXr0YP78+aSnp7Nz507mzp1Lx44dAdi6dSvHHXecz1Hmr8gEJCIRwHPAICAJGCoiSXlWuxVYr6qdgH7AYyJS1eNYjcdiY90NrVu2wKOPwtq1bsTpaae5KcatGrcxoW/lypVHk01+kpKS+O677+jQoQOpqaml39ADD5T+vQUozhFQTyBNVb9V1QxgJpB3MLkCtcQd48UAvwBHPI3UBE3Nmu6U3ObN8PTTbpqIs8+Gk092p+ksERkTulavXl1oAoqNjSUlJYXU1FQ6dOhQ+g01alT69xZAtIi/LiJyCTBQVW8KPL8G6KWqI3OtUwuYC5wI1AIuV9V38+lrODAcID4+vpunlVrLyf79+4mJifE7jBIrSdwZGcL8+Y2YMeMEduyoTps2+/jTnzbTo8ceyrOqR2XY16GksscdGxtLq1atPIioaJmZmURERJTLtkoqLS2NvXv3HvN63WXL6HznnSmq2t2zjalqoQtwKfBirufXAM/kWecS4AlAgFbAZqB2Yf22bt1aw9GiRYv8DqFUShN3RobqSy+pNm+uCqq9e6smJ3sfW0Eq074OBZU97vXr13vST3Hs27ev3LZVUgXuh+uvV2CFFpEzSrIU5xTcNqBZrudNgR/zrHMD8FYgzLRAAjqxtEnRhIaoKPjTn2DjRvjHP9ypuX79XJmfpUv9js4YE+6Kk4CWA4kikhAYWHAF7nRbbt8DZwCISEOgDfCtl4Ea/1StCjff7OYmevxxWL3aXR8aPNjVoDPGVAJ+DMNW1SPASGABsAGYparrRGSEiIwIrDYROEVEUoGPgDtVdbfn0Rpf1agBd9zhjoQmT4YlS6BrV7j0UqvAbUyF59d9QKr6nqq2VtWWqjop8NpUVZ0aePyjqp6tqh1Utb2q/tPzSE3IiIlxw7c3b4Z77nFzFLVvD8OGwY95T84aY0wBrBKCKbU6ddzsrJs3w223wWuvuRI/48dDPoNojDHh7MknPe/SEpAps7g497P59ddw4YX
"text/plain": [
"<Figure size 432x216 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code this cell generates and saves Figure 1721\n",
"plt.figure(figsize=(6, 3))\n",
"plt.plot(beta, \"r--\", label=r\"$\\beta_t$\")\n",
"plt.plot(alpha_cumprod, \"b\", label=r\"$\\bar{\\alpha}_t$\")\n",
"plt.axis([0, T, 0, 1])\n",
"plt.grid(True)\n",
"plt.xlabel(r\"t\")\n",
"plt.legend()\n",
"save_fig(\"variance_schedule_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZL7ZBkTKndy-"
},
"source": [
"The `prepare_batch()` function takes a batch of images and adds noise to each of them, using a different random time between 1 and $T$ for each image, and it returns a tuple containing the inputs and the targets:\n",
"\n",
"* The inputs are a `dict` containing the noisy images and the corresponding times. The function uses equation (4) from the DDPM paper to compute the noisy images in one shot, directly from the original images. It's a shortcut for the forward diffusion process.\n",
"* The target is the noise that was used to produce the noisy images."
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"def prepare_batch(X):\n",
" X = tf.cast(X[..., tf.newaxis], tf.float32) * 2 - 1 # scale from 1 to +1\n",
" X_shape = tf.shape(X)\n",
" t = tf.random.uniform([X_shape[0]], minval=1, maxval=T + 1, dtype=tf.int32)\n",
" alpha_cm = tf.gather(alpha_cumprod, t)\n",
" alpha_cm = tf.reshape(alpha_cm, [X_shape[0]] + [1] * (len(X_shape) - 1))\n",
" noise = tf.random.normal(X_shape)\n",
" return {\n",
" \"X_noisy\": alpha_cm ** 0.5 * X + (1 - alpha_cm) ** 0.5 * noise,\n",
" \"time\": t,\n",
" }, noise"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's prepare a `tf.data.Dataset` for training, and one for validation."
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {
"id": "i9qt5HPZ8rhv"
},
"outputs": [],
"source": [
"def prepare_dataset(X, batch_size=32, shuffle=False):\n",
" ds = tf.data.Dataset.from_tensor_slices(X)\n",
" if shuffle:\n",
" ds = ds.shuffle(10_000)\n",
" return ds.batch(batch_size).map(prepare_batch).prefetch(1)\n",
"\n",
"tf.random.set_seed(43) # extra code ensures reproducibility on CPU\n",
"train_set = prepare_dataset(X_train, batch_size=32, shuffle=True)\n",
"valid_set = prepare_dataset(X_valid, batch_size=32)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oAxL8EmF3Hcl"
},
"source": [
"As a quick sanity check, let's take a look at a few training samples, along with the corresponding noise to predict, and the original images (which we get by subtracting the appropriately scaled noise from the appropriately scaled noisy image):"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 271
},
"id": "y8mwYDDtz6Zm",
"outputId": "e5187d06-3d7b-4ab8-8b29-dc57c3428964"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Original images\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAAA9CAYAAAAgYlmOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA+HElEQVR4nO29aYhk13k+/tS+79XVe0+v09KMltFuyXIsGVtYRIoJCsb5kA/54EAcSAJJPoSAIRBICCSQgAIJhASDg00wNnKMosQaS6NlNFpm1XT3TO9r7fu+3t+H/j+nT92uXqa7q3qUfz0wzEx3Lffce867Pu/7ahRFQRdddNFFF110sT+0p30BXXTRRRdddPFFQFdhdtFFF1100cUh0FWYXXTRRRdddHEIdBVmF1100UUXXRwCXYXZRRdddNFFF4eA/oDfd4xCe+XKFczMzCCTycDn8+E73/kO9PqDLu9AaA7xmgPXqCgKNBrNrp8tLS1hY2MDS0tLsNvtePXVV2E0GqHV7m2HNBoNzM3NIRQK4dNPP8VDDz2E559/Hlar9ajrPZE13udo6xpXV1cRDAYRDAZht9vxta99DTqdbt/31Go1vPnmm8jn8+jv78fw8DDGx8ePeglAm9aoKArq9Tp0Ot2uPfzpp5/il7/8JX74wx9iYWEBo6OjKJVKCIVCeOmll/ClL30Jv/u7v4u+vr57/dq90LG9WigUcPfuXRSLRRSLRfh8PlgsFmg0GjQaDVQqFZRKJdRqNTzyyCOw2Wwn8bVAG9dYr9eh1Wqh0WiQy+WwsrKC119/XTxfnU4HrVYLi8Ui/l+pVFCtVmE2mzE8PIzvfve7941cvc/Rco3HvnP3gmKxiFwuh1gshnw+j0QigVqthkqlgoWFBWxubiKTyWBkZAT3Y7lLtVrF7OwsSqUSKpUKQqEQYrEYNjc3YTQaYTKZYLVaxeHT6XRiLcViEdFoFMlkEsFgEJlMBqFQCCaTCXq9HiaTCSaTCWNjY7BarXA4HKe82v+bKJfLKJVKWF9fR6lUQqFQwObmJiKRCOLxOCwWCxqNBur1OiqVSsvP0Ov10Gq1uHbtGorFIjY2NrC2toa1tTVYLBZYLBZMT0/DYDDsazy1GzxDWq0WtVoNtVoNc3NzyOfzqNfruH37NmZnZ1Eul2EymVCv1wEAZrMZ8Xgcd+7cwYcffoienh7o9Xp4PB4EAgE4HA4YDIZTW9dhoCiKUP5ra2vo7e2F1WqFRqNBtVpFLpeDyWSCxWJBLpcTZ/B+hkajgUajQb1ex/Xr1zE/P4+VlRUAgMlkQq1WAwAYjUYAEIqR+6BarSKdTsNut9/3a71foTlAMZ2o1lpfX8edO3dw6dIlLCws4PLly0in00gmkzCZTDAYDMjn83jiiSdw6dIlWCyW437lsSwh3hta5olEAt///veFwqvVaqjX6ygWiyiXy4jFYujp6cHAwAAAwG6343d+53fQaDSwsbGBixcv4tq1a2KtLpdLWIKNRgMOhwPf+973MDo6igcffLAja/yC4MTWGAwGEQ6H8YMf/ADBYBCrq6tIJBJIp9OoVCrQ6XRwOBzIZrOIRqMtP8Pj8cDpdKJQKKBer8NoNMJqtcLlcuHMmTMYGRnBn//5n8Pr9Qrh1ck1EvV6HRqNBlqtFtlsFslkEn/5l3+JpaUl5PN5RCIRrK+vw+/37zprmUwGpVIJ09PTcDgcsNvtePrpp/H1r38d586dg9frvZdLITq2V9PpNC5fvoxLly7h5z//OUZHR+FwOKDVapHL5bC5uYnz589jcnISr7zyCvr7+xEIBHZ54UfAicqcViiVSvi93/s9XL16FTMzMzCbzXC5XEgmkyiXy02vtdlsMJlMyGQymJ6exg9+8AMMDAwcN2rw/1uZ0zYPMx6PI5FI4N1330UkEkE4HEYmk0EymcTW1hbS6TRSqRQURYHb7YZerxdho3q9jmvXrmFkZARDQ0PtusQDIW/ad999F3fv3sXKygqKxSL0ej30ej00Gg0cDgfMZrPwTDY2NlCv12GxWPDZZ58hl8thdnYWoVBIvJ4eJaHValEqlfDzn/8ck5OTKJfLGBoagt/vP9Qh6mI35PtWKBSQyWTwox/9CLdu3cL8/DwKhQIKhQIajQYsFgtMJhMURUGtVoPNZoPL5YLX64XZbBYeSzweR71eF8+XVr9Wq0WxWMTi4iJCoRBef/11PProo/it3/qtpuvp5DPktQFAKBTC0tISFhcXsb6+DqPRCI1GA7fbLZS6RqMRIVwadaVSCdVqFcFgUIT8AoHAURVm26AoCtbX1xGJRFAqlaDX6zE+Po6rV68iHo9DURSxpnK5jFQqBYvFgqGhIRFlWF1dRSAQwNDQUMsQdqfXI39/rVZDPp/H1atXMTs7K4wen88HnU4HvV4Pp9PZFH7XarVCrur1epTLZfzDP/wDLly4gGeffRbT09PweDyntsYvItqmMBOJBJaWlvDGG29gfn4ed+7caQqzajQaEca02+3iMDYaDSiKgrm5ORiNxlNVmESj0cD169fxySefIBKJQKfTwe12Q6PRiOs2Go2o1WrCY1YUBUajEXfu3EEqlcLVq1dhsVhgNpuF1UfPstFooFaroVqt4qOPPkI8Hsfg4CBsNhv8fj8ajQYAHJhX66IZ8n0rFAoIh8N4++238cEHH8BgMECn08FoNIp9yPfk83nYbDb09PTgzJkzcDqdaDQayGQyQiin02lYrVYRcm00GiiXy4jH46jVasjlcshms/jWt74lDCugs0pT/p54PI7l5WVsbW0hGo3C5/MBAJxO567XK4oiwsmVSgX1el3se51Oh5dffrkj138Y8PzU63VsbW1hfn4euVwOHo8H58+fh91uR6FQQLlcFs+7Xq+jVCrBbDajr68P5XIZuVwOW1tbqFQq8Pv9MBqN4myftqFaq9VQKBQQi8XwySef4J133sHm5iZKpRLsdrt4HUOwfHbydRuNRlQqFfz0pz9FNBqF1WqF1+sVhuJpr/GLgrYpzA8++AC/+MUv8MknnyCbzcLlcgmLhw+TCpKhIyKVSuHNN9+EXq/Hk08+2a5LPBTy+TzS6TQ2NjawtbUFh8MhFBeFZa1Wg0ajQX9/P3w+H4rFojAOMpkMFEXBuXPnYDQamxLu1WpVKF2dTie8m3w+jytXrmBwcBBTU1NdRXlEyPnDubk5vPHGGwgGgzCZTHC5XIIAAuyELxuNBrRarfAoGXqnEiyVSlAUBXq9XrxXURShCD0eDxRFQSQSwa1bt/CTn/wETzzxBKampjoulOTvCwaDmJ2dhVarhd1uh81mE8YplQ7XQ4FLz0RRFJEDrNVq4nWnCZJ5ZmZmkEqlEAqFkEgkkM1mm4yg/v5+XLhwAQsLC0in06hWqzAajXC5XBgbG8PDDz+MYDCIaDSKpaUlbG1t4dq1awgEAnC5XHjooYdgt9vhdrs7si71Hsnn8/jJT36Cubk5vPPOO8hmsygUCtBqtU1EQe4/ylX5udbrdej1ehgMBng8HiwsLOD111/Hv/7rv6K3txf//M//jJ6envs+L30/oG0Ks1KpIJ/Pi/yezWYTh1AOdzQaDVSrVQDNyemNjQ2kUql2Xd6hwfBNPp9HtVqFzWZrCnURNAb4bwBCAVJI0WJl7hPArs/S6XSo1+st8xFd3Bvk+5rJZLCysoJyudxktFHZydEPev589sViEdVqFdVqVShMGjFqDgCffbVaRSaTwcLCAiYnJzuw2v2RTCaxvr6ORqMhvGsarIS8J3lOuU6Gq3kOGPrrNKrVquALZLNZxONxZDIZpFKppvNCg0en08Hj8QiPkUQft9uNer2ORCKBXC6HYrEIjUaDcrmMarUKg8GAer2O9fV1eDwe6HQ6mM3mjiqVSqWCbDaLzz//HHfu3MHi4qKQM2azucmLlOWIVqsVz1IGI2H83GKxiEgkgps3b2JiYgJTU1MdW9sXFW1TmOPj43juuedw8+ZNlEolGI1GkR/hw+SBowKldVStVrG8vIxYLNauyzs0UqkU7t69i3w+D2B703GjUvExpFEoFMT7ZIHDtVWr1SbmJS15OfRFYZ1Op7s
"text/plain": [
"<Figure size 576x72 with 8 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time steps: [3405 312 3441 1991 2443 1657 3308 1151]\n",
"Noisy images\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAAA9CAYAAAAgYlmOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACAg0lEQVR4nO29WYxl53UdvO48z/OtW/fWPHR39cTuJqmmSFEcIpKSYkWxnchGHMdwkIcABpIgQfKSh/gtCPJk5yGAASOK7djyIFvWSIqDSDbZ3eyxqmue71h3nud7/4fS2jwlm92Xefnx468PaEjsruGec76zv73XXmtt1XA4xOk6XafrdJ2u03W6Hr/U/29/gNN1uk7X6Tpdp+v/C+v0wDxdp+t0na7TdbpGWKcH5uk6XafrdJ2u0zXCOj0wT9fpOl2n63SdrhHW6YF5uk7X6Tpdp+t0jbC0j/vH//bf/ttwOBxiOBziwYMHODg4wNe//nVMTk7i0qVLSKVS2NzcxO3bt5HP5zExMYFms4l4PA6bzQaj0YhcLgeHw4HXX38dR0dH2N3dRaPRgFqtxtLSEjqdDlKpFJaWlhCLxfDgwQNUKhU0Gg04nU4Eg0GMj49Dr9fj9u3bcDgcuHDhAu7du4e9vT18/etfR7FYxH/4D/8BCwsLuH79OlKpFOr1On73d39X9aQbcPPmzaFOp4PFYoHL5YLFYkE8Hke/34fdbken00Gz2YROp4Ner4fP50O73cbR0RHS6TTy+Tz29vZgMpnw4osvQqVSod/vo1Qqodvtwu/3w2QywWKxoNVqodlsYnt7G/l8HpubmwiHw5iZmUEgEIDZbEatVoNOp4PT6YRafZzPxONxdDod6PV6GAwGmM1mNBoNdLtdXLp06YnX+Ou//utDjUYDrVYLv98Ph8MBAHA4HHj22WfxR3/0R/iv//W/4rXXXsO5c+fwj//xP0av18PBwQHS6TQKhQJqtRparRaKxSJmZmawtLQEn88Hg8GAQqGAg4MDvPfee1Cr1dDr9bh27RqMRiOOjo7QarXQarUAAGq1Gg6HA2q1GiqVCrVaDZ1OBzMzM0in0/gv/+W/IBgMYn5+HlevXkUgEMC//tf/+onXmEgkhjqdDmazGTdu3MDKygqcTieq1SreeecdTE1N4fz58+B98Pl80Gg0UKvVWF1dxf7+PjQaDdxuN1544QVks1ns7+8jHA7DaDQikUggk8lgZWUFw+EQarUaV69eBQDcu3cPc3NzuHLlCm7cuIGjoyPEYjHo9Xro9XqMjY3B7Xbj3XffRb1eh8vlQqvVQqPRQCwWg81mwy//8i8/8Rp///d/f9jtdlGv12EymWA0GjExMQG1Wo1sNotSqYRSqQSbzQaNRoNcLodQKITr16/jz//8z/Hhhx/ia1/7GpxOJyqVCux2O9xuN95//32USiW8/PLLSCaT+P73vw+LxQKbzYbf+I3fgFqtxp//+Z+j2Wyi3W5jfn4eoVAIzz33HACgWCxie3sbiUQCR0dHGAwGcDgccv1+vx9msxm/9Eu/9MRr/MM//MOhRqOBTqeDz+eDzWbDwcEBut0uDAYDTCYTzGYz4vE4arWa/A6r1Sq/84/+6I/Q7/fx+uuvy551u91Qq9W4d+8eWq0Wut0uvvzlL+PChQvY3NxEp9OB1WqFRqMBf3+73cbbb7+NsbExvPHGG9ja2kI8Hkev10O73UYqlYLFYoHT6cS5c+fg8XgwPj7+xGv87ne/O0wmk7hx4wY8Hg9cLheWlpbgcDhgtVrRbrfRbDaxsbGBXC6Ho6Mj2O12TE1NYX19HalUChcuXMBwOMTdu3dht9vh9XqhUqkwGAyQSCRgNBoRCATQ7Xblj1qtht1ux+zsLC5cuIDvf//7ODw8xMTEBDqdDo6OjpDL5VCpVGCz2eBwOHD+/HnodDpoNBpMT0/D6XTiwoULI12j0WiEy+UCAKhUKqjVagwGA1QqFej1ephMJiSTSbRaLYTDYWg0GgyHQ1itVuh0Ojx48AC1Wg1qtVrihdFoxGAwwPr6OgKBAL7whS/g4OAAR0dH8Pv9GAwGiMfj8Hq9GB8fRyqVQrvdRiAQgMPhQCgUws7OjsRurVaLaDQKvV4PnU6HTqeDfr+Pl1566e+9xscemCaTSQLg2NgYVCoVut0uKpUKarUa6vU6Wq0WLBYLut0u0uk0tFotwuEweNByEzLQGAwGuYGNRgO1Wg2ZTEaC79HRkRyoPKwymYwEV51Oh2aziW63i8FggLt376LZbGJ8fBwGgwHJZBJqtRo2m+1JzxQAoNfroVar0ev10O120el05LN3Oh10u130+31otVp0Oh1sbm6iUqkgkUigXC7L59dqtbh586Y83Hq9DpVKBY1GA71eD5VKhVwuh2KxiFQqhVqthnw+j+FwiH6/j6OjI5jNZng8Hvns/N06nQ4qlQoq1fEz5OcbdUUiETQaDeRyOdTrdeh0Oni9Xjmg9Xo9IpEIjEYjer0e8vm8fG7lfRwOhxKAkskkBoMBzGYzut0uWq0WKpUKnE4njEYjqtUqBoMBgsEgstksisUi7Ha7PD/em06ng1arhVKphGaziVAoBL/fD6fTCYvFApPJNNI1VioVqNVq1Go19Ho9mM1m2fyhUAg2mw2DwUBeykKhAK1WC4PBgEajgWazCZPJhEajgY2NDRQKBSSTSRiNRlgsFhweHqJYLKLdbstzKZVK0Gg06PV6KBaL2NraQq1Wk8+k0WhgMBiQzWaRyWSg0Whgs9lgMpnksO73+2g0GiNdY7ValZ87GAzQbrflpedn4h5m4G+1WtjZ2QEAhEIhmM1mAEC5XEaz2USlUpFkhnuaz9BsNmNra0uevUajkZjApKHT6aBYLKJer6Pdbste1Wq1GA6HaLfbqFQqaLfbI11juVyWBLPT6aBSqaDf76PX66HX66HZbKJcLsNqtcLpdKLf70OlUsFgMKDX66HT6SAQCMjz4TUy4PLnabVa1Ot15PN59Pt9DIdDNJtNiQNerxdarRaxWAx2ux3lchmNRgOdTkc+i9FohN1uh8/nQy6XQ6FQwPj4+BOvsd1uQ6PRIBgMwm63w2azIZFIIJ/PY2xsDFqtFlqtFg6HA4PBAKVSCe12G3t7e1Cr1QgGg2g2mwCAcDgMh8MBr9eLdDqNSqWCYrEIk8kEm80me6LX60Gr1UKn06Hf76NSqcBqtcLr9cq+rtfrMBgM8Hq9EsdyuRzcbjd8Ph8ajQZ6vd5Iz3EwGEhcUF43P89gMJBrZIzh8+M7yvcTAHQ6HbRaLfr9PgaDATweDywWi9wHo9GIbreL4XAIo9EosaXf78v+YQysVCqo1+vo9/sAgGazCb1eD7vdjlwuh06n85nX9dgD0+PxwGw2w+fzwev1olAoYGtrC/v7+wgEAsjlcsjn8/D7/bDb7fiTP/kTxGIxfOtb38LBwQEymQwikQhKpRK+/e1v48yZM/jCF74Ah8OB4XCITCaDVColWV8mk8Hm5iYGgwEikYgcEltbW3Kz+/2+XFS/38f/+B//A3q9Ht/61rewu7uLGzdu4Pnnn8fY2NhID9Zut6Pb7aLRaKDRaMhhNBgMUK/XMRgM5KXsdDr4m7/5G8TjcaytrcnGZpD6q7/6KxgMBlitVqjValgsFnzjG98AAORyOdy7dw+bm5sSbFjpDYdDqFQq2O12/JN/8k8wNjYGn8+Her2OZrMJq9UKlUqFZrMJrVYr3zMYDEa6xuvXr2N7exsrKyvodrtot9s4e/Ys7Ha7ZMlf/OIXYTQa0Wq1sLq6CpfLhXA4DKvViuFwiN3dXSSTSRQKBTSbTWSzWcRiMbhcLgSDQRSLRRweHsJut8PlcuHw8BBerxdf+tKXcOfOHdy/fx/BYBA2mw1ra2sYDAay4dvtNnZ2djAYDHDt2jU4HA54PB4EAgE4nc6RrvHw8FAOX41GA7/fj+3tbfT7fVy+fBkA0O/3YTAYMBwOsb6+Dr1eD6/Xi0wmg0KhALfbjUajgXv37qFaraJWq2E4HMLtduPtt99Gv9+H1WpFLpdDo9GAx+OByWRCt9vF1tYWlpeX4Xa7YbPZoFKpoNf
"text/plain": [
"<Figure size 576x72 with 8 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Noise to predict\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcwAAAA9CAYAAAAgYlmOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACIR0lEQVR4nO39Z5Dk53UdDp/OOYfp7unJOe7O5gV2AYIgABIEIUqkaIoKpizTKqnkKLvK5Sp/sWT7g1364iqpZMuiLJOWJUqUKIKkQGQssDnNTs6hu6d7Oueenunwflicy17+jd2G3w9vvfWfpwpFAgvM9K9/z3Ofe88951xFs9nE8Tpex+t4Ha/jdbwev5T/v/4Ax+t4Ha/jdbyO1/8/rOML83gdr+N1vI7X8WpjHV+Yx+t4Ha/jdbyOVxvr+MI8XsfreB2v43W82ljHF+bxOl7H63gdr+PVxlI/7g//83/+z00AaDabePDgAXZ3d/GFL3wBfX19mJmZwd7eHjY2NnD79m2kUin09vaiUqkgHA7DbDbDYDAgmUzCYrHg85//POLxOLa3t1Eul6FQKDA1NYXDw0PEYjHMzMygv78fc3NzyGazyGazsNvt6OjoQFdXF7RaLW7fvg2bzYYTJ07g/v372N7exquvvop0Oo1//a//NUZGRnDp0iVEo1EUi0X8h//wHxRP+gJ+5md+pmm32zE4OIhz585hcHAQ3/72t1Eul3HixAmkUins7e3BbrfDbrfjM5/5DPb39/HWW2/BZDJBp9Mhn88DAMxmM4rFIrLZLG7cuIFcLofPfvazCAQCGBkZQSQSQTweR6lUkn8/HA5jbW0NZ8+ehdPpxPLyMhwOB86cOQOtVgsA+Pa3v41UKgWn0wm/34/u7m5sb28jl8vhT/7kT574jL/0S7/UVKlUUKvV8Hq9sNlsAACbzYaLFy/if/2v/4X/9J/+Ez73uc9hcnISX/7yl1Gr1bC7u4tYLIZ0Oo1isYiDgwNkMhkMDg5iamoKHo8HOp0O6XQau7u7eP/996FUKqHVanHu3Dno9XrE43EcHBzg4OAAAKBUKmGz2aBUKqFQKFAsFnF4eIjBwUHEYjH8zu/8Dnw+H0ZGRnD27Fl0dHTgt37rt574jBsbG02tVgur1Yof/vCHuHHjBtxuN9LpNP7mb/4GExMTePrpp6FSqaDVatHZ2QmNRgOFQoHFxUXs7OzAbDbD5XLh6aefRjgcxvLyMvr7+2E0GrG5uYmdnR1cv34dAKDRaPCLv/iLUKvV+PGPf4zR0VGcO3cOf/d3f4dwOIyRkRHo9XrodDoMDg7C6/Xib/7mb5DNZuH1elGpVFAoFDAyMgKn04l//I//8ROf8d/9u3/XPDw8RD6fh8lkgtFoxNDQEFQqFaLRKDKZDBKJBKxWK9RqNeLxOLq7u/HSSy/hj/7oj/D666/jV37lV+B0OpHJZNDV1YWhoSHcvXsXhUIBn//85xEKhfCXf/mXqNVq8oz1eh3/7b/9N1QqFRwcHGB6elp+LgCkUiksLi5ia2sLe3t7aDQacDqd0Gq10Ov1CAQCMJvN+PrXv/7EZ/zd3/3dpkqlgkajgcfjgdVqxc7ODo6OjqDT6WAwGGAwGBCJRFAsFqHVaqHVamEymWC326HT6fDtb38bCoUCX/va11CtVlEsFmEymdBoNHDv3j1Uq1XU63W8+OKLmJmZwY0bN1Aul2EymaBWq8Gzcnh4iLfffhuBQACvvvoqVldXEQ6HUa/XUalUEIvF5PdOTU3B5XLh4sWLT3zG733ve829vT1cu3YNLpcLDocDU1NTsNlsMJvNqFarqFQqWF1dRTKZRDweh9VqxcDAAJaXlxGNRnHy5EkolUqsrKzAYrHA7XZDqVSiVqthb28POp0OHR0dqNVqqNVqqFaraDab0Gg0GBoawokTJ/DDH/4QoVAIvb29ODw8RDweRzKZRD6fh8Vigc1mw/T0NDQaDVQqFQYGBmC323HixIknPuPv/d7vNU0mE/x+PwBAoVBArVajVqshmUzCYDDAbDZjbW0NpVIJw8PDUKvVaDab8h7ffvttZLNZqFQqKJVKKJVKmEwm1Ot1XL9+HX19ffjyl7+M+fl5bG9vo6+vD/V6HUtLS+jq6sLY2Bg2NjZQKpXQ398Pj8eDwcFBzM7OYm1tDXt7e9BoNJiYmIBer4fBYEClUkGtVvvYvfrYC1Ov10OtVkOr1SIQCAAAjo6OkM/nUSwWUS6XUalUYDQacXR0hFgsBpVKBb/fj2aziWazCbPZDLVajcXFRTSbTej1eiiVSqhUKhwdHaFcLiOTycjFmk6nkc/ncXh4iMPDQ1QqFezv70OpVMoB4UM1m03cv38f5XIZ3d3dcpBUKhWsVuuT3ikAwOVyQa/Xo1QqIZfLIZPJyGfPZrMoFAo4ODhAtVrF4eEhFAoFms0myuUytFotNBqNXGwAUKvVUKlUYDKZ5CXXajXkcjk0Gg3o9XpUq1U0Gg0oFArodDrY7XaoVCo0m004HA5YrVb5/eVyGTabDWq1Wn4eADQaDTQajbaesbOzE5VKBclkEqVSCRqNBm63G0ajUb7TYDAIvV6PWq2GVColF4vFYpGf02w25dJkYOS7Pzg4QD6fh91uh16vR6FQQKPRgM/nQyKRQCaTgdVqhUajQaVSgUKhgEqlwuHhIQ4ODpDNZlGpVOD3++H1emG322EymWAwGNp6xnQ6DZVKhVwuh6OjI1gsFhwcHKDRaKCnpwcOhwONRgM6nQ4qlUoucLPZDK1WC51OB7PZDI1Gg729PSQSCRSLRahUKuh0OhSLRdTrdbjdbpTLZdTrddTrdSiVSlitVtTrddn/VqsVSqUSGo0GJpMJ4XAYu7u7UKlUcDqdsFgs0Gg0sjcKhUJbz5jNZgEAarVagnYikZDgzj3K/cUzs7i4CADo7e2F1+uFyWRCMplENptFKBTCwcEBms0mtre3kc1m4XK5ADxMbnZ3d3F0dCTvjt8fAKhUKlSrVQmyBwcHchY0Gg0AyLvl9/2klcvlJJgeHh4il8vJd814o1KpYLFY4HA4UKvVZK/WajUcHR2ho6MDCoUC+Xxe9hnPNM+cwWCAQqFAqVSCQqGAQqHAwcGB/Ay32w21Wo2enh7YbDY5i4eHh6jVaqjX69DpdLBYLPB4PEgmk0ilUrh48eITn7FarUKlUsHn88FqtcJisSASiSCVSqGzsxNqtRpqtRo2mw2NRgPZbBbVahU7OztQKpXw+Xzy3MFgEHa7HS6XC7FYTBJ2k8kk3w8/L5PZRqOBfD4Ps9kMt9uNTCaDarWKUqkEnU4nl69SqUQymYTT6YTH40G5XEatVmvrPdbrdRwdHaFYLMo/43/P/cmkyO12Q6FQSLznxanVamE0GqFQKCTO1mo1NBoNBAIB2O12FAoFKJVKSTTq9bokPpVKBYeHhzg6OgITzXA4jEQiIXECAAqFAoxGI1wuF0KhEMrl8sc+12MvTAZVj8cDl8uFdDqNjY0N7O7uoqOjA8lkEslkEl6vFxaLBX/+53+Onp4e/OIv/iJ2dnYQj8fhcDiQzWbxrW99CxMTE3jmmWeg0+mkuojH41hbW4NGo0GxWMTa2hqOjo5gNpvlc6yvr8uXXK/XkUqlUK1WUavV8Ad/8AeSCW9tbeHGjRu4fPmyXPBPWlNTU8jlctje3sbW1pZsrqOjI6ytreHg4ADlchlqtRp6vR4qlUqyJI1GI4dboVDIZkgkEujo6JDgUiwWsb29DavVCqfTKQdToVDAarWit7cXGo0GjUYDg4OD0Gq1aDabWF9fx+7uLsbGxqBUKhEKhWAwGNBoNORgt7MuXbqEjY0NLCws4OjoCNVqFRMTE7BarYhGozCZTLh8+TL0ej0ODg6wtLQEh8MhlUGz2ZTqIZ1OS6DmReTz+ZDJZBAKhWC1WuFwOBAKheB2u/GpT30Kd+/exezsLHw+HywWC5aXl9FoNGAymVCpVFCtVrG5uYlGo4Fz587BZrPB5XKho6MDdru9rWdcWVlBrVZDsViEWq1GZ2cn5ufnUavV8NxzzwF4mOw5HA7o9XpkMhnYbDYEAgHYbDbY7XYEAgH
"text/plain": [
"<Figure size 576x72 with 8 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code just a quick sanity check\n",
"\n",
"def subtract_noise(X_noisy, time, noise):\n",
" X_shape = tf.shape(X_noisy)\n",
" alpha_cm = tf.gather(alpha_cumprod, time)\n",
" alpha_cm = tf.reshape(alpha_cm, [X_shape[0]] + [1] * (len(X_shape) - 1))\n",
" return (X_noisy - (1 - alpha_cm) ** 0.5 * noise) / alpha_cm ** 0.5\n",
"\n",
"X_dict, Y_noise = list(train_set.take(1))[0] # get the first batch\n",
"X_original = subtract_noise(X_dict[\"X_noisy\"], X_dict[\"time\"], Y_noise)\n",
"\n",
"print(\"Original images\")\n",
"plot_multiple_images(X_original[:8].numpy())\n",
"plt.show()\n",
"print(\"Time steps:\", X_dict[\"time\"].numpy()[:8])\n",
"print(\"Noisy images\")\n",
"plot_multiple_images(X_dict[\"X_noisy\"][:8].numpy())\n",
"plt.show()\n",
"print(\"Noise to predict\")\n",
"plot_multiple_images(Y_noise[:8].numpy())\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dcBpIkNxrPw9"
},
"source": [
"Now we're ready to build the diffusion model itself. It will need to process both images and times. We will encode the times using a sinusoidal encoding, as suggested in the DDPM paper, just like in the [Attention is all you need](https://arxiv.org/abs/1706.03762) paper. Given a vector of _m_ integers representing time indices (integers), the layer returns an _m_ × _d_ matrix, where _d_ is the chosen embedding size."
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {
"id": "0HGLhS4oNkNC"
},
"outputs": [],
"source": [
"# extra code implements a custom time encoding layer\n",
"\n",
"embed_size = 64\n",
"\n",
"class TimeEncoding(tf.keras.layers.Layer):\n",
" def __init__(self, T, embed_size, dtype=tf.float32, **kwargs):\n",
" super().__init__(dtype=dtype, **kwargs)\n",
" assert embed_size % 2 == 0, \"embed_size must be even\"\n",
" p, i = np.meshgrid(np.arange(T + 1), 2 * np.arange(embed_size // 2))\n",
" t_emb = np.empty((T + 1, embed_size))\n",
" t_emb[:, ::2] = np.sin(p / 10_000 ** (i / embed_size)).T\n",
" t_emb[:, 1::2] = np.cos(p / 10_000 ** (i / embed_size)).T\n",
" self.time_encodings = tf.constant(t_emb.astype(self.dtype))\n",
"\n",
" def call(self, inputs):\n",
" return tf.gather(self.time_encodings, inputs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "18SYrUNysJ62"
},
"source": [
"Now let's build the model. In the Improved DDPM paper, they use a UNet model. We'll create a UNet-like model, that processes the image through `Conv2D` + `BatchNormalization` layers and skip connections, gradually downsampling the image (using `MaxPooling` layers with `strides=2`), then growing it back again (using `Upsampling2D` layers). Skip connections are also added across the downsampling part and the upsampling part. We also add the time encodings to the output of each block, after passing them through a `Dense` layer to resize them to the right dimension.\n",
"\n",
"* **Note**: an image's time encoding is added to every pixel in the image, along the last axis (channels). So the number of units in the `Conv2D` layer must correspond to the embedding size, and we must reshape the `time_enc` tensor to add the width and height dimensions.\n",
"* This UNet implementation was inspired by keras.io's [image segmentation example](https://keras.io/examples/vision/oxford_pets_image_segmentation/), as well as from the [official diffusion models implementation](https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py). Compared to the first implementation, I added a few things, especially time encodings and skip connections across down/up parts. Compared to the second implementation, I removed a few things, especially the attention layers. It seemed like overkill for Fashion MNIST, but feel free to add them."
]
},
{
"cell_type": "code",
"execution_count": 67,
"metadata": {
"id": "4QLnc8wkfXrh"
},
"outputs": [],
"source": [
"def build_diffusion_model():\n",
" X_noisy = tf.keras.layers.Input(shape=[28, 28, 1], name=\"X_noisy\")\n",
" time_input = tf.keras.layers.Input(shape=[], dtype=tf.int32, name=\"time\")\n",
" time_enc = TimeEncoding(T, embed_size)(time_input)\n",
"\n",
" dim = 16\n",
" Z = tf.keras.layers.ZeroPadding2D((3, 3))(X_noisy)\n",
" Z = tf.keras.layers.Conv2D(dim, 3)(Z)\n",
" Z = tf.keras.layers.BatchNormalization()(Z)\n",
" Z = tf.keras.layers.Activation(\"relu\")(Z)\n",
"\n",
" time = tf.keras.layers.Dense(dim)(time_enc) # adapt time encoding\n",
" Z = time[:, tf.newaxis, tf.newaxis, :] + Z # add time data to every pixel\n",
"\n",
" skip = Z\n",
" cross_skips = [] # skip connections across the down & up parts of the UNet\n",
"\n",
" for dim in (32, 64, 128):\n",
" Z = tf.keras.layers.Activation(\"relu\")(Z)\n",
" Z = tf.keras.layers.SeparableConv2D(dim, 3, padding=\"same\")(Z)\n",
" Z = tf.keras.layers.BatchNormalization()(Z)\n",
"\n",
" Z = tf.keras.layers.Activation(\"relu\")(Z)\n",
" Z = tf.keras.layers.SeparableConv2D(dim, 3, padding=\"same\")(Z)\n",
" Z = tf.keras.layers.BatchNormalization()(Z)\n",
"\n",
" cross_skips.append(Z)\n",
" Z = tf.keras.layers.MaxPooling2D(3, strides=2, padding=\"same\")(Z)\n",
" skip_link = tf.keras.layers.Conv2D(dim, 1, strides=2,\n",
" padding=\"same\")(skip)\n",
" Z = tf.keras.layers.add([Z, skip_link])\n",
"\n",
" time = tf.keras.layers.Dense(dim)(time_enc)\n",
" Z = time[:, tf.newaxis, tf.newaxis, :] + Z\n",
" skip = Z\n",
"\n",
" for dim in (64, 32, 16):\n",
" Z = tf.keras.layers.Activation(\"relu\")(Z)\n",
" Z = tf.keras.layers.Conv2DTranspose(dim, 3, padding=\"same\")(Z)\n",
" Z = tf.keras.layers.BatchNormalization()(Z)\n",
"\n",
" Z = tf.keras.layers.Activation(\"relu\")(Z)\n",
" Z = tf.keras.layers.Conv2DTranspose(dim, 3, padding=\"same\")(Z)\n",
" Z = tf.keras.layers.BatchNormalization()(Z)\n",
"\n",
" Z = tf.keras.layers.UpSampling2D(2)(Z)\n",
"\n",
" skip_link = tf.keras.layers.UpSampling2D(2)(skip)\n",
" skip_link = tf.keras.layers.Conv2D(dim, 1, padding=\"same\")(skip_link)\n",
" Z = tf.keras.layers.add([Z, skip_link])\n",
"\n",
" time = tf.keras.layers.Dense(dim)(time_enc)\n",
" Z = time[:, tf.newaxis, tf.newaxis, :] + Z\n",
" Z = tf.keras.layers.concatenate([Z, cross_skips.pop()], axis=-1)\n",
" skip = Z\n",
"\n",
" outputs = tf.keras.layers.Conv2D(1, 3, padding=\"same\")(Z)[:, 2:-2, 2:-2]\n",
" return tf.keras.Model(inputs=[X_noisy, time_input], outputs=[outputs])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R60NMcUT0b1K"
},
"source": [
"Let's train the model!"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5chylX-Vp9te",
"outputId": "fdd4f9f1-801e-4830-c8db-c8983f444ad8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/100\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_diffusion_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1719/1719 [==============================] - 347s 199ms/step - loss: 0.1120 - val_loss: 0.0719\n",
"Epoch 2/100\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_diffusion_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1719/1719 [==============================] - 352s 205ms/step - loss: 0.0649 - val_loss: 0.0589\n",
"Epoch 3/100\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_diffusion_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1719/1719 [==============================] - 358s 208ms/step - loss: 0.0550 - val_loss: 0.0539\n",
"<<94 more epochs>>\n",
"Epoch 98/100\n",
"1719/1719 [==============================] - 351s 204ms/step - loss: 0.0377 - val_loss: 0.0375\n",
"Epoch 99/100\n",
"1719/1719 [==============================] - 407s 237ms/step - loss: 0.0376 - val_loss: 0.0379\n",
"Epoch 100/100\n",
"1719/1719 [==============================] - 418s 243ms/step - loss: 0.0376 - val_loss: 0.0379\n"
]
}
],
"source": [
"tf.random.set_seed(42) # extra code ensures reproducibility on the CPU\n",
"model = build_diffusion_model()\n",
"model.compile(loss=tf.keras.losses.Huber(), optimizer=\"nadam\")\n",
"\n",
"# extra code adds a ModelCheckpoint callback\n",
"checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\"my_diffusion_model\",\n",
" save_best_only=True)\n",
"\n",
"history = model.fit(train_set, validation_data=valid_set, epochs=100,\n",
" callbacks=[checkpoint_cb]) # extra code"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that the model is trained, we can use it to generate new images. For this, we just generate Gaussian noise, and pretend this is the result of the diffusion process, and we're at time $T$. Then we use the model to predict the image at time $T - 1$, then we call it again to get $T - 2$, and so on, removing a bit of noise at each step. At the end, we get an image that looks like it's from the Fashion MNIST dataset. The equation for this reverse process is at the top of page 4 in the DDPM paper (step 4 in algorithm 2)."
]
},
{
"cell_type": "code",
"execution_count": 69,
"metadata": {
"id": "vMXxz4qV8Luk"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"t = 1 "
]
}
],
"source": [
"def generate(model, batch_size=32):\n",
" X = tf.random.normal([batch_size, 28, 28, 1])\n",
" for t in range(T - 1, 0, -1):\n",
" print(f\"\\rt = {t}\", end=\" \") # extra code show progress\n",
" noise = (tf.random.normal if t > 1 else tf.zeros)(tf.shape(X))\n",
" X_noise = model({\"X_noisy\": X, \"time\": tf.constant([t] * batch_size)})\n",
" X = (\n",
" 1 / alpha[t] ** 0.5\n",
" * (X - beta[t] / (1 - alpha_cumprod[t]) ** 0.5 * X_noise)\n",
" + (1 - alpha[t]) ** 0.5 * noise\n",
" )\n",
" return X\n",
"\n",
"tf.random.set_seed(42) # extra code ensures reproducibility on the CPU\n",
"X_gen = generate(model) # generated images"
]
},
{
"cell_type": "code",
"execution_count": 70,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 248
},
"id": "lfMa2NVwJuzE",
"outputId": "fc827f13-0b57-4372-a385-4cf70349ebd3"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAi4AAAEQCAYAAACX0v+fAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOy92Y9d2XUe/p07z1PdmllkcWqSraakbnfL6pY8KIosx1OE2AYSJDGchwTwo/+KIHmMA+QlQALnwQbiePjJkwLFkiy3BqvVaja7SXZzqmLNded5vuf3UP5Wrbt5zq1L1r2UBdQCClV1h3P2WXvvNXxr2JZt2zijMzqjMzqjMzqjM/pJIM+PewBndEZndEZndEZndEaT0pnhckZndEZndEZndEY/MXRmuJzRGZ3RGZ3RGZ3RTwydGS5ndEZndEZndEZn9BNDZ4bLGZ3RGZ3RGZ3RGf3E0JnhckZndEZndEZndEY/MeQb9+ZwOHSslWYJtVMp9XA4xHA4lPcGgwEGgwE2NjZQKpXw5MkTtFotXLp0CfF4HMlkEv1+Hw8fPkSr1UKz2UQ0GsW1a9fg8x0Nz7IseDweBINBRCIRRCIRxGIxWJYl75vj02Pj+3zN6/WOfuEZaTAYjDy4x+MZuad5P7fXAGB3dxeVSgWDwQDD4RBerxfD4RCHh4ewbRuLi4vw+XywbRt+vx+JRAIejwcejweDwQD9fh+RSATpdHrkupon5r3d5u+0fHFbLyTzvnr+zLH0+30Mh0O02230+32Uy2V0Oh00m03Yto1oNMp7AgB8Ph+CwSCSyST8fj9CoRAsy4LX631qHE684bic1vRp+WL/w0XNa5vrFgCq1SpKpRJ6vR56vR7q9TqazSYKhQIqlQr+6I/+CPfv38cv/dIv4dKlS3jllVeQSqUQjUblWYfDIXq9HqrVKm7fvo1yuYzNzU1kMhlcvXoViUQC2WwW0WgUsVgMoVAIoVAIsVgMkUhE9rBlWY5jnNY++oex2voenE+n++p7m2tnXFsHc/277UWPx/PU65RllEGTksfjmaqMmYTceEYin5w+x7k+iZcn3cu8x6xlr/k8psy1LEt0UKvVQrvdxg9/+EM8efIEh4eHaDabyGQyiMVieOONN5DJZBAIBER2DAYDVKtVuebBwQHefvttVCoV5PN5vPrqq/jVX/1VkcXA0ZrJZrMii7xeryNf9bhPu17cZK85p04y10lPOclkknmtceS219y+77AnHW9ijVukg8HAPmmBmlSpVEQRDwYD1Ot1tNtt3LlzBwcHB7JYqIwrlQp6vR4ajQb6/T76/b4YLqFQCJFIBP1+H81mE/F4HJlMBjdv3sRnPvMZeL1eUeJui9fp/9Nunn6/bwPHG8NcBNxMetGYSmA4HKLb7eI//+f/jG9/+9sol8uwbRs//dM/jfn5eSwtLaHT6eAb3/gGarUa+v0+lpaW8Cu/8isIBoPweDwolUrY2NjAq6++il/7tV8Tfpikx6nHqPli2zZ8Pt+p+eJkFDzL5rFtG8PhEMViEY1GA7u7u6jVatjc3ES1WsXGxoYIoHa7ja2tLXg8HsTjcaytreGLX/wiVlZW8MorryAcDiMajT7Fe5MX+rf+m5+ZllBxUxj6vt/4xjfwJ3/yJ8jlcjg8PESlUkGz2ZT3Dw8P0Wq1ZA34/X4Eg0HcuHFDjJder4f9/X20220Ui0VYloVgMDjyzP1+HwsLC1hZWcHly5dx8eJFvPXWW7h58+ZTvJqVofsP1xwx6tyEopNSNfeeNnr066ZBYs61bdsiR8x1qt/X/Bu3foDpOAF6LJMaTU6GuNtYnXis33e7tqnU3BSy/px6baoKmga2vodJzWYTzWYT3/ve9/DBBx/g7/7u7/Do0SPE43GEQiFkMhkEg0H0ej14PB6srq4iFAohGAyi0Wjg3XffRafTgW3b6PV6aLVaaLVaKBQKWF9fx2uvvYZyuYzDw0MMh0MMBgP823/7b/HP/tk/w+LiIuLxuKxNzTNtDE+LL5PMoZPuNnWUOceT6Fc+i54Tk5z0oZORdNJ6GYu4OCEZfK3b7aJaraLb7aLdbovF2W630Wq10O/3MRgM0G630e12UavV0Gq1BDnpdDro9/vyWQDiMXo8HjSbTQyHQ/h8PjGC6G3v7e1hY2NDJj6ZTCIWiyEQCLh62LNotPe8a00bDu12G41GA6VSCcPhEAcHB4I2dLtdHB4eol6vYzAYwOfz4eDgAKFQCD6fD9VqFfV6HYeHh3jw4AHm5uawsLDw1L2exTI+DZ3kJTtRp9NBp9NBr9eT5x4Oh2LwVioVNBoNMVTK5bL83+v1YNs2vF4vIpEIAoGArMdWqwWv14toNHoi8jWLtTEpcX5arRZqtRp2d3fx5MkTlEollMtl2TccN1G5TqcDAOIE7O3tIRwOw+v1ot/vI5fLodPpoNVqwefzCXJAHg8GA3i9XliWBb/fDwBYXV1FKpXC3NycIFqz5s2k1z/J+wOeb5+PE8jjPvMi1sy4PemEHp20h8mfacnDZ7nGpHLoecfhdO1cLofd3V00Gg00m03cv38fW1tbKBaLqFQqsCwL/X4foVBI0E3K5WAwiHA4LDKIhgsdq8FgAACo1+vY2tpCvV4XpLTf7+Pu3bvIZrN4/fXXEQqFnjJ8p82XafHWbTzPMs5J1uFpaSzior0h0/rf3d3Fu+++i/39fTx69Aj1eh2VSgWpVErCP71eT4wJWpzpdBo+nw+5XA6tVgv1eh3D4VBCIIFAQGC9UCiE+fl5seQqlYoo7kgkIsbSW2+9hU9+8pNYWFhAPB5/CoHRzwBMB3E5CUYHnvY4LMsSBTIYDNDtdvFf/+t/xXe/+11sbm4KbBkIBHgflMtlDIdDMdCuX78uz08aDAbo9Xr4+Z//eXzlK18R75H3ooIyxzgrZGEcTzRZloWdnR3s7Owgl8vJsw6HQ0SjUfh8PhSLRQkXtVotfPe730W1WoVlWQiHw3jppZeQTqdx+fJlCS3GYjEsLS1hbm4Oq6urjmiLkzdojndafLENBjgZTY8fP8a7776Lv/3bv8Wf//mfw+/3w+fzCQpp7kETPeNrRA3a7bYgLQBG1l0gEEA4HMZwOES/30cwGEQoFMKrr76K69ev4xd/8Rdx8+bNpzx2k07Ll38Yl2MYjTRubtzmz7yeNvr+YdwnjsvN23QjfY9poJfmmjHH5oQKOaG85vicLjupw6E9ZJ0OcBKSqGlaMsZEMPScclx/9Ed/hP/23/4b2u02Op0OfD4fvF4vCoUC6vW6oCrZbBaBQEAMDxr1qVRKHGGiLcBx+sNgMECn05G95vV6xcHinv2P//E/4jd+4zeQSCQQCARGQo+aT7OUMZqc0G+neT7JuB0nF9z27LjPTRVxcbupbduoVqv46KOP0Gw2JT/Dtm10u100m02xPAEgGAyK8qTHR/g1Ho8DgCjiVqsli5KoA4W4z+eT3IVOpyOGy87ODrxeL/b39xGPx7G+vo5UKuUKN0+bTK/HzSvT/9Mb9ng8kqPCZ+v3+/D5fAgEAlheXobf70csFpNnC4fDIwZau91GrVZDPB6X7zrBp/9YUAbgeFydTgf1eh2NRgONRkM2TbfbxXA4RLPZRLvdBgBBWDwej+R1RCIR4QcNZBq/nU4HtVoN3W4XnU4Hfr//qR+GLPWYpk1Om1KvkX6/j2KxiI8++gh7e3uCpjCPSRstZjjEHDc/T0GuBTohXNP763a7sG0bpVIJ+/v7aLVaM+HDOBq3N59nrZ52r0/y/XGKfVb0vM/0vDzU3zUVrpvyOsmQnAXp6xcKBWxvb+P+/fuSKwhgJO+NBj1DP0RQ+FzUaTr8oY01flbvQ36eUYLhcIhHjx7hBz/4AT7xiU9gYWEBPp/P0ciaBR9mTU4G2IuisYaLtsyA40kdDofY2trCH//xH2NlZQWvvvqqeG0MDzG5MB6PIxwOy4OVy2WxXP1+P+bn5wXKbjab2NjYwHA4RCqVQr/fR61WQyAQQDQaRTAYxMrKioRXmJD5wx/+EP/3//5fhMNhxGIx/M7v/A7eeOONkTHzeaZFTsrOjJGbn9cLvVKpoFA
"text/plain": [
"<Figure size 576x288 with 32 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_multiple_images(X_gen.numpy(), 8)\n",
"save_fig(\"ddpm_generated_images_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Some of these images are really convincing! Compared to GANs, diffusion models tend to generate more diverse images, and they have surpassed GANs in image quality. Moreover, training is much more stable. However, generating images takes *much* longer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FKyC0R9dyLhV"
},
"source": [
"# Extra Material Hashing Using a Binary Autoencoder"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fDE4f-XpyLhW"
},
"source": [
"Let's train an autoencoder where the encoder has a 16-neuron output layer, using the sigmoid activation function, and heavy Gaussian noise just before it. During training, the noise layer will encourage the previous layer to output large values, since small values will just be crushed by the noise. In turn, this means that the output layer will output values close to 0 or 1, thanks to the sigmoid activation function. Once we round the output values to 0s and 1s, we get a 16-bit \"semantic\" hash. If everything works well, images that look alike will have the same hash. This can be very useful for search engines: for example, if we store each image on a server identified by the image's semantic hash, then all similar images will end up on the same server. Users of the search engine can then provide an image to search for, and the search engine will compute the image's hash using the encoder, and quickly return all the images on the server identified by that hash."
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {
"id": "3DRvS7twyLhW",
"outputId": "80a63cbf-0f33-4b6c-8e25-32474da2e2fb"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1719/1719 [==============================] - 14s 8ms/step - loss: 0.4093 - rounded_accuracy: 0.8134 - val_loss: 0.3972 - val_rounded_accuracy: 0.8185\n",
"Epoch 2/10\n",
"1719/1719 [==============================] - 12s 7ms/step - loss: 0.3726 - rounded_accuracy: 0.8452 - val_loss: 0.3806 - val_rounded_accuracy: 0.8319\n",
"Epoch 3/10\n",
"1719/1719 [==============================] - 9s 6ms/step - loss: 0.3609 - rounded_accuracy: 0.8564 - val_loss: 0.3678 - val_rounded_accuracy: 0.8435\n",
"Epoch 4/10\n",
"1719/1719 [==============================] - 10s 6ms/step - loss: 0.3540 - rounded_accuracy: 0.8629 - val_loss: 0.3604 - val_rounded_accuracy: 0.8536\n",
"Epoch 5/10\n",
"1719/1719 [==============================] - 9s 5ms/step - loss: 0.3498 - rounded_accuracy: 0.8676 - val_loss: 0.3549 - val_rounded_accuracy: 0.8573\n",
"Epoch 6/10\n",
"1719/1719 [==============================] - 11s 6ms/step - loss: 0.3471 - rounded_accuracy: 0.8705 - val_loss: 0.3501 - val_rounded_accuracy: 0.8639\n",
"Epoch 7/10\n",
"1719/1719 [==============================] - 11s 6ms/step - loss: 0.3451 - rounded_accuracy: 0.8725 - val_loss: 0.3494 - val_rounded_accuracy: 0.8692\n",
"Epoch 8/10\n",
"1719/1719 [==============================] - 11s 6ms/step - loss: 0.3430 - rounded_accuracy: 0.8747 - val_loss: 0.3444 - val_rounded_accuracy: 0.8755\n",
"Epoch 9/10\n",
"1719/1719 [==============================] - 11s 6ms/step - loss: 0.3416 - rounded_accuracy: 0.8763 - val_loss: 0.3402 - val_rounded_accuracy: 0.8783\n",
"Epoch 10/10\n",
"1719/1719 [==============================] - 12s 7ms/step - loss: 0.3403 - rounded_accuracy: 0.8774 - val_loss: 0.3403 - val_rounded_accuracy: 0.8824\n"
]
}
],
"source": [
"tf.random.set_seed(42)\n",
"\n",
"hashing_encoder = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(),\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.GaussianNoise(15.),\n",
" tf.keras.layers.Dense(16, activation=\"sigmoid\"),\n",
"])\n",
"hashing_decoder = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(28 * 28, activation=\"sigmoid\"),\n",
" tf.keras.layers.Reshape([28, 28])\n",
"])\n",
"hashing_ae = tf.keras.Sequential([hashing_encoder, hashing_decoder])\n",
"hashing_ae.compile(loss=\"binary_crossentropy\", optimizer=\"nadam\",\n",
" metrics=[rounded_accuracy])\n",
"history = hashing_ae.fit(X_train, X_train, epochs=10,\n",
" validation_data=(X_valid, X_valid))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4OpQ1O-qyLhW"
},
"source": [
"The autoencoder compresses the information so much (down to 16 bits!) that it's quite lossy, but that's okay, we're using it to produce semantic hashes, not to perfectly reconstruct the images:"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {
"id": "HyOsfAdiyLhW",
"outputId": "678cc611-c625-434e-90b9-ad90d88baef4"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAbAAAACvCAYAAACcuYvQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACXpElEQVR4nO292Y+cR5YdfnLf19pY3KmRmq1uaaRe/MN098w8tAcewPaTH/3XGfCTH/xgGDDcsA2PG/AYjXEvmlFrp0SKLBZryX3ffw/0iTp5GV9WksysKrbyAoXM+vJbIuK7cZdzb9wIzWYzbGhDG9rQhjb0plH4shuwoQ1taEMb2tCr0EaBbWhDG9rQht5I2iiwDW1oQxva0BtJGwW2oQ1taEMbeiNpo8A2tKENbWhDbyRtFNiGNrShDW3ojaToOb+vJMdeU/VDodALv0+n0xeO+c572WfpfYLud17blqBXa+iKxvZPnF5lbFfOs7PZDLPZDL1eD/1+H99++y3+/b//9/j0008xGo0wGo0QjUZx8+ZNbG1tIZvNYnd3F4lEAuFwGOFwGMPhEKenp+h2u6jX63j27Bn6/T4ODw9RqVRw584d/PN//s9x7do1fPDBB/jxj3+MaDQ6x7+vOic8dKk8O5lMMJ1OcXh4iH/8x3/EYDDAD37wA9y/fx/h8MXY1K1WC7///e9xeHiI7e1t3L59G8lkEqVSCdls9nVuvZEH66MXxvY8Bbaap4ZCmM1mcxOQE7/dbuOLL75AvV5HOBxGKBRCJBJBsVhEOp1GJBJxx/X6yWSC2WzmJsN0OnUCZjweo9frYTqdYnd3F9euXUMikUC5XEYqlTq3bRv67tJ0OsVoNMKzZ8/QaDTQ7XZRrVYxHA7R7XbR7XZRq9VwcHCAXq/nFN1sNsNwOESn08F4PMZkMkE0eja9xuMxWq0WBoMB2u02hsMhxuMxQqEQYrEYhsMhHj58iGq1ipOTE3z00UdIJBLY2tpCKpVCqVTC9evXkUgkkMlkkEwmL2uIXptCoRDC4TDS6TT29/cxGo2Qy+UuZA7SGAmFQsjlchiNRiiVSsjn80gkEojFYmtvw4ZWRxeiwIAXvRtO9qdPn+K///f/jq+//topr3g8jrfffht7e3vufzJ9OBx2QmYymWA8HjthcHJyglqthl6vh0qlgtFohB/96Ef4yU9+4hjUKjBf2zb03aTZbIbpdIp+v4+vvvoK33zzDU5OTvD555+j0+mg1Wqh0+lgMpk4/ovH44jH4wCAfr+PUCiEdruNer0OAHMGlvJsv9/HdDpFKBRCKpXCYDDA559/jkgkgna7jVarhUwmg3fffRc7Ozt4++238Rd/8RfI5/O4du3aG63AwuEwZrMZstksbt26hel0imw2e2EKbDKZIBQKoVQqIRaLoVgsuu8bWfBm0doV2Gw2w2AwwGg0chN3Mpmg0+mg2+3i8PAQp6enOD09RSQScQrr+PgYABCJRJxVFI1GHfOrMBgMBhiPx6hUKmg2m+j1eqhWq5hMJjg9PcWzZ8/Q7XaRz+cxHA6dIgyHw0ilUkgmk05BXjWi8JtMJs67ZNtJvkn3OhVWLHz2qvcOupaCO5FIIBqNzhkpF02TyQS9Xg/j8Rjj8Rij0ch5WfV6Ha1Wy3n2NJQAzCEDpOl0OueR8Tc7DvQCptMpIpGI42uiDNPp1PF1p9NBMplEvV7H8fExer0eAGAwGCAWiyGVSrl7qMf3JhC9T47DRT87Go0imUwiFoshEolslNcbSKFzhNFr47L9fh+fffYZDg8PcXR0hE8++QTtdhv9fh+j0QidTgcPHz5Eu912SoTuPa1MHovFYk6ZjcdjN9GHwyEmk4kTMBRKs9kMpVIJW1tbSCaTuHbtGrLZLNLpNPL5PNLpND788EPcv38f8XgcmUzmVYTA2jDv2WzmrP56vY6PP/4YlUoFqVQKmUxmTuipAqbSs/fiH/+nx2t/Jzyr96Fw5nFfW92A/D9BwHcBwLWNyiCRSODP/uzPsL29ja2tLdy7d895MkJrj4HVajX84Q9/wMnJCVqtFur1OgaDAY6Pj9FsNh1/UchGIhFMJhMMBgP32e12EQqFkM/nkUqlHC9FIhE3hlSOk8nEeWhUhuFwGPF4HMViEbFYDN1uF51OBwCQSCQQiUSQSCSQz+cRi8UcD29vb+NHP/oRyuUytre3sbe3t6wRdqlxGuWzwWCA2WyGRCLhe/8rJ/KvGoR8XytSYJsY2Pro4mNgk8kER0dHDpL5u7/7O9RqNWfxTqdTDIdDJwDJRIeHhwDmBWMikUAikQDwXDhyEgyHw7lrlWq1Gr7++msHFSSTSeTzeezs7LjPW7duYTabIZ1Or3UsliEbjxsOh2i32zg9PcUnn3yCg4MDZLNZlMtl54kpvBoKhTCZTJziUIVFxUMBEg6HneXJsQSexyc5yam4+D9wFn+07QbOkg0YExqNRnOeCuGxbDbrDI1QKIRbt25diACz1O/38fDhQ3z77beoVCo4Ojpy8dnxeIxYLIZ0Oo14PI50Oo1UKoXJZIJut4vRaIRWq4VutztnHFAh0bvQGC0/yb+xWAzRaNQZVfF4HMlkEul02hl4/X4f1WoVX3/9NQC4e9++fRvlchnj8RjJZBK7u7sXPn6vShwrGlAXiX7QcEsmk84w2XhfbyatTYE9e/YMjx8/RqPRwB/+8Ac8fPgQR0dHTtmo4CQsCATDYXT5leEVhvGRKgOeMx6PXWC+1+vho48+wmAwQD6fx71795DL5VAoFFAuly+FqfWZzHxrNBrur9lsYjabOQhLBSfJp7yCPDCFufibKi56Cfq7z7uzfaBCZLyBnm2/38dgMEA4HEalUkE0GkUul3NtvgiaTqc4OTlBtVrF6empi5vOZjPnQXEMwuGwg5gjkYgbNwb8qTyYODQYDJxgjEajbtwICVL5cQwZl2X8hfMhHo87nk0kEkgmk0gkEnP8PhqN8PXXX6NarWI2m6FYLM7d77KJxqk1nKxhBGAOXaGBZOffq0DXi+BbesTRaNRB2DQ2wuGwyyLVtm3oatFaFNhsNsMf//hH/Mf/+B9Rq9Xw1Vdf4dmzZ3MwFy0gnq/XAvOC/LxUeB9xgliFMB6PUa/XUalUEAqF8OjRIySTSezv7+Ov//qvce3aNfzwhz9EoVC49JjCdDpFo9HAwcEBnj175v6azaaDmHyerO2zT3nx/j4FaK9V8qVznydYKAxCoRD6/b6DEb/66iscHR0hmUxiNBq93OC8Bk0mE3z22Wf47W9/i263i0qlgn6/j0gkgkKhMAfL2uso3PL5vIMUh8MhhsMhjo+P0Wg0HNyn/EPISvsZi8WQyWRQLpfnzmOGHgBkMhkAcBAkBS/jyb/+9a8BPPdsi8Uicrkc9vf3r4TAHY/HaDabDibkn2ZgcoxpEJCPafj4DKwgY9cabBqTtEQodzQaufcQCoVcvJ5Zy/F4HPl8Hvl8fuOlXUFaqYRW5qzVajg8PHRWbr1ed/g9J74VEkFruAC47EMlKsPzBCotW4VzKOwZFAeA4+NjRKNRBwldBRoMBi4OxkSCeDyOwWAAAE6YEUr1KTB+WgFAS9fXV46ZPcZP3/vQe+v5XM/E+BFjQoQP+/3+hY73dDp1iRqDwcB5RppQQdjTjhsTUJhcFIvFkEgk3Dmj0cgJQnqvfCbnBpEE/sViMWdcqXetMUr+byHdVquF8XiMRqOBdrvtxvgqEPs0Go3mjCZCyxrz1v6pArP308/z+NwHdfM6GgD9fh/JZNI9n20DnsP3hOQ39OqkxovOH5vk9CoGwkoVWLVaxf/5P/8HR0dH+P3vf48HDx6g1+thMpkgm80iFovNMYttsBWa2jGfsKQ1rOdbq4vwFSeTJjwQziAzf/bZZ3j27BmuXbuGyWRy6VbsZDLBt99+i3/4h39w1iKzzjip2F9CXaQg6MQeU3gn6Fql85hMn2MVHgCXkKBCXiG1dZK+b8a5ZrOZGzcqMABOQSifUhHz/+l06mJPFHYa26NBobBjKpVCOp1GoVBwcbVYLObaxTbyGZqYo/eKRqNOeRKO/fjjj7G1tYVSqYRSqbT28Qw
"text/plain": [
"<Figure size 540x216 with 10 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plot_reconstructions(hashing_ae)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "I8Q_Rx-PyLhX"
},
"source": [
"Now let's see what the hashes look like for the first few images in the validation set:"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {
"id": "Akb6fNptyLhX",
"outputId": "efdbf0bd-1506-42ae-f1e1-7cd68d0f45a7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1010110110001011\n",
"1000110110001011\n",
"1100110110001011\n",
"0010111110101001\n",
"0000010111101101\n",
"...\n"
]
}
],
"source": [
"hashes = hashing_encoder.predict(X_valid).round().astype(np.int32)\n",
"hashes *= np.array([[2 ** bit for bit in range(16)]])\n",
"hashes = hashes.sum(axis=1)\n",
"for h in hashes[:5]:\n",
" print(f\"{h:016b}\")\n",
"print(\"...\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T5n8N4RoyLhX"
},
"source": [
"Now let's find the most common image hashes in the validation set, and display a few images for each hash. In the following image, all the images on a given row have the same hash:"
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {
"id": "hS80DfcLyLhX",
"outputId": "51a38569-f02e-4737-d40e-b478d8a557b3"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAcsAAAItCAYAAABW91FFAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9Z4yk2Xkdjp/KOXZ1jtPT05N2dmc2r7i7DEuRXFkiRUPBpCQYMm0Dgg3bsA3LBmwD+mZAlj8YhH4CJVEkTYsiLWpNMe+SS3K55KbZNDl2DpVzDu/7/9A6t596p3q6e6aqZ/V3HWAwM90V7n3vvU84T7gmXdfRRx999NFHH33sDPO9HkAfffTRRx99vNfRV5Z99NFHH330sQv6yrKPPvroo48+dkFfWfbRRx999NHHLugryz766KOPPvrYBdZdfr+vVFld12EymQAA9Xod+Xwe165dAwCYzWaEw2HY7XaYzWb1Ok3ToGkaWq0WqtUqarUaKpUKLBYLpqam4PP54Pf7O37HLtjTi7DHOWqahmQyibfffhuFQgHFYhEXL15EKpXC4uIiQqEQgsEgTp06hXA4jPHxcdhsNthsNuRyOZRKJayurmJlZQWLi4uo1+uw2+04duwY/H4//H4/ZmZmEAwGMT09DZ/Ph2Aw2NM5MhO61Wqh2Wzi+vXrqFaraDabiEajSCaTiMViaDQaaDQa2NjYQDwex/DwMAKBAObm5mCxWGAymVAoFJBOp/HTn/4UQ0NDmJycRCQSgc/nw8zMDPx+P4LBIOx2O5xOJ0ZGRmC322G323s6xx1fbNhHuq6j1WohlUohHo/jL/7iLxCNRqFpGqxWK2w2G4rFIiwWC4aHh/G+970PzzzzDAKBAKxWq/oMfu4e9yhx13Pk99brdXzjG99AKpVCLpfD8ePHMT09DYfDAZvNBrfbDYfDAbvdDofDoc6iruvQNA31eh3NZhOVSgW1Wg21Wk2dy1QqhWg0itXVVZw4cQJjY2N49NFHYbFY9nIue7KOAFAqlVAul7GysoJsNovV1VVcu3YNa2trmJiYgNfrRSgUUnOkjIlGoygWiygUChgaGkIkEsHTTz+NwcFBDA0NYWBgAA6HYz9D6brMAbZkp0S9XkcqlcI3vvENvPDCC0ilUtA0DR6PB9VqFeVyGel0GlarFadOncLHP/5x/OZv/iasVqtaI2MVxD72a8/WsdVqodFo4Gc/+xmy2SwKhQIajQaazSZKpRJ0XYfdbkepVEKxWGzbuzabDRaLBcDW87LZbHj44YfxwQ9+cL/DAHaY427Kcn/fIBYin8/jypUr+J//83/CZDLBYrHgscceQyAQgNPphNlshtlsVsK5XC4jmUwiGo1ifX0dbrcb/+gf/SPMz8+3Kct9CqGuoV6v48KFC/j93//9tnGYTCY4HA4sLS2h2Wziu9/9Lmw2G8bHxxEIBOD3+7G8vIxCoYB8Po9gMIiBgQHY7XbUajX85Cc/URvX5XIhHA6ree9BWd4x5EGs1+vI5XL4yle+gmQyiXq9Dk3TlNC0Wq2IRCKo1WrQdR1/8zd/g1qthpGRESVQW60WzGYzQqEQ/H4/HA4H1tfXUa/Xce7cOdhsNtjtdkQiEQwMDODZZ59VxhPn36u1lYJ8p+/SdR21Wg3nzp3D2bNnsbi4iFQqhUQioQ5iKpWC1WrFsWPHMDw8jKmpKZw8eRI+n099plEB93JeEvyOVquF69evIxaLoVqtYnNzEw6HAwMDA/D7/ZiensbAwABCoRAGBweV8UqDKZVKoVQqYX19Hel0GslkEpubm6jVarDb7Wg0Gmi1Wjh79iyGhobw4IMPKoPpXiEej2NlZQVf/epXkUwmUavVUK/XUa/Xcf78eTQaDWiapmROrVaD2WxGJBKB1+uFz+fD+vo6lpeX8fLLL+PMmTP40Ic+hMcff3y/yrJr0HUdjUYDJpPpFoOyWCzizTffxDvvvIMLFy4oY2VpaUm9l8rk3LlzOHPmDCqVCjwej1Io0lnRdb3NgTlISDnUaDSQz+fxne98B4VCAZOTk3A6nXA4HHC73bDZbPD7/WoP8nxZLBYUi0Xk83msra0hlUrhwoUL0DStTVne7XnsmrJMpVJ46623cOnSJSwvL6NYLCKTyeDChQsAtia0vr4Oq9WqLHW73Y56vY5Wq6UsWVpGFosFuVwOw8PDGB0dRTgcRiQSwcc+9jH4fD54PJ5uDX1X1Ot1/N//+39x7tw51Go1uFwuOJ3OtofvcDjgcDiUpUMFVC6XYTKZ1KG02+1KuFgsFrhcLmXxAlsH4cUXX0Sz2cQDDzzQszlJa/Xdd9/FxYsXkclk0Gq14PV6leBvNBpK2QwODmJ4eBgWiwXlchnAlnButVpq/n6/H263Wwkkp9MJi8WiBFWlUkEikcCLL76I48eP48knn+z5IZWfL/8djUaRSCRw5coVZLNZRKNRZDIZ5Wm4XC5lxZZKJUQiEYRCIczOziKZTOK5557Dj3/8Y7jdbgwPD2NwcBBjY2OYmpqC2+0+EOEjhU02m0UqlUI2m0W1WoXNZgMANJtNZDIZFItF5HI5uN1uOJ1ONBoNWCwWBAIB5VnTsyyVSmg0GqjX620eCb1oGlONRgM2m+0W7+cgsLm5iR//+MdYW1tDPB5HNpuFrutqL+q6rs5aMBhU897c3FTMDvemw+GA1WqF3W5HIpHA888/j8XFRYyPj+PZZ59VXvhO3l63YTKZ1PoBwPr6OqLRKN555x1sbm7iypUr2NzchNPpRL1eh81mw+zsrJKh9XodJpMJkUgE169fx5/8yZ/A7XbD7/fjvvvuw8jICMbGxtQ5v1fGjqZpbWxktVpVht2xY8faDBwASp7U6/W2daRR8cQTT6DVamFkZATj4+PY3NxEOBzuitHTFWVZLBaxubmJt956Cz/72c9w/vx51Go1RXnwYWxsbCgF43Q6YbfblbCtVqtq0fiaWCwGr9cLv9+P0dFRTExM4MSJExgeHlYb/CAOabPZxNtvv41r165B0zRYLBbY7XZllQFo23RUlvS6eHhdLpf6TD4Tfk6r1VLP6+rVq5ibm0Oz2eyp1U6Bt7y8jIsXL6JSqcBsNt8iRChEPR4PPB4PdF1HpVJBoVBQz8Dn88HhcMDj8aDZbKLZbCohK+dAq/D69esIBAKoVCoHto78fs47FothZWUFb7/9NmKxGNbW1uBwOOB0OjExMQG73Q6v14tyuYx6vY5QKKTouXg8jhs3bihD6dChQ5ienkar1VKeJr22Xs5Nes3FYhGpVAqVSgWNRkMJCE3TUKlU1JpZrVZYLBZkMhmYTCaEw2E1RioDQho6jUYDZrMZFoulTVk2m03l/ewjTHJXc6UX/NprryEejyOfz8NmsyljnHuu1WrBZrNheHgYXq8XbrdbMSacq67r6n30ypLJJPL5PGKxGN73vvfB7/crw5aeWK9BR6LVamFtbQ0LCwv42c9+hmg0iqWlJcV4cDzhcBiNRgO1Wg3JZFIZDjQCnU4nQqEQnE4nNE2Dz+drM9rvNRqNBqrVKpxOpxorsL0+ANoMQNLpzWYTTqcTLpcLU1NTiqr1+XzIZDKK6bpbmHbp4LMr71wul/Ff/st/weXLl/HOO+/AarUqISupL2Ms53YHSr6e/67X6wAAn8+HU6dO4UMf+hA+8pGPYGpqase57Tb2vcxR13UUCgX803/6T7G2ttZmrcoDY5wbnwOpZgAd43OcH3l3AMjn83jf+96HT3/60zh06FAbDd3NOcbjcbz99tt47bXXcPPmTUXTcC6k5zgfCYvFon5GgUolSsFDr5QUD/cFABQKBQwMDGB6ehpPPfUUxsfHezJHYFvINptNvP7661hfX8eFCxcUXcc5ynh6q9UCsOVBZDIZZLNZnDhxAk6nE5lMRv2O1DTnZzab4fF4EAgE8PTTT2N4eBhjY2O7UUB3PEeO3Ww24+WXX8bFixfxzjvvKCONv/N6vYrNsVgsio7TdR3NZlN9HulKrpOu66hWqyqexL1KWu83fuM3FPuziyK56/h6q9VSivrcuXO4ePEivvrVr8Llcimjjq8F0HZGaSRReZp
"text/plain": [
"<Figure size 576x720 with 80 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"from collections import Counter\n",
"\n",
"n_hashes = 10\n",
"n_images = 8\n",
"\n",
"top_hashes = Counter(hashes).most_common(n_hashes)\n",
"\n",
"plt.figure(figsize=(n_images, n_hashes))\n",
"for hash_index, (image_hash, hash_count) in enumerate(top_hashes):\n",
" indices = (hashes == image_hash)\n",
" for index, image in enumerate(X_valid[indices][:n_images]):\n",
" plt.subplot(n_hashes, n_images, hash_index * n_images + index + 1)\n",
" plt.imshow(image, cmap=\"binary\")\n",
" plt.axis(\"off\")\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "J77AxVz5yLhX"
},
"source": [
"# Exercise Solutions"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QIf1WVKcyLhX"
},
"source": [
"## 1. to 8."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hSUWHPHUyLhX"
},
"source": [
"1. Here are some of the main tasks that autoencoders are used for:\n",
" * Feature extraction\n",
" * Unsupervised pretraining\n",
" * Dimensionality reduction\n",
" * Generative models\n",
" * Anomaly detection (an autoencoder is generally bad at reconstructing outliers)\n",
"2. If you want to train a classifier and you have plenty of unlabeled training data but only a few thousand labeled instances, then you could first train a deep autoencoder on the full dataset (labeled + unlabeled), then reuse its lower half for the classifier (i.e., reuse the layers up to the codings layer, included) and train the classifier using the labeled data. If you have little labeled data, you probably want to freeze the reused layers when training the classifier.\n",
"3. The fact that an autoencoder perfectly reconstructs its inputs does not necessarily mean that it is a good autoencoder; perhaps it is simply an overcomplete autoencoder that learned to copy its inputs to the codings layer and then to the outputs. In fact, even if the codings layer contained a single neuron, it would be possible for a very deep autoencoder to learn to map each training instance to a different coding (e.g., the first instance could be mapped to 0.001, the second to 0.002, the third to 0.003, and so on), and it could learn \"by heart\" to reconstruct the right training instance for each coding. It would perfectly reconstruct its inputs without really learning any useful pattern in the data. In practice such a mapping is unlikely to happen, but it illustrates the fact that perfect reconstructions are not a guarantee that the autoencoder learned anything useful. However, if it produces very bad reconstructions, then it is almost guaranteed to be a bad autoencoder. To evaluate the performance of an autoencoder, one option is to measure the reconstruction loss (e.g., compute the MSE, or the mean square of the outputs minus the inputs). Again, a high reconstruction loss is a good sign that the autoencoder is bad, but a low reconstruction loss is not a guarantee that it is good. You should also evaluate the autoencoder according to what it will be used for. For example, if you are using it for unsupervised pretraining of a classifier, then you should also evaluate the classifier's performance.\n",
"4. An undercomplete autoencoder is one whose codings layer is smaller than the input and output layers. If it is larger, then it is an overcomplete autoencoder. The main risk of an excessively undercomplete autoencoder is that it may fail to reconstruct the inputs. The main risk of an overcomplete autoencoder is that it may just copy the inputs to the outputs, without learning any useful features.\n",
"5. To tie the weights of an encoder layer and its corresponding decoder layer, you simply make the decoder weights equal to the transpose of the encoder weights. This reduces the number of parameters in the model by half, often making training converge faster with less training data and reducing the risk of overfitting the training set.\n",
"6. A generative model is a model capable of randomly generating outputs that resemble the training instances. For example, once trained successfully on the MNIST dataset, a generative model can be used to randomly generate realistic images of digits. The output distribution is typically similar to the training data. For example, since MNIST contains many images of each digit, the generative model would output roughly the same number of images of each digit. Some generative models can be parametrized—for example, to generate only some kinds of outputs. An example of a generative autoencoder is the variational autoencoder.\n",
"7. A generative adversarial network is a neural network architecture composed of two parts, the generator and the discriminator, which have opposing objectives. The generator's goal is to generate instances similar to those in the training set, to fool the discriminator. The discriminator must distinguish the real instances from the generated ones. At each training iteration, the discriminator is trained like a normal binary classifier, then the generator is trained to maximize the discriminator's error. GANs are used for advanced image processing tasks such as super resolution, colorization, image editing (replacing objects with realistic background), turning a simple sketch into a photorealistic image, or predicting the next frames in a video. They are also used to augment a dataset (to train other models), to generate other types of data (such as text, audio, and time series), and to identify the weaknesses in other models and strengthen them.\n",
"8. Training GANs is notoriously difficult, because of the complex dynamics between the generator and the discriminator. The biggest difficulty is mode collapse, where the generator produces outputs with very little diversity. Moreover, training can be terribly unstable: it may start out fine and then suddenly start oscillating or diverging, without any apparent reason. GANs are also very sensitive to the choice of hyperparameters.\n",
"9. Diffusion models are good at generating diverse and high quality images. They are also much easier to train than GANs. However, compared to GANs and VAEs, they are much slower when generating images, since they must go through each step in the reverse diffusion process."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0GXqpaDHyLhY"
},
"source": [
"## 10.\n",
"_Exercise: Try using a denoising autoencoder to pretrain an image classifier. You can use MNIST (the simplest option), or a more complex image dataset such as [CIFAR10](https://homl.info/122) if you want a bigger challenge. Regardless of the dataset you're using, follow these steps:_\n",
"* Split the dataset into a training set and a test set. Train a deep denoising autoencoder on the full training set.\n",
"* Check that the images are fairly well reconstructed. Visualize the images that most activate each neuron in the coding layer.\n",
"* Build a classification DNN, reusing the lower layers of the autoencoder. Train it using only 500 images from the training set. Does it perform better with or without pretraining?"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {
"id": "RVAVmIznyLhY"
},
"outputs": [],
"source": [
"(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()\n",
"X_train = X_train / 255\n",
"X_test = X_test / 255"
]
},
{
"cell_type": "code",
"execution_count": 76,
"metadata": {
"id": "qXcft-kHyLhY"
},
"outputs": [],
"source": [
"tf.random.set_seed(42)\n",
"\n",
"denoising_encoder = tf.keras.Sequential([\n",
" tf.keras.layers.GaussianNoise(0.1),\n",
" tf.keras.layers.Conv2D(32, 3, padding=\"same\", activation=\"relu\"),\n",
" tf.keras.layers.MaxPool2D(),\n",
" tf.keras.layers.Flatten(),\n",
" tf.keras.layers.Dense(512, activation=\"relu\"),\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {
"id": "a2lC7XBqyLhY"
},
"outputs": [],
"source": [
"denoising_decoder = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(16 * 16 * 32, activation=\"relu\"),\n",
" tf.keras.layers.Reshape([16, 16, 32]),\n",
" tf.keras.layers.Conv2DTranspose(filters=3, kernel_size=3, strides=2,\n",
" padding=\"same\", activation=\"sigmoid\")\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {
"id": "RgtxA9BjyLhY",
"outputId": "7b43827a-cc75-45cd-e5e0-dbd4ccfea440"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"1563/1563 [==============================] - 165s 105ms/step - loss: 0.5931 - mse: 0.0184 - val_loss: 0.5960 - val_mse: 0.0185\n",
"Epoch 2/10\n",
"1563/1563 [==============================] - 167s 107ms/step - loss: 0.5723 - mse: 0.0098 - val_loss: 0.5761 - val_mse: 0.0109\n",
"Epoch 3/10\n",
"1563/1563 [==============================] - 164s 105ms/step - loss: 0.5674 - mse: 0.0079 - val_loss: 0.5724 - val_mse: 0.0094\n",
"Epoch 4/10\n",
"1563/1563 [==============================] - 163s 104ms/step - loss: 0.5651 - mse: 0.0071 - val_loss: 0.5720 - val_mse: 0.0093\n",
"Epoch 5/10\n",
"1563/1563 [==============================] - 168s 108ms/step - loss: 0.5639 - mse: 0.0066 - val_loss: 0.5685 - val_mse: 0.0078\n",
"Epoch 6/10\n",
"1563/1563 [==============================] - 165s 105ms/step - loss: 0.5631 - mse: 0.0063 - val_loss: 0.5667 - val_mse: 0.0072\n",
"Epoch 7/10\n",
"1563/1563 [==============================] - 168s 107ms/step - loss: 0.5625 - mse: 0.0061 - val_loss: 0.5662 - val_mse: 0.0069\n",
"Epoch 8/10\n",
"1563/1563 [==============================] - 164s 105ms/step - loss: 0.5620 - mse: 0.0059 - val_loss: 0.5653 - val_mse: 0.0066\n",
"Epoch 9/10\n",
"1563/1563 [==============================] - 166s 107ms/step - loss: 0.5617 - mse: 0.0058 - val_loss: 0.5646 - val_mse: 0.0064\n",
"Epoch 10/10\n",
"1563/1563 [==============================] - 164s 105ms/step - loss: 0.5614 - mse: 0.0057 - val_loss: 0.5639 - val_mse: 0.0061\n"
]
}
],
"source": [
"denoising_ae = tf.keras.Sequential([denoising_encoder, denoising_decoder])\n",
"denoising_ae.compile(loss=\"binary_crossentropy\", optimizer=\"nadam\",\n",
" metrics=[\"mse\"])\n",
"history = denoising_ae.fit(X_train, X_train, epochs=10,\n",
" validation_data=(X_test, X_test))"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {
"id": "dTlFyrmSyLhY",
"outputId": "7a842325-f99e-4e85-962f-ba3bcbf8cdfd"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVgAAAI/CAYAAAAydcE/AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9edBtaZbWh/3WO+x9hm+4Uw6V1dVVTB20QAK3wEYSBiQjISRbIRCywrLCQvgPyxYOh2SCkEGEAYUskAUeomkHyBgICZkQlm2F2iETIIExDllqwAha3YKeu7qqsjLz3vsNZ9h7v8PyH+vd+5yblZVZ1Zl5C5r7Ru483z3Dnvd613rWs54lqsqr8Wq8Gq/Gq/HJD/et3oFX49V4NV6Nn6njlYF9NV6NV+PV+JTGKwP7arwar8ar8SmNVwb21Xg1Xo1X41Marwzsq/FqvBqvxqc0XhnYV+PVeDVejU9p/C1nYEXkx0Xkt3yTv1ER+Q2f8H78ThH5/k9yna/GRw8R+aMi8r3f6v14NT79ISJ/TkS++1PexveLyO/8tNb/LTGwIvJZEflDIvJTIjKJyJdE5N8SkW/7Bn7+S4Hv+SY3+RngP/zm9/TV+KRGM4wqIv/K+97/Ve39J9/gqv5nwD/zye/hq3F2jVREkoi8IyJ/VkT+BRGJ34Jd+vXA/+JbsN1PbLx0AysiPwv4i8AvBP5Z4OdiD8wvAL5PRL7wdX7XAajqu6p6+Ga2qapvq+r4cfb71fhExgD8VhF57ae7AlW9VdWbT26XXo33jT+DOSRfAP4hzDH5XcD/W0S2L3NHVPWZqt6/zG1+0uNb4cH+AaACv1pV/2NV/UlV/bPAr27v/wFYwoP/g4j8myLyLvD/ae+/ABGIyHeIyP9LRAYR+esi8o+IyE5EfuPZdxaIQES+0P79T4jInxaRg4j8gIj8g2ff9yLyh0Xkx0TkKCI/JCK/VUT+loNU/iYbfxb4ceB3fL0viMivEJH/rF3Pr4rI/2aeXNvnL0AE7fv/33bNb9tvf6GIbEXk7v3QkIj8g807e+NTOL6fCWNsDsmXVPWvqOrvB34V8F3AbwVzdkTk97YIdC8i3yciv2ZewVlU8t9q1+MgIn9RRL7rfEMi8utF5K+JyCgiXxSR3y4icvb5CxBB+/5fbc/ks/bcv3H2+X9HRP5Su3d+TET+tffdO6+LyH/Qfv8TIvKbPo0TeD5eqsEQkUfAPwz8gfd7oe3f3wP8WhF52N7+ZwAB/pvA/+AD1ueA/xuQgV8G/Ebgfwn038Du/GvA/x74RcD3AX9CRC7aZw74EvDfBb4T+O3AbwP+uW/wUF+NDx4V+JeBf15Efs77PxSRzwL/EfD/A/5rwP8Q+O8B//oHrUxEAvAfAH8Bu47/DeB/BxRV3QP/Z+D9D9FvAr5XVb/6SRzQ3w5DVb8f+H8C/0R7648AvxL4p4G/E/hjwH8oIr/ofT/917Hr/V3AU+CPzwZURP5u4E8C/9e2jn8ZgwN+8wftg4i8CfyJtq3vBH4F8G+fff5rgD8OfDcWDf8m4DcA/6uz1fxRLGL+1cA/jtmUL3zDJ+KnM1T1pS3YA6DAr/s6n/+69vl/HfhzwF/9gO/8OPBb2t+/BjOunz37/O9t6/iNZ+8p8Bva319o//4fnX3+2fbeL/+Qff89wJ85+/fvBL7/ZZ6/v5UX7Ob+3vb3nwX+RPv7V7Vz/wSb9H4YcGe/+43ACGw+YD2P2m9/5dfZ5i85vz+Ah8AR+G9/q8/H34zL+bn9gM9+D3AAfg42UX77+z7/vwPf875r+mvOPv/72nvf1v79x4H/5H3r+J3AT539+88B393+/q72+89/nf3788DveN97/ziww5y072i///vOPv88UIDf+Wmd029VyPv1FGbkfZ//pY9Yz88HvqyqXzp77/uwG+Cjxl89+/vL7fX1ZUdE/vkW1rwrIjvgXwS+/RtY76vx0eO3Av+kiPyS973/ncB/qqrn1+8vAB3mebwwVPUZZhT+lIj8P0TkXxKRz519/heBv4Zh/WAe13PMS341vrkh2HP5Xe3vH2iwzK49H/8oZnzPx4c9Y99Jg/3Oxl8APisiVx+w/f8Cw4e/X0T+fRH5H78Py/+7gd/+vn36d4Et8GbbXgX+8/kHqvoTZ/v1qYyXbWB/CLtIv+DrfP6d7fMfaf/ef8T65ov+0xlp/kPbdEY7HyLyTwH/W+zh/TXAL8bgi45X42MPVf0+4N8Hfu/7Pvqw6/mB76vqP4dFRn8e+MeAv3GOBwL/R07Qzm8C/qiqlp/mrv/tPP4O4EexZ0QxNs8vPlu+k6+FY9LZ3y88Y3yT17pds3+oLX8Vg49+6AyWcFgy7nyf/i7g5wHvcnLeXup4qQa2eRx/CvifiMjm/LP2738B+I/a976R8YPYjPfW2Xu/hI9/XL8c+M9U9btV9S+r6g/ztbPzq/Hxxm/DsPV/+Oy9HwD+nvclE385MHGadL9mqOp/oaq/V1V/FRZW/rNnH/872D3ymzHv6498Inv/t9EQkV+IXaf/C4aPC/Cmqv7w+5YvfeiKXhw/gF3b8/HLMYjgA5kDauM/VdXfhRn4LwP/VPv4LwM//wP26YdVNWO2wrXfzcf17cBbfIrjWwER/GYgAH9GRP4BEfmciPwq4E9jF+4DQe6vM/408NeBPyYiv0hEfhnw+zHc7eMI3f4N4LtE5NeKyM8Tkd+Bgfqvxic02qT1hzBe6zy+B7vhv0dEvlNE/lEM+/tu/QBqnoj8LBH5PSLy94rI50Xk78e8lh84284tlkz5fcCfV9Uf+vSO6mfE6EXkTRF5qz1T/xI2af0l4N9U1b+B4ad/VER+g4j8bBH5JSLyW0Tk138T2/l9wK8UK9j5DhH57wP/c+Df+KAvi8gvE5F/RUR+aTOM/xjwOU7X+ncD/7SI/O7GIvn5bf/+DQBV/etYou4PisjfIyK/GItQj9/U2fkmx0s3sKr6I5iX+V9iWcAfxbCSHwR+qar+2DexroolxnoMW/ljWKJEMc7lT3f8QeDfa/v1fVhi7Pd9jPW9Gh88fjc2GQLQPKBfizEI/grwf8KYAL/t6/z+gCUv/iQ2Kf4x7OF/P/TwhzF45w9/crv+M3b8auArwE8C/zFmyH4X8CvUmBlgkMsfwYzhfwV8L5bV/4lvdCOq+peBfxJjJnw/NpH+HowF8EHjFkuUfS8GNf4+4F9V1X+nre9PYTjw34/Zgv8cYyb85Nk6fiPwY8B/gvF7/10saf6pDTnBjz8zRsNk/grwS1T1o5Jkr8bfBqNh6n8QeOuDPOFX49X4tEb4Vu/Axx0i8uuwZNgPYZ7m78cyjn/5W7hbr8bfBKPh+l/APOB/65VxfTVe9viZUJl0iYUVP4CFhz+I8e9+Zrnmr8ZPZ/xWbLJ9Bvyr3+J9eTX+Nhw/4yCCV+PVeDVejb9Zxs8ED/bVeDVejVfjb8rxysC+Gq/Gq/FqfErjQ5Ncf+Rf/C5lgRAUAUQE7x1uFr1pr7VWSsmtBrdSW7Wjcw7vfVvDSfvgxFJVELX1S1udyUFAiaDBvrt8p+J9RqQigBMHKme/s3VWLW1bUKtS2/bscOYvOxSx7yioChXIy3c4/xH6AdTaUiq1VlRB1aEq7djsu2JbsF+rUtTWUhRStddSlVTqso//63/vRz/VqpM/+T/9BxRgpSNXdUck4deKv1ZcB8cD3N8KOUG6vWN69yllyhxXj9itX6f4Hq1XaH0IonTXP07/8EfBJTKOrB6nju24YjOu0M6x/7zn+Jan1sp0eyDtRsJeefgjhe1XKu4i4r7jGnlzA2mg3D9Fhz1Zrxj1C2Qe0h0Tl892xCExdcLuwpOisLnY8+jJM7p+Ypce8Hz8DFNdUaenpOHLaE0c1m+yu/g81XVcuANXfodTZRwvGIZLkEx/9cOsrn4YVyv9e9f0zy6pg3D4smd4x8NKCb8g4b6QKSly95WHHJ5tCapsdWKlmdKPHB/ekjZH0jFzePfAdDAm2nd/r37q1UTf81vs2kotuDIh1YrWThsWRAQEfHuWRQTnPM4HQOz5rbU9r7T7UqmloDUv72tbr/MOJw5xQoi
"text/plain": [
"<Figure size 432x720 with 15 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"n_images = 5\n",
"new_images = X_test[:n_images]\n",
"new_images_noisy = new_images + np.random.randn(n_images, 32, 32, 3) * 0.1\n",
"new_images_denoised = denoising_ae.predict(new_images_noisy)\n",
"\n",
"plt.figure(figsize=(6, n_images * 2))\n",
"for index in range(n_images):\n",
" plt.subplot(n_images, 3, index * 3 + 1)\n",
" plt.imshow(new_images[index])\n",
" plt.axis('off')\n",
" if index == 0:\n",
" plt.title(\"Original\")\n",
" plt.subplot(n_images, 3, index * 3 + 2)\n",
" plt.imshow(new_images_noisy[index].clip(0., 1.))\n",
" plt.axis('off')\n",
" if index == 0:\n",
" plt.title(\"Noisy\")\n",
" plt.subplot(n_images, 3, index * 3 + 3)\n",
" plt.imshow(new_images_denoised[index])\n",
" plt.axis('off')\n",
" if index == 0:\n",
" plt.title(\"Denoised\")\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ddB3aKyFyLhZ"
},
"source": [
"## 11.\n",
"_Exercise: Train a variational autoencoder on the image dataset of your choice, and use it to generate images. Alternatively, you can try to find an unlabeled dataset that you are interested in and see if you can generate new samples._\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HEfVD1TTyLhZ"
},
"source": [
"See the VAE code above."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0mePWI8ryLhZ"
},
"source": [
"## 12.\n",
"_Exercise: Train a DCGAN to tackle the image dataset of your choice, and use it to generate images. Add experience replay and see if this helps. Turn it into a conditional GAN where you can control the generated class._\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hWBNDAFvyLhZ"
},
"source": [
"TODO"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "17_autoencoders_and_gans.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.10"
},
"nav_menu": {
"height": "381px",
"width": "453px"
},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}