Merge pull request #275 from ibeauregard/changes-chap11

(Chapter 11) Adjust computation of steps per epoch
main
Aurélien Geron 2021-03-02 22:12:35 +13:00 committed by GitHub
commit 55ee303e56
1 changed files with 6 additions and 5 deletions

View File

@ -1233,10 +1233,12 @@
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"\n",
"learning_rate = 0.01\n",
"decay = 1e-4\n",
"batch_size = 32\n",
"n_steps_per_epoch = len(X_train) // batch_size\n",
"n_steps_per_epoch = math.ceil(len(X_train) / batch_size)\n",
"epochs = np.arange(n_epochs)\n",
"lrs = learning_rate / (1 + decay * epochs * n_steps_per_epoch)\n",
"\n",
@ -1630,7 +1632,7 @@
"\n",
"def find_learning_rate(model, X, y, epochs=1, batch_size=32, min_rate=10**-5, max_rate=10):\n",
" init_weights = model.get_weights()\n",
" iterations = len(X) // batch_size * epochs\n",
" iterations = math.ceil(len(X) / batch_size) * epochs\n",
" factor = np.exp(np.log(max_rate / min_rate) / iterations)\n",
" init_lr = K.get_value(model.optimizer.lr)\n",
" K.set_value(model.optimizer.lr, min_rate)\n",
@ -1709,7 +1711,6 @@
" else:\n",
" rate = self._interpolate(2 * self.half_iteration, self.iterations,\n",
" self.start_rate, self.last_rate)\n",
" rate = max(rate, self.last_rate)\n",
" self.iteration += 1\n",
" K.set_value(self.model.optimizer.lr, rate)"
]
@ -1721,7 +1722,7 @@
"outputs": [],
"source": [
"n_epochs = 25\n",
"onecycle = OneCycleScheduler(len(X_train) // batch_size * n_epochs, max_rate=0.05)\n",
"onecycle = OneCycleScheduler(math.ceil(len(X_train) / batch_size) * n_epochs, max_rate=0.05)\n",
"history = model.fit(X_train_scaled, y_train, epochs=n_epochs, batch_size=batch_size,\n",
" validation_data=(X_valid_scaled, y_valid),\n",
" callbacks=[onecycle])"
@ -2645,7 +2646,7 @@
"outputs": [],
"source": [
"n_epochs = 15\n",
"onecycle = OneCycleScheduler(len(X_train_scaled) // batch_size * n_epochs, max_rate=0.05)\n",
"onecycle = OneCycleScheduler(math.ceil(len(X_train_scaled) / batch_size) * n_epochs, max_rate=0.05)\n",
"history = model.fit(X_train_scaled, y_train, epochs=n_epochs, batch_size=batch_size,\n",
" validation_data=(X_valid_scaled, y_valid),\n",
" callbacks=[onecycle])"