From 2330a4dea3bfab67aec71ed252e477b1fbce2b91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Wed, 26 May 2021 15:41:56 +1200 Subject: [PATCH] Remove dataset.repeat() and stop using steps_per_epoch when calling mode.fit(), fixes #431 --- 16_nlp_with_rnns_and_attention.ipynb | 29 ++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/16_nlp_with_rnns_and_attention.ipynb b/16_nlp_with_rnns_and_attention.ipynb index 1b60381..57d4da3 100644 --- a/16_nlp_with_rnns_and_attention.ipynb +++ b/16_nlp_with_rnns_and_attention.ipynb @@ -237,6 +237,13 @@ "dataset = tf.data.Dataset.from_tensor_slices(encoded[:train_size])" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Note**: in previous versions of this code, we used `dataset.repeat()` now to make the dataset \"infinite\", and later in the notebook we set the `steps_per_epoch` argument when calling the `model.fit()` method. This was needed to work around some TensorFlow bugs. However, since these bugs have now been fixed, we can simplify the code: no need for `dataset.repeat()` or `steps_per_epoch` anymore." + ] + }, { "cell_type": "code", "execution_count": 11, @@ -245,7 +252,7 @@ "source": [ "n_steps = 100\n", "window_length = n_steps + 1 # target = input shifted 1 character ahead\n", - "dataset = dataset.repeat().window(window_length, shift=1, drop_remainder=True)" + "dataset = dataset.window(window_length, shift=1, drop_remainder=True)" ] }, { @@ -345,8 +352,7 @@ " activation=\"softmax\"))\n", "])\n", "model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"adam\")\n", - "history = model.fit(dataset, steps_per_epoch=train_size // batch_size,\n", - " epochs=10)" + "history = model.fit(dataset, epochs=10)" ] }, { @@ -488,7 +494,7 @@ "dataset = tf.data.Dataset.from_tensor_slices(encoded[:train_size])\n", "dataset = dataset.window(window_length, shift=n_steps, drop_remainder=True)\n", "dataset = dataset.flat_map(lambda window: window.batch(window_length))\n", - "dataset = dataset.repeat().batch(1)\n", + "dataset = dataset.batch(1)\n", "dataset = dataset.map(lambda windows: (windows[:, :-1], windows[:, 1:]))\n", "dataset = dataset.map(\n", " lambda X_batch, Y_batch: (tf.one_hot(X_batch, depth=max_id), Y_batch))\n", @@ -510,7 +516,7 @@ " dataset = dataset.flat_map(lambda window: window.batch(window_length))\n", " datasets.append(dataset)\n", "dataset = tf.data.Dataset.zip(tuple(datasets)).map(lambda *windows: tf.stack(windows))\n", - "dataset = dataset.repeat().map(lambda windows: (windows[:, :-1], windows[:, 1:]))\n", + "dataset = dataset.map(lambda windows: (windows[:, :-1], windows[:, 1:]))\n", "dataset = dataset.map(\n", " lambda X_batch, Y_batch: (tf.one_hot(X_batch, depth=max_id), Y_batch))\n", "dataset = dataset.prefetch(1)" @@ -560,8 +566,7 @@ "outputs": [], "source": [ "model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"adam\")\n", - "steps_per_epoch = train_size // batch_size // n_steps\n", - "history = model.fit(dataset, steps_per_epoch=steps_per_epoch, epochs=50,\n", + "history = model.fit(dataset, epochs=50,\n", " callbacks=[ResetStatesCallback()])" ] }, @@ -837,7 +842,7 @@ "def encode_words(X_batch, y_batch):\n", " return table.lookup(X_batch), y_batch\n", "\n", - "train_set = datasets[\"train\"].repeat().batch(32).map(preprocess)\n", + "train_set = datasets[\"train\"].batch(32).map(preprocess)\n", "train_set = train_set.map(encode_words).prefetch(1)" ] }, @@ -868,7 +873,7 @@ " keras.layers.Dense(1, activation=\"sigmoid\")\n", "])\n", "model.compile(loss=\"binary_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n", - "history = model.fit(train_set, steps_per_epoch=train_size // 32, epochs=5)" + "history = model.fit(train_set, epochs=5)" ] }, { @@ -894,7 +899,7 @@ "outputs = keras.layers.Dense(1, activation=\"sigmoid\")(z)\n", "model = keras.models.Model(inputs=[inputs], outputs=[outputs])\n", "model.compile(loss=\"binary_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n", - "history = model.fit(train_set, steps_per_epoch=train_size // 32, epochs=5)" + "history = model.fit(train_set, epochs=5)" ] }, { @@ -963,8 +968,8 @@ "datasets, info = tfds.load(\"imdb_reviews\", as_supervised=True, with_info=True)\n", "train_size = info.splits[\"train\"].num_examples\n", "batch_size = 32\n", - "train_set = datasets[\"train\"].repeat().batch(batch_size).prefetch(1)\n", - "history = model.fit(train_set, steps_per_epoch=train_size // batch_size, epochs=5)" + "train_set = datasets[\"train\"].batch(batch_size).prefetch(1)\n", + "history = model.fit(train_set, epochs=5)" ] }, {