To get GPU acceleration with GRUs, do not use recurrent_dropout

main
Aurélien Geron 2021-03-11 15:07:23 +13:00
parent d423e2254e
commit b46db0e1a1
1 changed files with 14 additions and 5 deletions

View File

@ -367,7 +367,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"**Warning**: the `predict_classes()` method is deprecated. Instead, we must use `np.argmax(model.predict(X_new), axis=-1)`." "**Warning**: the `predict_classes()` method is deprecated. Instead, we must use `np.argmax(model(X_new), axis=-1)`."
] ]
}, },
{ {
@ -378,7 +378,7 @@
"source": [ "source": [
"X_new = preprocess([\"How are yo\"])\n", "X_new = preprocess([\"How are yo\"])\n",
"#Y_pred = model.predict_classes(X_new)\n", "#Y_pred = model.predict_classes(X_new)\n",
"Y_pred = np.argmax(model.predict(X_new), axis=-1)\n", "Y_pred = np.argmax(model(X_new), axis=-1)\n",
"tokenizer.sequences_to_texts(Y_pred + 1)[0][-1] # 1st sentence, last char" "tokenizer.sequences_to_texts(Y_pred + 1)[0][-1] # 1st sentence, last char"
] ]
}, },
@ -401,7 +401,7 @@
"source": [ "source": [
"def next_char(text, temperature=1):\n", "def next_char(text, temperature=1):\n",
" X_new = preprocess([text])\n", " X_new = preprocess([text])\n",
" y_proba = model.predict(X_new)[0, -1:, :]\n", " y_proba = model(X_new)[0, -1:, :]\n",
" rescaled_logits = tf.math.log(y_proba) / temperature\n", " rescaled_logits = tf.math.log(y_proba) / temperature\n",
" char_id = tf.random.categorical(rescaled_logits, num_samples=1) + 1\n", " char_id = tf.random.categorical(rescaled_logits, num_samples=1) + 1\n",
" return tokenizer.sequences_to_texts(char_id.numpy())[0]" " return tokenizer.sequences_to_texts(char_id.numpy())[0]"
@ -512,6 +512,13 @@
"dataset = dataset.prefetch(1)" "dataset = dataset.prefetch(1)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note**: once again, I commented out `recurrent_dropout=0.2` (compared to the book) so you can get GPU acceleration (if you have one)."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 31, "execution_count": 31,
@ -520,10 +527,12 @@
"source": [ "source": [
"model = keras.models.Sequential([\n", "model = keras.models.Sequential([\n",
" keras.layers.GRU(128, return_sequences=True, stateful=True,\n", " keras.layers.GRU(128, return_sequences=True, stateful=True,\n",
" dropout=0.2, recurrent_dropout=0.2,\n", " #dropout=0.2, recurrent_dropout=0.2,\n",
" dropout=0.2,\n",
" batch_input_shape=[batch_size, None, max_id]),\n", " batch_input_shape=[batch_size, None, max_id]),\n",
" keras.layers.GRU(128, return_sequences=True, stateful=True,\n", " keras.layers.GRU(128, return_sequences=True, stateful=True,\n",
" dropout=0.2, recurrent_dropout=0.2),\n", " #dropout=0.2, recurrent_dropout=0.2),\n",
" dropout=0.2),\n",
" keras.layers.TimeDistributed(keras.layers.Dense(max_id,\n", " keras.layers.TimeDistributed(keras.layers.Dense(max_id,\n",
" activation=\"softmax\"))\n", " activation=\"softmax\"))\n",
"])" "])"