Add more examples of TF-Addons seq2seq usage

main
Aurélien Geron 2020-04-22 19:21:56 +12:00
parent 0f35b8192f
commit 1f0bbc782a
1 changed files with 353 additions and 41 deletions

View File

@ -1490,7 +1490,7 @@
"])\n",
"optimizer = keras.optimizers.SGD(lr=0.02, momentum = 0.95, nesterov=True)\n",
"model.compile(loss=\"binary_crossentropy\", optimizer=optimizer, metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, y_train, epochs=20, validation_data=[X_valid, y_valid])"
"history = model.fit(X_train, y_train, epochs=20, validation_data=(X_valid, y_valid))"
]
},
{
@ -1742,7 +1742,7 @@
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, Y_train, epochs=20,\n",
" validation_data=[X_valid, Y_valid])"
" validation_data=(X_valid, Y_valid))"
]
},
{
@ -1878,7 +1878,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Second version: feeding the shifted targets to the decoder"
"### Second version: feeding the shifted targets to the decoder (teacher forcing)"
]
},
{
@ -1973,7 +1973,7 @@
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit([X_train, X_train_decoder], Y_train, epochs=10,\n",
" validation_data=[[X_valid, X_valid_decoder], Y_valid])"
" validation_data=([X_valid, X_valid_decoder], Y_valid))"
]
},
{
@ -2062,8 +2062,9 @@
"encoder_embeddings = keras.layers.Embedding(\n",
" len(INPUT_CHARS) + 1, encoder_embedding_size)(encoder_inputs)\n",
"\n",
"decoder_embeddings = keras.layers.Embedding(\n",
" len(INPUT_CHARS) + 2, decoder_embedding_size)(decoder_inputs)\n",
"decoder_embedding_layer = keras.layers.Embedding(\n",
" len(INPUT_CHARS) + 2, decoder_embedding_size)\n",
"decoder_embeddings = decoder_embedding_layer(decoder_inputs)\n",
"\n",
"encoder = keras.layers.LSTM(units, return_state=True)\n",
"encoder_outputs, state_h, state_c = encoder(encoder_embeddings)\n",
@ -2072,7 +2073,7 @@
"sampler = tfa.seq2seq.sampler.TrainingSampler()\n",
"\n",
"decoder_cell = keras.layers.LSTMCell(units)\n",
"output_layer = keras.layers.Dense(len(OUTPUT_CHARS) + 1, activation=\"softmax\")\n",
"output_layer = keras.layers.Dense(len(OUTPUT_CHARS) + 1)\n",
"\n",
"decoder = tfa.seq2seq.basic_decoder.BasicDecoder(decoder_cell,\n",
" sampler,\n",
@ -2080,14 +2081,15 @@
"final_outputs, final_state, final_sequence_lengths = decoder(\n",
" decoder_embeddings,\n",
" initial_state=encoder_state)\n",
"Y_proba = keras.layers.Activation(\"softmax\")(final_outputs.rnn_output)\n",
"\n",
"model = keras.models.Model(inputs=[encoder_inputs, decoder_inputs],\n",
" outputs=[final_outputs.rnn_output])\n",
" outputs=[Y_proba])\n",
"optimizer = keras.optimizers.Nadam()\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit([X_train, X_train_decoder], Y_train, epochs=10,\n",
" validation_data=[[X_valid, X_valid_decoder], Y_valid])"
"history = model.fit([X_train, X_train_decoder], Y_train, epochs=15,\n",
" validation_data=([X_valid, X_valid_decoder], Y_valid))"
]
},
{
@ -2112,7 +2114,283 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Fourth version: using TFA seq2seq, the Keras subclassing API and attention mechanisms"
"However, there's a much more efficient way to perform inference. Until now, during inference, we've run the model once for each new character. Instead, we can create a new decoder, based on the previously trained layers, but using a `GreedyEmbeddingSampler` instead of a `TrainingSampler`.\n",
"\n",
"At each time step, the `GreedyEmbeddingSampler` will compute the argmax of the decoder's outputs, and run the resulting token IDs through the decoder's embedding layer. Then it will feed the resulting embeddings to the decoder's LSTM cell at the next time step. This way, we only need to run the decoder once to get the full prediction."
]
},
{
"cell_type": "code",
"execution_count": 116,
"metadata": {},
"outputs": [],
"source": [
"inference_sampler = tfa.seq2seq.sampler.GreedyEmbeddingSampler(\n",
" embedding_fn=decoder_embedding_layer)\n",
"inference_decoder = tfa.seq2seq.basic_decoder.BasicDecoder(\n",
" decoder_cell, inference_sampler, output_layer=output_layer,\n",
" maximum_iterations=max_output_length)\n",
"batch_size = tf.shape(encoder_inputs)[:1]\n",
"start_tokens = tf.fill(dims=batch_size, value=sos_id)\n",
"final_outputs, final_state, final_sequence_lengths = inference_decoder(\n",
" start_tokens,\n",
" initial_state=encoder_state,\n",
" start_tokens=start_tokens,\n",
" end_token=0)\n",
"\n",
"inference_model = keras.models.Model(inputs=[encoder_inputs],\n",
" outputs=[final_outputs.sample_id])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A few notes:\n",
"* The `GreedyEmbeddingSampler` needs the `start_tokens` (a vector containing the start-of-sequence ID for each decoder sequence), and the `end_token` (the decoder will stop decoding a sequence once the model outputs this token).\n",
"* We must set `maximum_iterations` when creating the `BasicDecoder`, or else it may run into an infinite loop (if the model never outputs the end token for at least one of the sequences). This would force you would to restart the Jupyter kernel.\n",
"* The decoder inputs are not needed anymore, since all the decoder inputs are generated dynamically based on the outputs from the previous time step.\n",
"* The model's outputs are `final_outputs.sample_id` instead of the softmax of `final_outputs.rnn_outputs`. This allows us to directly get the argmax of the model's outputs. If you prefer to have access to the logits, you can replace `final_outputs.sample_id` with `final_outputs.rnn_outputs`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can write a simple function that uses the model to perform the date format conversion:"
]
},
{
"cell_type": "code",
"execution_count": 117,
"metadata": {},
"outputs": [],
"source": [
"def fast_predict_date_strs(date_strs):\n",
" X = prepare_date_strs_padded(date_strs)\n",
" Y_pred = inference_model.predict(X)\n",
" return ids_to_date_strs(Y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 118,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"fast_predict_date_strs([\"July 14, 1789\", \"May 01, 2020\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's check that it really is faster:"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
"outputs": [],
"source": [
"%timeit predict_date_strs([\"July 14, 1789\", \"May 01, 2020\"])"
]
},
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
"outputs": [],
"source": [
"%timeit fast_predict_date_strs([\"July 14, 1789\", \"May 01, 2020\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's more than a 10x speedup! And it would be even more if we were handling longer sequences."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Fourth version: using TF-Addons's seq2seq implementation with Teacher Forcing"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When we trained the previous model, at each time step _t_ we gave the model the target token for time step _t_ - 1. However, at inference time, the model did not get the previous target at each time step. Instead, it got the previous prediction. So there is a discrepancy between training and inference, which may lead to disappointing performance. To alleviate this, we can gradually replace the targets with the predictions, during training. For this, we just need to replace the `TrainingSampler` with a `ScheduledEmbeddingTrainingSampler`, and use a Keras callback to gradually increase the `sampling_probability` (i.e., the probability that the decoder will use the prediction from the previous time step rather than the target for the previous time step)."
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow_addons as tfa\n",
"\n",
"np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"\n",
"n_epochs = 20\n",
"encoder_embedding_size = 32\n",
"decoder_embedding_size = 32\n",
"units = 128\n",
"\n",
"encoder_inputs = keras.layers.Input(shape=[None], dtype=np.int32)\n",
"decoder_inputs = keras.layers.Input(shape=[None], dtype=np.int32)\n",
"sequence_lengths = keras.layers.Input(shape=[], dtype=np.int32)\n",
"\n",
"encoder_embeddings = keras.layers.Embedding(\n",
" len(INPUT_CHARS) + 1, encoder_embedding_size)(encoder_inputs)\n",
"\n",
"decoder_embedding_layer = keras.layers.Embedding(\n",
" len(INPUT_CHARS) + 2, decoder_embedding_size)\n",
"decoder_embeddings = decoder_embedding_layer(decoder_inputs)\n",
"\n",
"encoder = keras.layers.LSTM(units, return_state=True)\n",
"encoder_outputs, state_h, state_c = encoder(encoder_embeddings)\n",
"encoder_state = [state_h, state_c]\n",
"\n",
"sampler = tfa.seq2seq.sampler.ScheduledEmbeddingTrainingSampler(\n",
" sampling_probability=0.,\n",
" embedding_fn=decoder_embedding_layer)\n",
"# we must set the sampling_probability after creating the sampler\n",
"# (see https://github.com/tensorflow/addons/pull/1714)\n",
"sampler.sampling_probability = tf.Variable(0.)\n",
"\n",
"decoder_cell = keras.layers.LSTMCell(units)\n",
"output_layer = keras.layers.Dense(len(OUTPUT_CHARS) + 1)\n",
"\n",
"decoder = tfa.seq2seq.basic_decoder.BasicDecoder(decoder_cell,\n",
" sampler,\n",
" output_layer=output_layer)\n",
"final_outputs, final_state, final_sequence_lengths = decoder(\n",
" decoder_embeddings,\n",
" initial_state=encoder_state)\n",
"Y_proba = keras.layers.Activation(\"softmax\")(final_outputs.rnn_output)\n",
"\n",
"model = keras.models.Model(inputs=[encoder_inputs, decoder_inputs],\n",
" outputs=[Y_proba])\n",
"optimizer = keras.optimizers.Nadam()\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"\n",
"def update_sampling_probability(epoch, logs):\n",
" proba = min(1.0, epoch / (n_epochs - 10))\n",
" sampler.sampling_probability.assign(proba)\n",
"\n",
"sampling_probability_cb = keras.callbacks.LambdaCallback(\n",
" on_epoch_begin=update_sampling_probability)\n",
"history = model.fit([X_train, X_train_decoder], Y_train, epochs=n_epochs,\n",
" validation_data=([X_valid, X_valid_decoder], Y_valid),\n",
" callbacks=[sampling_probability_cb])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Not quite 100% validation accuracy, but close enough!"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For inference, we could do the exact same thing as earlier, using a `GreedyEmbeddingSampler`. However, just for the sake of completeness, let's use a `SampleEmbeddingSampler` instead. It's almost the same thing, except that instead of using the argmax of the model's output to find the token ID, it treats the outputs as logits and uses them to sample a token ID randomly. This can be useful when you want to generate text. The `softmax_temperature` argument serves the \n",
"same purpose as when we generated Shakespeare-like text (the higher this argument, the more random the generated text will be)."
]
},
{
"cell_type": "code",
"execution_count": 122,
"metadata": {},
"outputs": [],
"source": [
"softmax_temperature = tf.Variable(1.)\n",
"\n",
"inference_sampler = tfa.seq2seq.sampler.SampleEmbeddingSampler(\n",
" embedding_fn=decoder_embedding_layer,\n",
" softmax_temperature=softmax_temperature)\n",
"inference_decoder = tfa.seq2seq.basic_decoder.BasicDecoder(\n",
" decoder_cell, inference_sampler, output_layer=output_layer,\n",
" maximum_iterations=max_output_length)\n",
"batch_size = tf.shape(encoder_inputs)[:1]\n",
"start_tokens = tf.fill(dims=batch_size, value=sos_id)\n",
"final_outputs, final_state, final_sequence_lengths = inference_decoder(\n",
" start_tokens,\n",
" initial_state=encoder_state,\n",
" start_tokens=start_tokens,\n",
" end_token=0)\n",
"\n",
"inference_model = keras.models.Model(inputs=[encoder_inputs],\n",
" outputs=[final_outputs.sample_id])"
]
},
{
"cell_type": "code",
"execution_count": 123,
"metadata": {},
"outputs": [],
"source": [
"def creative_predict_date_strs(date_strs, temperature=1.0):\n",
" softmax_temperature.assign(temperature)\n",
" X = prepare_date_strs_padded(date_strs)\n",
" Y_pred = inference_model.predict(X)\n",
" return ids_to_date_strs(Y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 124,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42)\n",
"\n",
"creative_predict_date_strs([\"July 14, 1789\", \"May 01, 2020\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Dates look good at room temperature. Now let's heat things up a bit:"
]
},
{
"cell_type": "code",
"execution_count": 125,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42)\n",
"\n",
"creative_predict_date_strs([\"July 14, 1789\", \"May 01, 2020\"],\n",
" temperature=5.)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Oops, the dates are overcooked, now. Let's call them \"creative\" dates."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Fifth version: using TFA seq2seq, the Keras subclassing API and attention mechanisms"
]
},
{
@ -2121,19 +2399,19 @@
"source": [
"The sequences in this problem are pretty short, but if we wanted to tackle longer sequences, we would probably have to use attention mechanisms. While it's possible to code our own implementation, it's simpler and more efficient to use TF-Addons's implementation instead. Let's do that now, this time using Keras' subclassing API.\n",
"\n",
"**Warning**: due to a TensorFlow bug (see [this issue](https://github.com/tensorflow/addons/issues/1153) for details), the `get_initial_state()` method fails in eager mode, so it needs to be wrapped in a tf.function (which the subclassing API does automatically), until this issue is resolved."
"**Warning**: due to a TensorFlow bug (see [this issue](https://github.com/tensorflow/addons/issues/1153) for details), the `get_initial_state()` method fails in eager mode, so for now we have to use the subclassing API, as Keras automatically calls `tf.function()` on the `call()` method (so it runs in graph mode)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this implementation, the encoder is identical to the previous one. We will also see how to handle masks. This would be especially useful if the sequence lengths varied significantly."
"In this implementation, we've reverted back to using the `TrainingSampler`, for simplicity (but you can easily tweak it to use a `ScheduledEmbeddingTrainingSampler` instead). We also use a `GreedyEmbeddingSampler` during inference, so this class is pretty easy to use:"
]
},
{
"cell_type": "code",
"execution_count": 116,
"execution_count": 126,
"metadata": {},
"outputs": [],
"source": [
@ -2143,59 +2421,67 @@
" super().__init__(**kwargs)\n",
" self.encoder_embedding = keras.layers.Embedding(\n",
" input_dim=len(INPUT_CHARS) + 1,\n",
" output_dim=encoder_embedding_size,\n",
" mask_zero=True)\n",
" output_dim=encoder_embedding_size)\n",
" self.encoder = keras.layers.LSTM(units,\n",
" return_sequences=True,\n",
" return_state=True)\n",
" self.decoder_embedding = keras.layers.Embedding(\n",
" input_dim=len(OUTPUT_CHARS) + 2,\n",
" output_dim=decoder_embedding_size,\n",
" mask_zero=True)\n",
" output_dim=decoder_embedding_size)\n",
" self.attention = tfa.seq2seq.LuongAttention(units)\n",
" decoder_inner_cell = keras.layers.LSTMCell(units)\n",
" self.decoder_cell = tfa.seq2seq.AttentionWrapper(\n",
" cell=decoder_inner_cell,\n",
" attention_mechanism=self.attention)\n",
" output_layer = keras.layers.Dense(len(OUTPUT_CHARS) + 1)\n",
" self.decoder = tfa.seq2seq.BasicDecoder(\n",
" cell=self.decoder_cell,\n",
" sampler=tfa.seq2seq.sampler.TrainingSampler(),\n",
" output_layer=keras.layers.Dense(len(OUTPUT_CHARS) + 1,\n",
" activation=\"softmax\"))\n",
" output_layer=output_layer)\n",
" self.inference_decoder = tfa.seq2seq.BasicDecoder(\n",
" cell=self.decoder_cell,\n",
" sampler=tfa.seq2seq.sampler.GreedyEmbeddingSampler(\n",
" embedding_fn=self.decoder_embedding),\n",
" output_layer=output_layer,\n",
" maximum_iterations=max_output_length)\n",
"\n",
" def call(self, inputs, training=None):\n",
" encoder_input, decoder_input = inputs\n",
" encoder_embeddings = self.encoder_embedding(encoder_input)\n",
" encoder_mask = self.encoder_embedding.compute_mask(encoder_input)\n",
" encoder_outputs, encoder_state_h, encoder_state_c = self.encoder(\n",
" encoder_embeddings,\n",
" mask=encoder_mask,\n",
" training=training)\n",
" encoder_state = [encoder_state_h, encoder_state_c]\n",
"\n",
" self.attention(encoder_outputs,\n",
" memory_mask=encoder_mask,\n",
" setup_memory=True)\n",
" \n",
" decoder_embeddings = self.decoder_embedding(decoder_input)\n",
"\n",
" decoder_initial_state = self.decoder_cell.get_initial_state(\n",
" decoder_embeddings)\n",
" decoder_initial_state = decoder_initial_state.clone(\n",
" cell_state=encoder_state)\n",
" \n",
" decoder_mask = self.decoder_embedding.compute_mask(decoder_input)\n",
" if training:\n",
" decoder_outputs, _, _ = self.decoder(\n",
" decoder_embeddings,\n",
" initial_state=decoder_initial_state,\n",
" training=training,\n",
" mask=decoder_mask)\n",
" training=training)\n",
" else:\n",
" start_tokens = tf.zeros_like(encoder_input[:, 0]) + sos_id\n",
" decoder_outputs, _, _ = self.inference_decoder(\n",
" decoder_embeddings,\n",
" initial_state=decoder_initial_state,\n",
" start_tokens=start_tokens,\n",
" end_token=0)\n",
"\n",
" return decoder_outputs.rnn_output"
" return tf.nn.softmax(decoder_outputs.rnn_output)"
]
},
{
"cell_type": "code",
"execution_count": 117,
"execution_count": 127,
"metadata": {},
"outputs": [],
"source": [
@ -2207,23 +2493,49 @@
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit([X_train, X_train_decoder], Y_train, epochs=25,\n",
" validation_data=[[X_valid, X_valid_decoder], Y_valid])"
" validation_data=([X_valid, X_valid_decoder], Y_valid))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And there we go, 100% validation accuracy. It took a bit longer to converge this time, but there were also more parameters and more computations per iteration. To use the model, we can once again just reuse the `predict_date_strs()` function:"
"Not quite 100% validation accuracy, but close. It took a bit longer to converge this time, but there were also more parameters and more computations per iteration. And we did not use a scheduled sampler.\n",
"\n",
"To use the model, we can write yet another little function:"
]
},
{
"cell_type": "code",
"execution_count": 118,
"execution_count": 128,
"metadata": {},
"outputs": [],
"source": [
"predict_date_strs([\"July 14, 1789\", \"May 01, 2020\"])"
"def fast_predict_date_strs_v2(date_strs):\n",
" X = prepare_date_strs_padded(date_strs)\n",
" X_decoder = tf.zeros(shape=(len(X), max_output_length), dtype=tf.int32)\n",
" Y_probas = model.predict([X, X_decoder])\n",
" Y_pred = tf.argmax(Y_probas, axis=-1)\n",
" return ids_to_date_strs(Y_pred)"
]
},
{
"cell_type": "code",
"execution_count": 129,
"metadata": {},
"outputs": [],
"source": [
"fast_predict_date_strs_v2([\"July 14, 1789\", \"May 01, 2020\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There are still a few interesting features from TF-Addons that you may want to look at:\n",
"* Using a `BeamSearchDecoder` rather than a `BasicDecoder` for inference. Instead of outputing the character with the highest probability, this decoder keeps track of the several candidates, and keeps only the most likely sequences of candidates (see chapter 16 in the book for more details).\n",
"* Setting masks or specifying `sequence_length` if the input or target sequences may have very different lengths.\n",
"* Using a `ScheduledOutputTrainingSampler`, which gives you more flexibility than the `ScheduledEmbeddingTrainingSampler` to decide how to feed the output at time _t_ to the cell at time _t_+1. By default it feeds the outputs directly to cell, without computing the argmax ID and passing it through an embedding layer. Alternatively, you specify a `next_inputs_fn` function that will be used to convert the cell outputs to inputs at the next step."
]
},
{
@ -2265,7 +2577,7 @@
},
{
"cell_type": "code",
"execution_count": 119,
"execution_count": 130,
"metadata": {},
"outputs": [],
"source": [
@ -2283,7 +2595,7 @@
},
{
"cell_type": "code",
"execution_count": 120,
"execution_count": 131,
"metadata": {},
"outputs": [],
"source": [
@ -2301,7 +2613,7 @@
},
{
"cell_type": "code",
"execution_count": 121,
"execution_count": 132,
"metadata": {},
"outputs": [],
"source": [
@ -2321,7 +2633,7 @@
},
{
"cell_type": "code",
"execution_count": 122,
"execution_count": 133,
"metadata": {},
"outputs": [],
"source": [
@ -2351,7 +2663,7 @@
},
{
"cell_type": "code",
"execution_count": 123,
"execution_count": 134,
"metadata": {},
"outputs": [],
"source": [