From 4eb3e7d947068f733088b5a3d59447504be6ed58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Tue, 25 May 2021 21:54:03 +1200 Subject: [PATCH] Fix exercise 10.e: exclude the padding tokens when computing the mean of the word embeddings --- 13_loading_and_preprocessing_data.ipynb | 34 ++++++++++++++++++------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/13_loading_and_preprocessing_data.ipynb b/13_loading_and_preprocessing_data.ipynb index ff2cfac..8eb137f 100644 --- a/13_loading_and_preprocessing_data.ipynb +++ b/13_loading_and_preprocessing_data.ipynb @@ -2609,7 +2609,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We get about 73.7% accuracy on the validation set after just the first epoch, but after that the model makes no significant progress. We will do better in Chapter 16. For now the point is just to perform efficient preprocessing using `tf.data` and Keras preprocessing layers." + "We get about 73.5% accuracy on the validation set after just the first epoch, but after that the model makes no significant progress. We will do better in Chapter 16. For now the point is just to perform efficient preprocessing using `tf.data` and Keras preprocessing layers." ] }, { @@ -2624,7 +2624,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To compute the mean embedding for each review, and multiply it by the square root of the number of words in that review, we will need a little function:" + "To compute the mean embedding for each review, and multiply it by the square root of the number of words in that review, we will need a little function. For each sentence, this function needs to compute $M \\times \\sqrt N$, where $M$ is the mean of all the word embeddings in the sentence (excluding padding tokens), and $N$ is the number of words in the sentence (also excluding padding tokens). We can rewrite $M$ as $\\dfrac{S}{N}$, where $S$ is the sum of all word embeddings (it does not matter whether or not we include the padding tokens in this sum, since their representation is a zero vector). So the function must return $M \\times \\sqrt N = \\dfrac{S}{N} \\times \\sqrt N = \\dfrac{S}{\\sqrt N \\times \\sqrt N} \\times \\sqrt N= \\dfrac{S}{\\sqrt N}$." ] }, { @@ -2637,7 +2637,7 @@ " not_pad = tf.math.count_nonzero(inputs, axis=-1)\n", " n_words = tf.math.count_nonzero(not_pad, axis=-1, keepdims=True) \n", " sqrt_n_words = tf.math.sqrt(tf.cast(n_words, tf.float32))\n", - " return tf.reduce_mean(inputs, axis=1) * sqrt_n_words\n", + " return tf.reduce_sum(inputs, axis=1) / sqrt_n_words\n", "\n", "another_example = tf.constant([[[1., 2., 3.], [4., 5., 0.], [0., 0., 0.]],\n", " [[6., 0., 0.], [0., 0., 0.], [0., 0., 0.]]])\n", @@ -2648,7 +2648,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Let's check that this is correct. The first review contains 2 words (the last token is a zero vector, which represents the `` token). The second review contains 1 word. So we need to compute the mean embedding for each review, and multiply the first one by the square root of 2, and the second one by the square root of 1:" + "Let's check that this is correct. The first review contains 2 words (the last token is a zero vector, which represents the `` token). Let's compute the mean embedding for these 2 words, and multiply the result by the square root of 2:" ] }, { @@ -2657,7 +2657,23 @@ "metadata": {}, "outputs": [], "source": [ - "tf.reduce_mean(another_example, axis=1) * tf.sqrt([[2.], [1.]])" + "tf.reduce_mean(another_example[0:1, :2], axis=1) * tf.sqrt(2.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Looks good! Now let's check the second review, which contains just one word (we ignore the two padding tokens):" + ] + }, + { + "cell_type": "code", + "execution_count": 156, + "metadata": {}, + "outputs": [], + "source": [ + "tf.reduce_mean(another_example[1:2, :1], axis=1) * tf.sqrt(1.)" ] }, { @@ -2669,7 +2685,7 @@ }, { "cell_type": "code", - "execution_count": 156, + "execution_count": 157, "metadata": {}, "outputs": [], "source": [ @@ -2696,7 +2712,7 @@ }, { "cell_type": "code", - "execution_count": 157, + "execution_count": 158, "metadata": {}, "outputs": [], "source": [ @@ -2721,7 +2737,7 @@ }, { "cell_type": "code", - "execution_count": 158, + "execution_count": 159, "metadata": {}, "outputs": [], "source": [ @@ -2733,7 +2749,7 @@ }, { "cell_type": "code", - "execution_count": 159, + "execution_count": 160, "metadata": {}, "outputs": [], "source": [