Remove dataset.repeat() and stop using steps_per_epoch when calling mode.fit(), fixes #431
parent
1b96533668
commit
2330a4dea3
|
@ -237,6 +237,13 @@
|
||||||
"dataset = tf.data.Dataset.from_tensor_slices(encoded[:train_size])"
|
"dataset = tf.data.Dataset.from_tensor_slices(encoded[:train_size])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"**Note**: in previous versions of this code, we used `dataset.repeat()` now to make the dataset \"infinite\", and later in the notebook we set the `steps_per_epoch` argument when calling the `model.fit()` method. This was needed to work around some TensorFlow bugs. However, since these bugs have now been fixed, we can simplify the code: no need for `dataset.repeat()` or `steps_per_epoch` anymore."
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 11,
|
||||||
|
@ -245,7 +252,7 @@
|
||||||
"source": [
|
"source": [
|
||||||
"n_steps = 100\n",
|
"n_steps = 100\n",
|
||||||
"window_length = n_steps + 1 # target = input shifted 1 character ahead\n",
|
"window_length = n_steps + 1 # target = input shifted 1 character ahead\n",
|
||||||
"dataset = dataset.repeat().window(window_length, shift=1, drop_remainder=True)"
|
"dataset = dataset.window(window_length, shift=1, drop_remainder=True)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -345,8 +352,7 @@
|
||||||
" activation=\"softmax\"))\n",
|
" activation=\"softmax\"))\n",
|
||||||
"])\n",
|
"])\n",
|
||||||
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"adam\")\n",
|
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"adam\")\n",
|
||||||
"history = model.fit(dataset, steps_per_epoch=train_size // batch_size,\n",
|
"history = model.fit(dataset, epochs=10)"
|
||||||
" epochs=10)"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -488,7 +494,7 @@
|
||||||
"dataset = tf.data.Dataset.from_tensor_slices(encoded[:train_size])\n",
|
"dataset = tf.data.Dataset.from_tensor_slices(encoded[:train_size])\n",
|
||||||
"dataset = dataset.window(window_length, shift=n_steps, drop_remainder=True)\n",
|
"dataset = dataset.window(window_length, shift=n_steps, drop_remainder=True)\n",
|
||||||
"dataset = dataset.flat_map(lambda window: window.batch(window_length))\n",
|
"dataset = dataset.flat_map(lambda window: window.batch(window_length))\n",
|
||||||
"dataset = dataset.repeat().batch(1)\n",
|
"dataset = dataset.batch(1)\n",
|
||||||
"dataset = dataset.map(lambda windows: (windows[:, :-1], windows[:, 1:]))\n",
|
"dataset = dataset.map(lambda windows: (windows[:, :-1], windows[:, 1:]))\n",
|
||||||
"dataset = dataset.map(\n",
|
"dataset = dataset.map(\n",
|
||||||
" lambda X_batch, Y_batch: (tf.one_hot(X_batch, depth=max_id), Y_batch))\n",
|
" lambda X_batch, Y_batch: (tf.one_hot(X_batch, depth=max_id), Y_batch))\n",
|
||||||
|
@ -510,7 +516,7 @@
|
||||||
" dataset = dataset.flat_map(lambda window: window.batch(window_length))\n",
|
" dataset = dataset.flat_map(lambda window: window.batch(window_length))\n",
|
||||||
" datasets.append(dataset)\n",
|
" datasets.append(dataset)\n",
|
||||||
"dataset = tf.data.Dataset.zip(tuple(datasets)).map(lambda *windows: tf.stack(windows))\n",
|
"dataset = tf.data.Dataset.zip(tuple(datasets)).map(lambda *windows: tf.stack(windows))\n",
|
||||||
"dataset = dataset.repeat().map(lambda windows: (windows[:, :-1], windows[:, 1:]))\n",
|
"dataset = dataset.map(lambda windows: (windows[:, :-1], windows[:, 1:]))\n",
|
||||||
"dataset = dataset.map(\n",
|
"dataset = dataset.map(\n",
|
||||||
" lambda X_batch, Y_batch: (tf.one_hot(X_batch, depth=max_id), Y_batch))\n",
|
" lambda X_batch, Y_batch: (tf.one_hot(X_batch, depth=max_id), Y_batch))\n",
|
||||||
"dataset = dataset.prefetch(1)"
|
"dataset = dataset.prefetch(1)"
|
||||||
|
@ -560,8 +566,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"adam\")\n",
|
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"adam\")\n",
|
||||||
"steps_per_epoch = train_size // batch_size // n_steps\n",
|
"history = model.fit(dataset, epochs=50,\n",
|
||||||
"history = model.fit(dataset, steps_per_epoch=steps_per_epoch, epochs=50,\n",
|
|
||||||
" callbacks=[ResetStatesCallback()])"
|
" callbacks=[ResetStatesCallback()])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -837,7 +842,7 @@
|
||||||
"def encode_words(X_batch, y_batch):\n",
|
"def encode_words(X_batch, y_batch):\n",
|
||||||
" return table.lookup(X_batch), y_batch\n",
|
" return table.lookup(X_batch), y_batch\n",
|
||||||
"\n",
|
"\n",
|
||||||
"train_set = datasets[\"train\"].repeat().batch(32).map(preprocess)\n",
|
"train_set = datasets[\"train\"].batch(32).map(preprocess)\n",
|
||||||
"train_set = train_set.map(encode_words).prefetch(1)"
|
"train_set = train_set.map(encode_words).prefetch(1)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -868,7 +873,7 @@
|
||||||
" keras.layers.Dense(1, activation=\"sigmoid\")\n",
|
" keras.layers.Dense(1, activation=\"sigmoid\")\n",
|
||||||
"])\n",
|
"])\n",
|
||||||
"model.compile(loss=\"binary_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
|
"model.compile(loss=\"binary_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
|
||||||
"history = model.fit(train_set, steps_per_epoch=train_size // 32, epochs=5)"
|
"history = model.fit(train_set, epochs=5)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -894,7 +899,7 @@
|
||||||
"outputs = keras.layers.Dense(1, activation=\"sigmoid\")(z)\n",
|
"outputs = keras.layers.Dense(1, activation=\"sigmoid\")(z)\n",
|
||||||
"model = keras.models.Model(inputs=[inputs], outputs=[outputs])\n",
|
"model = keras.models.Model(inputs=[inputs], outputs=[outputs])\n",
|
||||||
"model.compile(loss=\"binary_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
|
"model.compile(loss=\"binary_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"])\n",
|
||||||
"history = model.fit(train_set, steps_per_epoch=train_size // 32, epochs=5)"
|
"history = model.fit(train_set, epochs=5)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -963,8 +968,8 @@
|
||||||
"datasets, info = tfds.load(\"imdb_reviews\", as_supervised=True, with_info=True)\n",
|
"datasets, info = tfds.load(\"imdb_reviews\", as_supervised=True, with_info=True)\n",
|
||||||
"train_size = info.splits[\"train\"].num_examples\n",
|
"train_size = info.splits[\"train\"].num_examples\n",
|
||||||
"batch_size = 32\n",
|
"batch_size = 32\n",
|
||||||
"train_set = datasets[\"train\"].repeat().batch(batch_size).prefetch(1)\n",
|
"train_set = datasets[\"train\"].batch(batch_size).prefetch(1)\n",
|
||||||
"history = model.fit(train_set, steps_per_epoch=train_size // batch_size, epochs=5)"
|
"history = model.fit(train_set, epochs=5)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
Loading…
Reference in New Issue