From aad6e5186a5f347a750e99e7cc6d656beb3145e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Thu, 7 Oct 2021 16:39:37 +1300 Subject: [PATCH] Target=blue dot, prediction=red cross, fixes #472 --- ...essing_sequences_using_rnns_and_cnns.ipynb | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/15_processing_sequences_using_rnns_and_cnns.ipynb b/15_processing_sequences_using_rnns_and_cnns.ipynb index 99dba80..d1410b3 100644 --- a/15_processing_sequences_using_rnns_and_cnns.ipynb +++ b/15_processing_sequences_using_rnns_and_cnns.ipynb @@ -162,12 +162,12 @@ "metadata": {}, "outputs": [], "source": [ - "def plot_series(series, y=None, y_pred=None, x_label=\"$t$\", y_label=\"$x(t)$\"):\n", + "def plot_series(series, y=None, y_pred=None, x_label=\"$t$\", y_label=\"$x(t)$\", legend=True):\n", " plt.plot(series, \".-\")\n", " if y is not None:\n", - " plt.plot(n_steps, y, \"bx\", markersize=10)\n", + " plt.plot(n_steps, y, \"bo\", label=\"Target\")\n", " if y_pred is not None:\n", - " plt.plot(n_steps, y_pred, \"ro\")\n", + " plt.plot(n_steps, y_pred, \"rx\", markersize=10, label=\"Prediction\")\n", " plt.grid(True)\n", " if x_label:\n", " plt.xlabel(x_label, fontsize=16)\n", @@ -175,16 +175,26 @@ " plt.ylabel(y_label, fontsize=16, rotation=0)\n", " plt.hlines(0, 0, 100, linewidth=1)\n", " plt.axis([0, n_steps + 1, -1, 1])\n", + " if legend and (y or y_pred):\n", + " plt.legend(fontsize=14, loc=\"upper left\")\n", "\n", "fig, axes = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=(12, 4))\n", "for col in range(3):\n", " plt.sca(axes[col])\n", " plot_series(X_valid[col, :, 0], y_valid[col, 0],\n", - " y_label=(\"$x(t)$\" if col==0 else None))\n", + " y_label=(\"$x(t)$\" if col==0 else None),\n", + " legend=(col == 0))\n", "save_fig(\"time_series_plot\")\n", "plt.show()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note**: in this notebook, the blue dots represent targets, and red crosses represent predictions. In the book, I first used blue crosses for targets and red dots for predictions, then I reversed this later in the chapter. Sorry if this caused some confusion." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -499,8 +509,8 @@ " n_steps = X.shape[1]\n", " ahead = Y.shape[1]\n", " plot_series(X[0, :, 0])\n", - " plt.plot(np.arange(n_steps, n_steps + ahead), Y[0, :, 0], \"ro-\", label=\"Actual\")\n", - " plt.plot(np.arange(n_steps, n_steps + ahead), Y_pred[0, :, 0], \"bx-\", label=\"Forecast\", markersize=10)\n", + " plt.plot(np.arange(n_steps, n_steps + ahead), Y[0, :, 0], \"bo-\", label=\"Actual\")\n", + " plt.plot(np.arange(n_steps, n_steps + ahead), Y_pred[0, :, 0], \"rx-\", label=\"Forecast\", markersize=10)\n", " plt.axis([0, n_steps + ahead, -1, 1])\n", " plt.legend(fontsize=14)\n", "\n",