Target=blue dot, prediction=red cross, fixes #472
parent
43f8795bcc
commit
aad6e5186a
|
@ -162,12 +162,12 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def plot_series(series, y=None, y_pred=None, x_label=\"$t$\", y_label=\"$x(t)$\"):\n",
|
"def plot_series(series, y=None, y_pred=None, x_label=\"$t$\", y_label=\"$x(t)$\", legend=True):\n",
|
||||||
" plt.plot(series, \".-\")\n",
|
" plt.plot(series, \".-\")\n",
|
||||||
" if y is not None:\n",
|
" if y is not None:\n",
|
||||||
" plt.plot(n_steps, y, \"bx\", markersize=10)\n",
|
" plt.plot(n_steps, y, \"bo\", label=\"Target\")\n",
|
||||||
" if y_pred is not None:\n",
|
" if y_pred is not None:\n",
|
||||||
" plt.plot(n_steps, y_pred, \"ro\")\n",
|
" plt.plot(n_steps, y_pred, \"rx\", markersize=10, label=\"Prediction\")\n",
|
||||||
" plt.grid(True)\n",
|
" plt.grid(True)\n",
|
||||||
" if x_label:\n",
|
" if x_label:\n",
|
||||||
" plt.xlabel(x_label, fontsize=16)\n",
|
" plt.xlabel(x_label, fontsize=16)\n",
|
||||||
|
@ -175,16 +175,26 @@
|
||||||
" plt.ylabel(y_label, fontsize=16, rotation=0)\n",
|
" plt.ylabel(y_label, fontsize=16, rotation=0)\n",
|
||||||
" plt.hlines(0, 0, 100, linewidth=1)\n",
|
" plt.hlines(0, 0, 100, linewidth=1)\n",
|
||||||
" plt.axis([0, n_steps + 1, -1, 1])\n",
|
" plt.axis([0, n_steps + 1, -1, 1])\n",
|
||||||
|
" if legend and (y or y_pred):\n",
|
||||||
|
" plt.legend(fontsize=14, loc=\"upper left\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"fig, axes = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=(12, 4))\n",
|
"fig, axes = plt.subplots(nrows=1, ncols=3, sharey=True, figsize=(12, 4))\n",
|
||||||
"for col in range(3):\n",
|
"for col in range(3):\n",
|
||||||
" plt.sca(axes[col])\n",
|
" plt.sca(axes[col])\n",
|
||||||
" plot_series(X_valid[col, :, 0], y_valid[col, 0],\n",
|
" plot_series(X_valid[col, :, 0], y_valid[col, 0],\n",
|
||||||
" y_label=(\"$x(t)$\" if col==0 else None))\n",
|
" y_label=(\"$x(t)$\" if col==0 else None),\n",
|
||||||
|
" legend=(col == 0))\n",
|
||||||
"save_fig(\"time_series_plot\")\n",
|
"save_fig(\"time_series_plot\")\n",
|
||||||
"plt.show()"
|
"plt.show()"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"**Note**: in this notebook, the blue dots represent targets, and red crosses represent predictions. In the book, I first used blue crosses for targets and red dots for predictions, then I reversed this later in the chapter. Sorry if this caused some confusion."
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
@ -499,8 +509,8 @@
|
||||||
" n_steps = X.shape[1]\n",
|
" n_steps = X.shape[1]\n",
|
||||||
" ahead = Y.shape[1]\n",
|
" ahead = Y.shape[1]\n",
|
||||||
" plot_series(X[0, :, 0])\n",
|
" plot_series(X[0, :, 0])\n",
|
||||||
" plt.plot(np.arange(n_steps, n_steps + ahead), Y[0, :, 0], \"ro-\", label=\"Actual\")\n",
|
" plt.plot(np.arange(n_steps, n_steps + ahead), Y[0, :, 0], \"bo-\", label=\"Actual\")\n",
|
||||||
" plt.plot(np.arange(n_steps, n_steps + ahead), Y_pred[0, :, 0], \"bx-\", label=\"Forecast\", markersize=10)\n",
|
" plt.plot(np.arange(n_steps, n_steps + ahead), Y_pred[0, :, 0], \"rx-\", label=\"Forecast\", markersize=10)\n",
|
||||||
" plt.axis([0, n_steps + ahead, -1, 1])\n",
|
" plt.axis([0, n_steps + ahead, -1, 1])\n",
|
||||||
" plt.legend(fontsize=14)\n",
|
" plt.legend(fontsize=14)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|
Loading…
Reference in New Issue