1232 lines
35 KiB
Plaintext
1232 lines
35 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"**Chapter 16 – Natural Language Processing with RNNs and Attention**"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"_This notebook contains all the sample code in chapter 16._"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Setup"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures. We also check that Python 3.5 or later is installed (although Python 2.x may work, it is deprecated so we strongly recommend you use Python 3 instead), as well as Scikit-Learn ≥0.20 and TensorFlow ≥2.0-preview."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Python ≥3.5 is required\n",
|
||
"import sys\n",
|
||
"assert sys.version_info >= (3, 5)\n",
|
||
"\n",
|
||
"# Scikit-Learn ≥0.20 is required\n",
|
||
"import sklearn\n",
|
||
"assert sklearn.__version__ >= \"0.20\"\n",
|
||
"\n",
|
||
"# TensorFlow ≥2.0-preview is required\n",
|
||
"import tensorflow as tf\n",
|
||
"from tensorflow import keras\n",
|
||
"assert tf.__version__ >= \"2.0\"\n",
|
||
"\n",
|
||
"# Common imports\n",
|
||
"import numpy as np\n",
|
||
"import os\n",
|
||
"\n",
|
||
"# to make this notebook's output stable across runs\n",
|
||
"np.random.seed(42)\n",
|
||
"tf.random.set_seed(42)\n",
|
||
"\n",
|
||
"# To plot pretty figures\n",
|
||
"%matplotlib inline\n",
|
||
"import matplotlib as mpl\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"mpl.rc('axes', labelsize=14)\n",
|
||
"mpl.rc('xtick', labelsize=12)\n",
|
||
"mpl.rc('ytick', labelsize=12)\n",
|
||
"\n",
|
||
"# Where to save the figures\n",
|
||
"PROJECT_ROOT_DIR = \".\"\n",
|
||
"CHAPTER_ID = \"nlp\"\n",
|
||
"IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID)\n",
|
||
"os.makedirs(IMAGES_PATH, exist_ok=True)\n",
|
||
"\n",
|
||
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
|
||
" path = os.path.join(IMAGES_PATH, fig_id + \".\" + fig_extension)\n",
|
||
" print(\"Saving figure\", fig_id)\n",
|
||
" if tight_layout:\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.savefig(path, format=fig_extension, dpi=resolution)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Char-RNN"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Splitting a sequence into batches of shuffled windows"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"For example, let's split the sequence 0 to 14 into windows of length 5, each shifted by 2 (e.g.,`[0, 1, 2, 3, 4]`, `[2, 3, 4, 5, 6]`, etc.), then shuffle them, and split them into inputs (the first 4 steps) and targets (the last 4 steps) (e.g., `[2, 3, 4, 5, 6]` would be split into `[[2, 3, 4, 5], [3, 4, 5, 6]]`), then create batches of 3 such input/target pairs:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {
|
||
"scrolled": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"np.random.seed(42)\n",
|
||
"tf.random.set_seed(42)\n",
|
||
"\n",
|
||
"n_steps = 5\n",
|
||
"dataset = tf.data.Dataset.from_tensor_slices(tf.range(15))\n",
|
||
"dataset = dataset.window(n_steps, shift=2, drop_remainder=True)\n",
|
||
"dataset = dataset.flat_map(lambda window: window.batch(n_steps))\n",
|
||
"dataset = dataset.shuffle(10).map(lambda window: (window[:-1], window[1:]))\n",
|
||
"dataset = dataset.batch(3).prefetch(1)\n",
|
||
"for index, (X_batch, Y_batch) in enumerate(dataset):\n",
|
||
" print(\"_\" * 20, \"Batch\", index, \"\\nX_batch\")\n",
|
||
" print(X_batch.numpy())\n",
|
||
" print(\"=\" * 5, \"\\nY_batch\")\n",
|
||
" print(Y_batch.numpy())"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Loading the Data and Preparing the Dataset"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"shakespeare_url = \"https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\"\n",
|
||
"filepath = keras.utils.get_file(\"shakespeare.txt\", shakespeare_url)\n",
|
||
"with open(filepath) as f:\n",
|
||
" shakespeare_text = f.read()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"print(shakespeare_text[:148])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"\"\".join(sorted(set(shakespeare_text.lower())))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"tokenizer = keras.preprocessing.text.Tokenizer(char_level=True)\n",
|
||
"tokenizer.fit_on_texts(shakespeare_text)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"tokenizer.texts_to_sequences([\"First\"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"tokenizer.sequences_to_texts([[20, 6, 9, 8, 3]])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"max_id = len(tokenizer.word_index) # number of distinct characters\n",
|
||
"dataset_size = tokenizer.document_count # total number of characters"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"[encoded] = np.array(tokenizer.texts_to_sequences([shakespeare_text])) - 1\n",
|
||
"train_size = dataset_size * 90 // 100\n",
|
||
"dataset = tf.data.Dataset.from_tensor_slices(encoded[:train_size])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"n_steps = 100\n",
|
||
"window_length = n_steps + 1 # target = input shifted 1 character ahead\n",
|
||
"dataset = dataset.repeat().window(window_length, shift=1, drop_remainder=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"dataset = dataset.flat_map(lambda window: window.batch(window_length))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"np.random.seed(42)\n",
|
||
"tf.random.set_seed(42)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"batch_size = 32\n",
|
||
"dataset = dataset.shuffle(10000).batch(batch_size)\n",
|
||
"dataset = dataset.map(lambda windows: (windows[:, :-1], windows[:, 1:]))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"dataset = dataset.map(\n",
|
||
" lambda X_batch, Y_batch: (tf.one_hot(X_batch, depth=max_id), Y_batch))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"dataset = dataset.prefetch(1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"for X_batch, Y_batch in dataset.take(1):\n",
|
||
" print(X_batch.shape, Y_batch.shape)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Creating and Training the Model"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"model = keras.models.Sequential([\n",
|
||
" keras.layers.GRU(128, return_sequences=True, input_shape=[None, max_id],\n",
|
||
" # no dropout in stateful RNN (https://github.com/ageron/handson-ml2/issues/32)\n",
|
||
" # dropout=0.2, recurrent_dropout=0.2,\n",
|
||
" ),\n",
|
||
" keras.layers.GRU(128, return_sequences=True,\n",
|
||
" # dropout=0.2, recurrent_dropout=0.2\n",
|
||
" ),\n",
|
||
" keras.layers.TimeDistributed(keras.layers.Dense(max_id,\n",
|
||
" activation=\"softmax\"))\n",
|
||
"])\n",
|
||
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"adam\")\n",
|
||
"history = model.fit(dataset, steps_per_epoch=train_size // batch_size,\n",
|
||
" epochs=10)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Using the Model to Generate Text"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def preprocess(texts):\n",
|
||
" X = np.array(tokenizer.texts_to_sequences(texts)) - 1\n",
|
||
" return tf.one_hot(X, max_id)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"X_new = preprocess([\"How are yo\"])\n",
|
||
"Y_pred = model.predict_classes(X_new)\n",
|
||
"tokenizer.sequences_to_texts(Y_pred + 1)[0][-1] # 1st sentence, last char"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 21,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"tf.random.set_seed(42)\n",
|
||
"\n",
|
||
"tf.random.categorical([[np.log(0.5), np.log(0.4), np.log(0.1)]], num_samples=40).numpy()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def next_char(text, temperature=1):\n",
|
||
" X_new = preprocess([text])\n",
|
||
" y_proba = model.predict(X_new)[0, -1:, :]\n",
|
||
" rescaled_logits = tf.math.log(y_proba) / temperature\n",
|
||
" char_id = tf.random.categorical(rescaled_logits, num_samples=1) + 1\n",
|
||
" return tokenizer.sequences_to_texts(char_id.numpy())[0]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 23,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"tf.random.set_seed(42)\n",
|
||
"\n",
|
||
"next_char(\"How are yo\", temperature=1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 24,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def complete_text(text, n_chars=50, temperature=1):\n",
|
||
" for _ in range(n_chars):\n",
|
||
" text += next_char(text, temperature)\n",
|
||
" return text"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"tf.random.set_seed(42)\n",
|
||
"\n",
|
||
"print(complete_text(\"t\", temperature=0.2))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"print(complete_text(\"t\", temperature=1))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"print(complete_text(\"t\", temperature=2))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Stateful RNN"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"tf.random.set_seed(42)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 29,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"dataset = tf.data.Dataset.from_tensor_slices(encoded[:train_size])\n",
|
||
"dataset = dataset.window(window_length, shift=n_steps, drop_remainder=True)\n",
|
||
"dataset = dataset.flat_map(lambda window: window.batch(window_length))\n",
|
||
"dataset = dataset.repeat().batch(1)\n",
|
||
"dataset = dataset.map(lambda windows: (windows[:, :-1], windows[:, 1:]))\n",
|
||
"dataset = dataset.map(\n",
|
||
" lambda X_batch, Y_batch: (tf.one_hot(X_batch, depth=max_id), Y_batch))\n",
|
||
"dataset = dataset.prefetch(1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 30,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"batch_size = 32\n",
|
||
"encoded_parts = np.array_split(encoded[:train_size], batch_size)\n",
|
||
"datasets = []\n",
|
||
"for encoded_part in encoded_parts:\n",
|
||
" dataset = tf.data.Dataset.from_tensor_slices(encoded_part)\n",
|
||
" dataset = dataset.window(window_length, shift=n_steps, drop_remainder=True)\n",
|
||
" dataset = dataset.flat_map(lambda window: window.batch(window_length))\n",
|
||
" datasets.append(dataset)\n",
|
||
"dataset = tf.data.Dataset.zip(tuple(datasets)).map(lambda *windows: tf.stack(windows))\n",
|
||
"dataset = dataset.repeat().map(lambda windows: (windows[:, :-1], windows[:, 1:]))\n",
|
||
"dataset = dataset.map(\n",
|
||
" lambda X_batch, Y_batch: (tf.one_hot(X_batch, depth=max_id), Y_batch))\n",
|
||
"dataset = dataset.prefetch(1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 31,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"model = keras.models.Sequential([\n",
|
||
" keras.layers.GRU(128, return_sequences=True, stateful=True,\n",
|
||
" dropout=0.2, recurrent_dropout=0.2,\n",
|
||
" batch_input_shape=[batch_size, None, max_id]),\n",
|
||
" keras.layers.GRU(128, return_sequences=True, stateful=True,\n",
|
||
" dropout=0.2, recurrent_dropout=0.2),\n",
|
||
" keras.layers.TimeDistributed(keras.layers.Dense(max_id,\n",
|
||
" activation=\"softmax\"))\n",
|
||
"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 32,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class ResetStatesCallback(keras.callbacks.Callback):\n",
|
||
" def on_epoch_begin(self, epoch, logs):\n",
|
||
" self.model.reset_states()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 33,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"adam\")\n",
|
||
"steps_per_epoch = train_size // batch_size // n_steps\n",
|
||
"model.fit(dataset, steps_per_epoch=steps_per_epoch, epochs=50,\n",
|
||
" callbacks=[ResetStatesCallback()])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"To use the model with different batch sizes, we need to create a stateless copy. We can get rid of dropout since it is only used during training:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 34,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"stateless_model = keras.models.Sequential([\n",
|
||
" keras.layers.GRU(128, return_sequences=True, input_shape=[None, max_id]),\n",
|
||
" keras.layers.GRU(128, return_sequences=True),\n",
|
||
" keras.layers.TimeDistributed(keras.layers.Dense(max_id,\n",
|
||
" activation=\"softmax\"))\n",
|
||
"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"To set the weights, we first need to build the model (so the weights get created):"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 35,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"stateless_model.build(tf.TensorShape([None, None, max_id]))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 36,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"stateless_model.set_weights(model.get_weights())\n",
|
||
"model = stateless_model"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 37,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"tf.random.set_seed(42)\n",
|
||
"\n",
|
||
"print(complete_text(\"t\"))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Sentiment Analysis"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 38,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"tf.random.set_seed(42)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"You can load the IMDB dataset easily:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 39,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"(X_train, y_test), (X_valid, y_test) = keras.datasets.imdb.load_data()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 40,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"X_train[0][:10]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 41,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"word_index = keras.datasets.imdb.get_word_index()\n",
|
||
"id_to_word = {id_ + 3: word for word, id_ in word_index.items()}\n",
|
||
"for id_, token in enumerate((\"<pad>\", \"<sos>\", \"<unk>\")):\n",
|
||
" id_to_word[id_] = token\n",
|
||
"\" \".join([id_to_word[id_] for id_ in X_train[0][:10]])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 42,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import tensorflow_datasets as tfds\n",
|
||
"\n",
|
||
"datasets, info = tfds.load(\"imdb_reviews\", as_supervised=True, with_info=True)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 43,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"datasets.keys()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 44,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"train_size = info.splits[\"train\"].num_examples\n",
|
||
"test_size = info.splits[\"test\"].num_examples"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 45,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"train_size, test_size"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 46,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"for X_batch, y_batch in datasets[\"train\"].batch(2).take(1):\n",
|
||
" for review, label in zip(X_batch.numpy(), y_batch.numpy()):\n",
|
||
" print(\"Review:\", review.decode(\"utf-8\")[:200], \"...\")\n",
|
||
" print(\"Label:\", label, \"= Positive\" if label else \"= Negative\")\n",
|
||
" print()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 47,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def preprocess(X_batch, y_batch):\n",
|
||
" X_batch = tf.strings.substr(X_batch, 0, 300)\n",
|
||
" X_batch = tf.strings.regex_replace(X_batch, rb\"<br\\s*/?>\", b\" \")\n",
|
||
" X_batch = tf.strings.regex_replace(X_batch, b\"[^a-zA-Z']\", b\" \")\n",
|
||
" X_batch = tf.strings.split(X_batch)\n",
|
||
" return X_batch.to_tensor(default_value=b\"<pad>\"), y_batch"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 48,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"preprocess(X_batch, y_batch)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 49,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"from collections import Counter\n",
|
||
"\n",
|
||
"vocabulary = Counter()\n",
|
||
"for X_batch, y_batch in datasets[\"train\"].batch(32).map(preprocess):\n",
|
||
" for review in X_batch:\n",
|
||
" vocabulary.update(list(review.numpy()))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 50,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"vocabulary.most_common()[:3]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 51,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"len(vocabulary)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 52,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"vocab_size = 10000\n",
|
||
"truncated_vocabulary = [\n",
|
||
" word for word, count in vocabulary.most_common()[:vocab_size]]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 53,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"word_to_id = {word: index for index, word in enumerate(truncated_vocabulary)}\n",
|
||
"for word in b\"This movie was faaaaaantastic\".split():\n",
|
||
" print(word_to_id.get(word) or vocab_size)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 54,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"words = tf.constant(truncated_vocabulary)\n",
|
||
"word_ids = tf.range(len(truncated_vocabulary), dtype=tf.int64)\n",
|
||
"vocab_init = tf.lookup.KeyValueTensorInitializer(words, word_ids)\n",
|
||
"num_oov_buckets = 1000\n",
|
||
"table = tf.lookup.StaticVocabularyTable(vocab_init, num_oov_buckets)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 55,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"table.lookup(tf.constant([b\"This movie was faaaaaantastic\".split()]))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 56,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def encode_words(X_batch, y_batch):\n",
|
||
" return table.lookup(X_batch), y_batch\n",
|
||
"\n",
|
||
"train_set = datasets[\"train\"].repeat().batch(32).map(preprocess)\n",
|
||
"train_set = train_set.map(encode_words).prefetch(1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 57,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"for X_batch, y_batch in train_set.take(1):\n",
|
||
" print(X_batch)\n",
|
||
" print(y_batch)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 58,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"embed_size = 128\n",
|
||
"model = keras.models.Sequential([\n",
|
||
" keras.layers.Embedding(vocab_size + num_oov_buckets, embed_size,\n",
|
||
" mask_zero=True, # not shown in the book\n",
|
||
" input_shape=[None]),\n",
|
||
" keras.layers.GRU(128, return_sequences=True),\n",
|
||
" keras.layers.GRU(128),\n",
|
||
" keras.layers.Dense(1, activation=\"sigmoid\")\n",
|
||
"])\n",
|
||
"model.compile(loss=\"binary_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
|
||
"history = model.fit(train_set, steps_per_epoch=train_size // 32, epochs=5)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Or using manual masking:"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 59,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"K = keras.backend\n",
|
||
"embed_size = 128\n",
|
||
"inputs = keras.layers.Input(shape=[None])\n",
|
||
"mask = keras.layers.Lambda(lambda inputs: K.not_equal(inputs, 0))(inputs)\n",
|
||
"z = keras.layers.Embedding(vocab_size + num_oov_buckets, embed_size)(inputs)\n",
|
||
"z = keras.layers.GRU(128, return_sequences=True)(z, mask=mask)\n",
|
||
"z = keras.layers.GRU(128)(z, mask=mask)\n",
|
||
"outputs = keras.layers.Dense(1, activation=\"sigmoid\")(z)\n",
|
||
"model = keras.models.Model(inputs=[inputs], outputs=[outputs])\n",
|
||
"model.compile(loss=\"binary_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
|
||
"history = model.fit(train_set, steps_per_epoch=train_size // 32, epochs=5)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Reusing Pretrained Embeddings"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 60,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"tf.random.set_seed(42)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 61,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"TFHUB_CACHE_DIR = os.path.join(os.curdir, \"my_tfhub_cache\")\n",
|
||
"os.environ[\"TFHUB_CACHE_DIR\"] = TFHUB_CACHE_DIR"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 62,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import tensorflow_hub as hub\n",
|
||
"\n",
|
||
"model = keras.Sequential([\n",
|
||
" hub.KerasLayer(\"https://tfhub.dev/google/tf2-preview/nnlm-en-dim50/1\",\n",
|
||
" dtype=tf.string, input_shape=[], output_shape=[50]),\n",
|
||
" keras.layers.Dense(128, activation=\"relu\"),\n",
|
||
" keras.layers.Dense(1, activation=\"sigmoid\")\n",
|
||
"])\n",
|
||
"model.compile(loss=\"binary_crossentropy\", optimizer=\"adam\",\n",
|
||
" metrics=[\"accuracy\"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 63,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"for dirpath, dirnames, filenames in os.walk(TFHUB_CACHE_DIR):\n",
|
||
" for filename in filenames:\n",
|
||
" print(os.path.join(dirpath, filename))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 64,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import tensorflow_datasets as tfds\n",
|
||
"\n",
|
||
"datasets, info = tfds.load(\"imdb_reviews\", as_supervised=True, with_info=True)\n",
|
||
"train_size = info.splits[\"train\"].num_examples\n",
|
||
"batch_size = 32\n",
|
||
"train_set = datasets[\"train\"].repeat().batch(batch_size).prefetch(1)\n",
|
||
"history = model.fit(train_set, steps_per_epoch=train_size // batch_size, epochs=5)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## Automatic Translation"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 65,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"tf.random.set_seed(42)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 66,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"vocab_size = 100\n",
|
||
"embed_size = 10"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 67,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import tensorflow_addons as tfa\n",
|
||
"\n",
|
||
"encoder_inputs = keras.layers.Input(shape=[None], dtype=np.int32)\n",
|
||
"decoder_inputs = keras.layers.Input(shape=[None], dtype=np.int32)\n",
|
||
"sequence_lengths = keras.layers.Input(shape=[], dtype=np.int32)\n",
|
||
"\n",
|
||
"embeddings = keras.layers.Embedding(vocab_size, embed_size)\n",
|
||
"encoder_embeddings = embeddings(encoder_inputs)\n",
|
||
"decoder_embeddings = embeddings(decoder_inputs)\n",
|
||
"\n",
|
||
"encoder = keras.layers.LSTM(512, return_state=True)\n",
|
||
"encoder_outputs, state_h, state_c = encoder(encoder_embeddings)\n",
|
||
"encoder_state = [state_h, state_c]\n",
|
||
"\n",
|
||
"sampler = tfa.seq2seq.sampler.TrainingSampler()\n",
|
||
"\n",
|
||
"decoder_cell = keras.layers.LSTMCell(512)\n",
|
||
"output_layer = keras.layers.Dense(vocab_size)\n",
|
||
"decoder = tfa.seq2seq.basic_decoder.BasicDecoder(decoder_cell, sampler,\n",
|
||
" output_layer=output_layer)\n",
|
||
"final_outputs, final_state, final_sequence_lengths = decoder(\n",
|
||
" decoder_embeddings, initial_state=encoder_state,\n",
|
||
" sequence_length=sequence_lengths)\n",
|
||
"Y_proba = tf.nn.softmax(final_outputs.rnn_output)\n",
|
||
"\n",
|
||
"model = keras.models.Model(\n",
|
||
" inputs=[encoder_inputs, decoder_inputs, sequence_lengths],\n",
|
||
" outputs=[Y_proba])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 68,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"adam\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 69,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"X = np.random.randint(100, size=10*1000).reshape(1000, 10)\n",
|
||
"Y = np.random.randint(100, size=15*1000).reshape(1000, 15)\n",
|
||
"X_decoder = np.c_[np.zeros((1000, 1)), Y[:, :-1]]\n",
|
||
"seq_lengths = np.full([1000], 15)\n",
|
||
"\n",
|
||
"history = model.fit([X, X_decoder, seq_lengths], Y, epochs=2)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Bidirectional Recurrent Layers"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 70,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"model = keras.models.Sequential([\n",
|
||
" keras.layers.GRU(10, return_sequences=True, input_shape=[None, 10]),\n",
|
||
" keras.layers.Bidirectional(keras.layers.GRU(10, return_sequences=True))\n",
|
||
"])\n",
|
||
"\n",
|
||
"model.summary()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"### Positional Encoding"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 71,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class PositionalEncoding(keras.layers.Layer):\n",
|
||
" def __init__(self, max_steps, max_dims, dtype=tf.float32, **kwargs):\n",
|
||
" super().__init__(dtype=dtype, **kwargs)\n",
|
||
" if max_dims % 2 == 1: max_dims += 1 # max_dims must be even\n",
|
||
" p, i = np.meshgrid(np.arange(max_steps), np.arange(max_dims // 2))\n",
|
||
" pos_emb = np.empty((1, max_steps, max_dims))\n",
|
||
" pos_emb[0, :, ::2] = np.sin(p / 10000**(2 * i / max_dims)).T\n",
|
||
" pos_emb[0, :, 1::2] = np.cos(p / 10000**(2 * i / max_dims)).T\n",
|
||
" self.positional_embedding = tf.constant(pos_emb.astype(self.dtype))\n",
|
||
" def call(self, inputs):\n",
|
||
" shape = tf.shape(inputs)\n",
|
||
" return inputs + self.positional_embedding[:, :shape[-2], :shape[-1]]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 72,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"max_steps = 201\n",
|
||
"max_dims = 512\n",
|
||
"pos_emb = PositionalEncoding(max_steps, max_dims)\n",
|
||
"PE = pos_emb(np.zeros((1, max_steps, max_dims), np.float32))[0].numpy()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 73,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"i1, i2, crop_i = 100, 101, 150\n",
|
||
"p1, p2, p3 = 22, 60, 35\n",
|
||
"fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(9, 5))\n",
|
||
"ax1.plot([p1, p1], [-1, 1], \"k--\", label=\"$p = {}$\".format(p1))\n",
|
||
"ax1.plot([p2, p2], [-1, 1], \"k--\", label=\"$p = {}$\".format(p2), alpha=0.5)\n",
|
||
"ax1.plot(p3, PE[p3, i1], \"bx\", label=\"$p = {}$\".format(p3))\n",
|
||
"ax1.plot(PE[:,i1], \"b-\", label=\"$i = {}$\".format(i1))\n",
|
||
"ax1.plot(PE[:,i2], \"r-\", label=\"$i = {}$\".format(i2))\n",
|
||
"ax1.plot([p1, p2], [PE[p1, i1], PE[p2, i1]], \"bo\")\n",
|
||
"ax1.plot([p1, p2], [PE[p1, i2], PE[p2, i2]], \"ro\")\n",
|
||
"ax1.legend(loc=\"center right\", fontsize=14, framealpha=0.95)\n",
|
||
"ax1.set_ylabel(\"$P_{(p,i)}$\", rotation=0, fontsize=16)\n",
|
||
"ax1.grid(True, alpha=0.3)\n",
|
||
"ax1.hlines(0, 0, max_steps - 1, color=\"k\", linewidth=1, alpha=0.3)\n",
|
||
"ax1.axis([0, max_steps - 1, -1, 1])\n",
|
||
"ax2.imshow(PE.T[:crop_i], cmap=\"gray\", interpolation=\"bilinear\", aspect=\"auto\")\n",
|
||
"ax2.hlines(i1, 0, max_steps - 1, color=\"b\")\n",
|
||
"cheat = 2 # need to raise the red line a bit, or else it hides the blue one\n",
|
||
"ax2.hlines(i2+cheat, 0, max_steps - 1, color=\"r\")\n",
|
||
"ax2.plot([p1, p1], [0, crop_i], \"k--\")\n",
|
||
"ax2.plot([p2, p2], [0, crop_i], \"k--\", alpha=0.5)\n",
|
||
"ax2.plot([p1, p2], [i2+cheat, i2+cheat], \"ro\")\n",
|
||
"ax2.plot([p1, p2], [i1, i1], \"bo\")\n",
|
||
"ax2.axis([0, max_steps - 1, 0, crop_i])\n",
|
||
"ax2.set_xlabel(\"$p$\", fontsize=16)\n",
|
||
"ax2.set_ylabel(\"$i$\", rotation=0, fontsize=16)\n",
|
||
"plt.savefig(\"positional_embedding_plot\")\n",
|
||
"plt.show()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 74,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"embed_size = 512; max_steps = 500; vocab_size = 10000\n",
|
||
"encoder_inputs = keras.layers.Input(shape=[None], dtype=np.int32)\n",
|
||
"decoder_inputs = keras.layers.Input(shape=[None], dtype=np.int32)\n",
|
||
"embeddings = keras.layers.Embedding(vocab_size, embed_size)\n",
|
||
"encoder_embeddings = embeddings(encoder_inputs)\n",
|
||
"decoder_embeddings = embeddings(decoder_inputs)\n",
|
||
"positional_encoding = PositionalEncoding(max_steps, max_dims=embed_size)\n",
|
||
"encoder_in = positional_encoding(encoder_embeddings)\n",
|
||
"decoder_in = positional_encoding(decoder_embeddings)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Here is a (very) simplified Transformer (the actual architecture has skip connections, layer norm, dense nets, and most importantly it uses Multi-Head Attention instead of regular Attention):"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 75,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"Z = encoder_in\n",
|
||
"for N in range(6):\n",
|
||
" Z = keras.layers.Attention(use_scale=True)([Z, Z])\n",
|
||
"\n",
|
||
"encoder_outputs = Z\n",
|
||
"Z = decoder_in\n",
|
||
"for N in range(6):\n",
|
||
" Z = keras.layers.Attention(use_scale=True, causal=True)([Z, Z])\n",
|
||
" Z = keras.layers.Attention(use_scale=True)([Z, encoder_outputs])\n",
|
||
"\n",
|
||
"outputs = keras.layers.TimeDistributed(\n",
|
||
" keras.layers.Dense(vocab_size, activation=\"softmax\"))(Z)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"Here's a basic implementation of the `MultiHeadAttention` layer. One will likely be added to `keras.layers` in the near future. Note that `Conv1D` layers with `kernel_size=1` (and the default `padding=\"valid\"` and `strides=1`) is equivalent to a `TimeDistributed(Dense(...))` layer."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 76,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"K = keras.backend\n",
|
||
"\n",
|
||
"class MultiHeadAttention(keras.layers.Layer):\n",
|
||
" def __init__(self, n_heads, causal=False, use_scale=False, **kwargs):\n",
|
||
" self.n_heads = n_heads\n",
|
||
" self.causal = causal\n",
|
||
" self.use_scale = use_scale\n",
|
||
" super().__init__(**kwargs)\n",
|
||
" def build(self, batch_input_shape):\n",
|
||
" self.dims = batch_input_shape[0][-1]\n",
|
||
" self.q_dims, self.v_dims, self.k_dims = [self.dims // self.n_heads] * 3 # could be hyperparameters instead\n",
|
||
" self.q_linear = keras.layers.Conv1D(self.n_heads * self.q_dims, kernel_size=1, use_bias=False)\n",
|
||
" self.v_linear = keras.layers.Conv1D(self.n_heads * self.v_dims, kernel_size=1, use_bias=False)\n",
|
||
" self.k_linear = keras.layers.Conv1D(self.n_heads * self.k_dims, kernel_size=1, use_bias=False)\n",
|
||
" self.attention = keras.layers.Attention(causal=self.causal, use_scale=self.use_scale)\n",
|
||
" self.out_linear = keras.layers.Conv1D(self.dims, kernel_size=1, use_bias=False)\n",
|
||
" super().build(batch_input_shape)\n",
|
||
" def _multi_head_linear(self, inputs, linear):\n",
|
||
" shape = K.concatenate([K.shape(inputs)[:-1], [self.n_heads, -1]])\n",
|
||
" projected = K.reshape(linear(inputs), shape)\n",
|
||
" perm = K.permute_dimensions(projected, [0, 2, 1, 3])\n",
|
||
" return K.reshape(perm, [shape[0] * self.n_heads, shape[1], -1])\n",
|
||
" def call(self, inputs):\n",
|
||
" q = inputs[0]\n",
|
||
" v = inputs[1]\n",
|
||
" k = inputs[2] if len(inputs) > 2 else v\n",
|
||
" shape = K.shape(q)\n",
|
||
" q_proj = self._multi_head_linear(q, self.q_linear)\n",
|
||
" v_proj = self._multi_head_linear(v, self.v_linear)\n",
|
||
" k_proj = self._multi_head_linear(k, self.k_linear)\n",
|
||
" multi_attended = self.attention([q_proj, v_proj, k_proj])\n",
|
||
" shape_attended = K.shape(multi_attended)\n",
|
||
" reshaped_attended = K.reshape(multi_attended, [shape[0], self.n_heads, shape_attended[1], shape_attended[2]])\n",
|
||
" perm = K.permute_dimensions(reshaped_attended, [0, 2, 1, 3])\n",
|
||
" concat = K.reshape(perm, [shape[0], shape_attended[1], -1])\n",
|
||
" return self.out_linear(concat)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 77,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"Q = np.random.rand(2, 50, 512)\n",
|
||
"V = np.random.rand(2, 80, 512)\n",
|
||
"multi_attn = MultiHeadAttention(8)\n",
|
||
"multi_attn([Q, V]).shape"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.6.8"
|
||
},
|
||
"nav_menu": {},
|
||
"toc": {
|
||
"navigate_menu": true,
|
||
"number_sections": true,
|
||
"sideBar": true,
|
||
"threshold": 6,
|
||
"toc_cell": false,
|
||
"toc_section_display": "block",
|
||
"toc_window_display": false
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 1
|
||
}
|