Remove dataset.repeat() and stop using steps_per_epoch when calling mode.fit(), fixes #431

main
Aurélien Geron 2021-05-26 15:41:56 +12:00
parent 1b96533668
commit 2330a4dea3
1 changed files with 17 additions and 12 deletions

View File

@ -237,6 +237,13 @@
"dataset = tf.data.Dataset.from_tensor_slices(encoded[:train_size])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note**: in previous versions of this code, we used `dataset.repeat()` now to make the dataset \"infinite\", and later in the notebook we set the `steps_per_epoch` argument when calling the `model.fit()` method. This was needed to work around some TensorFlow bugs. However, since these bugs have now been fixed, we can simplify the code: no need for `dataset.repeat()` or `steps_per_epoch` anymore."
]
},
{
"cell_type": "code",
"execution_count": 11,
@ -245,7 +252,7 @@
"source": [
"n_steps = 100\n",
"window_length = n_steps + 1 # target = input shifted 1 character ahead\n",
"dataset = dataset.repeat().window(window_length, shift=1, drop_remainder=True)"
"dataset = dataset.window(window_length, shift=1, drop_remainder=True)"
]
},
{
@ -345,8 +352,7 @@
" activation=\"softmax\"))\n",
"])\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"adam\")\n",
"history = model.fit(dataset, steps_per_epoch=train_size // batch_size,\n",
" epochs=10)"
"history = model.fit(dataset, epochs=10)"
]
},
{
@ -488,7 +494,7 @@
"dataset = tf.data.Dataset.from_tensor_slices(encoded[:train_size])\n",
"dataset = dataset.window(window_length, shift=n_steps, drop_remainder=True)\n",
"dataset = dataset.flat_map(lambda window: window.batch(window_length))\n",
"dataset = dataset.repeat().batch(1)\n",
"dataset = dataset.batch(1)\n",
"dataset = dataset.map(lambda windows: (windows[:, :-1], windows[:, 1:]))\n",
"dataset = dataset.map(\n",
" lambda X_batch, Y_batch: (tf.one_hot(X_batch, depth=max_id), Y_batch))\n",
@ -510,7 +516,7 @@
" dataset = dataset.flat_map(lambda window: window.batch(window_length))\n",
" datasets.append(dataset)\n",
"dataset = tf.data.Dataset.zip(tuple(datasets)).map(lambda *windows: tf.stack(windows))\n",
"dataset = dataset.repeat().map(lambda windows: (windows[:, :-1], windows[:, 1:]))\n",
"dataset = dataset.map(lambda windows: (windows[:, :-1], windows[:, 1:]))\n",
"dataset = dataset.map(\n",
" lambda X_batch, Y_batch: (tf.one_hot(X_batch, depth=max_id), Y_batch))\n",
"dataset = dataset.prefetch(1)"
@ -560,8 +566,7 @@
"outputs": [],
"source": [
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"adam\")\n",
"steps_per_epoch = train_size // batch_size // n_steps\n",
"history = model.fit(dataset, steps_per_epoch=steps_per_epoch, epochs=50,\n",
"history = model.fit(dataset, epochs=50,\n",
" callbacks=[ResetStatesCallback()])"
]
},
@ -837,7 +842,7 @@
"def encode_words(X_batch, y_batch):\n",
" return table.lookup(X_batch), y_batch\n",
"\n",
"train_set = datasets[\"train\"].repeat().batch(32).map(preprocess)\n",
"train_set = datasets[\"train\"].batch(32).map(preprocess)\n",
"train_set = train_set.map(encode_words).prefetch(1)"
]
},
@ -868,7 +873,7 @@
" keras.layers.Dense(1, activation=\"sigmoid\")\n",
"])\n",
"model.compile(loss=\"binary_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
"history = model.fit(train_set, steps_per_epoch=train_size // 32, epochs=5)"
"history = model.fit(train_set, epochs=5)"
]
},
{
@ -894,7 +899,7 @@
"outputs = keras.layers.Dense(1, activation=\"sigmoid\")(z)\n",
"model = keras.models.Model(inputs=[inputs], outputs=[outputs])\n",
"model.compile(loss=\"binary_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
"history = model.fit(train_set, steps_per_epoch=train_size // 32, epochs=5)"
"history = model.fit(train_set, epochs=5)"
]
},
{
@ -963,8 +968,8 @@
"datasets, info = tfds.load(\"imdb_reviews\", as_supervised=True, with_info=True)\n",
"train_size = info.splits[\"train\"].num_examples\n",
"batch_size = 32\n",
"train_set = datasets[\"train\"].repeat().batch(batch_size).prefetch(1)\n",
"history = model.fit(train_set, steps_per_epoch=train_size // batch_size, epochs=5)"
"train_set = datasets[\"train\"].batch(batch_size).prefetch(1)\n",
"history = model.fit(train_set, epochs=5)"
]
},
{