From 3f9ff484a65302c02afc48c6f56bddb9628ae352 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Sun, 10 Oct 2021 13:55:50 +1300 Subject: [PATCH] Simplify plot_digits() and add comments, fixes #479 --- 03_classification.ipynb | 24 ++++++++++++++++-------- 08_dimensionality_reduction.ipynb | 25 +++++++++++++++++-------- 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/03_classification.ipynb b/03_classification.ipynb index 6c8c7fc..877ab61 100644 --- a/03_classification.ipynb +++ b/03_classification.ipynb @@ -197,16 +197,24 @@ "def plot_digits(instances, images_per_row=10, **options):\n", " size = 28\n", " images_per_row = min(len(instances), images_per_row)\n", - " images = [instance.reshape(size,size) for instance in instances]\n", + " # This is equivalent to n_rows = ceil(len(instances) / images_per_row):\n", " n_rows = (len(instances) - 1) // images_per_row + 1\n", - " row_images = []\n", + "\n", + " # Append empty images to fill the end of the grid, if needed:\n", " n_empty = n_rows * images_per_row - len(instances)\n", - " images.append(np.zeros((size, size * n_empty)))\n", - " for row in range(n_rows):\n", - " rimages = images[row * images_per_row : (row + 1) * images_per_row]\n", - " row_images.append(np.concatenate(rimages, axis=1))\n", - " image = np.concatenate(row_images, axis=0)\n", - " plt.imshow(image, cmap = mpl.cm.binary, **options)\n", + " padded_instances = np.concatenate([instances, np.zeros((n_empty, size * size))], axis=0)\n", + "\n", + " # Reshape the array so it's organized as a grid containing 28×28 images:\n", + " image_grid = padded_instances.reshape((n_rows, images_per_row, size, size))\n", + "\n", + " # Combine axes 0 and 2 (vertical image grid axis, and vertical image axis),\n", + " # and axes 1 and 3 (horizontal axes). We first need to move the axes that we\n", + " # want to combine next to each other, using transpose(), and only then we\n", + " # can reshape:\n", + " big_image = image_grid.transpose(0, 2, 1, 3).reshape(n_rows * size,\n", + " images_per_row * size)\n", + " # Now that we have a big image, we just need to show it:\n", + " plt.imshow(big_image, cmap = mpl.cm.binary, **options)\n", " plt.axis(\"off\")" ] }, diff --git a/08_dimensionality_reduction.ipynb b/08_dimensionality_reduction.ipynb index ecb8b9d..ab103af 100644 --- a/08_dimensionality_reduction.ipynb +++ b/08_dimensionality_reduction.ipynb @@ -947,19 +947,28 @@ "metadata": {}, "outputs": [], "source": [ + "# EXTRA\n", "def plot_digits(instances, images_per_row=5, **options):\n", " size = 28\n", " images_per_row = min(len(instances), images_per_row)\n", - " images = [instance.reshape(size,size) for instance in instances]\n", + " # This is equivalent to n_rows = ceil(len(instances) / images_per_row):\n", " n_rows = (len(instances) - 1) // images_per_row + 1\n", - " row_images = []\n", + "\n", + " # Append empty images to fill the end of the grid, if needed:\n", " n_empty = n_rows * images_per_row - len(instances)\n", - " images.append(np.zeros((size, size * n_empty)))\n", - " for row in range(n_rows):\n", - " rimages = images[row * images_per_row : (row + 1) * images_per_row]\n", - " row_images.append(np.concatenate(rimages, axis=1))\n", - " image = np.concatenate(row_images, axis=0)\n", - " plt.imshow(image, cmap = mpl.cm.binary, **options)\n", + " padded_instances = np.concatenate([instances, np.zeros((n_empty, size * size))], axis=0)\n", + "\n", + " # Reshape the array so it's organized as a grid containing 28×28 images:\n", + " image_grid = padded_instances.reshape((n_rows, images_per_row, size, size))\n", + "\n", + " # Combine axes 0 and 2 (vertical image grid axis, and vertical image axis),\n", + " # and axes 1 and 3 (horizontal axes). We first need to move the axes that we\n", + " # want to combine next to each other, using transpose(), and only then we\n", + " # can reshape:\n", + " big_image = image_grid.transpose(0, 2, 1, 3).reshape(n_rows * size,\n", + " images_per_row * size)\n", + " # Now that we have a big image, we just need to show it:\n", + " plt.imshow(big_image, cmap = mpl.cm.binary, **options)\n", " plt.axis(\"off\")" ] },