From 9d33d65ef0c54d05f373132b4a9d2ed3264c57e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Fri, 19 Mar 2021 10:50:13 +1300 Subject: [PATCH] logs["loss"] is the mean loss, not the batch loss anymore (since TF 2.2), fixes #188 --- 11_training_deep_neural_networks.ipynb | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/11_training_deep_neural_networks.ipynb b/11_training_deep_neural_networks.ipynb index 74a4f30..e7d3c36 100644 --- a/11_training_deep_neural_networks.ipynb +++ b/11_training_deep_neural_networks.ipynb @@ -1652,6 +1652,29 @@ " plt.ylabel(\"Loss\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Warning**: In the `on_batch_end()` method, `logs[\"loss\"]` used to contain the batch loss, but in TensorFlow 2.2.0 it was replaced with the mean loss (since the start of the epoch). This explains why the graph below is much smoother than in the book (if you are using TF 2.2 or above). It also means that there is a lag between the moment the batch loss starts exploding and the moment the explosion becomes clear in the graph. So you should choose a slightly smaller learning rate than you would have chosen with the \"noisy\" graph. Alternatively, you can tweak the `ExponentialLearningRate` callback above so it computes the batch loss (based on the current mean loss and the previous mean loss):\n", + "\n", + "```python\n", + "class ExponentialLearningRate(keras.callbacks.Callback):\n", + " def __init__(self, factor):\n", + " self.factor = factor\n", + " self.rates = []\n", + " self.losses = []\n", + " def on_epoch_begin(self, epoch, logs=None):\n", + " self.prev_loss = 0\n", + " def on_batch_end(self, batch, logs=None):\n", + " batch_loss = logs[\"loss\"] * (batch + 1) - self.prev_loss * batch\n", + " self.prev_loss = logs[\"loss\"]\n", + " self.rates.append(K.get_value(self.model.optimizer.lr))\n", + " self.losses.append(batch_loss)\n", + " K.set_value(self.model.optimizer.lr, self.model.optimizer.lr * self.factor)\n", + "```" + ] + }, { "cell_type": "code", "execution_count": 97,