diff --git a/11_training_deep_neural_networks.ipynb b/11_training_deep_neural_networks.ipynb index 0e8ed3b..9c7345c 100644 --- a/11_training_deep_neural_networks.ipynb +++ b/11_training_deep_neural_networks.ipynb @@ -1536,7 +1536,130 @@ "outputs": [], "source": [ "learning_rate = keras.optimizers.schedules.PiecewiseConstantDecay(\n", - " boundaries=[5. * n_steps_per_epoch, 15. * n_steps_per_epoch], values=[0.01, 0.005, 0.001])" + " boundaries=[5. * n_steps_per_epoch, 15. * n_steps_per_epoch],\n", + " values=[0.01, 0.005, 0.001])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1Cycle scheduling" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": {}, + "outputs": [], + "source": [ + "K = keras.backend\n", + "\n", + "class ExponentialLearningRate(keras.callbacks.Callback):\n", + " def __init__(self, factor):\n", + " self.factor = factor\n", + " self.rates = []\n", + " self.losses = []\n", + " def on_batch_end(self, batch, logs):\n", + " self.rates.append(K.get_value(self.model.optimizer.lr))\n", + " self.losses.append(logs[\"loss\"])\n", + " K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)\n", + "\n", + "def find_learning_rate(model, X, y, epochs=1, batch_size=32, min_rate=10**-5, max_rate=10):\n", + " init_weights = model.get_weights()\n", + " iterations = len(X) // batch_size * epochs\n", + " factor = np.exp(np.log(max_rate / min_rate) / iterations)\n", + " init_lr = K.get_value(model.optimizer.lr)\n", + " K.set_value(model.optimizer.lr, min_rate)\n", + " exp_lr = ExponentialLearningRate(factor)\n", + " history = model.fit(X, y, epochs=epochs, batch_size=batch_size,\n", + " callbacks=[exp_lr])\n", + " K.set_value(model.optimizer.lr, init_lr)\n", + " model.set_weights(init_weights)\n", + " return exp_lr.rates, exp_lr.losses\n", + "\n", + "def plot_lr_vs_loss(rates, losses):\n", + " plt.plot(rates, losses)\n", + " plt.gca().set_xscale('log')\n", + " plt.hlines(min(losses), min(rates), max(rates))\n", + " plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])\n", + " plt.xlabel(\"Learning rate\")\n", + " plt.ylabel(\"Loss\")" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, + "outputs": [], + "source": [ + "tf.random.set_seed(42)\n", + "np.random.seed(42)\n", + "\n", + "model = keras.models.Sequential([\n", + " keras.layers.Flatten(input_shape=[28, 28]),\n", + " keras.layers.Dense(300, activation=\"selu\", kernel_initializer=\"lecun_normal\"),\n", + " keras.layers.Dense(100, activation=\"selu\", kernel_initializer=\"lecun_normal\"),\n", + " keras.layers.Dense(10, activation=\"softmax\")\n", + "])\n", + "model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"sgd\", metrics=[\"accuracy\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 128\n", + "rates, losses = find_learning_rate(model, X_train_scaled, y_train, epochs=1, batch_size=batch_size)\n", + "plot_lr_vs_loss(rates, losses)" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [], + "source": [ + "class OneCycleScheduler(keras.callbacks.Callback):\n", + " def __init__(self, iterations, max_rate, start_rate=None,\n", + " last_iterations=None, last_rate=None):\n", + " self.iterations = iterations\n", + " self.max_rate = max_rate\n", + " self.start_rate = start_rate or max_rate / 10\n", + " self.last_iterations = last_iterations or iterations // 10 + 1\n", + " self.half_iteration = (iterations - self.last_iterations) // 2\n", + " self.last_rate = last_rate or self.start_rate / 1000\n", + " self.iteration = 0\n", + " def _interpolate(self, iter1, iter2, rate1, rate2):\n", + " return ((rate2 - rate1) * (iter2 - self.iteration)\n", + " / (iter2 - iter1) + rate1)\n", + " def on_batch_begin(self, batch, logs):\n", + " if self.iteration < self.half_iteration:\n", + " rate = self._interpolate(0, self.half_iteration, self.start_rate, self.max_rate)\n", + " elif self.iteration < 2 * self.half_iteration:\n", + " rate = self._interpolate(self.half_iteration, 2 * self.half_iteration,\n", + " self.max_rate, self.start_rate)\n", + " else:\n", + " rate = self._interpolate(2 * self.half_iteration, self.iterations,\n", + " self.start_rate, self.last_rate)\n", + " rate = max(rate, self.last_rate)\n", + " self.iteration += 1\n", + " K.set_value(self.model.optimizer.lr, rate)" + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": {}, + "outputs": [], + "source": [ + "n_epochs = 25\n", + "onecycle = OneCycleScheduler(len(X_train) // batch_size * n_epochs, max_rate=0.05)\n", + "history = model.fit(X_train_scaled, y_train, epochs=n_epochs, batch_size=batch_size,\n", + " validation_data=(X_valid_scaled, y_valid),\n", + " callbacks=[onecycle])" ] }, { @@ -1555,7 +1678,7 @@ }, { "cell_type": "code", - "execution_count": 95, + "execution_count": 100, "metadata": {}, "outputs": [], "source": [ @@ -1568,7 +1691,7 @@ }, { "cell_type": "code", - "execution_count": 96, + "execution_count": 101, "metadata": {}, "outputs": [], "source": [ @@ -1591,7 +1714,7 @@ }, { "cell_type": "code", - "execution_count": 97, + "execution_count": 102, "metadata": {}, "outputs": [], "source": [ @@ -1623,7 +1746,7 @@ }, { "cell_type": "code", - "execution_count": 98, + "execution_count": 103, "metadata": {}, "outputs": [], "source": [ @@ -1651,7 +1774,7 @@ }, { "cell_type": "code", - "execution_count": 99, + "execution_count": 104, "metadata": {}, "outputs": [], "source": [ @@ -1661,7 +1784,7 @@ }, { "cell_type": "code", - "execution_count": 100, + "execution_count": 105, "metadata": {}, "outputs": [], "source": [ @@ -1683,7 +1806,7 @@ }, { "cell_type": "code", - "execution_count": 101, + "execution_count": 106, "metadata": {}, "outputs": [], "source": [ @@ -1692,7 +1815,7 @@ }, { "cell_type": "code", - "execution_count": 102, + "execution_count": 107, "metadata": {}, "outputs": [], "source": [ @@ -1701,7 +1824,7 @@ }, { "cell_type": "code", - "execution_count": 103, + "execution_count": 108, "metadata": {}, "outputs": [], "source": [ @@ -1718,7 +1841,7 @@ }, { "cell_type": "code", - "execution_count": 104, + "execution_count": 109, "metadata": {}, "outputs": [], "source": [ @@ -1728,7 +1851,7 @@ }, { "cell_type": "code", - "execution_count": 105, + "execution_count": 110, "metadata": {}, "outputs": [], "source": [ @@ -1740,7 +1863,7 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 111, "metadata": {}, "outputs": [], "source": [ @@ -1749,7 +1872,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 112, "metadata": {}, "outputs": [], "source": [ @@ -1758,7 +1881,7 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 113, "metadata": {}, "outputs": [], "source": [ @@ -1767,7 +1890,7 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 114, "metadata": {}, "outputs": [], "source": [ @@ -1777,7 +1900,7 @@ }, { "cell_type": "code", - "execution_count": 110, + "execution_count": 115, "metadata": {}, "outputs": [], "source": [ @@ -1786,7 +1909,7 @@ }, { "cell_type": "code", - "execution_count": 111, + "execution_count": 116, "metadata": {}, "outputs": [], "source": [ @@ -1796,7 +1919,7 @@ }, { "cell_type": "code", - "execution_count": 112, + "execution_count": 117, "metadata": {}, "outputs": [], "source": [ @@ -1811,7 +1934,7 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 118, "metadata": {}, "outputs": [], "source": [ @@ -1821,7 +1944,7 @@ }, { "cell_type": "code", - "execution_count": 114, + "execution_count": 119, "metadata": {}, "outputs": [], "source": [ @@ -1833,7 +1956,7 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": 120, "metadata": {}, "outputs": [], "source": [ @@ -1842,7 +1965,7 @@ }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 121, "metadata": {}, "outputs": [], "source": [ @@ -1852,7 +1975,7 @@ }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 122, "metadata": {}, "outputs": [], "source": [ @@ -1868,7 +1991,7 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 123, "metadata": {}, "outputs": [], "source": [ @@ -1884,7 +2007,7 @@ }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 124, "metadata": {}, "outputs": [], "source": [ @@ -1894,7 +2017,7 @@ }, { "cell_type": "code", - "execution_count": 120, + "execution_count": 125, "metadata": {}, "outputs": [], "source": [