Improve the code to forecast the next 10 time steps

main
Aurélien Geron 2019-07-13 11:28:50 +02:00
parent f8f2b9e4bb
commit 6c0cd5d2df
1 changed files with 189 additions and 100 deletions

View File

@ -487,7 +487,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Now let's create an RNN that predicts all 10 next values at once:" "Now let's use this model to predict the next 10 values. We first need to regenerate the sequences with 9 more time steps."
] ]
}, },
{ {
@ -505,11 +505,93 @@
"X_test, Y_test = series[9000:, :n_steps], series[9000:, -10:, 0]" "X_test, Y_test = series[9000:, :n_steps], series[9000:, -10:, 0]"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's predict the next 10 values one by one:"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 28, "execution_count": 28,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [
"X = X_valid\n",
"for step_ahead in range(10):\n",
" y_pred_one = model.predict(X)[:, np.newaxis, :]\n",
" X = np.concatenate([X, y_pred_one], axis=1)\n",
"\n",
"Y_pred = X[:, n_steps:, 0]"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"Y_pred.shape"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"np.mean(keras.metrics.mean_squared_error(Y_valid, Y_pred))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's compare this performance with some baselines: naive predictions and a simple linear model:"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"Y_naive_pred = Y_valid[:, -1:]\n",
"np.mean(keras.metrics.mean_squared_error(Y_valid, Y_naive_pred))"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"\n",
"model = keras.models.Sequential([\n",
" keras.layers.Flatten(input_shape=[50, 1]),\n",
" keras.layers.Dense(10)\n",
"])\n",
"\n",
"model.compile(loss=\"mse\", optimizer=\"adam\")\n",
"history = model.fit(X_train, Y_train, epochs=20,\n",
" validation_data=(X_valid, Y_valid))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's create an RNN that predicts all 10 next values at once:"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [ "source": [
"np.random.seed(42)\n", "np.random.seed(42)\n",
"tf.random.set_seed(42)\n", "tf.random.set_seed(42)\n",
@ -527,7 +609,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 29, "execution_count": 34,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -540,7 +622,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 30, "execution_count": 35,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -557,7 +639,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 31, "execution_count": 36,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -578,7 +660,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 32, "execution_count": 37,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -587,7 +669,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 33, "execution_count": 38,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -603,14 +685,14 @@
"def last_time_step_mse(Y_true, Y_pred):\n", "def last_time_step_mse(Y_true, Y_pred):\n",
" return keras.metrics.mean_squared_error(Y_true[:, -1], Y_pred[:, -1])\n", " return keras.metrics.mean_squared_error(Y_true[:, -1], Y_pred[:, -1])\n",
"\n", "\n",
"model.compile(loss=\"mse\", optimizer=\"adam\", metrics=[last_time_step_mse])\n", "model.compile(loss=\"mse\", optimizer=keras.optimizers.Adam(lr=0.01), metrics=[last_time_step_mse])\n",
"history = model.fit(X_train, Y_train, epochs=20,\n", "history = model.fit(X_train, Y_train, epochs=20,\n",
" validation_data=(X_valid, Y_valid))" " validation_data=(X_valid, Y_valid))"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 34, "execution_count": 39,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -623,7 +705,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 35, "execution_count": 40,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -640,7 +722,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 36, "execution_count": 41,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -669,7 +751,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 37, "execution_count": 42,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -678,7 +760,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 38, "execution_count": 43,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -704,7 +786,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 39, "execution_count": 44,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -732,7 +814,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 40, "execution_count": 45,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -764,7 +846,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 41, "execution_count": 46,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -792,7 +874,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 42, "execution_count": 47,
"metadata": { "metadata": {
"scrolled": true "scrolled": true
}, },
@ -812,79 +894,6 @@
" validation_data=(X_valid, Y_valid))" " validation_data=(X_valid, Y_valid))"
] ]
}, },
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"model.evaluate(X_valid, Y_valid)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"plot_learning_curves(history.history[\"loss\"], history.history[\"val_loss\"])\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(43)\n",
"\n",
"series = generate_time_series(1, 50 + 10)\n",
"X_new, Y_new = series[:, :50, :], series[:, 50:, :]\n",
"Y_pred = model.predict(X_new)[:, -1][..., np.newaxis]"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"plot_multiple_forecasts(X_new, Y_new, Y_pred)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GRUs"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"\n",
"model = keras.models.Sequential([\n",
" keras.layers.GRU(20, return_sequences=True, input_shape=[None, 1]),\n",
" keras.layers.GRU(20, return_sequences=True),\n",
" keras.layers.TimeDistributed(keras.layers.Dense(10))\n",
"])\n",
"\n",
"model.compile(loss=\"mse\", optimizer=\"adam\", metrics=[last_time_step_mse])\n",
"history = model.fit(X_train, Y_train, epochs=20,\n",
" validation_data=(X_valid, Y_valid))"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 48, "execution_count": 48,
@ -929,6 +938,79 @@
"plt.show()" "plt.show()"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GRUs"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"\n",
"model = keras.models.Sequential([\n",
" keras.layers.GRU(20, return_sequences=True, input_shape=[None, 1]),\n",
" keras.layers.GRU(20, return_sequences=True),\n",
" keras.layers.TimeDistributed(keras.layers.Dense(10))\n",
"])\n",
"\n",
"model.compile(loss=\"mse\", optimizer=\"adam\", metrics=[last_time_step_mse])\n",
"history = model.fit(X_train, Y_train, epochs=20,\n",
" validation_data=(X_valid, Y_valid))"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"model.evaluate(X_valid, Y_valid)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"plot_learning_curves(history.history[\"loss\"], history.history[\"val_loss\"])\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(43)\n",
"\n",
"series = generate_time_series(1, 50 + 10)\n",
"X_new, Y_new = series[:, :50, :], series[:, 50:, :]\n",
"Y_pred = model.predict(X_new)[:, -1][..., np.newaxis]"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"plot_multiple_forecasts(X_new, Y_new, Y_pred)\n",
"plt.show()"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
@ -959,7 +1041,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 52, "execution_count": 57,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1003,7 +1085,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 53, "execution_count": 58,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1030,7 +1112,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 54, "execution_count": 59,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1049,7 +1131,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 55, "execution_count": 60,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1063,7 +1145,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 56, "execution_count": 61,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1091,7 +1173,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 57, "execution_count": 62,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1146,7 +1228,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 58, "execution_count": 63,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1191,7 +1273,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 59, "execution_count": 64,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1208,7 +1290,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 60, "execution_count": 65,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1225,7 +1307,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 61, "execution_count": 66,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1246,7 +1328,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 62, "execution_count": 67,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1260,6 +1342,13 @@
"source": [ "source": [
"To be continued..." "To be continued..."
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {