Target=blue dot, prediction=red cross, fixes #472

main
Aurélien Geron 2021-10-07 16:39:37 +13:00
parent 43f8795bcc
commit aad6e5186a
1 changed files with 16 additions and 6 deletions

View File

@ -162,12 +162,12 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def plot_series(series, y=None, y_pred=None, x_label=\"$t$\", y_label=\"$x(t)$\"):\n", "def plot_series(series, y=None, y_pred=None, x_label=\"$t$\", y_label=\"$x(t)$\", legend=True):\n",
" plt.plot(series, \".-\")\n", " plt.plot(series, \".-\")\n",
" if y is not None:\n", " if y is not None:\n",
" plt.plot(n_steps, y, \"bx\", markersize=10)\n", " plt.plot(n_steps, y, \"bo\", label=\"Target\")\n",
" if y_pred is not None:\n", " if y_pred is not None:\n",
" plt.plot(n_steps, y_pred, \"ro\")\n", " plt.plot(n_steps, y_pred, \"rx\", markersize=10, label=\"Prediction\")\n",
" plt.grid(True)\n", " plt.grid(True)\n",
" if x_label:\n", " if x_label:\n",
" plt.xlabel(x_label, fontsize=16)\n", " plt.xlabel(x_label, fontsize=16)\n",
@ -175,16 +175,26 @@
" plt.ylabel(y_label, fontsize=16, rotation=0)\n", " plt.ylabel(y_label, fontsize=16, rotation=0)\n",
" plt.hlines(0, 0, 100, linewidth=1)\n", " plt.hlines(0, 0, 100, linewidth=1)\n",
" plt.axis([0, n_steps + 1, -1, 1])\n", " plt.axis([0, n_steps + 1, -1, 1])\n",
" if legend and (y or y_pred):\n",
" plt.legend(fontsize=14, loc=\"upper left\")\n",
"\n", "\n",
"fig, axes = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=(12, 4))\n", "fig, axes = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=(12, 4))\n",
"for col in range(3):\n", "for col in range(3):\n",
" plt.sca(axes[col])\n", " plt.sca(axes[col])\n",
" plot_series(X_valid[col, :, 0], y_valid[col, 0],\n", " plot_series(X_valid[col, :, 0], y_valid[col, 0],\n",
" y_label=(\"$x(t)$\" if col==0 else None))\n", " y_label=(\"$x(t)$\" if col==0 else None),\n",
" legend=(col == 0))\n",
"save_fig(\"time_series_plot\")\n", "save_fig(\"time_series_plot\")\n",
"plt.show()" "plt.show()"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note**: in this notebook, the blue dots represent targets, and red crosses represent predictions. In the book, I first used blue crosses for targets and red dots for predictions, then I reversed this later in the chapter. Sorry if this caused some confusion."
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
@ -499,8 +509,8 @@
" n_steps = X.shape[1]\n", " n_steps = X.shape[1]\n",
" ahead = Y.shape[1]\n", " ahead = Y.shape[1]\n",
" plot_series(X[0, :, 0])\n", " plot_series(X[0, :, 0])\n",
" plt.plot(np.arange(n_steps, n_steps + ahead), Y[0, :, 0], \"ro-\", label=\"Actual\")\n", " plt.plot(np.arange(n_steps, n_steps + ahead), Y[0, :, 0], \"bo-\", label=\"Actual\")\n",
" plt.plot(np.arange(n_steps, n_steps + ahead), Y_pred[0, :, 0], \"bx-\", label=\"Forecast\", markersize=10)\n", " plt.plot(np.arange(n_steps, n_steps + ahead), Y_pred[0, :, 0], \"rx-\", label=\"Forecast\", markersize=10)\n",
" plt.axis([0, n_steps + ahead, -1, 1])\n", " plt.axis([0, n_steps + ahead, -1, 1])\n",
" plt.legend(fontsize=14)\n", " plt.legend(fontsize=14)\n",
"\n", "\n",