Add 1cycle scheduling

main
Aurélien Geron 2019-05-05 12:42:08 +08:00
parent d2a518cdb1
commit fd1e088dab
1 changed files with 150 additions and 27 deletions

View File

@ -1536,7 +1536,130 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"learning_rate = keras.optimizers.schedules.PiecewiseConstantDecay(\n", "learning_rate = keras.optimizers.schedules.PiecewiseConstantDecay(\n",
" boundaries=[5. * n_steps_per_epoch, 15. * n_steps_per_epoch], values=[0.01, 0.005, 0.001])" " boundaries=[5. * n_steps_per_epoch, 15. * n_steps_per_epoch],\n",
" values=[0.01, 0.005, 0.001])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1Cycle scheduling"
]
},
{
"cell_type": "code",
"execution_count": 95,
"metadata": {},
"outputs": [],
"source": [
"K = keras.backend\n",
"\n",
"class ExponentialLearningRate(keras.callbacks.Callback):\n",
" def __init__(self, factor):\n",
" self.factor = factor\n",
" self.rates = []\n",
" self.losses = []\n",
" def on_batch_end(self, batch, logs):\n",
" self.rates.append(K.get_value(self.model.optimizer.lr))\n",
" self.losses.append(logs[\"loss\"])\n",
" K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)\n",
"\n",
"def find_learning_rate(model, X, y, epochs=1, batch_size=32, min_rate=10**-5, max_rate=10):\n",
" init_weights = model.get_weights()\n",
" iterations = len(X) // batch_size * epochs\n",
" factor = np.exp(np.log(max_rate / min_rate) / iterations)\n",
" init_lr = K.get_value(model.optimizer.lr)\n",
" K.set_value(model.optimizer.lr, min_rate)\n",
" exp_lr = ExponentialLearningRate(factor)\n",
" history = model.fit(X, y, epochs=epochs, batch_size=batch_size,\n",
" callbacks=[exp_lr])\n",
" K.set_value(model.optimizer.lr, init_lr)\n",
" model.set_weights(init_weights)\n",
" return exp_lr.rates, exp_lr.losses\n",
"\n",
"def plot_lr_vs_loss(rates, losses):\n",
" plt.plot(rates, losses)\n",
" plt.gca().set_xscale('log')\n",
" plt.hlines(min(losses), min(rates), max(rates))\n",
" plt.axis([min(rates), max(rates), min(losses), (losses[0] + min(losses)) / 2])\n",
" plt.xlabel(\"Learning rate\")\n",
" plt.ylabel(\"Loss\")"
]
},
{
"cell_type": "code",
"execution_count": 96,
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42)\n",
"np.random.seed(42)\n",
"\n",
"model = keras.models.Sequential([\n",
" keras.layers.Flatten(input_shape=[28, 28]),\n",
" keras.layers.Dense(300, activation=\"selu\", kernel_initializer=\"lecun_normal\"),\n",
" keras.layers.Dense(100, activation=\"selu\", kernel_initializer=\"lecun_normal\"),\n",
" keras.layers.Dense(10, activation=\"softmax\")\n",
"])\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"sgd\", metrics=[\"accuracy\"])"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 128\n",
"rates, losses = find_learning_rate(model, X_train_scaled, y_train, epochs=1, batch_size=batch_size)\n",
"plot_lr_vs_loss(rates, losses)"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
"class OneCycleScheduler(keras.callbacks.Callback):\n",
" def __init__(self, iterations, max_rate, start_rate=None,\n",
" last_iterations=None, last_rate=None):\n",
" self.iterations = iterations\n",
" self.max_rate = max_rate\n",
" self.start_rate = start_rate or max_rate / 10\n",
" self.last_iterations = last_iterations or iterations // 10 + 1\n",
" self.half_iteration = (iterations - self.last_iterations) // 2\n",
" self.last_rate = last_rate or self.start_rate / 1000\n",
" self.iteration = 0\n",
" def _interpolate(self, iter1, iter2, rate1, rate2):\n",
" return ((rate2 - rate1) * (iter2 - self.iteration)\n",
" / (iter2 - iter1) + rate1)\n",
" def on_batch_begin(self, batch, logs):\n",
" if self.iteration < self.half_iteration:\n",
" rate = self._interpolate(0, self.half_iteration, self.start_rate, self.max_rate)\n",
" elif self.iteration < 2 * self.half_iteration:\n",
" rate = self._interpolate(self.half_iteration, 2 * self.half_iteration,\n",
" self.max_rate, self.start_rate)\n",
" else:\n",
" rate = self._interpolate(2 * self.half_iteration, self.iterations,\n",
" self.start_rate, self.last_rate)\n",
" rate = max(rate, self.last_rate)\n",
" self.iteration += 1\n",
" K.set_value(self.model.optimizer.lr, rate)"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [],
"source": [
"n_epochs = 25\n",
"onecycle = OneCycleScheduler(len(X_train) // batch_size * n_epochs, max_rate=0.05)\n",
"history = model.fit(X_train_scaled, y_train, epochs=n_epochs, batch_size=batch_size,\n",
" validation_data=(X_valid_scaled, y_valid),\n",
" callbacks=[onecycle])"
] ]
}, },
{ {
@ -1555,7 +1678,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 95, "execution_count": 100,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1568,7 +1691,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 96, "execution_count": 101,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1591,7 +1714,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 97, "execution_count": 102,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1623,7 +1746,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 98, "execution_count": 103,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1651,7 +1774,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 99, "execution_count": 104,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1661,7 +1784,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 100, "execution_count": 105,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1683,7 +1806,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 101, "execution_count": 106,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1692,7 +1815,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 102, "execution_count": 107,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1701,7 +1824,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 103, "execution_count": 108,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1718,7 +1841,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 104, "execution_count": 109,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1728,7 +1851,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 105, "execution_count": 110,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1740,7 +1863,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 106, "execution_count": 111,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1749,7 +1872,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 107, "execution_count": 112,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1758,7 +1881,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 108, "execution_count": 113,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1767,7 +1890,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 109, "execution_count": 114,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1777,7 +1900,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 110, "execution_count": 115,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1786,7 +1909,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 111, "execution_count": 116,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1796,7 +1919,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 112, "execution_count": 117,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1811,7 +1934,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 113, "execution_count": 118,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1821,7 +1944,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 114, "execution_count": 119,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1833,7 +1956,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 115, "execution_count": 120,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1842,7 +1965,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 116, "execution_count": 121,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1852,7 +1975,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 117, "execution_count": 122,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1868,7 +1991,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 118, "execution_count": 123,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1884,7 +2007,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 119, "execution_count": 124,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1894,7 +2017,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 120, "execution_count": 125,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [