Add 1cycle scheduling
parent
d2a518cdb1
commit
fd1e088dab
|
@ -1536,7 +1536,130 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"learning_rate = keras.optimizers.schedules.PiecewiseConstantDecay(\n",
|
"learning_rate = keras.optimizers.schedules.PiecewiseConstantDecay(\n",
|
||||||
" boundaries=[5. * n_steps_per_epoch, 15. * n_steps_per_epoch], values=[0.01, 0.005, 0.001])"
|
" boundaries=[5. * n_steps_per_epoch, 15. * n_steps_per_epoch],\n",
|
||||||
|
" values=[0.01, 0.005, 0.001])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### 1Cycle scheduling"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 95,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"K = keras.backend\n",
|
||||||
|
"\n",
|
||||||
|
"class ExponentialLearningRate(keras.callbacks.Callback):\n",
|
||||||
|
" def __init__(self, factor):\n",
|
||||||
|
" self.factor = factor\n",
|
||||||
|
" self.rates = []\n",
|
||||||
|
" self.losses = []\n",
|
||||||
|
" def on_batch_end(self, batch, logs):\n",
|
||||||
|
" self.rates.append(K.get_value(self.model.optimizer.lr))\n",
|
||||||
|
" self.losses.append(logs[\"loss\"])\n",
|
||||||
|
" K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)\n",
|
||||||
|
"\n",
|
||||||
|
"def find_learning_rate(model, X, y, epochs=1, batch_size=32, min_rate=10**-5, max_rate=10):\n",
|
||||||
|
" init_weights = model.get_weights()\n",
|
||||||
|
" iterations = len(X) // batch_size * epochs\n",
|
||||||
|
" factor = np.exp(np.log(max_rate / min_rate) / iterations)\n",
|
||||||
|
" init_lr = K.get_value(model.optimizer.lr)\n",
|
||||||
|
" K.set_value(model.optimizer.lr, min_rate)\n",
|
||||||
|
" exp_lr = ExponentialLearningRate(factor)\n",
|
||||||
|
" history = model.fit(X, y, epochs=epochs, batch_size=batch_size,\n",
|
||||||
|
" callbacks=[exp_lr])\n",
|
||||||
|
" K.set_value(model.optimizer.lr, init_lr)\n",
|
||||||
|
" model.set_weights(init_weights)\n",
|
||||||
|
" return exp_lr.rates, exp_lr.losses\n",
|
||||||
|
"\n",
|
||||||
|
"def plot_lr_vs_loss(rates, losses):\n",
|
||||||
|
" plt.plot(rates, losses)\n",
|
||||||
|
" plt.gca().set_xscale('log')\n",
|
||||||
|
" plt.hlines(min(losses), min(rates), max(rates))\n",
|
||||||
|
" plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])\n",
|
||||||
|
" plt.xlabel(\"Learning rate\")\n",
|
||||||
|
" plt.ylabel(\"Loss\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 96,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"tf.random.set_seed(42)\n",
|
||||||
|
"np.random.seed(42)\n",
|
||||||
|
"\n",
|
||||||
|
"model = keras.models.Sequential([\n",
|
||||||
|
" keras.layers.Flatten(input_shape=[28, 28]),\n",
|
||||||
|
" keras.layers.Dense(300, activation=\"selu\", kernel_initializer=\"lecun_normal\"),\n",
|
||||||
|
" keras.layers.Dense(100, activation=\"selu\", kernel_initializer=\"lecun_normal\"),\n",
|
||||||
|
" keras.layers.Dense(10, activation=\"softmax\")\n",
|
||||||
|
"])\n",
|
||||||
|
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"sgd\", metrics=[\"accuracy\"])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 97,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"batch_size = 128\n",
|
||||||
|
"rates, losses = find_learning_rate(model, X_train_scaled, y_train, epochs=1, batch_size=batch_size)\n",
|
||||||
|
"plot_lr_vs_loss(rates, losses)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 98,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"class OneCycleScheduler(keras.callbacks.Callback):\n",
|
||||||
|
" def __init__(self, iterations, max_rate, start_rate=None,\n",
|
||||||
|
" last_iterations=None, last_rate=None):\n",
|
||||||
|
" self.iterations = iterations\n",
|
||||||
|
" self.max_rate = max_rate\n",
|
||||||
|
" self.start_rate = start_rate or max_rate / 10\n",
|
||||||
|
" self.last_iterations = last_iterations or iterations // 10 + 1\n",
|
||||||
|
" self.half_iteration = (iterations - self.last_iterations) // 2\n",
|
||||||
|
" self.last_rate = last_rate or self.start_rate / 1000\n",
|
||||||
|
" self.iteration = 0\n",
|
||||||
|
" def _interpolate(self, iter1, iter2, rate1, rate2):\n",
|
||||||
|
" return ((rate2 - rate1) * (iter2 - self.iteration)\n",
|
||||||
|
" / (iter2 - iter1) + rate1)\n",
|
||||||
|
" def on_batch_begin(self, batch, logs):\n",
|
||||||
|
" if self.iteration < self.half_iteration:\n",
|
||||||
|
" rate = self._interpolate(0, self.half_iteration, self.start_rate, self.max_rate)\n",
|
||||||
|
" elif self.iteration < 2 * self.half_iteration:\n",
|
||||||
|
" rate = self._interpolate(self.half_iteration, 2 * self.half_iteration,\n",
|
||||||
|
" self.max_rate, self.start_rate)\n",
|
||||||
|
" else:\n",
|
||||||
|
" rate = self._interpolate(2 * self.half_iteration, self.iterations,\n",
|
||||||
|
" self.start_rate, self.last_rate)\n",
|
||||||
|
" rate = max(rate, self.last_rate)\n",
|
||||||
|
" self.iteration += 1\n",
|
||||||
|
" K.set_value(self.model.optimizer.lr, rate)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 99,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"n_epochs = 25\n",
|
||||||
|
"onecycle = OneCycleScheduler(len(X_train) // batch_size * n_epochs, max_rate=0.05)\n",
|
||||||
|
"history = model.fit(X_train_scaled, y_train, epochs=n_epochs, batch_size=batch_size,\n",
|
||||||
|
" validation_data=(X_valid_scaled, y_valid),\n",
|
||||||
|
" callbacks=[onecycle])"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -1555,7 +1678,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 95,
|
"execution_count": 100,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1568,7 +1691,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 96,
|
"execution_count": 101,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1591,7 +1714,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 97,
|
"execution_count": 102,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1623,7 +1746,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 98,
|
"execution_count": 103,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1651,7 +1774,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 99,
|
"execution_count": 104,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1661,7 +1784,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 100,
|
"execution_count": 105,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1683,7 +1806,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 101,
|
"execution_count": 106,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1692,7 +1815,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 102,
|
"execution_count": 107,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1701,7 +1824,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 103,
|
"execution_count": 108,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1718,7 +1841,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 104,
|
"execution_count": 109,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1728,7 +1851,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 105,
|
"execution_count": 110,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1740,7 +1863,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 106,
|
"execution_count": 111,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1749,7 +1872,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 107,
|
"execution_count": 112,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1758,7 +1881,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 108,
|
"execution_count": 113,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1767,7 +1890,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 109,
|
"execution_count": 114,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1777,7 +1900,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 110,
|
"execution_count": 115,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1786,7 +1909,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 111,
|
"execution_count": 116,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1796,7 +1919,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 112,
|
"execution_count": 117,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1811,7 +1934,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 113,
|
"execution_count": 118,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1821,7 +1944,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 114,
|
"execution_count": 119,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1833,7 +1956,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 115,
|
"execution_count": 120,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1842,7 +1965,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 116,
|
"execution_count": 121,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1852,7 +1975,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 117,
|
"execution_count": 122,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1868,7 +1991,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 118,
|
"execution_count": 123,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1884,7 +2007,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 119,
|
"execution_count": 124,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1894,7 +2017,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 120,
|
"execution_count": 125,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
|
Loading…
Reference in New Issue