2016-09-27 23:31:21 +02:00
{
"cells": [
{
"cell_type": "markdown",
2017-10-05 13:22:06 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 10:09:28 +01:00
"**Chapter 16 – Natural Language Processing with RNNs and Attention**"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-10-05 13:22:06 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 10:09:28 +01:00
"_This notebook contains all the sample code and solutions to the exercises in chapter 16._"
2016-09-27 23:31:21 +02:00
]
},
2019-11-06 14:06:55 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
2022-02-19 10:09:28 +01:00
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/16_nlp_with_rnns_and_attention.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
2019-11-06 14:06:55 +01:00
" </td>\n",
2021-05-25 05:15:46 +02:00
" <td>\n",
2022-02-19 10:09:28 +01:00
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/16_nlp_with_rnns_and_attention.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
2021-05-25 05:15:46 +02:00
" </td>\n",
2019-11-06 14:06:55 +01:00
"</table>"
]
},
2022-02-19 10:09:28 +01:00
{
"cell_type": "markdown",
"metadata": {
"id": "dFXIv9qNpKzt",
"tags": []
},
2016-09-27 23:31:21 +02:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
2022-02-19 10:09:28 +01:00
"metadata": {
"id": "8IPbJEmZpKzu"
},
2016-09-27 23:31:21 +02:00
"source": [
2022-02-19 11:03:20 +01:00
"This project requires Python 3.7 or above:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2022-03-21 09:43:01 +01:00
"execution_count": 1,
2022-02-19 10:09:28 +01:00
"metadata": {
"id": "TFSU3FCOpKzu"
},
2019-04-05 11:04:38 +02:00
"outputs": [],
"source": [
"import sys\n",
"\n",
2022-02-19 11:03:20 +01:00
"assert sys.version_info >= (3, 7)"
2022-02-19 10:09:28 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GJtVEqxfpKzw"
},
"source": [
2022-02-28 23:41:27 +01:00
"And TensorFlow ≥ 2.8:"
2022-02-19 10:09:28 +01:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 2,
2022-02-19 10:09:28 +01:00
"metadata": {
"id": "0Piq5se2pKzx"
},
"outputs": [],
"source": [
2022-09-22 09:14:01 +02:00
"from packaging import version\n",
2019-04-05 11:04:38 +02:00
"import tensorflow as tf\n",
"\n",
2022-09-22 09:14:01 +02:00
"assert version.parse(tf.__version__) >= version.parse(\"2.8.0\")"
2022-02-19 10:09:28 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DDaDoLQTpKzx"
},
"source": [
"As we did in earlier chapters, let's define the default font sizes to make the figures prettier:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 3,
2022-02-19 10:09:28 +01:00
"metadata": {
"id": "8d4TH3NbpKzx"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
2019-11-06 14:06:55 +01:00
"\n",
2022-02-19 10:09:28 +01:00
"plt.rc('font', size=14)\n",
"plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n",
"plt.rc('xtick', labelsize=10)\n",
"plt.rc('ytick', labelsize=10)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RcoUIRsvpKzy"
},
"source": [
"And let's create the `images/nlp` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 4,
2022-02-19 10:09:28 +01:00
"metadata": {
"id": "PQFH5Y9PpKzy"
},
"outputs": [],
"source": [
2021-10-15 10:46:27 +02:00
"from pathlib import Path\n",
2019-04-05 11:04:38 +02:00
"\n",
2021-10-15 10:46:27 +02:00
"IMAGES_PATH = Path() / \"images\" / \"nlp\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
2019-04-05 11:04:38 +02:00
"\n",
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
2021-10-15 10:46:27 +02:00
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
2019-04-05 11:04:38 +02:00
" if tight_layout:\n",
" plt.tight_layout()\n",
" plt.savefig(path, format=fig_extension, dpi=resolution)"
]
},
2022-02-19 10:09:28 +01:00
{
"cell_type": "markdown",
"metadata": {
"id": "YTsawKlapKzy"
},
"source": [
"This chapter can be very slow without a GPU, so let's make sure there's one, or else issue a warning:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 5,
2022-02-19 10:09:28 +01:00
"metadata": {
"id": "Ekxzo6pOpKzy"
},
"outputs": [],
"source": [
"if not tf.config.list_physical_devices('GPU'):\n",
" print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n",
" if \"google.colab\" in sys.modules:\n",
" print(\"Go to Runtime > Change runtime and select a GPU hardware \"\n",
" \"accelerator.\")\n",
" if \"kaggle_secrets\" in sys.modules:\n",
" print(\"Go to Settings > Accelerator and select GPU.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"# Generating Shakespearean Text Using a Character RNN"
2022-02-19 10:09:28 +01:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
2022-02-19 10:09:28 +01:00
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"## Creating the Training Dataset"
2022-02-19 10:09:28 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"Let's download the Shakespeare data from Andrej Karpathy's [char-rnn project](https://github.com/karpathy/char-rnn/)"
2022-02-19 10:09:28 +01:00
]
},
2019-04-05 11:04:38 +02:00
{
2022-03-21 09:43:01 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 6,
2019-04-05 11:04:38 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading data from https://homl.info/shakespeare\n",
"1122304/1115394 [==============================] - 0s 0us/step\n",
"1130496/1115394 [==============================] - 0s 0us/step\n"
]
}
],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"import tensorflow as tf\n",
"\n",
"shakespeare_url = \"https://homl.info/shakespeare\" # shortcut URL\n",
"filepath = tf.keras.utils.get_file(\"shakespeare.txt\", shakespeare_url)\n",
"with open(filepath) as f:\n",
" shakespeare_text = f.read()"
2019-04-15 18:09:10 +02:00
]
},
2019-04-05 11:04:38 +02:00
{
2022-03-21 09:43:01 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 7,
2019-04-05 11:04:38 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"First Citizen:\n",
"Before we proceed any further, hear me speak.\n",
"\n",
"All:\n",
"Speak, speak.\n"
]
}
],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"# extra code – shows a short text sample\n",
"print(shakespeare_text[:80])"
2019-04-05 11:04:38 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 8,
2019-04-05 11:04:38 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"\"\\n !$&',-.3:;?abcdefghijklmnopqrstuvwxyz\""
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 8,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"# extra code – shows all 39 distinct characters (after converting to lower case)\n",
"\"\".join(sorted(set(shakespeare_text.lower())))"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 9,
2022-03-21 09:43:01 +01:00
"metadata": {},
2019-04-05 11:04:38 +02:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"text_vec_layer = tf.keras.layers.TextVectorization(split=\"character\",\n",
" standardize=\"lower\")\n",
"text_vec_layer.adapt([shakespeare_text])\n",
"encoded = text_vec_layer([shakespeare_text])[0]"
2019-04-05 11:04:38 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 10,
2019-04-05 11:04:38 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"encoded -= 2 # drop tokens 0 (pad) and 1 (unknown), which we will not use\n",
"n_tokens = text_vec_layer.vocabulary_size() - 2 # number of distinct chars = 39\n",
"dataset_size = len(encoded) # total number of chars = 1,115,394"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 11,
2019-04-05 11:04:38 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"39"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 11,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"n_tokens"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 12,
2019-04-05 11:04:38 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"1115394"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 12,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"dataset_size"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 13,
2019-04-05 11:04:38 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"def to_dataset(sequence, length, shuffle=False, seed=None, batch_size=32):\n",
" ds = tf.data.Dataset.from_tensor_slices(sequence)\n",
" ds = ds.window(length + 1, shift=1, drop_remainder=True)\n",
" ds = ds.flat_map(lambda window_ds: window_ds.batch(length + 1))\n",
" if shuffle:\n",
" ds = ds.shuffle(100_000, seed=seed)\n",
" ds = ds.batch(batch_size)\n",
" return ds.map(lambda window: (window[:, :-1], window[:, 1:])).prefetch(1)"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 14,
2019-04-05 11:04:38 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[(<tf.Tensor: shape=(1, 4), dtype=int64, numpy=array([[ 4, 5, 2, 23]])>,\n",
" <tf.Tensor: shape=(1, 4), dtype=int64, numpy=array([[ 5, 2, 23, 3]])>)]"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 14,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"# extra code – a simple example using to_dataset()\n",
"# There's just one sample in this dataset: the input represents \"to b\" and the\n",
"# output represents \"o be\"\n",
"list(to_dataset(text_vec_layer([\"To be\"])[0], length=4))"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 15,
2019-04-05 11:04:38 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"length = 100\n",
"tf.random.set_seed(42)\n",
"train_set = to_dataset(encoded[:1_000_000], length=length, shuffle=True,\n",
" seed=42)\n",
"valid_set = to_dataset(encoded[1_000_000:1_060_000], length=length)\n",
"test_set = to_dataset(encoded[1_060_000:], length=length)"
2019-04-05 11:04:38 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
2019-04-05 11:04:38 +02:00
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"## Building and Training the Char-RNN Model"
2019-04-05 11:04:38 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
"metadata": {},
"source": [
"**Warning**: the following code may one or two hours to run, depending on your GPU. Without a GPU, it may take over 24 hours. If you don't want to wait, just skip the next two code cells and run the code below to download a pretrained model."
]
},
{
"cell_type": "markdown",
2019-04-05 11:04:38 +02:00
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"**Note**: the `GRU` class will only use cuDNN acceleration (assuming you have a GPU) when using the default values for the following arguments: `activation`, `recurrent_activation`, `recurrent_dropout`, `unroll`, `use_bias` and `reset_after`."
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 16,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"31247/31247 [==============================] - 1407s 45ms/step - loss: 1.3873 - accuracy: 0.5754 - val_loss: 1.6155 - val_accuracy: 0.5333\n",
"Epoch 2/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"31247/31247 [==============================] - 1376s 44ms/step - loss: 1.2921 - accuracy: 0.5973 - val_loss: 1.5881 - val_accuracy: 0.5401\n",
"Epoch 3/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"31247/31247 [==============================] - 1379s 44ms/step - loss: 1.2743 - accuracy: 0.6015 - val_loss: 1.5885 - val_accuracy: 0.5407\n",
"Epoch 4/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"31247/31247 [==============================] - 1381s 44ms/step - loss: 1.2654 - accuracy: 0.6031 - val_loss: 1.5701 - val_accuracy: 0.5418\n",
"Epoch 5/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"31247/31247 [==============================] - 1379s 44ms/step - loss: 1.2594 - accuracy: 0.6045 - val_loss: 1.5674 - val_accuracy: 0.5450\n",
"Epoch 6/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"31247/31247 [==============================] - 1386s 44ms/step - loss: 1.2545 - accuracy: 0.6058 - val_loss: 1.5587 - val_accuracy: 0.5492\n",
"Epoch 7/10\n",
"31247/31247 [==============================] - 1381s 44ms/step - loss: 1.2514 - accuracy: 0.6062 - val_loss: 1.5532 - val_accuracy: 0.5460\n",
"Epoch 8/10\n",
"31247/31247 [==============================] - 1381s 44ms/step - loss: 1.2485 - accuracy: 0.6067 - val_loss: 1.5522 - val_accuracy: 0.5479\n",
"Epoch 9/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"31247/31247 [==============================] - 1382s 44ms/step - loss: 1.2460 - accuracy: 0.6073 - val_loss: 1.5521 - val_accuracy: 0.5497\n",
"Epoch 10/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"31247/31247 [==============================] - 1385s 44ms/step - loss: 1.2436 - accuracy: 0.6080 - val_loss: 1.5477 - val_accuracy: 0.5513\n"
]
}
],
"source": [
"tf.random.set_seed(42) # extra code – ensures reproducibility on CPU\n",
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Embedding(input_dim=n_tokens, output_dim=16),\n",
" tf.keras.layers.GRU(128, return_sequences=True),\n",
" tf.keras.layers.Dense(n_tokens, activation=\"softmax\")\n",
"])\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"nadam\",\n",
" metrics=[\"accuracy\"])\n",
"model_ckpt = tf.keras.callbacks.ModelCheckpoint(\n",
" \"my_shakespeare_model\", monitor=\"val_accuracy\", save_best_only=True)\n",
"history = model.fit(train_set, validation_data=valid_set, epochs=10,\n",
" callbacks=[model_ckpt])"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 17,
2019-04-05 11:04:38 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"shakespeare_model = tf.keras.Sequential([\n",
" text_vec_layer,\n",
" tf.keras.layers.Lambda(lambda X: X - 2), # no <PAD> or <UNK> tokens\n",
" model\n",
"])"
2019-04-05 11:04:38 +02:00
]
},
2021-05-26 05:41:56 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"If you don't want to wait for training to complete, I've pretrained a model for you. The following code will download it. Uncomment the last line if you want to use it instead of the model trained above."
2021-05-26 05:41:56 +02:00
]
},
2019-04-05 11:04:38 +02:00
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 18,
2019-04-15 18:09:10 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"# extra code – downloads a pretrained model\n",
"url = \"https://github.com/ageron/data/raw/main/shakespeare_model.tgz\"\n",
"path = tf.keras.utils.get_file(\"shakespeare_model.tgz\", url, extract=True)\n",
"model_path = Path(path).with_name(\"shakespeare_model\")\n",
"#shakespeare_model = tf.keras.models.load_model(model_path)"
2019-04-15 18:09:10 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 19,
2019-04-05 11:04:38 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"'e'"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 19,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"y_proba = shakespeare_model.predict([\"To be or not to b\"])[0, -1]\n",
"y_pred = tf.argmax(y_proba) # choose the most probable character ID\n",
"text_vec_layer.get_vocabulary()[y_pred + 2]"
2019-04-05 11:04:38 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
2019-04-05 11:04:38 +02:00
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"## Generating Fake Shakespearean Text"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 20,
2017-10-27 16:19:15 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(1, 8), dtype=int64, numpy=array([[0, 1, 0, 2, 1, 0, 0, 1]])>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 20,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"log_probas = tf.math.log([[0.5, 0.4, 0.1]]) # probas = 50%, 40%, and 10%\n",
"tf.random.set_seed(42)\n",
"tf.random.categorical(log_probas, num_samples=8) # draw 8 samples"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 21,
2019-04-05 11:04:38 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"def next_char(text, temperature=1):\n",
" y_proba = shakespeare_model.predict([text])[0, -1:]\n",
" rescaled_logits = tf.math.log(y_proba) / temperature\n",
" char_id = tf.random.categorical(rescaled_logits, num_samples=1)[0, 0]\n",
" return text_vec_layer.get_vocabulary()[char_id + 2]"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 22,
2019-04-05 11:04:38 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"def extend_text(text, n_chars=50, temperature=1):\n",
" for _ in range(n_chars):\n",
" text += next_char(text, temperature)\n",
" return text"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 23,
2019-04-05 11:04:38 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"tf.random.set_seed(42) # extra code – ensures reproducibility on CPU"
2019-04-05 11:04:38 +02:00
]
},
2019-04-16 14:39:14 +02:00
{
2022-03-21 09:43:01 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 24,
2019-04-16 14:39:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"To be or not to be the duke\n",
"as it is a proper strange death,\n",
"and the\n"
]
}
],
2019-04-16 14:39:14 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"print(extend_text(\"To be or not to be\", temperature=0.01))"
2019-04-16 14:39:14 +02:00
]
},
2021-02-16 03:04:34 +01:00
{
2022-03-21 09:43:01 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 25,
2021-02-16 03:04:34 +01:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"To be or not to behold?\n",
"\n",
"second push:\n",
"gremio, lord all, a sistermen,\n"
]
}
],
2021-02-16 03:04:34 +01:00
"source": [
2022-03-21 09:43:01 +01:00
"print(extend_text(\"To be or not to be\", temperature=1))"
2021-02-16 03:04:34 +01:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 26,
2021-02-16 03:04:34 +01:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"To be or not to bef ,mt'&o3fpadm!$\n",
"wh!nse?bws3est--vgerdjw?c-y-ewznq\n"
]
}
],
2021-02-16 03:04:34 +01:00
"source": [
2022-03-21 09:43:01 +01:00
"print(extend_text(\"To be or not to be\", temperature=100))"
2021-02-16 03:04:34 +01:00
]
},
2019-04-05 11:04:38 +02:00
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
2019-04-05 11:04:38 +02:00
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"## Stateful RNN"
2019-04-05 11:04:38 +02:00
]
},
2019-04-16 14:39:14 +02:00
{
2022-03-21 09:43:01 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 27,
2019-04-16 14:39:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [],
2019-04-16 14:39:14 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"def to_dataset_for_stateful_rnn(sequence, length):\n",
" ds = tf.data.Dataset.from_tensor_slices(sequence)\n",
" ds = ds.window(length + 1, shift=length, drop_remainder=True)\n",
" ds = ds.flat_map(lambda window: window.batch(length + 1)).batch(1)\n",
" return ds.map(lambda window: (window[:, :-1], window[:, 1:])).prefetch(1)\n",
"\n",
"stateful_train_set = to_dataset_for_stateful_rnn(encoded[:1_000_000], length)\n",
2022-09-12 01:48:12 +02:00
"stateful_valid_set = to_dataset_for_stateful_rnn(encoded[1_000_000:1_060_000],\n",
" length)\n",
2022-03-21 09:43:01 +01:00
"stateful_test_set = to_dataset_for_stateful_rnn(encoded[1_060_000:], length)"
2019-04-16 14:39:14 +02:00
]
},
2019-04-05 11:04:38 +02:00
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 28,
2019-04-05 11:04:38 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[(<tf.Tensor: shape=(1, 3), dtype=int32, numpy=array([[0, 1, 2]], dtype=int32)>,\n",
" <tf.Tensor: shape=(1, 3), dtype=int32, numpy=array([[1, 2, 3]], dtype=int32)>),\n",
" (<tf.Tensor: shape=(1, 3), dtype=int32, numpy=array([[3, 4, 5]], dtype=int32)>,\n",
" <tf.Tensor: shape=(1, 3), dtype=int32, numpy=array([[4, 5, 6]], dtype=int32)>),\n",
" (<tf.Tensor: shape=(1, 3), dtype=int32, numpy=array([[6, 7, 8]], dtype=int32)>,\n",
" <tf.Tensor: shape=(1, 3), dtype=int32, numpy=array([[7, 8, 9]], dtype=int32)>)]"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 28,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"# extra code – simple example using to_dataset_for_stateful_rnn()\n",
"list(to_dataset_for_stateful_rnn(tf.range(10), 3))"
2019-04-05 11:04:38 +02:00
]
},
2021-02-16 03:04:34 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"If you'd like to have more than one window per batch, you can use the `to_batched_dataset_for_stateful_rnn()` function instead of `to_dataset_for_stateful_rnn()`:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 29,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[(<tf.Tensor: shape=(2, 3), dtype=int32, numpy=\n",
" array([[ 0, 1, 2],\n",
" [10, 11, 12]], dtype=int32)>,\n",
" <tf.Tensor: shape=(2, 3), dtype=int32, numpy=\n",
" array([[ 1, 2, 3],\n",
" [11, 12, 13]], dtype=int32)>),\n",
" (<tf.Tensor: shape=(2, 3), dtype=int32, numpy=\n",
" array([[ 3, 4, 5],\n",
" [13, 14, 15]], dtype=int32)>,\n",
" <tf.Tensor: shape=(2, 3), dtype=int32, numpy=\n",
" array([[ 4, 5, 6],\n",
" [14, 15, 16]], dtype=int32)>),\n",
" (<tf.Tensor: shape=(2, 3), dtype=int32, numpy=\n",
" array([[ 6, 7, 8],\n",
" [16, 17, 18]], dtype=int32)>,\n",
" <tf.Tensor: shape=(2, 3), dtype=int32, numpy=\n",
" array([[ 7, 8, 9],\n",
" [17, 18, 19]], dtype=int32)>)]"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 29,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code – shows one way to prepare a batched dataset for a stateful RNN\n",
"\n",
"import numpy as np\n",
"\n",
"def to_non_overlapping_windows(sequence, length):\n",
" ds = tf.data.Dataset.from_tensor_slices(sequence)\n",
" ds = ds.window(length + 1, shift=length, drop_remainder=True)\n",
" return ds.flat_map(lambda window: window.batch(length + 1))\n",
"\n",
"def to_batched_dataset_for_stateful_rnn(sequence, length, batch_size=32):\n",
" parts = np.array_split(sequence, batch_size)\n",
" datasets = tuple(to_non_overlapping_windows(part, length) for part in parts)\n",
" ds = tf.data.Dataset.zip(datasets).map(lambda *windows: tf.stack(windows))\n",
" return ds.map(lambda window: (window[:, :-1], window[:, 1:])).prefetch(1)\n",
"\n",
"list(to_batched_dataset_for_stateful_rnn(tf.range(20), length=3, batch_size=2))"
2021-02-16 03:04:34 +01:00
]
},
2019-04-05 11:04:38 +02:00
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 30,
2019-04-05 11:04:38 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"tf.random.set_seed(42) # extra code – ensures reproducibility on CPU\n",
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Embedding(input_dim=n_tokens, output_dim=16,\n",
" batch_input_shape=[1, None]),\n",
" tf.keras.layers.GRU(128, return_sequences=True, stateful=True),\n",
" tf.keras.layers.Dense(n_tokens, activation=\"softmax\")\n",
"])"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 31,
2019-04-05 11:04:38 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"class ResetStatesCallback(tf.keras.callbacks.Callback):\n",
" def on_epoch_begin(self, epoch, logs):\n",
" self.model.reset_states()"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 32,
2019-04-05 11:04:38 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"# extra code – use a different directory to save the checkpoints\n",
"model_ckpt = tf.keras.callbacks.ModelCheckpoint(\n",
" \"my_stateful_shakespeare_model\",\n",
" monitor=\"val_accuracy\",\n",
" save_best_only=True)"
2019-04-05 11:04:38 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
2019-04-05 11:04:38 +02:00
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"**Warning**: the following cell will take a while to run (possibly an hour if you are not using a GPU)."
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 33,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_stateful_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"9999/9999 [==============================] - 213s 21ms/step - loss: 1.8690 - accuracy: 0.4494 - val_loss: 1.7632 - val_accuracy: 0.4672\n",
"Epoch 2/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_stateful_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"9999/9999 [==============================] - 211s 21ms/step - loss: 1.5635 - accuracy: 0.5284 - val_loss: 1.6334 - val_accuracy: 0.4994\n",
"Epoch 3/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_stateful_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"9999/9999 [==============================] - 209s 21ms/step - loss: 1.4875 - accuracy: 0.5478 - val_loss: 1.5788 - val_accuracy: 0.5153\n",
"Epoch 4/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_stateful_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"9999/9999 [==============================] - 208s 21ms/step - loss: 1.4483 - accuracy: 0.5579 - val_loss: 1.5471 - val_accuracy: 0.5236\n",
"Epoch 5/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_stateful_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"9999/9999 [==============================] - 213s 21ms/step - loss: 1.4241 - accuracy: 0.5643 - val_loss: 1.5270 - val_accuracy: 0.5286\n",
"Epoch 6/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_stateful_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"9999/9999 [==============================] - 215s 21ms/step - loss: 1.4074 - accuracy: 0.5686 - val_loss: 1.5109 - val_accuracy: 0.5338\n",
"Epoch 7/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_stateful_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"9999/9999 [==============================] - 210s 21ms/step - loss: 1.3953 - accuracy: 0.5714 - val_loss: 1.5008 - val_accuracy: 0.5361\n",
"Epoch 8/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_stateful_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"9999/9999 [==============================] - 212s 21ms/step - loss: 1.3863 - accuracy: 0.5737 - val_loss: 1.4938 - val_accuracy: 0.5381\n",
"Epoch 9/10\n",
"9999/9999 [==============================] - 207s 21ms/step - loss: 1.3790 - accuracy: 0.5757 - val_loss: 1.4890 - val_accuracy: 0.5380\n",
"Epoch 10/10\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:tensorflow:Assets written to: my_stateful_shakespeare_model/assets\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"9999/9999 [==============================] - 208s 21ms/step - loss: 1.3729 - accuracy: 0.5770 - val_loss: 1.4786 - val_accuracy: 0.5420\n"
]
}
],
"source": [
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"nadam\",\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(stateful_train_set, validation_data=stateful_valid_set,\n",
" epochs=10, callbacks=[ResetStatesCallback(), model_ckpt])"
2019-04-05 11:04:38 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
2019-04-05 11:04:38 +02:00
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"**Extra Material: converting the stateful RNN to a stateless RNN and using it**"
2019-04-05 11:04:38 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
2019-04-05 11:04:38 +02:00
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"To use the model with different batch sizes, we need to create a stateless copy:"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 34,
2019-04-05 11:04:38 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"stateless_model = tf.keras.Sequential([\n",
" tf.keras.layers.Embedding(input_dim=n_tokens, output_dim=16),\n",
" tf.keras.layers.GRU(128, return_sequences=True),\n",
" tf.keras.layers.Dense(n_tokens, activation=\"softmax\")\n",
"])"
2019-04-05 11:04:38 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
2019-04-05 11:04:38 +02:00
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"To set the weights, we first need to build the model (so the weights get created):"
2019-04-05 11:04:38 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 35,
2019-04-05 11:04:38 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"stateless_model.build(tf.TensorShape([None, None]))"
2019-04-05 11:04:38 +02:00
]
},
{
2019-04-15 18:09:10 +02:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 36,
2019-04-05 11:04:38 +02:00
"metadata": {},
2019-04-15 18:09:10 +02:00
"outputs": [],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"stateless_model.set_weights(model.get_weights())"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 37,
2019-04-05 11:04:38 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"shakespeare_model = tf.keras.Sequential([\n",
" text_vec_layer,\n",
" tf.keras.layers.Lambda(lambda X: X - 2), # no <PAD> or <UNK> tokens\n",
" stateless_model\n",
"])"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 38,
2019-04-05 11:04:38 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"to be or not to be so in the world and the strangeness\n",
"to see the wo\n"
]
}
],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"tf.random.set_seed(42)\n",
"\n",
"print(extend_text(\"to be or not to be\", temperature=0.01))"
2019-04-05 11:04:38 +02:00
]
},
2021-03-11 03:07:23 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"# Sentiment Analysis"
2021-03-11 03:07:23 +01:00
]
},
2019-04-05 11:04:38 +02:00
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 39,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1mDownloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /home/ageron/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...\u001b[0m\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "055c0f544ac349d9a14da8f843651df0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Dl Completed...: 0 url [00:00, ? url/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e2abc244f4844d56919979b33cc2fa79",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Dl Size...: 0 MiB [00:00, ? MiB/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "af507eed124c4ff6900538205b1b00fd",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating splits...: 0%| | 0/3 [00:00<?, ? splits/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "18cd596aa97b46f1aa3f93d0c29edd59",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating train examples...: 0%| | 0/25000 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7c623038199e46909b7a8b0a39cecbab",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Shuffling /home/ageron/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incomplete0WPKUH/imdb_reviews-train.t…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b2c0a46cc37b4eb6b9feb67d715d7022",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating test examples...: 0%| | 0/25000 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4abb656c416049c085e0f2f761d5bf9c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Shuffling /home/ageron/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incomplete0WPKUH/imdb_reviews-test.tf…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "edb7ceb384634b8ebd766e55ba21c5d4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating unsupervised examples...: 0%| | 0/50000 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ad80e1205d5e4914840999fcd3ae3b88",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Shuffling /home/ageron/tensorflow_datasets/imdb_reviews/plain_text/1.0.0.incomplete0WPKUH/imdb_reviews-unsuper…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1mDataset imdb_reviews downloaded and prepared to /home/ageron/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.\u001b[0m\n"
]
}
],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"import tensorflow_datasets as tfds\n",
"\n",
"raw_train_set, raw_valid_set, raw_test_set = tfds.load(\n",
" name=\"imdb_reviews\",\n",
" split=[\"train[:90%]\", \"train[90%:]\", \"test\"],\n",
" as_supervised=True\n",
")\n",
"tf.random.set_seed(42)\n",
"train_set = raw_train_set.shuffle(5000, seed=42).batch(32).prefetch(1)\n",
"valid_set = raw_valid_set.batch(32).prefetch(1)\n",
"test_set = raw_test_set.batch(32).prefetch(1)"
2019-04-05 11:04:38 +02:00
]
},
{
2019-04-15 18:09:10 +02:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 40,
2019-04-05 11:04:38 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting ...\n",
"Label: 0\n",
"I have been known to fall asleep during films, but this is usually due to a combination of things including, really tired, being warm and comfortable on the sette and having just eaten a lot. However ...\n",
"Label: 0\n",
"Mann photographs the Alberta Rocky Mountains in a superb fashion, and Jimmy Stewart and Walter Brennan give enjoyable performances as they always seem to do. <br /><br />But come on Hollywood - a Moun ...\n",
"Label: 0\n",
"This is the kind of film for a snowy Sunday afternoon when the rest of the world can go ahead with its own business as you descend into a big arm-chair and mellow for a couple of hours. Wonderful perf ...\n",
"Label: 1\n"
]
}
],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"for review, label in raw_train_set.take(4):\n",
" print(review.numpy().decode(\"utf-8\")[:200], \"...\")\n",
" print(\"Label:\", label.numpy())"
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 41,
2019-04-05 11:04:38 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"vocab_size = 1000\n",
"text_vec_layer = tf.keras.layers.TextVectorization(max_tokens=vocab_size)\n",
"text_vec_layer.adapt(train_set.map(lambda reviews, labels: reviews))"
2019-04-05 11:04:38 +02:00
]
},
{
2019-04-15 18:09:10 +02:00
"cell_type": "markdown",
2019-04-05 11:04:38 +02:00
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"**Warning**: the following cell will take a few minutes to run and the model will probably not learn anything because we didn't mask the padding tokens (that's the point of the next section)."
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 42,
2019-04-05 11:04:38 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"704/704 [==============================] - 255s 359ms/step - loss: 0.6934 - accuracy: 0.4990 - val_loss: 0.6931 - val_accuracy: 0.5016\n",
"Epoch 2/2\n",
"704/704 [==============================] - 250s 355ms/step - loss: 0.6934 - accuracy: 0.5042 - val_loss: 0.6942 - val_accuracy: 0.5008\n"
]
}
],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"embed_size = 128\n",
"tf.random.set_seed(42)\n",
"model = tf.keras.Sequential([\n",
" text_vec_layer,\n",
" tf.keras.layers.Embedding(vocab_size, embed_size),\n",
" tf.keras.layers.GRU(128),\n",
" tf.keras.layers.Dense(1, activation=\"sigmoid\")\n",
"])\n",
"model.compile(loss=\"binary_crossentropy\", optimizer=\"nadam\",\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(train_set, validation_data=valid_set, epochs=2)"
2019-04-05 11:04:38 +02:00
]
},
{
2019-04-15 18:09:10 +02:00
"cell_type": "markdown",
2019-04-05 11:04:38 +02:00
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"## Masking"
2019-04-05 11:04:38 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
2019-04-05 11:04:38 +02:00
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"**Warning**: the following cell will take a while to run (possibly 30 minutes if you are not using a GPU)."
2019-04-05 11:04:38 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 43,
2019-04-05 11:04:38 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"704/704 [==============================] - 303s 426ms/step - loss: 0.5296 - accuracy: 0.7234 - val_loss: 0.4045 - val_accuracy: 0.8244\n",
"Epoch 2/5\n",
"704/704 [==============================] - 295s 419ms/step - loss: 0.3702 - accuracy: 0.8418 - val_loss: 0.3390 - val_accuracy: 0.8532\n",
"Epoch 3/5\n",
"704/704 [==============================] - 298s 423ms/step - loss: 0.3057 - accuracy: 0.8747 - val_loss: 0.3196 - val_accuracy: 0.8696\n",
"Epoch 4/5\n",
"704/704 [==============================] - 294s 418ms/step - loss: 0.2784 - accuracy: 0.8871 - val_loss: 0.3162 - val_accuracy: 0.8596\n",
"Epoch 5/5\n",
"704/704 [==============================] - 293s 417ms/step - loss: 0.2597 - accuracy: 0.8961 - val_loss: 0.3209 - val_accuracy: 0.8548\n"
]
}
],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"embed_size = 128\n",
2019-04-05 11:04:38 +02:00
"tf.random.set_seed(42)\n",
2022-03-21 09:43:01 +01:00
"model = tf.keras.Sequential([\n",
" text_vec_layer,\n",
" tf.keras.layers.Embedding(vocab_size, embed_size, mask_zero=True),\n",
" tf.keras.layers.GRU(128),\n",
" tf.keras.layers.Dense(1, activation=\"sigmoid\")\n",
"])\n",
"model.compile(loss=\"binary_crossentropy\", optimizer=\"nadam\",\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(train_set, validation_data=valid_set, epochs=5)"
2019-04-05 11:04:38 +02:00
]
},
{
2019-04-15 18:09:10 +02:00
"cell_type": "markdown",
2019-04-05 11:04:38 +02:00
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"Or using manual masking:"
2016-09-27 23:31:21 +02:00
]
},
2019-04-05 11:04:38 +02:00
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 44,
2019-04-05 11:04:38 +02:00
"metadata": {},
"outputs": [],
2019-04-15 18:09:10 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"tf.random.set_seed(42) # extra code – ensures reproducibility on the CPU\n",
"inputs = tf.keras.layers.Input(shape=[], dtype=tf.string)\n",
"token_ids = text_vec_layer(inputs)\n",
"mask = tf.math.not_equal(token_ids, 0)\n",
"Z = tf.keras.layers.Embedding(vocab_size, embed_size)(token_ids)\n",
"Z = tf.keras.layers.GRU(128, dropout=0.2)(Z, mask=mask)\n",
"outputs = tf.keras.layers.Dense(1, activation=\"sigmoid\")(Z)\n",
"model = tf.keras.Model(inputs=[inputs], outputs=[outputs])"
2019-04-15 18:09:10 +02:00
]
2019-04-05 11:04:38 +02:00
},
{
2019-04-15 18:09:10 +02:00
"cell_type": "markdown",
2019-04-05 11:04:38 +02:00
"metadata": {},
2019-04-15 18:09:10 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"**Warning**: the following cell will take a while to run (possibly 30 minutes if you are not using a GPU)."
2019-04-15 18:09:10 +02:00
]
2019-04-05 11:04:38 +02:00
},
2016-09-27 23:31:21 +02:00
{
2019-04-15 18:09:10 +02:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 45,
2017-10-05 13:22:06 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"704/704 [==============================] - 303s 427ms/step - loss: 0.5447 - accuracy: 0.7198 - val_loss: 0.4604 - val_accuracy: 0.7720\n",
"Epoch 2/5\n",
"704/704 [==============================] - 301s 427ms/step - loss: 0.3469 - accuracy: 0.8512 - val_loss: 0.3214 - val_accuracy: 0.8608\n",
"Epoch 3/5\n",
"704/704 [==============================] - 295s 419ms/step - loss: 0.3054 - accuracy: 0.8713 - val_loss: 0.3069 - val_accuracy: 0.8672\n",
"Epoch 4/5\n",
"704/704 [==============================] - 295s 420ms/step - loss: 0.2798 - accuracy: 0.8828 - val_loss: 0.3028 - val_accuracy: 0.8672\n",
"Epoch 5/5\n",
"704/704 [==============================] - 298s 423ms/step - loss: 0.2622 - accuracy: 0.8920 - val_loss: 0.2953 - val_accuracy: 0.8700\n"
]
}
],
2016-09-27 23:31:21 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"# extra code – compiles and trains the model, as usual\n",
"model.compile(loss=\"binary_crossentropy\", optimizer=\"nadam\",\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(train_set, validation_data=valid_set, epochs=5)"
2016-09-27 23:31:21 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
2017-10-27 16:19:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"**Extra material: using ragged tensors**"
2016-09-27 23:31:21 +02:00
]
},
{
2019-04-05 11:04:38 +02:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 46,
2017-10-05 13:22:06 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.RaggedTensor [[86, 18], [11, 7, 1, 116, 217]]>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 46,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"text_vec_layer_ragged = tf.keras.layers.TextVectorization(\n",
" max_tokens=vocab_size, ragged=True)\n",
"text_vec_layer_ragged.adapt(train_set.map(lambda reviews, labels: reviews))\n",
"text_vec_layer_ragged([\"Great movie!\", \"This is DiCaprio's best role.\"])"
2016-09-27 23:31:21 +02:00
]
},
{
2019-04-05 11:04:38 +02:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 47,
2019-04-05 11:04:38 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(2, 5), dtype=int64, numpy=\n",
"array([[ 86, 18, 0, 0, 0],\n",
" [ 11, 7, 1, 116, 217]])>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 47,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2019-04-05 11:04:38 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"text_vec_layer([\"Great movie!\", \"This is DiCaprio's best role.\"])"
2019-04-05 11:04:38 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
2017-10-05 13:22:06 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"**Warning**: the following cell will take a while to run (possibly 30 minutes if you are not using a GPU)."
2016-09-27 23:31:21 +02:00
]
},
2019-04-05 11:04:38 +02:00
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 48,
2019-04-05 11:04:38 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-04-16 07:30:08 +02:00
"Epoch 1/5\n",
2022-03-21 09:43:01 +01:00
"704/704 [==============================] - 280s 395ms/step - loss: 0.5038 - accuracy: 0.7496 - val_loss: 0.6706 - val_accuracy: 0.6752\n",
"Epoch 2/5\n",
"704/704 [==============================] - 277s 393ms/step - loss: 0.4499 - accuracy: 0.7892 - val_loss: 0.3494 - val_accuracy: 0.8500\n",
"Epoch 3/5\n",
"704/704 [==============================] - 276s 392ms/step - loss: 0.3270 - accuracy: 0.8592 - val_loss: 0.3855 - val_accuracy: 0.8260\n",
"Epoch 4/5\n",
"704/704 [==============================] - 277s 394ms/step - loss: 0.2935 - accuracy: 0.8760 - val_loss: 0.3401 - val_accuracy: 0.8520\n",
"Epoch 5/5\n",
"704/704 [==============================] - 275s 390ms/step - loss: 0.2742 - accuracy: 0.8854 - val_loss: 0.3971 - val_accuracy: 0.8208\n"
]
}
],
2019-04-15 18:09:10 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"embed_size = 128\n",
"tf.random.set_seed(42)\n",
"model = tf.keras.Sequential([\n",
" text_vec_layer_ragged,\n",
" tf.keras.layers.Embedding(vocab_size, embed_size),\n",
" tf.keras.layers.GRU(128),\n",
" tf.keras.layers.Dense(1, activation=\"sigmoid\")\n",
"])\n",
"model.compile(loss=\"binary_crossentropy\", optimizer=\"nadam\",\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(train_set, validation_data=valid_set, epochs=5)"
2019-04-15 18:09:10 +02:00
]
2019-04-05 11:04:38 +02:00
},
2016-09-27 23:31:21 +02:00
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
2017-10-27 16:19:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"## Reusing Pretrained Embeddings and Language Models"
2016-09-27 23:31:21 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
2017-10-27 16:19:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"**Warning**: the following cell will take a while to run (possibly an hour if you are not using a GPU)."
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 49,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"704/704 [==============================] - 224s 303ms/step - loss: 0.3141 - accuracy: 0.8648 - val_loss: 0.2397 - val_accuracy: 0.9008\n",
"Epoch 2/10\n",
"704/704 [==============================] - 205s 291ms/step - loss: 0.0489 - accuracy: 0.9852 - val_loss: 0.3257 - val_accuracy: 0.8936\n",
"Epoch 3/10\n",
"704/704 [==============================] - 204s 290ms/step - loss: 0.0061 - accuracy: 0.9988 - val_loss: 0.3963 - val_accuracy: 0.8944\n",
"Epoch 4/10\n",
"704/704 [==============================] - 204s 290ms/step - loss: 9.4918e-04 - accuracy: 0.9999 - val_loss: 0.4291 - val_accuracy: 0.8924\n",
"Epoch 5/10\n",
"704/704 [==============================] - 203s 289ms/step - loss: 5.1920e-04 - accuracy: 1.0000 - val_loss: 0.4691 - val_accuracy: 0.8932\n",
"Epoch 6/10\n",
"704/704 [==============================] - 204s 289ms/step - loss: 5.0053e-04 - accuracy: 1.0000 - val_loss: 0.4687 - val_accuracy: 0.8912\n",
"Epoch 7/10\n",
"704/704 [==============================] - 208s 296ms/step - loss: 3.7360e-04 - accuracy: 1.0000 - val_loss: 0.5034 - val_accuracy: 0.8984\n",
"Epoch 8/10\n",
"704/704 [==============================] - 209s 297ms/step - loss: 2.3907e-05 - accuracy: 1.0000 - val_loss: 0.5773 - val_accuracy: 0.8924\n",
"Epoch 9/10\n",
"704/704 [==============================] - 204s 290ms/step - loss: 9.0970e-06 - accuracy: 1.0000 - val_loss: 0.6163 - val_accuracy: 0.8972\n",
"Epoch 10/10\n",
"704/704 [==============================] - 205s 291ms/step - loss: 5.2528e-06 - accuracy: 1.0000 - val_loss: 0.6455 - val_accuracy: 0.8956\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f89897f6d30>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 49,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import os\n",
"import tensorflow_hub as hub\n",
"\n",
"os.environ[\"TFHUB_CACHE_DIR\"] = \"my_tfhub_cache\"\n",
"tf.random.set_seed(42) # extra code – ensures reproducibility on CPU\n",
"model = tf.keras.Sequential([\n",
" hub.KerasLayer(\"https://tfhub.dev/google/universal-sentence-encoder/4\",\n",
" trainable=True, dtype=tf.string, input_shape=[]),\n",
" tf.keras.layers.Dense(64, activation=\"relu\"),\n",
" tf.keras.layers.Dense(1, activation=\"sigmoid\")\n",
"])\n",
"model.compile(loss=\"binary_crossentropy\", optimizer=\"nadam\",\n",
" metrics=[\"accuracy\"])\n",
"model.fit(train_set, validation_data=valid_set, epochs=10)"
2016-09-27 23:31:21 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
2017-10-05 13:22:06 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"# An Encoder– Decoder Network for Neural Machine Translation"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 50,
2017-10-05 13:22:06 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"url = \"https://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip\"\n",
"path = tf.keras.utils.get_file(\"spa-eng.zip\", origin=url, cache_dir=\"datasets\",\n",
" extract=True)\n",
"text = (Path(path).with_name(\"spa-eng\") / \"spa.txt\").read_text()"
2016-09-27 23:31:21 +02:00
]
},
{
2019-04-15 18:09:10 +02:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 51,
2017-10-05 13:22:06 +02:00
"metadata": {},
2019-04-15 18:09:10 +02:00
"outputs": [],
2016-09-27 23:31:21 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"import numpy as np\n",
2019-04-15 18:09:10 +02:00
"\n",
2022-03-21 09:43:01 +01:00
"text = text.replace(\"¡\", \"\").replace(\"¿\", \"\")\n",
"pairs = [line.split(\"\\t\") for line in text.splitlines()]\n",
"np.random.seed(42) # extra code – ensures reproducibility on CPU\n",
"np.random.shuffle(pairs)\n",
"sentences_en, sentences_es = zip(*pairs) # separates the pairs into 2 lists"
2016-09-27 23:31:21 +02:00
]
},
2018-12-25 14:54:14 +01:00
{
2019-04-15 18:09:10 +02:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 52,
2018-12-25 14:54:14 +01:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"How boring! => Qué aburrimiento!\n",
"I love sports. => Adoro el deporte.\n",
"Would you like to swap jobs? => Te gustaría que intercambiemos los trabajos?\n"
]
}
],
2018-12-25 14:54:14 +01:00
"source": [
2022-03-21 09:43:01 +01:00
"for i in range(3):\n",
" print(sentences_en[i], \"=>\", sentences_es[i])"
2018-12-25 14:54:14 +01:00
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 53,
2017-10-27 16:19:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"vocab_size = 1000\n",
"max_length = 50\n",
"text_vec_layer_en = tf.keras.layers.TextVectorization(\n",
" vocab_size, output_sequence_length=max_length)\n",
"text_vec_layer_es = tf.keras.layers.TextVectorization(\n",
" vocab_size, output_sequence_length=max_length)\n",
"text_vec_layer_en.adapt(sentences_en)\n",
"text_vec_layer_es.adapt([f\"startofseq {s} endofseq\" for s in sentences_es])"
2017-06-05 19:18:20 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 54,
2017-10-27 16:19:15 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"['', '[UNK]', 'the', 'i', 'to', 'you', 'tom', 'a', 'is', 'he']"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 54,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-05 19:18:20 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"text_vec_layer_en.get_vocabulary()[:10]"
2017-06-05 19:18:20 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 55,
2017-10-27 16:19:15 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"['', '[UNK]', 'startofseq', 'endofseq', 'de', 'que', 'a', 'no', 'tom', 'la']"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 55,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2017-06-05 19:18:20 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"text_vec_layer_es.get_vocabulary()[:10]"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 56,
2017-10-27 16:19:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"X_train = tf.constant(sentences_en[:100_000])\n",
"X_valid = tf.constant(sentences_en[100_000:])\n",
"X_train_dec = tf.constant([f\"startofseq {s}\" for s in sentences_es[:100_000]])\n",
"X_valid_dec = tf.constant([f\"startofseq {s}\" for s in sentences_es[100_000:]])\n",
"Y_train = text_vec_layer_es([f\"{s} endofseq\" for s in sentences_es[:100_000]])\n",
"Y_valid = text_vec_layer_es([f\"{s} endofseq\" for s in sentences_es[100_000:]])"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 57,
2017-10-05 13:22:06 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"tf.random.set_seed(42) # extra code – ensures reproducibility on CPU\n",
"encoder_inputs = tf.keras.layers.Input(shape=[], dtype=tf.string)\n",
"decoder_inputs = tf.keras.layers.Input(shape=[], dtype=tf.string)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 58,
2017-10-05 13:22:06 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"embed_size = 128\n",
"encoder_input_ids = text_vec_layer_en(encoder_inputs)\n",
"decoder_input_ids = text_vec_layer_es(decoder_inputs)\n",
"encoder_embedding_layer = tf.keras.layers.Embedding(vocab_size, embed_size,\n",
" mask_zero=True)\n",
"decoder_embedding_layer = tf.keras.layers.Embedding(vocab_size, embed_size,\n",
" mask_zero=True)\n",
"encoder_embeddings = encoder_embedding_layer(encoder_input_ids)\n",
"decoder_embeddings = decoder_embedding_layer(decoder_input_ids)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 59,
2017-10-27 16:19:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"encoder = tf.keras.layers.LSTM(512, return_state=True)\n",
"encoder_outputs, *encoder_state = encoder(encoder_embeddings)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 60,
2017-10-05 13:22:06 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"decoder = tf.keras.layers.LSTM(512, return_sequences=True)\n",
"decoder_outputs = decoder(decoder_embeddings, initial_state=encoder_state)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 61,
2017-10-27 16:19:15 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"output_layer = tf.keras.layers.Dense(vocab_size, activation=\"softmax\")\n",
"Y_proba = output_layer(decoder_outputs)"
2019-04-15 18:09:10 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"**Warning**: the following cell will take a while to run (possibly a couple hours if you are not using a GPU)."
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 62,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"3125/3125 [==============================] - 698s 221ms/step - loss: 0.4154 - accuracy: 0.4256 - val_loss: 0.3069 - val_accuracy: 0.5246\n",
"Epoch 2/10\n",
"3125/3125 [==============================] - 686s 219ms/step - loss: 0.2631 - accuracy: 0.5745 - val_loss: 0.2367 - val_accuracy: 0.6055\n",
"Epoch 3/10\n",
"3125/3125 [==============================] - 686s 220ms/step - loss: 0.2066 - accuracy: 0.6457 - val_loss: 0.2061 - val_accuracy: 0.6500\n",
"Epoch 4/10\n",
"3125/3125 [==============================] - 682s 218ms/step - loss: 0.1740 - accuracy: 0.6907 - val_loss: 0.1920 - val_accuracy: 0.6691\n",
"Epoch 5/10\n",
"3125/3125 [==============================] - 676s 216ms/step - loss: 0.1507 - accuracy: 0.7237 - val_loss: 0.1865 - val_accuracy: 0.6767\n",
"Epoch 6/10\n",
"3125/3125 [==============================] - 675s 216ms/step - loss: 0.1316 - accuracy: 0.7522 - val_loss: 0.1847 - val_accuracy: 0.6804\n",
"Epoch 7/10\n",
"3125/3125 [==============================] - 675s 216ms/step - loss: 0.1154 - accuracy: 0.7774 - val_loss: 0.1866 - val_accuracy: 0.6822\n",
"Epoch 8/10\n",
"3125/3125 [==============================] - 673s 215ms/step - loss: 0.1011 - accuracy: 0.8007 - val_loss: 0.1907 - val_accuracy: 0.6829\n",
"Epoch 9/10\n",
"3125/3125 [==============================] - 673s 215ms/step - loss: 0.0888 - accuracy: 0.8215 - val_loss: 0.1961 - val_accuracy: 0.6792\n",
"Epoch 10/10\n",
"3125/3125 [==============================] - 673s 215ms/step - loss: 0.0782 - accuracy: 0.8402 - val_loss: 0.2027 - val_accuracy: 0.6763\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f897878ac10>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 62,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = tf.keras.Model(inputs=[encoder_inputs, decoder_inputs],\n",
" outputs=[Y_proba])\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"nadam\",\n",
" metrics=[\"accuracy\"])\n",
"model.fit((X_train, X_train_dec), Y_train, epochs=10,\n",
" validation_data=((X_valid, X_valid_dec), Y_valid))"
2017-06-05 19:18:20 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 63,
2017-10-27 16:19:15 +02:00
"metadata": {},
2017-06-05 19:18:20 +02:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"def translate(sentence_en):\n",
" translation = \"\"\n",
" for word_idx in range(max_length):\n",
" X = np.array([sentence_en]) # encoder input \n",
" X_dec = np.array([\"startofseq \" + translation]) # decoder input\n",
" y_proba = model.predict((X, X_dec))[0, word_idx] # last token's probas\n",
" predicted_word_id = np.argmax(y_proba)\n",
" predicted_word = text_vec_layer_es.get_vocabulary()[predicted_word_id]\n",
" if predicted_word == \"endofseq\":\n",
" break\n",
" translation += \" \" + predicted_word\n",
" return translation.strip()"
2017-06-05 19:18:20 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 64,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'me gusta el fútbol'"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 64,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"translate(\"I like soccer\")"
]
},
{
"cell_type": "markdown",
2017-10-27 16:19:15 +02:00
"metadata": {},
2017-06-05 19:18:20 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"Nice! However, the model struggles with longer sentences:"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 65,
2017-10-27 16:19:15 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"'me gusta el fútbol y a veces mismo al bus'"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 65,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"translate(\"I like soccer and also going to the beach\")"
2016-09-27 23:31:21 +02:00
]
},
2021-10-15 10:46:27 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"## Bidirectional RNNs"
2021-10-15 10:46:27 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
2021-10-15 10:46:27 +02:00
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"To create a bidirectional recurrent layer, just wrap a regular recurrent layer in a `Bidirectional` layer:"
2021-10-15 10:46:27 +02:00
]
},
2016-09-27 23:31:21 +02:00
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 66,
2017-10-05 13:22:06 +02:00
"metadata": {},
2017-06-07 17:52:59 +02:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"tf.random.set_seed(42) # extra code – ensures reproducibility on CPU\n",
"encoder = tf.keras.layers.Bidirectional(\n",
" tf.keras.layers.LSTM(256, return_state=True))"
2017-06-07 17:52:59 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 67,
2017-10-05 13:22:06 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"encoder_outputs, *encoder_state = encoder(encoder_embeddings)\n",
"encoder_state = [tf.concat(encoder_state[::2], axis=-1), # short-term (0 & 2)\n",
" tf.concat(encoder_state[1::2], axis=-1)] # long-term (1 & 3)"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "markdown",
2017-10-05 13:22:06 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"**Warning**: the following cell will take a while to run (possibly a couple hours if you are not using a GPU)."
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 68,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"3125/3125 [==============================] - 574s 181ms/step - loss: 0.3075 - accuracy: 0.5393 - val_loss: 0.2192 - val_accuracy: 0.6319\n",
"Epoch 2/10\n",
"3125/3125 [==============================] - 564s 180ms/step - loss: 0.1916 - accuracy: 0.6689 - val_loss: 0.1880 - val_accuracy: 0.6731\n",
"Epoch 3/10\n",
"3125/3125 [==============================] - 566s 181ms/step - loss: 0.1602 - accuracy: 0.7119 - val_loss: 0.1751 - val_accuracy: 0.6916\n",
"Epoch 4/10\n",
"3125/3125 [==============================] - 566s 181ms/step - loss: 0.1395 - accuracy: 0.7415 - val_loss: 0.1715 - val_accuracy: 0.6979\n",
"Epoch 5/10\n",
"3125/3125 [==============================] - 566s 181ms/step - loss: 0.1227 - accuracy: 0.7666 - val_loss: 0.1707 - val_accuracy: 0.7025\n",
"Epoch 6/10\n",
"3125/3125 [==============================] - 567s 181ms/step - loss: 0.1085 - accuracy: 0.7887 - val_loss: 0.1730 - val_accuracy: 0.6995\n",
"Epoch 7/10\n",
"3125/3125 [==============================] - 571s 183ms/step - loss: 0.0961 - accuracy: 0.8089 - val_loss: 0.1764 - val_accuracy: 0.7000\n",
"Epoch 8/10\n",
"3125/3125 [==============================] - 567s 181ms/step - loss: 0.0852 - accuracy: 0.8273 - val_loss: 0.1821 - val_accuracy: 0.6981\n",
"Epoch 9/10\n",
"3125/3125 [==============================] - 565s 181ms/step - loss: 0.0759 - accuracy: 0.8438 - val_loss: 0.1881 - val_accuracy: 0.6956\n",
"Epoch 10/10\n",
"3125/3125 [==============================] - 565s 181ms/step - loss: 0.0682 - accuracy: 0.8577 - val_loss: 0.1951 - val_accuracy: 0.6906\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f892d2d5fa0>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 68,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code — completes the model and trains it\n",
"decoder = tf.keras.layers.LSTM(512, return_sequences=True)\n",
"decoder_outputs = decoder(decoder_embeddings, initial_state=encoder_state)\n",
"output_layer = tf.keras.layers.Dense(vocab_size, activation=\"softmax\")\n",
"Y_proba = output_layer(decoder_outputs)\n",
"model = tf.keras.Model(inputs=[encoder_inputs, decoder_inputs],\n",
" outputs=[Y_proba])\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"nadam\",\n",
" metrics=[\"accuracy\"])\n",
"model.fit((X_train, X_train_dec), Y_train, epochs=10,\n",
" validation_data=((X_valid, X_valid_dec), Y_valid))"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 69,
2017-10-27 16:19:15 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"'me gusta el fútbol'"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 69,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2016-09-27 23:31:21 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"translate(\"I like soccer\")"
2017-06-05 19:18:20 +02:00
]
},
{
2022-03-21 09:43:01 +01:00
"cell_type": "markdown",
2017-10-27 16:19:15 +02:00
"metadata": {},
2017-06-05 19:18:20 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"## Beam Search"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is a very basic implementation of beam search. I tried to make it readable and understandable, but it's definitely not optimized for speed! The function first uses the model to find the top _k_ words to start the translations (where _k_ is the beam width). For each of the top _k_ translations, it evaluates the conditional probabilities of all possible words it could add to that translation. These extended translations and their probabilities are added to the list of candidates. Once we've gone through all top _k_ translations and all words that could complete them, we keep only the top _k_ candidates with the highest probability, and we iterate over and over until they all finish with an EOS token. The top translation is then returned (after removing its EOS token).\n",
"\n",
"* Note: If p(S) is the probability of sentence S, and p(W|S) is the conditional probability of the word W given that the translation starts with S, then the probability of the sentence S' = concat(S, W) is p(S') = p(S) * p(W|S). As we add more words, the probability gets smaller and smaller. To avoid the risk of it getting too small, which could cause floating point precision errors, the function keeps track of log probabilities instead of probabilities: recall that log(a\\*b) = log(a) + log(b), therefore log(p(S')) = log(p(S)) + log(p(W|S))."
2017-06-05 19:18:20 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 70,
2018-08-06 17:27:32 +02:00
"metadata": {},
2016-11-24 17:23:11 +01:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"# extra code – a basic implementation of beam search\n",
2016-11-24 17:23:11 +01:00
"\n",
2022-03-21 09:43:01 +01:00
"def beam_search(sentence_en, beam_width, verbose=False):\n",
" X = np.array([sentence_en]) # encoder input\n",
" X_dec = np.array([\"startofseq\"]) # decoder input\n",
" y_proba = model.predict((X, X_dec))[0, 0] # first token's probas\n",
" top_k = tf.math.top_k(y_proba, k=beam_width)\n",
" top_translations = [ # list of best (log_proba, translation)\n",
" (np.log(word_proba), text_vec_layer_es.get_vocabulary()[word_id])\n",
" for word_proba, word_id in zip(top_k.values, top_k.indices)\n",
" ]\n",
" \n",
" # extra code – displays the top first words in verbose mode\n",
" if verbose:\n",
" print(\"Top first words:\", top_translations)\n",
2016-11-24 17:23:11 +01:00
"\n",
2022-03-21 09:43:01 +01:00
" for idx in range(1, max_length):\n",
" candidates = []\n",
" for log_proba, translation in top_translations:\n",
" if translation.endswith(\"endofseq\"):\n",
" candidates.append((log_proba, translation))\n",
" continue # translation is finished, so don't try to extend it\n",
" X = np.array([sentence_en]) # encoder input\n",
" X_dec = np.array([\"startofseq \" + translation]) # decoder input\n",
" y_proba = model.predict((X, X_dec))[0, idx] # last token's proba\n",
" for word_id, word_proba in enumerate(y_proba):\n",
" word = text_vec_layer_es.get_vocabulary()[word_id]\n",
" candidates.append((log_proba + np.log(word_proba),\n",
" f\"{translation} {word}\"))\n",
" top_translations = sorted(candidates, reverse=True)[:beam_width]\n",
2019-04-15 18:09:10 +02:00
"\n",
2022-03-21 09:43:01 +01:00
" # extra code – displays the top translation so far in verbose mode\n",
" if verbose:\n",
" print(\"Top translations so far:\", top_translations)\n",
2019-04-15 18:09:10 +02:00
"\n",
2022-03-21 09:43:01 +01:00
" if all([tr.endswith(\"endofseq\") for _, tr in top_translations]):\n",
" return top_translations[0][1].replace(\"endofseq\", \"\").strip()"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 71,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'me [UNK] los gatos y los gatos'"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 71,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code – shows how the model making an error\n",
"sentence_en = \"I love cats and dogs\"\n",
"translate(sentence_en)"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 72,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Top first words: [(-0.012974381, 'me'), (-4.592527, '[UNK]'), (-6.314033, 'yo')]\n",
"Top translations so far: [(-0.4831518, 'me [UNK]'), (-1.4920667, 'me encanta'), (-1.986235, 'me gustan')]\n",
"Top translations so far: [(-0.6793061, 'me [UNK] los'), (-1.9889652, 'me gustan los'), (-2.0470557, 'me encanta los')]\n",
"Top translations so far: [(-0.7609749, 'me [UNK] los gatos'), (-2.0677316, 'me gustan los gatos'), (-2.26029, 'me encanta los gatos')]\n",
"Top translations so far: [(-0.76985043, 'me [UNK] los gatos y'), (-2.0701222, 'me gustan los gatos y'), (-2.2649746, 'me encanta los gatos y')]\n",
"Top translations so far: [(-0.81283045, 'me [UNK] los gatos y los'), (-2.118244, 'me gustan los gatos y los'), (-2.96167, 'me encanta los gatos y los')]\n",
"Top translations so far: [(-1.2259341, 'me [UNK] los gatos y los gatos'), (-1.9556838, 'me [UNK] los gatos y los perros'), (-2.7524388, 'me gustan los gatos y los perros')]\n",
"Top translations so far: [(-1.2261332, 'me [UNK] los gatos y los gatos endofseq'), (-1.9560521, 'me [UNK] los gatos y los perros endofseq'), (-2.7566314, 'me gustan los gatos y los perros endofseq')]\n"
]
},
{
"data": {
"text/plain": [
"'me [UNK] los gatos y los gatos'"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 72,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# extra code – shows how beam search can help\n",
"beam_search(sentence_en, beam_width=3, verbose=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The correct translation is in the top 3 sentences found by beam search, but it's not the first. Since we're using a small vocabulary, the \\[UNK] token is quite frequent, so you may want to penalize it (e.g., divide its probability by 2 in the beam search function): this will discourage beam search from using it too much."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Attention Mechanisms"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We need to feed all the encoder's outputs to the `Attention` layer, so we must add `return_sequences=True` to the encoder:"
2017-06-05 19:18:20 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 73,
2018-08-06 17:27:32 +02:00
"metadata": {},
2017-06-05 19:18:20 +02:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"tf.random.set_seed(42) # extra code – ensures reproducibility on CPU\n",
"encoder = tf.keras.layers.Bidirectional(\n",
" tf.keras.layers.LSTM(256, return_sequences=True, return_state=True))"
2017-06-05 19:18:20 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 74,
2018-08-06 17:27:32 +02:00
"metadata": {},
2017-06-05 19:18:20 +02:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"# extra code – this part of the model is exactly the same as earlier\n",
"encoder_outputs, *encoder_state = encoder(encoder_embeddings)\n",
"encoder_state = [tf.concat(encoder_state[::2], axis=-1), # short-term (0 & 2)\n",
" tf.concat(encoder_state[1::2], axis=-1)] # long-term (1 & 3)\n",
"decoder = tf.keras.layers.LSTM(512, return_sequences=True)\n",
"decoder_outputs = decoder(decoder_embeddings, initial_state=encoder_state)"
2016-11-24 17:23:11 +01:00
]
},
{
"cell_type": "markdown",
2017-10-05 13:22:06 +02:00
"metadata": {},
2016-11-24 17:23:11 +01:00
"source": [
2022-03-21 09:43:01 +01:00
"And finally, let's add the `Attention` layer and the output layer:"
2016-11-24 17:23:11 +01:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 75,
2017-10-05 13:22:06 +02:00
"metadata": {},
2016-11-24 17:23:11 +01:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"attention_layer = tf.keras.layers.Attention()\n",
"attention_outputs = attention_layer([decoder_outputs, encoder_outputs])\n",
"output_layer = tf.keras.layers.Dense(vocab_size, activation=\"softmax\")\n",
"Y_proba = output_layer(attention_outputs)"
2016-11-24 17:23:11 +01:00
]
},
{
"cell_type": "markdown",
2017-10-05 13:22:06 +02:00
"metadata": {},
2016-11-24 17:23:11 +01:00
"source": [
2022-03-21 09:43:01 +01:00
"**Warning**: the following cell will take a while to run (possibly a couple hours if you are not using a GPU)."
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 76,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"3125/3125 [==============================] - 597s 189ms/step - loss: 0.3074 - accuracy: 0.5469 - val_loss: 0.2106 - val_accuracy: 0.6487\n",
"Epoch 2/10\n",
"3125/3125 [==============================] - 585s 187ms/step - loss: 0.1902 - accuracy: 0.6789 - val_loss: 0.1865 - val_accuracy: 0.6830\n",
"Epoch 3/10\n",
"3125/3125 [==============================] - 585s 187ms/step - loss: 0.1659 - accuracy: 0.7123 - val_loss: 0.1759 - val_accuracy: 0.7005\n",
"Epoch 4/10\n",
"3125/3125 [==============================] - 584s 187ms/step - loss: 0.1493 - accuracy: 0.7359 - val_loss: 0.1728 - val_accuracy: 0.7060\n",
"Epoch 5/10\n",
"3125/3125 [==============================] - 582s 186ms/step - loss: 0.1358 - accuracy: 0.7548 - val_loss: 0.1724 - val_accuracy: 0.7084\n",
"Epoch 6/10\n",
"3125/3125 [==============================] - 583s 186ms/step - loss: 0.1245 - accuracy: 0.7712 - val_loss: 0.1738 - val_accuracy: 0.7103\n",
"Epoch 7/10\n",
"3125/3125 [==============================] - 582s 186ms/step - loss: 0.1148 - accuracy: 0.7863 - val_loss: 0.1770 - val_accuracy: 0.7111\n",
"Epoch 8/10\n",
"3125/3125 [==============================] - 582s 186ms/step - loss: 0.1064 - accuracy: 0.7992 - val_loss: 0.1806 - val_accuracy: 0.7110\n",
"Epoch 9/10\n",
"3125/3125 [==============================] - 582s 186ms/step - loss: 0.0991 - accuracy: 0.8101 - val_loss: 0.1862 - val_accuracy: 0.7088\n",
"Epoch 10/10\n",
"3125/3125 [==============================] - 581s 186ms/step - loss: 0.0929 - accuracy: 0.8205 - val_loss: 0.1903 - val_accuracy: 0.7077\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f87e5c8ad90>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 76,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = tf.keras.Model(inputs=[encoder_inputs, decoder_inputs],\n",
" outputs=[Y_proba])\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"nadam\",\n",
" metrics=[\"accuracy\"])\n",
"model.fit((X_train, X_train_dec), Y_train, epochs=10,\n",
" validation_data=((X_valid, X_valid_dec), Y_valid))"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 77,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'me gusta el fútbol y también ir a la playa'"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 77,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"translate(\"I like soccer and also going to the beach\")"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 78,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Top first words: [(-0.26210824, 'me'), (-2.553061, 'prefiero'), (-3.2005944, 'yo')]\n",
"Top translations so far: [(-0.32478744, 'me gusta'), (-3.0608056, 'prefiero el'), (-3.1685317, 'me gustan')]\n",
"Top translations so far: [(-0.7464272, 'me gusta el'), (-2.4712462, 'me gusta fútbol'), (-2.9149299, 'me gusta al')]\n",
"Top translations so far: [(-1.0369574, 'me gusta el fútbol'), (-2.3301778, 'me gusta el el'), (-2.9658434, 'me gusta fútbol y')]\n",
"Top translations so far: [(-1.0404125, 'me gusta el fútbol y'), (-2.5983238, 'me gusta el el fútbol'), (-2.9736564, 'me gusta fútbol y también')]\n",
"Top translations so far: [(-1.0520902, 'me gusta el fútbol y también'), (-2.6003318, 'me gusta el el fútbol y'), (-3.128903, 'me gusta fútbol y también me')]\n",
"Top translations so far: [(-1.9568634, 'me gusta el fútbol y también ir'), (-2.6169589, 'me gusta el el fútbol y también'), (-2.6949644, 'me gusta el fútbol y también fuera')]\n",
"Top translations so far: [(-1.9676423, 'me gusta el fútbol y también ir a'), (-2.8482866, 'me gusta el fútbol y también fuera a'), (-3.7197533, 'me gusta el el fútbol y también ir')]\n",
"Top translations so far: [(-1.9692448, 'me gusta el fútbol y también ir a la'), (-2.8501132, 'me gusta el fútbol y también fuera a la'), (-3.7309551, 'me gusta el el fútbol y también ir a')]\n",
"Top translations so far: [(-1.9733216, 'me gusta el fútbol y también ir a la playa'), (-2.851697, 'me gusta el fútbol y también fuera a la playa'), (-3.7333717, 'me gusta el el fútbol y también ir a la')]\n",
"Top translations so far: [(-1.9737166, 'me gusta el fútbol y también ir a la playa endofseq'), (-2.8547554, 'me gusta el fútbol y también fuera a la playa endofseq'), (-3.737218, 'me gusta el el fútbol y también ir a la playa')]\n",
"Top translations so far: [(-1.9737166, 'me gusta el fútbol y también ir a la playa endofseq'), (-2.8547554, 'me gusta el fútbol y también fuera a la playa endofseq'), (-3.7375438, 'me gusta el el fútbol y también ir a la playa endofseq')]\n"
]
},
{
"data": {
"text/plain": [
"'me gusta el fútbol y también ir a la playa'"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 78,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"beam_search(\"I like soccer and also going to the beach\", beam_width=3,\n",
" verbose=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Attention Is All You Need: The Transformer Architecture\n",
"### Positional encodings"
2016-11-24 17:23:11 +01:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 79,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [],
"source": [
"max_length = 50 # max length in the whole training set\n",
"embed_size = 128\n",
"tf.random.set_seed(42) # extra code – ensures reproducibility on CPU\n",
"pos_embed_layer = tf.keras.layers.Embedding(max_length, embed_size)\n",
"batch_max_len_enc = tf.shape(encoder_embeddings)[1]\n",
"encoder_in = encoder_embeddings + pos_embed_layer(tf.range(batch_max_len_enc))\n",
"batch_max_len_dec = tf.shape(decoder_embeddings)[1]\n",
"decoder_in = decoder_embeddings + pos_embed_layer(tf.range(batch_max_len_dec))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Alternatively, we can use fixed, non-trainable positional encodings:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 80,
2018-08-06 17:27:32 +02:00
"metadata": {},
2016-11-24 17:23:11 +01:00
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"class PositionalEncoding(tf.keras.layers.Layer):\n",
2022-03-21 09:43:01 +01:00
" def __init__(self, max_length, embed_size, dtype=tf.float32, **kwargs):\n",
2019-04-15 18:09:10 +02:00
" super().__init__(dtype=dtype, **kwargs)\n",
2022-04-16 07:30:08 +02:00
" assert embed_size % 2 == 0, \"embed_size must be even\"\n",
" p, i = np.meshgrid(np.arange(max_length),\n",
" 2 * np.arange(embed_size // 2))\n",
" pos_emb = np.empty((1, max_length, embed_size))\n",
" pos_emb[0, :, ::2] = np.sin(p / 10_000 ** (i / embed_size)).T\n",
" pos_emb[0, :, 1::2] = np.cos(p / 10_000 ** (i / embed_size)).T\n",
2022-03-21 09:43:01 +01:00
" self.pos_encodings = tf.constant(pos_emb.astype(self.dtype))\n",
" self.supports_masking = True\n",
"\n",
2019-04-15 18:09:10 +02:00
" def call(self, inputs):\n",
2022-03-21 09:43:01 +01:00
" batch_max_length = tf.shape(inputs)[1]\n",
" return inputs + self.pos_encodings[:, :batch_max_length]"
2016-09-27 23:31:21 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 81,
2018-08-06 17:27:32 +02:00
"metadata": {},
2016-09-27 23:31:21 +02:00
"outputs": [],
2016-11-24 17:23:11 +01:00
"source": [
2022-03-21 09:43:01 +01:00
"pos_embed_layer = PositionalEncoding(max_length, embed_size)\n",
"encoder_in = pos_embed_layer(encoder_embeddings)\n",
"decoder_in = pos_embed_layer(decoder_embeddings)"
2016-11-24 17:23:11 +01:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 82,
2017-10-05 13:22:06 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAngAAAFYCAYAAADA04GRAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOxdeZwUxfX/9uzJLfclIgoiioAIKl4ICwoo4C3emkTjkRhNTOIRk5hfvHKoSbxiEkHjrSgIghy7qHhFFBZRkUNBuW+XhT1n5v3+eFu7tbVV3T0zXb0L9vfz6c/M9PR09VS9evV9r169cogIESJEiBAhQoQIEfYfxBr7ASJEiBAhQoQIESIEi4jgRYgQIUKECBEi7GeICF6ECBEiRIgQIcJ+hojgRYgQIUKECBEi7GeICF6ECBEiRIgQIcJ+hojgRYgQIUKECBEi7GdosgTPcZwnHcfZ6jjOZ4bvHcdx/u44zmrHcT51HGdw2M8YIUKECBEiRIjQFNFkCR6AKQDGuHw/FkCfmuMaAI+F8EwRIkSIECFChAhNHk2W4BHROwB2ulwyEcDTxPgQwAGO43QN5+kiRIgQIUKECBGaLposwfOB7gDWSZ/X15yLECFChAgRIkT4XiO7sR8gAziac9p91xzHuQY8jYt2wDG9+CSobVskDjwQcHS3Sg1EwJYtMWzdGoO6+1ubNoTu3RPIycm4mJSxevVqAEDv3r1dr0skEsjKygrjkRqCCFnr1sHZtYs/Z2UBiQR/1awZkj16gJo1C6So0lIH69dnoaqq/vmcHKBHjwRataprvDDrZOdOdla3a9culPLSgfPdd8hatw5IJvmEaKfsbCQ7dUKyY8dAyqmuBjZuzMJ33zXsl506JdGlSzKILhsoGrX/yIjHkbVmDZyyMv4s96U2bVjfZWeu9omAHTti2Lw5Jm5fi2bNCD17JpCd3UTqpIkhkUggZ/t2xDZv5hOOw0cyCcrPR7JrV1Dr1oGUVVbmYMOGGMrKHDgOasemWAzo3j2Jdu2SgZSTKZpM/5HglJcja80aVkhAXV/KykKyfXsku3QJhDvE48wdduxgf1sdf/hkOxFlplSJqMkeAA4G8Jnhu38CuEj6vAJAV697Dhw4kOj994muu44IIJowgaiigjLB2rVE/frx7S69lGjLFqJ4nG97331EeXlEbdsSzZ6dUTFpYfjw4TR8+HDP67Zt22b/YXTYs4do9GiuvJtvJlq0iCiRINqwgeihh4i6dSNq145o8eKMikkmiW64gYvp3Zto3jw+V1FB9N57RIcfzt9dfz0XTxRunUyePJkmT54cWnkp41//IorFqGrIEKLXXiMqKSGqrCSaPp3o9NO58n7724yLeestolatuM/89rdEZWVE1dVEW7cSXX01FzNgANHKlZn/pSDRaP1HxurVRIceStSsGfedVatYyD//nOiWW/h8//5EGT5rWRnRiSdyWxQUEC1bxsWUlBC9+CLruubNif7+990B/bH9CNXVVHbFFVx5F1xAtHAhK6Hdu4n++U8eSLKyiF59NeOiHn2Ui+nShejf/+YxqbycaPlyohEj+LtJk7joxkaT6D8yZs8matmSqHt3oldeIdq8mYX8nXeIzj+fK+/ss1k5ZYBvviHq2pUoFiP60Y+I1q/ndtq4kQjAx5Qph8r0BjYPD4J3BoDZYE/e8QA+8nPPgQMH1tXuI49wFZx2GtHevem0D5WVEQ0eTNSmDdGsWfprvvySaOBAHri+/DKtYtJGkyZ4333HI0UsRmQiN6tXEx10ENEBBzD5SxMPPshN/dOfspJTUVZGdNNNfM1dd/G5iODV4E9/4ooZO5a2ffNNw+8TCaIf/ICv+d3v0i5mzRqi9u2ZbK9apb9mxgy+pm9fotLStIsKHI0+QC1dStS5MxtDH3ygv2b+fKL8fKJBg4h27EirmGSS6KKLiByHaMoU/qxi3bo6AjFzZlrF7J+oqiIaP54r5tZb9ZW3ezfR8ccT5eRk5BGYM4d54pln6glcPE509918zVln6R8lTDR6/5Hx7LNcMYMGMePS4W9/43a8/PI6j0CKKCkhOuoootatiT7+uOH3+zXBA/A8gE0AqsHxdT8EcC2Aa2u+dwA8AuArAMsADPFz33oEj4joySdZW118sa9GkZFMEl15JdfijBnu1377LVGHDkRHHBHuwNRkCV4yyZolJ4fo5Zfdr12zhujgg5nkmTqcC4qK6hSZW19MJtkD6zisICOCR0RvvMECfuGFRJWV5jqRSd7f/55yMXv2sBF0wAHe3rmiIrYJLr648QcmgUYdoPbuZbd09+5EX3zhfu2bbxLl5hINHcoe2BRxzz3cxPfc435deTnRkUdWU7t2rPsiENEf/kAEUKlX5e3aRXT00UzG33sv5WKWL2eHw4AB3t65Bx5Iu8sGiiZD8FatYvfzKad4V15Ne9KNN6ZcTHU10bhxPC7NmaO/Zr8meLaOBgSPiOj3v+eqMLngDHjsMf7ZnXf6u37+fB6YLrwwvIHp6quvpquvvtrzutA72CuvcOX96U/+rl+1ihXeOeekVMw33zCx7tfP31TEnj1ERx7JXqLi4vS8HOng9ddfp9dffz208nyhpISoRw+2SmrCGFzlJJFgrdW8OVe8TySTPFsVi/l3Wvzxjyw+jz3muxiraNQB6pZbuDIKC/1d//LLfP1996VUzMyZdbawH/314Yc7qGVLdtJXVaVU1P6HpUvZmL3oIn+ysm0bG7WHH54SEd+9m2fpO3Xi0CEvJJPs5cvNJfrkE9/FBI4mQfASCaKTTmJ2vG6d9/XJJJO7VPpeDW69lX/2+OPmayKCFxTBq6hgBtCzp2/32ldfcacYM4bd3X5x331c60895f83YSDUDrZzJweGDB6cWgzDvfdy5U2f7vsnY8emPjX+5Zf8m2HDqpqMh6hRcP31PKJLU36ecrJ2LRO8iRN9FzNtGjfrvff6f7REgvtebi7RZ5/5/50tNNoA9eGHzIx//OPUfnfWWdxOflgAcQhDjx7sFSor81fEtm3b6LnnuG3vuCO1x9uvUFXFHrlOnYi2bfMvKzNnpkzEb7+df7Jwof/H27aNnb+9ezde2EOTIHjCnTlliv/flJUR9erF/MEnEV+xgig7myc83BAEwXP4Pt8fDBo0iIqLixt+8e67wMknAz//OfDXv3re55JLgFdfBVavBrqnkJyFCDj2WGDrVmDlSiAvz/9vbWL79u3o0KFDOIVdfTUweTKwaBFw9NH+f1ddDQweDJSUAF98AbRs6Xr5ggXAyJHAn/4E/PKXqT3i448D110HzJgBnHlmar/dL7BwIXDKKcBNNwEPPlh72pec/PnPwK9+BUybBkyc6HppPA7078+L0ZYtS22B57ZtQJ8+wKmnclGNiVD7j0BlJfeH3buBzz8HUll5+e23QL9+wOjRvirv/vuBW2/lPnXqqf6KEHVy+eXAiy8Cq1YBBx3k/xH3G/zxj8CddwJTpwLnnJOarJx1FjBvHrB8uWflrVsHHHYYcO65wDPPuN92165d2LRpE6qqqkBESCZZvWZn82LR7x2IgKoqXl6carqLFCuvuhpIJBy0adMKhx/eC7GYPlud4zifENGQ1B5GQaYMcV87tB48gR//mK1hjxWbS5ZQbZxsOpg3j3//0EPp/T4VNLkp2vfe4z//61+n9/v33+ff//znrpclkxxmdOCB/j0OMqqqiHr1itOAAWnH0KaEJjVFm0jwtOzBB/OctQRfclJVxas1e/TwdAk8/njKTtl6EFO1H36Y3u+DQqN4IP7yF0ontKQW99/Pv/eQu61bORB8/PjUbi/q5Jtv2NPqQw3tf9i0iUNLLrig9lRKsrJ2La9+9hGacumlvPrcyym7c+dOWrZsGe3Zs4cSknLbu5ejMhpj1iIRhpJ1g/jz6T6Hz99XV/NlZWUJ+uqrr2jDhg3GaxFN0QZM8Hbt4ihvj840ZgynAti1y/UyV4waxbFhJSXp38MPmtwii7FjeaoizVXLRET0wx/yiLFpk/GSl15i6c5k3cITT5QQwIuqbKNJLbKYOpUr7/nnG3zlW07efZfvcf/9xktKS3n
"text/plain": [
"<Figure size 648x360 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-11-24 17:23:11 +01:00
"source": [
2022-03-21 09:43:01 +01:00
"# extra code – this cells generates and saves Figure 16– 9\n",
"figure_max_length = 201\n",
"figure_embed_size = 512\n",
"pos_emb = PositionalEncoding(figure_max_length, figure_embed_size)\n",
"zeros = np.zeros((1, figure_max_length, figure_embed_size), np.float32)\n",
"P = pos_emb(zeros)[0].numpy()\n",
2019-04-15 18:09:10 +02:00
"i1, i2, crop_i = 100, 101, 150\n",
"p1, p2, p3 = 22, 60, 35\n",
"fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(9, 5))\n",
"ax1.plot([p1, p1], [-1, 1], \"k--\", label=\"$p = {}$\".format(p1))\n",
"ax1.plot([p2, p2], [-1, 1], \"k--\", label=\"$p = {}$\".format(p2), alpha=0.5)\n",
2022-03-21 09:43:01 +01:00
"ax1.plot(p3, P[p3, i1], \"bx\", label=\"$p = {}$\".format(p3))\n",
"ax1.plot(P[:,i1], \"b-\", label=\"$i = {}$\".format(i1))\n",
"ax1.plot(P[:,i2], \"r-\", label=\"$i = {}$\".format(i2))\n",
"ax1.plot([p1, p2], [P[p1, i1], P[p2, i1]], \"bo\")\n",
"ax1.plot([p1, p2], [P[p1, i2], P[p2, i2]], \"ro\")\n",
2019-04-15 18:09:10 +02:00
"ax1.legend(loc=\"center right\", fontsize=14, framealpha=0.95)\n",
"ax1.set_ylabel(\"$P_{(p,i)}$\", rotation=0, fontsize=16)\n",
"ax1.grid(True, alpha=0.3)\n",
2022-03-21 09:43:01 +01:00
"ax1.hlines(0, 0, figure_max_length - 1, color=\"k\", linewidth=1, alpha=0.3)\n",
"ax1.axis([0, figure_max_length - 1, -1, 1])\n",
"ax2.imshow(P.T[:crop_i], cmap=\"gray\", interpolation=\"bilinear\", aspect=\"auto\")\n",
"ax2.hlines(i1, 0, figure_max_length - 1, color=\"b\", linewidth=3)\n",
"cheat = 2 # need to raise the red line a bit, or else it hides the blue one\n",
"ax2.hlines(i2+cheat, 0, figure_max_length - 1, color=\"r\", linewidth=3)\n",
2019-04-15 18:09:10 +02:00
"ax2.plot([p1, p1], [0, crop_i], \"k--\")\n",
"ax2.plot([p2, p2], [0, crop_i], \"k--\", alpha=0.5)\n",
"ax2.plot([p1, p2], [i2+cheat, i2+cheat], \"ro\")\n",
"ax2.plot([p1, p2], [i1, i1], \"bo\")\n",
2022-03-21 09:43:01 +01:00
"ax2.axis([0, figure_max_length - 1, 0, crop_i])\n",
2019-04-15 18:09:10 +02:00
"ax2.set_xlabel(\"$p$\", fontsize=16)\n",
"ax2.set_ylabel(\"$i$\", rotation=0, fontsize=16)\n",
2021-03-11 03:11:44 +01:00
"save_fig(\"positional_embedding_plot\")\n",
2019-04-15 18:09:10 +02:00
"plt.show()"
2016-11-24 17:23:11 +01:00
]
},
2022-03-21 09:43:01 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Multi-Head Attention"
]
},
2016-11-24 17:23:11 +01:00
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 83,
2018-08-06 17:27:32 +02:00
"metadata": {},
2016-11-24 17:23:11 +01:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"N = 2 # instead of 6\n",
"num_heads = 8\n",
"dropout_rate = 0.1\n",
"n_units = 128 # for the first Dense layer in each Feed Forward block\n",
"encoder_pad_mask = tf.math.not_equal(encoder_input_ids, 0)[:, tf.newaxis]\n",
"Z = encoder_in\n",
"for _ in range(N):\n",
" skip = Z\n",
" attn_layer = tf.keras.layers.MultiHeadAttention(\n",
" num_heads=num_heads, key_dim=embed_size, dropout=dropout_rate)\n",
" Z = attn_layer(Z, value=Z, attention_mask=encoder_pad_mask)\n",
" Z = tf.keras.layers.LayerNormalization()(tf.keras.layers.Add()([Z, skip]))\n",
" skip = Z\n",
" Z = tf.keras.layers.Dense(n_units, activation=\"relu\")(Z)\n",
" Z = tf.keras.layers.Dense(embed_size)(Z)\n",
" Z = tf.keras.layers.Dropout(dropout_rate)(Z)\n",
" Z = tf.keras.layers.LayerNormalization()(tf.keras.layers.Add()([Z, skip]))"
2016-11-24 17:23:11 +01:00
]
},
2019-04-16 13:52:49 +02:00
{
2022-03-21 09:43:01 +01:00
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 84,
2019-04-16 13:52:49 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [],
2019-04-16 13:52:49 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"decoder_pad_mask = tf.math.not_equal(decoder_input_ids, 0)[:, tf.newaxis]\n",
"causal_mask = tf.linalg.band_part( # creates a lower triangular matrix\n",
" tf.ones((batch_max_len_dec, batch_max_len_dec), tf.bool), -1, 0)"
2019-04-16 13:52:49 +02:00
]
},
2016-11-24 17:23:11 +01:00
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 85,
2018-08-06 17:27:32 +02:00
"metadata": {},
2016-11-24 17:23:11 +01:00
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"encoder_outputs = Z # let's save the encoder's final outputs\n",
"Z = decoder_in # the decoder starts with its own inputs\n",
"for _ in range(N):\n",
" skip = Z\n",
" attn_layer = tf.keras.layers.MultiHeadAttention(\n",
" num_heads=num_heads, key_dim=embed_size, dropout=dropout_rate)\n",
" Z = attn_layer(Z, value=Z, attention_mask=causal_mask & decoder_pad_mask)\n",
" Z = tf.keras.layers.LayerNormalization()(tf.keras.layers.Add()([Z, skip]))\n",
" skip = Z\n",
" attn_layer = tf.keras.layers.MultiHeadAttention(\n",
" num_heads=num_heads, key_dim=embed_size, dropout=dropout_rate)\n",
" Z = attn_layer(Z, value=encoder_outputs, attention_mask=encoder_pad_mask)\n",
" Z = tf.keras.layers.LayerNormalization()(tf.keras.layers.Add()([Z, skip]))\n",
" skip = Z\n",
" Z = tf.keras.layers.Dense(n_units, activation=\"relu\")(Z)\n",
" Z = tf.keras.layers.Dense(embed_size)(Z)\n",
" Z = tf.keras.layers.LayerNormalization()(tf.keras.layers.Add()([Z, skip]))"
2016-11-24 17:23:11 +01:00
]
},
2019-04-16 13:52:49 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"**Warning**: the following cell will take a while to run (possibly 2 or 3 hours if you are not using a GPU)."
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 86,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"3125/3125 [==============================] - 828s 263ms/step - loss: 0.2982 - accuracy: 0.5545 - val_loss: 0.2105 - val_accuracy: 0.6476\n",
"Epoch 2/10\n",
"3125/3125 [==============================] - 820s 262ms/step - loss: 0.2006 - accuracy: 0.6601 - val_loss: 0.1876 - val_accuracy: 0.6802\n",
"Epoch 3/10\n",
"3125/3125 [==============================] - 820s 263ms/step - loss: 0.1842 - accuracy: 0.6816 - val_loss: 0.1766 - val_accuracy: 0.6975\n",
"Epoch 4/10\n",
"3125/3125 [==============================] - 820s 262ms/step - loss: 0.1748 - accuracy: 0.6942 - val_loss: 0.1704 - val_accuracy: 0.7055\n",
"Epoch 5/10\n",
"3125/3125 [==============================] - 820s 262ms/step - loss: 0.1683 - accuracy: 0.7021 - val_loss: 0.1657 - val_accuracy: 0.7102\n",
"Epoch 6/10\n",
"3125/3125 [==============================] - 821s 263ms/step - loss: 0.1628 - accuracy: 0.7096 - val_loss: 0.1628 - val_accuracy: 0.7130\n",
"Epoch 7/10\n",
"3125/3125 [==============================] - 826s 264ms/step - loss: 0.1588 - accuracy: 0.7154 - val_loss: 0.1595 - val_accuracy: 0.7205\n",
"Epoch 8/10\n",
"3125/3125 [==============================] - 822s 263ms/step - loss: 0.1550 - accuracy: 0.7205 - val_loss: 0.1590 - val_accuracy: 0.7199\n",
"Epoch 9/10\n",
"3125/3125 [==============================] - 821s 263ms/step - loss: 0.1518 - accuracy: 0.7249 - val_loss: 0.1547 - val_accuracy: 0.7258\n",
"Epoch 10/10\n",
"3125/3125 [==============================] - 821s 263ms/step - loss: 0.1492 - accuracy: 0.7279 - val_loss: 0.1538 - val_accuracy: 0.7281\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f8946cdf9a0>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 86,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Y_proba = tf.keras.layers.Dense(vocab_size, activation=\"softmax\")(Z)\n",
"model = tf.keras.Model(inputs=[encoder_inputs, decoder_inputs],\n",
" outputs=[Y_proba])\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"nadam\",\n",
" metrics=[\"accuracy\"])\n",
"model.fit((X_train, X_train_dec), Y_train, epochs=10,\n",
" validation_data=((X_valid, X_valid_dec), Y_valid))"
2019-04-16 13:52:49 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 87,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'me gusta el fútbol y yo también voy a la playa'"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 87,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"translate(\"I like soccer and also going to the beach\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# HuggingFace"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Install the Transformers and Datasets libraries if we're running on Colab:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 88,
2019-04-16 13:52:49 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-03-21 09:43:01 +01:00
"if \"google.colab\" in sys.modules:\n",
" %pip install -q -U transformers\n",
" %pip install -q -U datasets"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 89,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"No model was supplied, defaulted to distilbert-base-uncased-finetuned-sst-2-english (https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english)\n",
"All model checkpoint layers were used when initializing TFDistilBertForSequenceClassification.\n",
"\n",
"All the layers of TFDistilBertForSequenceClassification were initialized from the model checkpoint at distilbert-base-uncased-finetuned-sst-2-english.\n",
"If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertForSequenceClassification for predictions without further training.\n"
]
}
],
"source": [
"from transformers import pipeline\n",
2019-04-16 13:52:49 +02:00
"\n",
2022-03-21 09:43:01 +01:00
"classifier = pipeline(\"sentiment-analysis\") # many other tasks are available\n",
"result = classifier(\"The actors were very convincing.\")"
]
},
2022-09-25 06:31:58 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Models can be very biased. For example, it may like or dislike some countries depending on the data it was trained on, and how it is used, so use it with care:"
]
},
2022-03-21 09:43:01 +01:00
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 90,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'label': 'POSITIVE', 'score': 0.9896161556243896},\n",
" {'label': 'NEGATIVE', 'score': 0.9811071157455444}]"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 90,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"classifier([\"I am from India.\", \"I am from Iraq.\"])"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 91,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some layers from the model checkpoint at huggingface/distilbert-base-uncased-finetuned-mnli were not used when initializing TFDistilBertForSequenceClassification: ['dropout_19']\n",
"- This IS expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some layers of TFDistilBertForSequenceClassification were not initialized from the model checkpoint at huggingface/distilbert-base-uncased-finetuned-mnli and are newly initialized: ['dropout_39']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
},
{
"data": {
"text/plain": [
"[{'label': 'contradiction', 'score': 0.9790192246437073}]"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 91,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_name = \"huggingface/distilbert-base-uncased-finetuned-mnli\"\n",
"classifier_mnli = pipeline(\"text-classification\", model=model_name)\n",
"classifier_mnli(\"She loves me. [SEP] She loves me not.\")"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 92,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some layers from the model checkpoint at huggingface/distilbert-base-uncased-finetuned-mnli were not used when initializing TFDistilBertForSequenceClassification: ['dropout_19']\n",
"- This IS expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
"Some layers of TFDistilBertForSequenceClassification were not initialized from the model checkpoint at huggingface/distilbert-base-uncased-finetuned-mnli and are newly initialized: ['dropout_59']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"from transformers import AutoTokenizer, TFAutoModelForSequenceClassification\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model = TFAutoModelForSequenceClassification.from_pretrained(model_name)"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 93,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': <tf.Tensor: shape=(2, 15), dtype=int32, numpy=\n",
"array([[ 101, 1045, 2066, 4715, 1012, 102, 2057, 2035, 2293, 4715, 999,\n",
" 102, 0, 0, 0],\n",
" [ 101, 3533, 2973, 2005, 1037, 2200, 2146, 2051, 1012, 102, 3533,\n",
" 2003, 2214, 1012, 102]], dtype=int32)>, 'attention_mask': <tf.Tensor: shape=(2, 15), dtype=int32, numpy=\n",
"array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],\n",
" [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)>}"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 93,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"token_ids = tokenizer([\"I like soccer. [SEP] We all love soccer!\",\n",
" \"Joe lived for a very long time. [SEP] Joe is old.\"],\n",
" padding=True, return_tensors=\"tf\")\n",
"token_ids"
2019-04-16 13:52:49 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 94,
2019-04-16 13:52:49 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': <tf.Tensor: shape=(2, 15), dtype=int32, numpy=\n",
"array([[ 101, 1045, 2066, 4715, 1012, 102, 2057, 2035, 2293, 4715, 999,\n",
" 102, 0, 0, 0],\n",
" [ 101, 3533, 2973, 2005, 1037, 2200, 2146, 2051, 1012, 102, 3533,\n",
" 2003, 2214, 1012, 102]], dtype=int32)>, 'attention_mask': <tf.Tensor: shape=(2, 15), dtype=int32, numpy=\n",
"array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],\n",
" [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=int32)>}"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 94,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"token_ids = tokenizer([(\"I like soccer.\", \"We all love soccer!\"),\n",
" (\"Joe lived for a very long time.\", \"Joe is old.\")],\n",
" padding=True, return_tensors=\"tf\")\n",
"token_ids"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 95,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TFSequenceClassifierOutput(loss=None, logits=<tf.Tensor: shape=(2, 3), dtype=float32, numpy=\n",
"array([[-2.1123817 , 1.1786783 , 1.4101017 ],\n",
" [-0.01478387, 1.0962474 , -0.9919954 ]], dtype=float32)>, hidden_states=None, attentions=None)"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 95,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"outputs = model(token_ids)\n",
"outputs"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 96,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(2, 3), dtype=float32, numpy=\n",
"array([[0.01619702, 0.43523544, 0.5485676 ],\n",
" [0.22655967, 0.6881726 , 0.0852678 ]], dtype=float32)>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 96,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Y_probas = tf.keras.activations.softmax(outputs.logits)\n",
"Y_probas"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 97,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(2,), dtype=int64, numpy=array([2, 1])>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 97,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Y_pred = tf.argmax(Y_probas, axis=1)\n",
"Y_pred # 0 = contradiction, 1 = entailment, 2 = neutral"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 98,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/2\n",
"1/1 [==============================] - 10s 10s/step - loss: 1.1190 - accuracy: 0.5000\n",
"Epoch 2/2\n",
"1/1 [==============================] - 0s 491ms/step - loss: 0.6666 - accuracy: 0.5000\n"
]
}
],
2019-04-16 13:52:49 +02:00
"source": [
2022-03-21 09:43:01 +01:00
"sentences = [(\"Sky is blue\", \"Sky is red\"), (\"I love her\", \"She loves me\")]\n",
"X_train = tokenizer(sentences, padding=True, return_tensors=\"tf\").data\n",
"y_train = tf.constant([0, 2]) # contradiction, neutral\n",
"loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
"model.compile(loss=loss, optimizer=\"nadam\", metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, y_train, epochs=2)"
2019-04-16 13:52:49 +02:00
]
2020-04-19 06:01:14 +02:00
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Exercise solutions"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. to 7."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 10:09:28 +01:00
"1. Stateless RNNs can only capture patterns whose length is less than, or equal to, the size of the windows the RNN is trained on. Conversely, stateful RNNs can capture longer-term patterns. However, implementing a stateful RNN is much harder —especially preparing the dataset properly. Moreover, stateful RNNs do not always work better, in part because consecutive batches are not independent and identically distributed (IID). Gradient Descent is not fond of non-IID datasets.\n",
"2. In general, if you translate a sentence one word at a time, the result will be terrible. For example, the French sentence \"Je vous en prie\" means \"You are welcome,\" but if you translate it one word at a time, you get \"I you in pray.\" Huh? It is much better to read the whole sentence first and then translate it. A plain sequence-to-sequence RNN would start translating a sentence immediately after reading the first word, while an Encoder– Decoder RNN will first read the whole sentence and then translate it. That said, one could imagine a plain sequence-to-sequence RNN that would output silence whenever it is unsure about what to say next (just like human translators do when they must translate a live broadcast).\n",
2022-03-21 09:43:01 +01:00
"3. Variable-length input sequences can be handled by padding the shorter sequences so that all sequences in a batch have the same length, and using masking to ensure the RNN ignores the padding token. For better performance, you may also want to create batches containing sequences of similar sizes. Ragged tensors can hold sequences of variable lengths, and Keras now supports them, which simplifies handling variable-length input sequences (at the time of this writing, it still does not handle ragged tensors as targets on the GPU, though). Regarding variable-length output sequences, if the length of the output sequence is known in advance (e.g., if you know that it is the same as the input sequence), then you just need to configure the loss function so that it ignores tokens that come after the end of the sequence. Similarly, the code that will use the model should ignore tokens beyond the end of the sequence. But generally the length of the output sequence is not known ahead of time, so the solution is to train the model so that it outputs an end-of-sequence token at the end of each sequence.\n",
"4. Beam search is a technique used to improve the performance of a trained Encoder– Decoder model, for example in a neural machine translation system. The algorithm keeps track of a short list of the _k_ most promising output sentences (say, the top three), and at each decoder step it tries to extend them by one word; then it keeps only the _k_ most likely sentences. The parameter _k_ is called the _beam width_: the larger it is, the more CPU and RAM will be used, but also the more accurate the system will be. Instead of greedily choosing the most likely next word at each step to extend a single sentence, this technique allows the system to explore several promising sentences simultaneously. Moreover, this technique lends itself well to parallelization. You can implement beam search by writing a custom memory cell. Alternatively, TensorFlow Addons's seq2seq API provides an implementation.\n",
2022-02-19 10:09:28 +01:00
"5. An attention mechanism is a technique initially used in Encoder– Decoder models to give the decoder more direct access to the input sequence, allowing it to deal with longer input sequences. At each decoder time step, the current decoder's state and the full output of the encoder are processed by an alignment model that outputs an alignment score for each input time step. This score indicates which part of the input is most relevant to the current decoder time step. The weighted sum of the encoder output (weighted by their alignment score) is then fed to the decoder, which produces the next decoder state and the output for this time step. The main benefit of using an attention mechanism is the fact that the Encoder– Decoder model can successfully process longer input sequences. Another benefit is that the alignment scores make the model easier to debug and interpret: for example, if the model makes a mistake, you can look at which part of the input it was paying attention to, and this can help diagnose the issue. An attention mechanism is also at the core of the Transformer architecture, in the Multi-Head Attention layers. See the next answer.\n",
"6. The most important layer in the Transformer architecture is the Multi-Head Attention layer (the original Transformer architecture contains 18 of them, including 6 Masked Multi-Head Attention layers). It is at the core of language models such as BERT and GPT-2. Its purpose is to allow the model to identify which words are most aligned with each other, and then improve each word's representation using these contextual clues.\n",
"7. Sampled softmax is used when training a classification model when there are many classes (e.g., thousands). It computes an approximation of the cross-entropy loss based on the logit predicted by the model for the correct class, and the predicted logits for a sample of incorrect words. This speeds up training considerably compared to computing the softmax over all logits and then estimating the cross-entropy loss. After training, the model can be used normally, using the regular softmax function to compute all the class probabilities based on all the logits."
2020-04-19 06:01:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 8.\n",
"_Exercise:_ Embedded Reber grammars _were used by Hochreiter and Schmidhuber in [their paper](https://homl.info/93) about LSTMs. They are artificial grammars that produce strings such as \"BPBTSXXVPSEPE.\" Check out Jenny Orr's [nice introduction](https://homl.info/108) to this topic. Choose a particular embedded Reber grammar (such as the one represented on Jenny Orr's page), then train an RNN to identify whether a string respects that grammar or not. You will first need to write a function capable of generating a training batch containing about 50% strings that respect the grammar, and 50% that don't._"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First we need to build a function that generates strings based on a grammar. The grammar will be represented as a list of possible transitions for each state. A transition specifies the string to output (or a grammar to generate it) and the next state."
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 99,
2020-04-19 06:01:14 +02:00
"metadata": {},
"outputs": [],
"source": [
"default_reber_grammar = [\n",
" [(\"B\", 1)], # (state 0) =B=>(state 1)\n",
" [(\"T\", 2), (\"P\", 3)], # (state 1) =T=>(state 2) or =P=>(state 3)\n",
" [(\"S\", 2), (\"X\", 4)], # (state 2) =S=>(state 2) or =X=>(state 4)\n",
" [(\"T\", 3), (\"V\", 5)], # and so on...\n",
" [(\"X\", 3), (\"S\", 6)],\n",
" [(\"P\", 4), (\"V\", 6)],\n",
" [(\"E\", None)]] # (state 6) =E=>(terminal state)\n",
"\n",
"embedded_reber_grammar = [\n",
" [(\"B\", 1)],\n",
" [(\"T\", 2), (\"P\", 3)],\n",
" [(default_reber_grammar, 4)],\n",
" [(default_reber_grammar, 5)],\n",
" [(\"T\", 6)],\n",
" [(\"P\", 6)],\n",
" [(\"E\", None)]]\n",
"\n",
"def generate_string(grammar):\n",
" state = 0\n",
" output = []\n",
" while state is not None:\n",
" index = np.random.randint(len(grammar[state]))\n",
" production, state = grammar[state][index]\n",
" if isinstance(production, list):\n",
" production = generate_string(grammar=production)\n",
" output.append(production)\n",
" return \"\".join(output)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's generate a few strings based on the default Reber grammar:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 100,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"BTXXTTVPXTVPXTTVPSE BPVPSE BTXSE BPVVE BPVVE BTSXSE BPTVPXTTTVVE BPVVE BTXSE BTXXVPSE BPTTTTTTTTVVE BTXSE BPVPSE BTXSE BPTVPSE BTXXTVPSE BPVVE BPVVE BPVVE BPTTVVE BPVVE BPVVE BTXXVVE BTXXVVE BTXXVPXVVE "
]
}
],
2020-04-19 06:01:14 +02:00
"source": [
"np.random.seed(42)\n",
"\n",
"for _ in range(25):\n",
" print(generate_string(default_reber_grammar), end=\" \")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Looks good. Now let's generate a few strings based on the embedded Reber grammar:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 101,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"BTBPTTTVPXTVPXTTVPSETE BPBPTVPSEPE BPBPVVEPE BPBPVPXVVEPE BPBTXXTTTTVVEPE BPBPVPSEPE BPBTXXVPSEPE BPBTSSSSSSSXSEPE BTBPVVETE BPBTXXVVEPE BPBTXXVPSEPE BTBTXXVVETE BPBPVVEPE BPBPVVEPE BPBTSXSEPE BPBPVVEPE BPBPTVPSEPE BPBTXXVVEPE BTBPTVPXVVETE BTBPVVETE BTBTSSSSSSSXXVVETE BPBTSSSXXTTTTVPSEPE BTBPTTVVETE BPBTXXTVVEPE BTBTXSETE "
]
}
],
2020-04-19 06:01:14 +02:00
"source": [
"np.random.seed(42)\n",
"\n",
"for _ in range(25):\n",
" print(generate_string(embedded_reber_grammar), end=\" \")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Okay, now we need a function to generate strings that do not respect the grammar. We could generate a random string, but the task would be a bit too easy, so instead we will generate a string that respects the grammar, and we will corrupt it by changing just one character:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 102,
2020-04-19 06:01:14 +02:00
"metadata": {},
"outputs": [],
"source": [
"POSSIBLE_CHARS = \"BEPSTVX\"\n",
"\n",
"def generate_corrupted_string(grammar, chars=POSSIBLE_CHARS):\n",
" good_string = generate_string(grammar)\n",
" index = np.random.randint(len(good_string))\n",
" good_char = good_string[index]\n",
" bad_char = np.random.choice(sorted(set(chars) - set(good_char)))\n",
" return good_string[:index] + bad_char + good_string[index + 1:]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's look at a few corrupted strings:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 103,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"BTBPTTTPPXTVPXTTVPSETE BPBTXEEPE BPBPTVVVEPE BPBTSSSSXSETE BPTTXSEPE BTBPVPXTTTTTTEVETE BPBTXXSVEPE BSBPTTVPSETE BPBXVVEPE BEBTXSETE BPBPVPSXPE BTBPVVVETE BPBTSXSETE BPBPTTTPTTTTTVPSEPE BTBTXXTTSTVPSETE BBBTXSETE BPBTPXSEPE BPBPVPXTTTTVPXTVPXVPXTTTVVEVE BTBXXXTVPSETE BEBTSSSSSXXVPXTVVETE BTBXTTVVETE BPBTXSTPE BTBTXXTTTVPSBTE BTBTXSETX BTBTSXSSTE "
]
}
],
2020-04-19 06:01:14 +02:00
"source": [
"np.random.seed(42)\n",
"\n",
"for _ in range(25):\n",
" print(generate_corrupted_string(embedded_reber_grammar), end=\" \")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We cannot feed strings directly to an RNN, so we need to encode them somehow. One option would be to one-hot encode each character. Another option is to use embeddings. Let's go for the second option (but since there are just a handful of characters, one-hot encoding would probably be a good option as well). For embeddings to work, we need to convert each string into a sequence of character IDs. Let's write a function for that, using each character's index in the string of possible characters \"BEPSTVX\":"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 104,
2020-04-19 06:01:14 +02:00
"metadata": {},
"outputs": [],
"source": [
"def string_to_ids(s, chars=POSSIBLE_CHARS):\n",
2020-10-19 18:38:17 +02:00
" return [chars.index(c) for c in s]"
2020-04-19 06:01:14 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 105,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[0, 4, 4, 4, 6, 6, 5, 5, 1, 4, 1]"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 105,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-04-19 06:01:14 +02:00
"source": [
"string_to_ids(\"BTTTXXVVETE\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now generate the dataset, with 50% good strings, and 50% bad strings:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 106,
2020-04-19 06:01:14 +02:00
"metadata": {},
"outputs": [],
"source": [
"def generate_dataset(size):\n",
2022-03-21 09:43:01 +01:00
" good_strings = [\n",
" string_to_ids(generate_string(embedded_reber_grammar))\n",
" for _ in range(size // 2)\n",
" ]\n",
" bad_strings = [\n",
" string_to_ids(generate_corrupted_string(embedded_reber_grammar))\n",
" for _ in range(size - size // 2)\n",
" ]\n",
2020-04-19 06:01:14 +02:00
" all_strings = good_strings + bad_strings\n",
" X = tf.ragged.constant(all_strings, ragged_rank=1)\n",
" y = np.array([[1.] for _ in range(len(good_strings))] +\n",
" [[0.] for _ in range(len(bad_strings))])\n",
" return X, y"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 107,
2020-04-19 06:01:14 +02:00
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(42)\n",
"\n",
"X_train, y_train = generate_dataset(10000)\n",
"X_valid, y_valid = generate_dataset(2000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's take a look at the first training sequence:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 108,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(22,), dtype=int32, numpy=\n",
"array([0, 4, 0, 2, 4, 4, 4, 5, 2, 6, 4, 5, 2, 6, 4, 4, 5, 2, 3, 1, 4, 1],\n",
" dtype=int32)>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 108,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-04-19 06:01:14 +02:00
"source": [
"X_train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-10-19 18:44:33 +02:00
"What class does it belong to?"
2020-04-19 06:01:14 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 109,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([1.])"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 109,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-04-19 06:01:14 +02:00
"source": [
"y_train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Perfect! We are ready to create the RNN to identify good strings. We build a simple sequence binary classifier:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 110,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2022-04-16 07:30:08 +02:00
"Epoch 1/20\n",
2022-03-21 09:43:01 +01:00
"313/313 [==============================] - 4s 8ms/step - loss: 0.6910 - accuracy: 0.5095 - val_loss: 0.6825 - val_accuracy: 0.5645\n",
"Epoch 2/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 0.6678 - accuracy: 0.5659 - val_loss: 0.6635 - val_accuracy: 0.6105\n",
"Epoch 3/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 0.6504 - accuracy: 0.5766 - val_loss: 0.6521 - val_accuracy: 0.6110\n",
"Epoch 4/20\n",
"313/313 [==============================] - 2s 8ms/step - loss: 0.6347 - accuracy: 0.5980 - val_loss: 0.6224 - val_accuracy: 0.6445\n",
"Epoch 5/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 0.6054 - accuracy: 0.6361 - val_loss: 0.5779 - val_accuracy: 0.6980\n",
"Epoch 6/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 0.5414 - accuracy: 0.7093 - val_loss: 0.4695 - val_accuracy: 0.7795\n",
"Epoch 7/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 0.3756 - accuracy: 0.8418 - val_loss: 0.2685 - val_accuracy: 0.9115\n",
"Epoch 8/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 0.2601 - accuracy: 0.9044 - val_loss: 0.1534 - val_accuracy: 0.9615\n",
"Epoch 9/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 0.1774 - accuracy: 0.9427 - val_loss: 0.1063 - val_accuracy: 0.9735\n",
"Epoch 10/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 0.0624 - accuracy: 0.9826 - val_loss: 0.0219 - val_accuracy: 0.9975\n",
"Epoch 11/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 0.0371 - accuracy: 0.9914 - val_loss: 0.0055 - val_accuracy: 1.0000\n",
"Epoch 12/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 0.0029 - accuracy: 0.9995 - val_loss: 8.7265e-04 - val_accuracy: 1.0000\n",
"Epoch 13/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 6.7552e-04 - accuracy: 1.0000 - val_loss: 4.9408e-04 - val_accuracy: 1.0000\n",
"Epoch 14/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 4.4514e-04 - accuracy: 1.0000 - val_loss: 3.6322e-04 - val_accuracy: 1.0000\n",
"Epoch 15/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 3.3943e-04 - accuracy: 1.0000 - val_loss: 2.8524e-04 - val_accuracy: 1.0000\n",
"Epoch 16/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 2.7723e-04 - accuracy: 1.0000 - val_loss: 2.3880e-04 - val_accuracy: 1.0000\n",
"Epoch 17/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 2.3477e-04 - accuracy: 1.0000 - val_loss: 2.0363e-04 - val_accuracy: 1.0000\n",
"Epoch 18/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 2.0382e-04 - accuracy: 1.0000 - val_loss: 1.7760e-04 - val_accuracy: 1.0000\n",
"Epoch 19/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 1.8077e-04 - accuracy: 1.0000 - val_loss: 1.5916e-04 - val_accuracy: 1.0000\n",
"Epoch 20/20\n",
"313/313 [==============================] - 2s 8ms/step - loss: 1.6246e-04 - accuracy: 1.0000 - val_loss: 1.4362e-04 - val_accuracy: 1.0000\n"
]
}
],
2020-04-19 06:01:14 +02:00
"source": [
"np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"\n",
"embedding_size = 5\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.InputLayer(input_shape=[None], dtype=tf.int32, ragged=True),\n",
2022-03-21 09:43:01 +01:00
" tf.keras.layers.Embedding(input_dim=len(POSSIBLE_CHARS),\n",
" output_dim=embedding_size),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.GRU(30),\n",
" tf.keras.layers.Dense(1, activation=\"sigmoid\")\n",
2020-04-19 06:01:14 +02:00
"])\n",
2022-03-21 09:43:01 +01:00
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.02, momentum = 0.95,\n",
" nesterov=True)\n",
"model.compile(loss=\"binary_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, y_train, epochs=20,\n",
" validation_data=(X_valid, y_valid))"
2020-04-19 06:01:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's test our RNN on two tricky strings: the first one is bad while the second one is good. They only differ by the second to last character. If the RNN gets this right, it shows that it managed to notice the pattern that the second letter should always be equal to the second to last letter. That requires a fairly long short-term memory (which is the reason why we used a GRU cell)."
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 111,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Estimated probability that these are Reber strings:\n",
"BPBTSSSSSSSXXTTVPXVPXTTTTTVVETE: 0.02%\n",
"BPBTSSSSSSSXXTTVPXVPXTTTTTVVEPE: 99.99%\n"
]
}
],
2020-04-19 06:01:14 +02:00
"source": [
"test_strings = [\"BPBTSSSSSSSXXTTVPXVPXTTTTTVVETE\",\n",
" \"BPBTSSSSSSSXXTTVPXVPXTTTTTVVEPE\"]\n",
"X_test = tf.ragged.constant([string_to_ids(s) for s in test_strings], ragged_rank=1)\n",
"\n",
"y_proba = model.predict(X_test)\n",
"print()\n",
"print(\"Estimated probability that these are Reber strings:\")\n",
"for index, string in enumerate(test_strings):\n",
" print(\"{}: {:.2f}%\".format(string, 100 * y_proba[index][0]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ta-da! It worked fine. The RNN found the correct answers with very high confidence. :)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 9.\n",
"_Exercise: Train an Encoder– Decoder model that can convert a date string from one format to another (e.g., from \"April 22, 2019\" to \"2019-04-22\")._"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's start by creating the dataset. We will use random days between 1000-01-01 and 9999-12-31:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 112,
2020-04-19 06:01:14 +02:00
"metadata": {},
"outputs": [],
"source": [
"from datetime import date\n",
"\n",
"# cannot use strftime()'s %B format since it depends on the locale\n",
"MONTHS = [\"January\", \"February\", \"March\", \"April\", \"May\", \"June\",\n",
" \"July\", \"August\", \"September\", \"October\", \"November\", \"December\"]\n",
"\n",
"def random_dates(n_dates):\n",
" min_date = date(1000, 1, 1).toordinal()\n",
" max_date = date(9999, 12, 31).toordinal()\n",
"\n",
" ordinals = np.random.randint(max_date - min_date, size=n_dates) + min_date\n",
" dates = [date.fromordinal(ordinal) for ordinal in ordinals]\n",
"\n",
" x = [MONTHS[dt.month - 1] + \" \" + dt.strftime(\"%d, %Y\") for dt in dates]\n",
" y = [dt.isoformat() for dt in dates]\n",
" return x, y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here are a few random dates, displayed in both the input format and the target format:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 113,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input Target \n",
"--------------------------------------------------\n",
"September 20, 7075 7075-09-20 \n",
"May 15, 8579 8579-05-15 \n",
"January 11, 7103 7103-01-11 \n"
]
}
],
2020-04-19 06:01:14 +02:00
"source": [
"np.random.seed(42)\n",
"\n",
"n_dates = 3\n",
"x_example, y_example = random_dates(n_dates)\n",
"print(\"{:25s}{:25s}\".format(\"Input\", \"Target\"))\n",
"print(\"-\" * 50)\n",
"for idx in range(n_dates):\n",
" print(\"{:25s}{:25s}\".format(x_example[idx], y_example[idx]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's get the list of all possible characters in the inputs:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 114,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"' ,0123456789ADFJMNOSabceghilmnoprstuvy'"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 114,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-04-19 06:01:14 +02:00
"source": [
2020-10-19 19:33:35 +02:00
"INPUT_CHARS = \"\".join(sorted(set(\"\".join(MONTHS) + \"0123456789, \")))\n",
2020-04-19 06:01:14 +02:00
"INPUT_CHARS"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And here's the list of possible characters in the outputs:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 115,
2020-04-19 06:01:14 +02:00
"metadata": {},
"outputs": [],
"source": [
"OUTPUT_CHARS = \"0123456789-\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's write a function to convert a string to a list of character IDs, as we did in the previous exercise:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 116,
2020-04-19 06:01:14 +02:00
"metadata": {},
"outputs": [],
"source": [
"def date_str_to_ids(date_str, chars=INPUT_CHARS):\n",
" return [chars.index(c) for c in date_str]"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 117,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[19, 23, 31, 34, 23, 28, 21, 23, 32, 0, 4, 2, 1, 0, 9, 2, 9, 7]"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 117,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-04-19 06:01:14 +02:00
"source": [
"date_str_to_ids(x_example[0], INPUT_CHARS)"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 118,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[7, 0, 7, 5, 10, 0, 9, 10, 2, 0]"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 118,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-04-19 06:01:14 +02:00
"source": [
"date_str_to_ids(y_example[0], OUTPUT_CHARS)"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 119,
2020-04-19 06:01:14 +02:00
"metadata": {},
"outputs": [],
"source": [
"def prepare_date_strs(date_strs, chars=INPUT_CHARS):\n",
" X_ids = [date_str_to_ids(dt, chars) for dt in date_strs]\n",
" X = tf.ragged.constant(X_ids, ragged_rank=1)\n",
" return (X + 1).to_tensor() # using 0 as the padding token ID\n",
"\n",
"def create_dataset(n_dates):\n",
" x, y = random_dates(n_dates)\n",
" return prepare_date_strs(x, INPUT_CHARS), prepare_date_strs(y, OUTPUT_CHARS)"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 120,
2020-04-19 06:01:14 +02:00
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(42)\n",
"\n",
"X_train, Y_train = create_dataset(10000)\n",
"X_valid, Y_valid = create_dataset(2000)\n",
"X_test, Y_test = create_dataset(2000)"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 121,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(10,), dtype=int32, numpy=array([ 8, 1, 8, 6, 11, 1, 10, 11, 3, 1], dtype=int32)>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 121,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-04-19 06:01:14 +02:00
"source": [
"Y_train[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### First version: a very basic seq2seq model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's first try the simplest possible model: we feed in the input sequence, which first goes through the encoder (an embedding layer followed by a single LSTM layer), which outputs a vector, then it goes through a decoder (a single LSTM layer, followed by a dense output layer), which outputs a sequence of vectors, each representing the estimated probabilities for all possible output character.\n",
"\n",
2022-02-19 10:09:28 +01:00
"Since the decoder expects a sequence as input, we repeat the vector (which is output by the encoder) as many times as the longest possible output sequence."
2020-04-19 06:01:14 +02:00
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 122,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"313/313 [==============================] - 10s 23ms/step - loss: 1.8150 - accuracy: 0.3489 - val_loss: 1.3726 - val_accuracy: 0.4939\n",
"Epoch 2/20\n",
"313/313 [==============================] - 7s 22ms/step - loss: 1.2447 - accuracy: 0.5510 - val_loss: 1.0725 - val_accuracy: 0.6115\n",
"Epoch 3/20\n",
"313/313 [==============================] - 7s 23ms/step - loss: 1.0937 - accuracy: 0.6125 - val_loss: 1.0548 - val_accuracy: 0.6130\n",
"Epoch 4/20\n",
"313/313 [==============================] - 7s 23ms/step - loss: 1.0032 - accuracy: 0.6413 - val_loss: 3.8747 - val_accuracy: 0.1788\n",
"Epoch 5/20\n",
"313/313 [==============================] - 8s 26ms/step - loss: 0.8159 - accuracy: 0.7023 - val_loss: 0.6623 - val_accuracy: 0.7474\n",
"Epoch 6/20\n",
"313/313 [==============================] - 8s 26ms/step - loss: 0.5645 - accuracy: 0.7795 - val_loss: 0.5005 - val_accuracy: 0.8032\n",
"Epoch 7/20\n",
"313/313 [==============================] - 8s 26ms/step - loss: 0.5037 - accuracy: 0.8103 - val_loss: 0.3798 - val_accuracy: 0.8500\n",
"Epoch 8/20\n",
"313/313 [==============================] - 8s 26ms/step - loss: 0.3131 - accuracy: 0.8795 - val_loss: 0.2582 - val_accuracy: 0.9043\n",
"Epoch 9/20\n",
"313/313 [==============================] - 8s 26ms/step - loss: 0.2141 - accuracy: 0.9280 - val_loss: 0.1637 - val_accuracy: 0.9498\n",
"Epoch 10/20\n",
"313/313 [==============================] - 9s 28ms/step - loss: 0.1282 - accuracy: 0.9650 - val_loss: 0.0918 - val_accuracy: 0.9774\n",
"Epoch 11/20\n",
"313/313 [==============================] - 9s 28ms/step - loss: 0.0669 - accuracy: 0.9871 - val_loss: 0.3368 - val_accuracy: 0.8871\n",
"Epoch 12/20\n",
"313/313 [==============================] - 10s 32ms/step - loss: 0.1551 - accuracy: 0.9662 - val_loss: 0.0398 - val_accuracy: 0.9949\n",
"Epoch 13/20\n",
"313/313 [==============================] - 9s 29ms/step - loss: 0.0291 - accuracy: 0.9969 - val_loss: 0.0240 - val_accuracy: 0.9984\n",
"Epoch 14/20\n",
"313/313 [==============================] - 9s 30ms/step - loss: 0.0182 - accuracy: 0.9986 - val_loss: 0.0161 - val_accuracy: 0.9993\n",
"Epoch 15/20\n",
"313/313 [==============================] - 9s 30ms/step - loss: 0.0119 - accuracy: 0.9995 - val_loss: 0.0112 - val_accuracy: 0.9997\n",
"Epoch 16/20\n",
"313/313 [==============================] - 10s 32ms/step - loss: 0.0082 - accuracy: 0.9998 - val_loss: 0.0083 - val_accuracy: 0.9999\n",
"Epoch 17/20\n",
"313/313 [==============================] - 10s 33ms/step - loss: 0.0059 - accuracy: 0.9999 - val_loss: 0.0058 - val_accuracy: 0.9999\n",
"Epoch 18/20\n",
"313/313 [==============================] - 11s 34ms/step - loss: 0.0042 - accuracy: 1.0000 - val_loss: 0.0043 - val_accuracy: 0.9999\n",
"Epoch 19/20\n",
"313/313 [==============================] - 10s 33ms/step - loss: 0.0031 - accuracy: 1.0000 - val_loss: 0.0034 - val_accuracy: 0.9999\n",
"Epoch 20/20\n",
"313/313 [==============================] - 12s 40ms/step - loss: 0.0024 - accuracy: 1.0000 - val_loss: 0.0026 - val_accuracy: 1.0000\n"
]
}
],
2020-04-19 06:01:14 +02:00
"source": [
"embedding_size = 32\n",
"max_output_length = Y_train.shape[1]\n",
"\n",
"np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"encoder = tf.keras.Sequential([\n",
" tf.keras.layers.Embedding(input_dim=len(INPUT_CHARS) + 1,\n",
2020-04-19 06:01:14 +02:00
" output_dim=embedding_size,\n",
" input_shape=[None]),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.LSTM(128)\n",
2020-04-19 06:01:14 +02:00
"])\n",
"\n",
2021-10-17 04:04:08 +02:00
"decoder = tf.keras.Sequential([\n",
" tf.keras.layers.LSTM(128, return_sequences=True),\n",
" tf.keras.layers.Dense(len(OUTPUT_CHARS) + 1, activation=\"softmax\")\n",
2020-04-19 06:01:14 +02:00
"])\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
2020-04-19 06:01:14 +02:00
" encoder,\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.RepeatVector(max_output_length),\n",
2020-04-19 06:01:14 +02:00
" decoder\n",
"])\n",
"\n",
2021-10-17 04:04:08 +02:00
"optimizer = tf.keras.optimizers.Nadam()\n",
2020-04-19 06:01:14 +02:00
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(X_train, Y_train, epochs=20,\n",
2020-04-22 09:21:56 +02:00
" validation_data=(X_valid, Y_valid))"
2020-04-19 06:01:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Looks great, we reach 100% validation accuracy! Let's use the model to make some predictions. We will need to be able to convert a sequence of character IDs to a readable string:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 123,
2020-04-19 06:01:14 +02:00
"metadata": {},
"outputs": [],
"source": [
"def ids_to_date_strs(ids, chars=OUTPUT_CHARS):\n",
" return [\"\".join([(\"?\" + chars)[index] for index in sequence])\n",
" for sequence in ids]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can use the model to convert some dates"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 124,
2020-04-19 06:01:14 +02:00
"metadata": {},
"outputs": [],
"source": [
"X_new = prepare_date_strs([\"September 17, 2009\", \"July 14, 1789\"])"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 125,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2009-09-17\n",
"1789-07-14\n"
]
}
],
2020-04-19 06:01:14 +02:00
"source": [
2022-02-19 10:09:28 +01:00
"ids = model.predict(X_new).argmax(axis=-1)\n",
2020-04-19 06:01:14 +02:00
"for date_str in ids_to_date_strs(ids):\n",
" print(date_str)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Perfect! :)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, since the model was only trained on input strings of length 18 (which is the length of the longest date), it does not perform well if we try to use it to make predictions on shorter sequences:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 126,
2020-04-19 06:01:14 +02:00
"metadata": {},
"outputs": [],
"source": [
"X_new = prepare_date_strs([\"May 02, 2020\", \"July 14, 1789\"])"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 127,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2020-02-02\n",
"1789-01-14\n"
]
}
],
2020-04-19 06:01:14 +02:00
"source": [
2022-02-19 10:09:28 +01:00
"ids = model.predict(X_new).argmax(axis=-1)\n",
2020-04-19 06:01:14 +02:00
"for date_str in ids_to_date_strs(ids):\n",
" print(date_str)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Oops! We need to ensure that we always pass sequences of the same length as during training, using padding if necessary. Let's write a little helper function for that:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 128,
2020-04-19 06:01:14 +02:00
"metadata": {},
"outputs": [],
"source": [
"max_input_length = X_train.shape[1]\n",
"\n",
"def prepare_date_strs_padded(date_strs):\n",
" X = prepare_date_strs(date_strs)\n",
" if X.shape[1] < max_input_length:\n",
" X = tf.pad(X, [[0, 0], [0, max_input_length - X.shape[1]]])\n",
" return X\n",
"\n",
"def convert_date_strs(date_strs):\n",
" X = prepare_date_strs_padded(date_strs)\n",
2022-02-19 10:09:28 +01:00
" ids = model.predict(X).argmax(axis=-1)\n",
2020-04-19 06:01:14 +02:00
" return ids_to_date_strs(ids)"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 129,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"['2020-05-02', '1789-07-14']"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 129,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-04-19 06:01:14 +02:00
"source": [
"convert_date_strs([\"May 02, 2020\", \"July 14, 1789\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Cool! Granted, there are certainly much easier ways to write a date conversion tool (e.g., using regular expressions or even basic string manipulation), but you have to admit that using neural networks is way cooler. ;-)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"However, real-life sequence-to-sequence problems will usually be harder, so for the sake of completeness, let's build a more powerful model."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-04-22 09:21:56 +02:00
"### Second version: feeding the shifted targets to the decoder (teacher forcing)"
2020-04-19 06:01:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Instead of feeding the decoder a simple repetition of the encoder's output vector, we can feed it the target sequence, shifted by one time step to the right. This way, at each time step the decoder will know what the previous target character was. This should help is tackle more complex sequence-to-sequence problems.\n",
"\n",
"Since the first output character of each target sequence has no previous character, we will need a new token to represent the start-of-sequence (sos).\n",
"\n",
"During inference, we won't know the target, so what will we feed the decoder? We can just predict one character at a time, starting with an sos token, then feeding the decoder all the characters that were predicted so far (we will look at this in more details later in this notebook).\n",
"\n",
"But if the decoder's LSTM expects to get the previous target as input at each step, how shall we pass it it the vector output by the encoder? Well, one option is to ignore the output vector, and instead use the encoder's LSTM state as the initial state of the decoder's LSTM (which requires that encoder's LSTM must have the same number of units as the decoder's LSTM).\n",
"\n",
"Now let's create the decoder's inputs (for training, validation and testing). The sos token will be represented using the last possible output character's ID + 1."
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 130,
2020-04-19 06:01:14 +02:00
"metadata": {},
"outputs": [],
"source": [
"sos_id = len(OUTPUT_CHARS) + 1\n",
"\n",
"def shifted_output_sequences(Y):\n",
" sos_tokens = tf.fill(dims=(len(Y), 1), value=sos_id)\n",
" return tf.concat([sos_tokens, Y[:, :-1]], axis=1)\n",
"\n",
"X_train_decoder = shifted_output_sequences(Y_train)\n",
"X_valid_decoder = shifted_output_sequences(Y_valid)\n",
"X_test_decoder = shifted_output_sequences(Y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's take a look at the decoder's training inputs:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 131,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(10000, 10), dtype=int32, numpy=\n",
"array([[12, 8, 1, ..., 10, 11, 3],\n",
" [12, 9, 6, ..., 6, 11, 2],\n",
" [12, 8, 2, ..., 2, 11, 2],\n",
" ...,\n",
" [12, 10, 8, ..., 2, 11, 4],\n",
" [12, 2, 2, ..., 3, 11, 3],\n",
" [12, 8, 9, ..., 8, 11, 3]], dtype=int32)>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 131,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-04-19 06:01:14 +02:00
"source": [
"X_train_decoder"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's build the model. It's not a simple sequential model anymore, so let's use the functional API:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 132,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10\n",
"313/313 [==============================] - 11s 27ms/step - loss: 1.6824 - accuracy: 0.3734 - val_loss: 1.4054 - val_accuracy: 0.4681\n",
"Epoch 2/10\n",
"313/313 [==============================] - 8s 26ms/step - loss: 1.1935 - accuracy: 0.5550 - val_loss: 0.8868 - val_accuracy: 0.6750\n",
"Epoch 3/10\n",
"313/313 [==============================] - 8s 26ms/step - loss: 0.6403 - accuracy: 0.7700 - val_loss: 0.3493 - val_accuracy: 0.8978\n",
"Epoch 4/10\n",
"313/313 [==============================] - 8s 26ms/step - loss: 0.2292 - accuracy: 0.9423 - val_loss: 0.1254 - val_accuracy: 0.9782\n",
"Epoch 5/10\n",
"313/313 [==============================] - 8s 26ms/step - loss: 0.0694 - accuracy: 0.9932 - val_loss: 0.0441 - val_accuracy: 0.9982\n",
"Epoch 6/10\n",
"313/313 [==============================] - 9s 29ms/step - loss: 0.0576 - accuracy: 0.9923 - val_loss: 0.0280 - val_accuracy: 0.9988\n",
"Epoch 7/10\n",
"313/313 [==============================] - 8s 26ms/step - loss: 0.0179 - accuracy: 0.9998 - val_loss: 0.0143 - val_accuracy: 0.9999\n",
"Epoch 8/10\n",
"313/313 [==============================] - 6s 18ms/step - loss: 0.0107 - accuracy: 0.9999 - val_loss: 0.0092 - val_accuracy: 0.9999\n",
"Epoch 9/10\n",
"313/313 [==============================] - 6s 20ms/step - loss: 0.0070 - accuracy: 1.0000 - val_loss: 0.0065 - val_accuracy: 0.9999\n",
"Epoch 10/10\n",
"313/313 [==============================] - 6s 18ms/step - loss: 0.0050 - accuracy: 1.0000 - val_loss: 0.0047 - val_accuracy: 0.9999\n"
]
}
],
2020-04-19 06:01:14 +02:00
"source": [
"encoder_embedding_size = 32\n",
"decoder_embedding_size = 32\n",
"lstm_units = 128\n",
"\n",
"np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"encoder_input = tf.keras.layers.Input(shape=[None], dtype=tf.int32)\n",
"encoder_embedding = tf.keras.layers.Embedding(\n",
2020-04-19 06:01:14 +02:00
" input_dim=len(INPUT_CHARS) + 1,\n",
" output_dim=encoder_embedding_size)(encoder_input)\n",
2021-10-17 04:04:08 +02:00
"_, encoder_state_h, encoder_state_c = tf.keras.layers.LSTM(\n",
2020-04-19 06:01:14 +02:00
" lstm_units, return_state=True)(encoder_embedding)\n",
"encoder_state = [encoder_state_h, encoder_state_c]\n",
"\n",
2021-10-17 04:04:08 +02:00
"decoder_input = tf.keras.layers.Input(shape=[None], dtype=tf.int32)\n",
"decoder_embedding = tf.keras.layers.Embedding(\n",
2020-04-19 06:01:14 +02:00
" input_dim=len(OUTPUT_CHARS) + 2,\n",
" output_dim=decoder_embedding_size)(decoder_input)\n",
2021-10-17 04:04:08 +02:00
"decoder_lstm_output = tf.keras.layers.LSTM(lstm_units, return_sequences=True)(\n",
2020-04-19 06:01:14 +02:00
" decoder_embedding, initial_state=encoder_state)\n",
2021-10-17 04:04:08 +02:00
"decoder_output = tf.keras.layers.Dense(len(OUTPUT_CHARS) + 1,\n",
2020-04-19 06:01:14 +02:00
" activation=\"softmax\")(decoder_lstm_output)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Model(inputs=[encoder_input, decoder_input],\n",
2020-04-19 06:01:14 +02:00
" outputs=[decoder_output])\n",
"\n",
2021-10-17 04:04:08 +02:00
"optimizer = tf.keras.optimizers.Nadam()\n",
2020-04-19 06:01:14 +02:00
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=optimizer,\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit([X_train, X_train_decoder], Y_train, epochs=10,\n",
2020-04-22 09:21:56 +02:00
" validation_data=([X_valid, X_valid_decoder], Y_valid))"
2020-04-19 06:01:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This model also reaches 100% validation accuracy, but it does so even faster."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's once again use the model to make some predictions. This time we need to predict characters one by one."
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 133,
2020-04-19 06:01:14 +02:00
"metadata": {},
"outputs": [],
"source": [
"sos_id = len(OUTPUT_CHARS) + 1\n",
"\n",
"def predict_date_strs(date_strs):\n",
" X = prepare_date_strs_padded(date_strs)\n",
" Y_pred = tf.fill(dims=(len(X), 1), value=sos_id)\n",
" for index in range(max_output_length):\n",
" pad_size = max_output_length - Y_pred.shape[1]\n",
" X_decoder = tf.pad(Y_pred, [[0, 0], [0, pad_size]])\n",
" Y_probas_next = model.predict([X, X_decoder])[:, index:index+1]\n",
" Y_pred_next = tf.argmax(Y_probas_next, axis=-1, output_type=tf.int32)\n",
" Y_pred = tf.concat([Y_pred, Y_pred_next], axis=1)\n",
" return ids_to_date_strs(Y_pred[:, 1:])"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 134,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"['1789-07-14', '2020-05-01']"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 134,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-04-19 06:01:14 +02:00
"source": [
"predict_date_strs([\"July 14, 1789\", \"May 01, 2020\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"Works fine! Next, feel free to write a Transformer version. :)"
2020-04-19 06:01:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 10.\n",
2022-03-21 09:43:01 +01:00
"_Exercise: Go through Keras's tutorial for [Natural language image search with a Dual Encoder](https://homl.info/dualtuto). You will learn how to build a model capable of representing both images and text within the same embedding space. This makes it possible to search for images using a text prompt, like in the [CLIP model](https://openai.com/blog/clip/) by OpenAI._ "
2020-04-19 06:01:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-03-21 09:43:01 +01:00
"Just click the link and follow the instructions."
2020-04-19 06:01:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 11.\n",
2022-03-21 09:43:01 +01:00
"_Exercise: Use the Transformers library to download a pretrained language model capable of generating text (e.g., GPT), and try generating more convincing Shakespearean text. You will need to use the model's `generate()` method—see Hugging Face's documentation for more details._"
2020-04-19 06:01:14 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, let's load a pretrained model. In this example, we will use OpenAI's GPT model, with an additional Language Model on top (just a linear layer with weights tied to the input embeddings). Let's import it and load the pretrained weights (this will download about 445MB of data to `~/.cache/torch/transformers`):"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 135,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"All model checkpoint layers were used when initializing TFOpenAIGPTLMHeadModel.\n",
"\n",
"All the layers of TFOpenAIGPTLMHeadModel were initialized from the model checkpoint at openai-gpt.\n",
"If your task is similar to the task the model of the checkpoint was trained on, you can already use TFOpenAIGPTLMHeadModel for predictions without further training.\n"
]
}
],
2020-04-19 06:01:14 +02:00
"source": [
"from transformers import TFOpenAIGPTLMHeadModel\n",
"\n",
"model = TFOpenAIGPTLMHeadModel.from_pretrained(\"openai-gpt\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next we will need a specialized tokenizer for this model. This one will try to use the [spaCy](https://spacy.io/) and [ftfy](https://pypi.org/project/ftfy/) libraries if they are installed, or else it will fall back to BERT's `BasicTokenizer` followed by Byte-Pair Encoding (which should be fine for most use cases)."
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 136,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.\n"
]
}
],
2020-04-19 06:01:14 +02:00
"source": [
"from transformers import OpenAIGPTTokenizer\n",
"\n",
"tokenizer = OpenAIGPTTokenizer.from_pretrained(\"openai-gpt\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's use the tokenizer to tokenize and encode the prompt text:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 137,
2020-04-19 06:01:14 +02:00
"metadata": {},
2022-03-21 09:43:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': [3570, 1473], 'attention_mask': [1, 1]}"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 137,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer(\"hello everyone\")"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 138,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(1, 10), dtype=int32, numpy=\n",
"array([[ 616, 5751, 6404, 498, 9606, 240, 616, 26271, 7428,\n",
" 16187]], dtype=int32)>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 138,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-04-19 06:01:14 +02:00
"source": [
"prompt_text = \"This royal throne of kings, this sceptred isle\"\n",
"encoded_prompt = tokenizer.encode(prompt_text,\n",
" add_special_tokens=False,\n",
" return_tensors=\"tf\")\n",
"encoded_prompt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Easy! Next, let's use the model to generate text after the prompt. We will generate 5 different sentences, each starting with the prompt text, followed by 40 additional tokens. For an explanation of what all the hyperparameters do, make sure to check out this great [blog post](https://huggingface.co/blog/how-to-generate) by Patrick von Platen (from Hugging Face). You can play around with the hyperparameters to try to obtain better results."
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 139,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(5, 50), dtype=int32, numpy=\n",
"array([[ 616, 5751, 6404, 498, 9606, 240, 616, 26271, 7428,\n",
" 16187, 498, 481, 550, 12974, 554, 20275, 544, 481,\n",
" 808, 1082, 525, 759, 13717, 507, 617, 616, 1294,\n",
" 1276, 239, 40477, 249, 1048, 2210, 525, 249, 880,\n",
" 694, 817, 485, 788, 507, 240, 244, 481, 762,\n",
" 4049, 3983, 6474, 1387, 485],\n",
" [ 616, 5751, 6404, 498, 9606, 240, 616, 26271, 7428,\n",
" 16187, 509, 1163, 485, 1272, 8660, 3380, 14760, 240,\n",
" 1389, 557, 481, 7232, 8, 789, 3408, 239, 754,\n",
" 10253, 558, 694, 2556, 488, 2093, 485, 2185, 917,\n",
" 11, 5272, 6372, 562, 1272, 11413, 239, 40477, 481,\n",
" 1583, 618, 558, 524, 1074],\n",
" [ 616, 5751, 6404, 498, 9606, 240, 616, 26271, 7428,\n",
" 16187, 544, 597, 622, 1163, 488, 481, 1594, 498,\n",
" 622, 11547, 267, 256, 616, 509, 885, 481, 7789,\n",
" 498, 481, 588, 1917, 240, 984, 544, 491, 618,\n",
" 4647, 681, 535, 4244, 239, 40477, 616, 509, 481,\n",
" 12194, 1734, 481, 588, 1917],\n",
" [ 616, 5751, 6404, 498, 9606, 240, 616, 26271, 7428,\n",
" 16187, 980, 246, 3128, 4321, 525, 759, 595, 580,\n",
" 12563, 522, 15668, 239, 507, 812, 16841, 1073, 655,\n",
" 544, 664, 3409, 500, 622, 6903, 522, 481, 1092,\n",
" 812, 7629, 617, 481, 1988, 240, 488, 481, 4814,\n",
" 812, 580, 7752, 498, 987],\n",
" [ 616, 5751, 6404, 498, 9606, 240, 616, 26271, 7428,\n",
" 16187, 812, 580, 704, 3360, 4034, 485, 618, 6099,\n",
" 33974, 239, 40477, 870, 3754, 240, 547, 3089, 239,\n",
" 40477, 269, 269, 269, 40477, 246, 1092, 1882, 504,\n",
" 513, 1188, 3761, 27661, 485, 10525, 239, 244, 848,\n",
" 504, 239, 249, 825, 512]], dtype=int32)>"
]
},
2022-04-17 02:17:13 +02:00
"execution_count": 139,
2022-03-21 09:43:01 +01:00
"metadata": {},
"output_type": "execute_result"
}
],
2020-04-19 06:01:14 +02:00
"source": [
"num_sequences = 5\n",
"length = 40\n",
"\n",
"generated_sequences = model.generate(\n",
" input_ids=encoded_prompt,\n",
" do_sample=True,\n",
" max_length=length + len(encoded_prompt[0]),\n",
" temperature=1.0,\n",
" top_k=0,\n",
" top_p=0.9,\n",
" repetition_penalty=1.0,\n",
" num_return_sequences=num_sequences,\n",
")\n",
"\n",
"generated_sequences"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's decode the generated sequences and print them:"
]
},
{
"cell_type": "code",
2022-04-17 02:17:13 +02:00
"execution_count": 140,
2022-03-21 09:43:01 +01:00
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"this royal throne of kings, this sceptred isle of the necronomicon is the only place that can unlock it from this dark world. \n",
" i am surprised that i've been able to see it, \" the man named dallon says to\n",
"--------------------------------------------------------------------------------\n",
"this royal throne of kings, this sceptred isle was home to many beloved possessors, such as the mighty astaroth. their wives had been husband and wife to lord teixiara for many generations. \n",
" the high king had his own\n",
"--------------------------------------------------------------------------------\n",
"this royal throne of kings, this sceptred isle is now our home and the land of our fathers!'this was made the standard of the coates, which is at king celebrant's command. \n",
" this was the longest story the coates\n",
"--------------------------------------------------------------------------------\n",
"this royal throne of kings, this sceptred isle has a powerful spirit that can not be severed or erased. it will reign until there is no army in our realm or the light will fade from the sky, and the lands will be stripped of its\n",
"--------------------------------------------------------------------------------\n",
"this royal throne of kings, this sceptred isle will be your final gift to king dragomir. \n",
" good luck, my guards. \n",
" * * * \n",
" a light touch on her arm caused aleria to jolt. \" come on. i think you\n",
"--------------------------------------------------------------------------------\n"
]
}
],
2020-04-19 06:01:14 +02:00
"source": [
"for sequence in generated_sequences:\n",
" text = tokenizer.decode(sequence, clean_up_tokenization_spaces=True)\n",
" print(text)\n",
" print(\"-\" * 80)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can try more recent (and larger) models, such as GPT-2, CTRL, Transformer-XL or XLNet, which are all available as pretrained models in the transformers library, including variants with Language Models on top. The preprocessing steps vary slightly between models, so make sure to check out this [generation example](https://github.com/huggingface/transformers/blob/master/examples/run_generation.py) from the transformers documentation (this example uses PyTorch, but it will work with very little tweaks, such as adding `TF` at the beginning of the model class name, removing the `.to()` method calls, and using `return_tensors=\"tf\"` instead of `\"pt\"`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Hope you enjoyed this chapter! :)"
]
2016-09-27 23:31:21 +02:00
}
],
"metadata": {
2022-04-16 07:30:08 +02:00
"accelerator": "GPU",
2016-09-27 23:31:21 +02:00
"kernelspec": {
2022-02-19 10:09:28 +01:00
"display_name": "Python 3",
2016-09-27 23:31:21 +02:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
2017-10-18 09:44:57 +02:00
"version": 3
2016-09-27 23:31:21 +02:00
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2022-09-22 09:14:01 +02:00
"version": "3.10.6"
2016-09-27 23:31:21 +02:00
},
"nav_menu": {},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
2017-10-18 09:44:57 +02:00
"threshold": 6,
2016-09-27 23:31:21 +02:00
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
2022-04-16 07:30:08 +02:00
}
2016-09-27 23:31:21 +02:00
},
"nbformat": 4,
2020-04-06 09:13:12 +02:00
"nbformat_minor": 4
2016-09-27 23:31:21 +02:00
}