Simplify plot_digits() and add comments, fixes #479

main
Aurélien Geron 2021-10-10 13:55:50 +13:00
parent 495de15361
commit 3f9ff484a6
2 changed files with 33 additions and 16 deletions

View File

@ -197,16 +197,24 @@
"def plot_digits(instances, images_per_row=10, **options):\n",
" size = 28\n",
" images_per_row = min(len(instances), images_per_row)\n",
" images = [instance.reshape(size,size) for instance in instances]\n",
" # This is equivalent to n_rows = ceil(len(instances) / images_per_row):\n",
" n_rows = (len(instances) - 1) // images_per_row + 1\n",
" row_images = []\n",
"\n",
" # Append empty images to fill the end of the grid, if needed:\n",
" n_empty = n_rows * images_per_row - len(instances)\n",
" images.append(np.zeros((size, size * n_empty)))\n",
" for row in range(n_rows):\n",
" rimages = images[row * images_per_row : (row + 1) * images_per_row]\n",
" row_images.append(np.concatenate(rimages, axis=1))\n",
" image = np.concatenate(row_images, axis=0)\n",
" plt.imshow(image, cmap = mpl.cm.binary, **options)\n",
" padded_instances = np.concatenate([instances, np.zeros((n_empty, size * size))], axis=0)\n",
"\n",
" # Reshape the array so it's organized as a grid containing 28×28 images:\n",
" image_grid = padded_instances.reshape((n_rows, images_per_row, size, size))\n",
"\n",
" # Combine axes 0 and 2 (vertical image grid axis, and vertical image axis),\n",
" # and axes 1 and 3 (horizontal axes). We first need to move the axes that we\n",
" # want to combine next to each other, using transpose(), and only then we\n",
" # can reshape:\n",
" big_image = image_grid.transpose(0, 2, 1, 3).reshape(n_rows * size,\n",
" images_per_row * size)\n",
" # Now that we have a big image, we just need to show it:\n",
" plt.imshow(big_image, cmap = mpl.cm.binary, **options)\n",
" plt.axis(\"off\")"
]
},

View File

@ -947,19 +947,28 @@
"metadata": {},
"outputs": [],
"source": [
"# EXTRA\n",
"def plot_digits(instances, images_per_row=5, **options):\n",
" size = 28\n",
" images_per_row = min(len(instances), images_per_row)\n",
" images = [instance.reshape(size,size) for instance in instances]\n",
" # This is equivalent to n_rows = ceil(len(instances) / images_per_row):\n",
" n_rows = (len(instances) - 1) // images_per_row + 1\n",
" row_images = []\n",
"\n",
" # Append empty images to fill the end of the grid, if needed:\n",
" n_empty = n_rows * images_per_row - len(instances)\n",
" images.append(np.zeros((size, size * n_empty)))\n",
" for row in range(n_rows):\n",
" rimages = images[row * images_per_row : (row + 1) * images_per_row]\n",
" row_images.append(np.concatenate(rimages, axis=1))\n",
" image = np.concatenate(row_images, axis=0)\n",
" plt.imshow(image, cmap = mpl.cm.binary, **options)\n",
" padded_instances = np.concatenate([instances, np.zeros((n_empty, size * size))], axis=0)\n",
"\n",
" # Reshape the array so it's organized as a grid containing 28×28 images:\n",
" image_grid = padded_instances.reshape((n_rows, images_per_row, size, size))\n",
"\n",
" # Combine axes 0 and 2 (vertical image grid axis, and vertical image axis),\n",
" # and axes 1 and 3 (horizontal axes). We first need to move the axes that we\n",
" # want to combine next to each other, using transpose(), and only then we\n",
" # can reshape:\n",
" big_image = image_grid.transpose(0, 2, 1, 3).reshape(n_rows * size,\n",
" images_per_row * size)\n",
" # Now that we have a big image, we just need to show it:\n",
" plt.imshow(big_image, cmap = mpl.cm.binary, **options)\n",
" plt.axis(\"off\")"
]
},