diff --git a/11_training_deep_neural_networks.ipynb b/11_training_deep_neural_networks.ipynb index 9621268..fd1af12 100644 --- a/11_training_deep_neural_networks.ipynb +++ b/11_training_deep_neural_networks.ipynb @@ -1233,10 +1233,12 @@ "metadata": {}, "outputs": [], "source": [ + "import math\n", + "\n", "learning_rate = 0.01\n", "decay = 1e-4\n", "batch_size = 32\n", - "n_steps_per_epoch = len(X_train) // batch_size\n", + "n_steps_per_epoch = math.ceil(len(X_train) / batch_size)\n", "epochs = np.arange(n_epochs)\n", "lrs = learning_rate / (1 + decay * epochs * n_steps_per_epoch)\n", "\n", @@ -1630,7 +1632,7 @@ "\n", "def find_learning_rate(model, X, y, epochs=1, batch_size=32, min_rate=10**-5, max_rate=10):\n", " init_weights = model.get_weights()\n", - " iterations = len(X) // batch_size * epochs\n", + " iterations = math.ceil(len(X) / batch_size) * epochs\n", " factor = np.exp(np.log(max_rate / min_rate) / iterations)\n", " init_lr = K.get_value(model.optimizer.lr)\n", " K.set_value(model.optimizer.lr, min_rate)\n", @@ -1709,7 +1711,6 @@ " else:\n", " rate = self._interpolate(2 * self.half_iteration, self.iterations,\n", " self.start_rate, self.last_rate)\n", - " rate = max(rate, self.last_rate)\n", " self.iteration += 1\n", " K.set_value(self.model.optimizer.lr, rate)" ] @@ -1721,7 +1722,7 @@ "outputs": [], "source": [ "n_epochs = 25\n", - "onecycle = OneCycleScheduler(len(X_train) // batch_size * n_epochs, max_rate=0.05)\n", + "onecycle = OneCycleScheduler(math.ceil(len(X_train) / batch_size) * n_epochs, max_rate=0.05)\n", "history = model.fit(X_train_scaled, y_train, epochs=n_epochs, batch_size=batch_size,\n", " validation_data=(X_valid_scaled, y_valid),\n", " callbacks=[onecycle])" @@ -2645,7 +2646,7 @@ "outputs": [], "source": [ "n_epochs = 15\n", - "onecycle = OneCycleScheduler(len(X_train_scaled) // batch_size * n_epochs, max_rate=0.05)\n", + "onecycle = OneCycleScheduler(math.ceil(len(X_train_scaled) / batch_size) * n_epochs, max_rate=0.05)\n", "history = model.fit(X_train_scaled, y_train, epochs=n_epochs, batch_size=batch_size,\n", " validation_data=(X_valid_scaled, y_valid),\n", " callbacks=[onecycle])"