diff --git a/16_nlp_with_rnns_and_attention.ipynb b/16_nlp_with_rnns_and_attention.ipynb index d9bd855..831c835 100644 --- a/16_nlp_with_rnns_and_attention.ipynb +++ b/16_nlp_with_rnns_and_attention.ipynb @@ -4,14 +4,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "**Chapter 15 – Natural Language Processing with RNNs and Attention**" + "**Chapter 16 – Natural Language Processing with RNNs and Attention**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "_This notebook contains all the sample code and solutions to the exercises in chapter 15._" + "_This notebook contains all the sample code and solutions to the exercises in chapter 16._" ] }, { @@ -20,10 +20,10 @@ "source": [ "\n", " \n", " \n", "
\n", - " \"Open\n", + " \"Open\n", " \n", - " \n", + " \n", "
" ] @@ -31,67 +31,134 @@ { "cell_type": "markdown", "metadata": {}, + "source": [ + "# WORK IN PROGRESS\n", + "\n", + "\n", + "**I'm still working on updating this chapter to the 3rd edition. Please come back in a few weeks.**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dFXIv9qNpKzt", + "tags": [] + }, "source": [ "# Setup" ] }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "8IPbJEmZpKzu" + }, "source": [ - "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures." + "This project requires Python 3.8 or above:" ] }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, + "execution_count": null, + "metadata": { + "id": "TFSU3FCOpKzu" + }, "outputs": [], "source": [ - "# Python ≥3.8 is required\n", "import sys\n", - "assert sys.version_info >= (3, 8)\n", "\n", - "# Is this notebook running on Colab or Kaggle?\n", - "IS_COLAB = \"google.colab\" in sys.modules\n", - "IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n", - "\n", - "if IS_COLAB:\n", - " %pip install -q -U tensorflow-addons\n", - " %pip install -q -U transformers\n", - "\n", - "# Common imports\n", - "import numpy as np\n", - "from pathlib import Path\n", - "\n", - "# Scikit-Learn ≥1.0 is required\n", + "assert sys.version_info >= (3, 8)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TAlKky09pKzv" + }, + "source": [ + "It also requires Scikit-Learn ≥ 1.0.1:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YqCwW7cMpKzw" + }, + "outputs": [], + "source": [ "import sklearn\n", - "assert sklearn.__version__ >= \"1.0\"\n", "\n", - "# TensorFlow ≥2.6 is required\n", + "assert sklearn.__version__ >= \"1.0.1\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GJtVEqxfpKzw" + }, + "source": [ + "And TensorFlow ≥ 2.6:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0Piq5se2pKzx" + }, + "outputs": [], + "source": [ "import tensorflow as tf\n", - "assert tf.__version__ >= \"2.6\"\n", "\n", - "if not tf.config.list_physical_devices('GPU'):\n", - " print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n", - " if IS_COLAB:\n", - " print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n", - " if IS_KAGGLE:\n", - " print(\"Go to Settings > Accelerator and select GPU.\")\n", + "assert tf.__version__ >= \"2.6.0\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DDaDoLQTpKzx" + }, + "source": [ + "As we did in earlier chapters, let's define the default font sizes to make the figures prettier:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8d4TH3NbpKzx" + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", "\n", - "# Common imports\n", - "import numpy as np\n", + "plt.rc('font', size=14)\n", + "plt.rc('axes', labelsize=14, titlesize=14)\n", + "plt.rc('legend', fontsize=14)\n", + "plt.rc('xtick', labelsize=10)\n", + "plt.rc('ytick', labelsize=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RcoUIRsvpKzy" + }, + "source": [ + "And let's create the `images/nlp` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PQFH5Y9PpKzy" + }, + "outputs": [], + "source": [ "from pathlib import Path\n", "\n", - "# To plot pretty figures\n", - "%matplotlib inline\n", - "import matplotlib as mpl\n", - "import matplotlib.pyplot as plt\n", - "mpl.rc('axes', labelsize=14)\n", - "mpl.rc('xtick', labelsize=12)\n", - "mpl.rc('ytick', labelsize=12)\n", - "\n", - "# Where to save the figures\n", "IMAGES_PATH = Path() / \"images\" / \"nlp\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n", "\n", @@ -102,6 +169,57 @@ " plt.savefig(path, format=fig_extension, dpi=resolution)" ] }, + { + "cell_type": "markdown", + "metadata": { + "id": "YTsawKlapKzy" + }, + "source": [ + "This chapter can be very slow without a GPU, so let's make sure there's one, or else issue a warning:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ekxzo6pOpKzy" + }, + "outputs": [], + "source": [ + "if not tf.config.list_physical_devices('GPU'):\n", + " print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n", + " if \"google.colab\" in sys.modules:\n", + " print(\"Go to Runtime > Change runtime and select a GPU hardware \"\n", + " \"accelerator.\")\n", + " if \"kaggle_secrets\" in sys.modules:\n", + " print(\"Go to Settings > Accelerator and select GPU.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebooks uses the TensorFlow Addons library, and the Transformers library. If you're running on Colab, then we need to install them now:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if \"google.colab\" in sys.modules:\n", + " %pip install -q -U tensorflow-addons\n", + " %pip install -q -U transformers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -125,7 +243,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "scrolled": true }, @@ -156,7 +274,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -168,7 +286,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -177,7 +295,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -186,7 +304,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -196,7 +314,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -205,7 +323,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -214,7 +332,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -224,7 +342,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -242,7 +360,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -253,7 +371,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -262,7 +380,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -272,7 +390,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -283,7 +401,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -293,7 +411,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -302,7 +420,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -333,7 +451,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -360,7 +478,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -373,24 +491,24 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "**Warning**: the `predict_classes()` method is deprecated. Instead, we must use `np.argmax(model(X_new), axis=-1)`." + "**Warning**: the `predict_classes()` method is deprecated. Instead, we must use `model(X_new).argmax(axis=-1)`." ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "X_new = preprocess([\"How are yo\"])\n", "#Y_pred = model.predict_classes(X_new)\n", - "Y_pred = np.argmax(model(X_new), axis=-1)\n", + "Y_pred = model(X_new).argmax(axis=-1)\n", "tokenizer.sequences_to_texts(Y_pred + 1)[0][-1] # 1st sentence, last char" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -401,7 +519,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -415,7 +533,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -426,7 +544,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -438,7 +556,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -449,7 +567,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -458,7 +576,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -474,7 +592,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -483,7 +601,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -499,7 +617,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -527,7 +645,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -546,7 +664,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -557,7 +675,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -575,7 +693,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -596,7 +714,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -605,7 +723,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -615,7 +733,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -633,7 +751,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -649,7 +767,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -658,7 +776,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -667,7 +785,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -680,7 +798,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -691,7 +809,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -700,7 +818,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -710,7 +828,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -719,7 +837,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -732,7 +850,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -746,7 +864,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -755,7 +873,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -769,7 +887,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -778,7 +896,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -787,7 +905,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -798,7 +916,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -809,7 +927,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -822,7 +940,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -831,7 +949,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -844,7 +962,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -855,7 +973,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -881,7 +999,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -907,7 +1025,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -916,7 +1034,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -926,7 +1044,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -951,7 +1069,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -972,7 +1090,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -981,7 +1099,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1003,7 +1121,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1012,7 +1130,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1022,7 +1140,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1058,7 +1176,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1067,14 +1185,14 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "X = np.random.randint(100, size=10*1000).reshape(1000, 10)\n", - "Y = np.random.randint(100, size=15*1000).reshape(1000, 15)\n", + "X = np.random.randint(100, size=1000 * 10).reshape(1000, 10)\n", + "Y = np.random.randint(100, size=1000 * 15).reshape(1000, 15)\n", "X_decoder = np.c_[np.zeros((1000, 1)), Y[:, :-1]]\n", - "seq_lengths = np.full([1000], 15)\n", + "seq_lengths = np.full([1000], 10)\n", "\n", "history = model.fit([X, X_decoder, seq_lengths], Y, epochs=2)" ] @@ -1088,7 +1206,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1109,7 +1227,7 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1129,7 +1247,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1141,7 +1259,7 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1177,7 +1295,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1201,7 +1319,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1228,7 +1346,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1272,7 +1390,7 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1300,7 +1418,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "See Appendix A." + "1. Stateless RNNs can only capture patterns whose length is less than, or equal to, the size of the windows the RNN is trained on. Conversely, stateful RNNs can capture longer-term patterns. However, implementing a stateful RNN is much harder⁠—especially preparing the dataset properly. Moreover, stateful RNNs do not always work better, in part because consecutive batches are not independent and identically distributed (IID). Gradient Descent is not fond of non-IID datasets.\n", + "2. In general, if you translate a sentence one word at a time, the result will be terrible. For example, the French sentence \"Je vous en prie\" means \"You are welcome,\" but if you translate it one word at a time, you get \"I you in pray.\" Huh? It is much better to read the whole sentence first and then translate it. A plain sequence-to-sequence RNN would start translating a sentence immediately after reading the first word, while an Encoder–Decoder RNN will first read the whole sentence and then translate it. That said, one could imagine a plain sequence-to-sequence RNN that would output silence whenever it is unsure about what to say next (just like human translators do when they must translate a live broadcast).\n", + "3. Variable-length input sequences can be handled by padding the shorter sequences so that all sequences in a batch have the same length, and using masking to ensure the RNN ignores the padding token. For better performance, you may also want to create batches containing sequences of similar sizes. Ragged tensors can hold sequences of variable lengths, and tf.keras will likely support them eventually, which will greatly simplify handling variable-length input sequences (at the time of this writing, it is not the case yet). Regarding variable-length output sequences, if the length of the output sequence is known in advance (e.g., if you know that it is the same as the input sequence), then you just need to configure the loss function so that it ignores tokens that come after the end of the sequence. Similarly, the code that will use the model should ignore tokens beyond the end of the sequence. But generally the length of the output sequence is not known ahead of time, so the solution is to train the model so that it outputs an end-of-sequence token at the end of each sequence.\n", + "4. Beam search is a technique used to improve the performance of a trained Encoder–Decoder model, for example in a neural machine translation system. The algorithm keeps track of a short list of the _k_ most promising output sentences (say, the top three), and at each decoder step it tries to extend them by one word; then it keeps only the _k_ most likely sentences. The parameter _k_ is called the _beam width_: the larger it is, the more CPU and RAM will be used, but also the more accurate the system will be. Instead of greedily choosing the most likely next word at each step to extend a single sentence, this technique allows the system to explore several promising sentences simultaneously. Moreover, this technique lends itself well to parallelization. You can implement beam search fairly easily using TensorFlow Addons.\n", + "5. An attention mechanism is a technique initially used in Encoder–Decoder models to give the decoder more direct access to the input sequence, allowing it to deal with longer input sequences. At each decoder time step, the current decoder's state and the full output of the encoder are processed by an alignment model that outputs an alignment score for each input time step. This score indicates which part of the input is most relevant to the current decoder time step. The weighted sum of the encoder output (weighted by their alignment score) is then fed to the decoder, which produces the next decoder state and the output for this time step. The main benefit of using an attention mechanism is the fact that the Encoder–Decoder model can successfully process longer input sequences. Another benefit is that the alignment scores make the model easier to debug and interpret: for example, if the model makes a mistake, you can look at which part of the input it was paying attention to, and this can help diagnose the issue. An attention mechanism is also at the core of the Transformer architecture, in the Multi-Head Attention layers. See the next answer.\n", + "6. The most important layer in the Transformer architecture is the Multi-Head Attention layer (the original Transformer architecture contains 18 of them, including 6 Masked Multi-Head Attention layers). It is at the core of language models such as BERT and GPT-2. Its purpose is to allow the model to identify which words are most aligned with each other, and then improve each word's representation using these contextual clues.\n", + "7. Sampled softmax is used when training a classification model when there are many classes (e.g., thousands). It computes an approximation of the cross-entropy loss based on the logit predicted by the model for the correct class, and the predicted logits for a sample of incorrect words. This speeds up training considerably compared to computing the softmax over all logits and then estimating the cross-entropy loss. After training, the model can be used normally, using the regular softmax function to compute all the class probabilities based on all the logits." ] }, { @@ -1320,7 +1444,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1363,7 +1487,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1382,7 +1506,7 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1401,7 +1525,7 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1424,7 +1548,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1443,7 +1567,7 @@ }, { "cell_type": "code", - "execution_count": 83, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1453,7 +1577,7 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1469,7 +1593,7 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1487,7 +1611,7 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1506,7 +1630,7 @@ }, { "cell_type": "code", - "execution_count": 87, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1522,7 +1646,7 @@ }, { "cell_type": "code", - "execution_count": 88, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1538,7 +1662,7 @@ }, { "cell_type": "code", - "execution_count": 89, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1567,7 +1691,7 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1606,7 +1730,7 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1637,7 +1761,7 @@ }, { "cell_type": "code", - "execution_count": 92, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1660,7 +1784,7 @@ }, { "cell_type": "code", - "execution_count": 93, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1677,7 +1801,7 @@ }, { "cell_type": "code", - "execution_count": 94, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1693,7 +1817,7 @@ }, { "cell_type": "code", - "execution_count": 95, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1703,7 +1827,7 @@ }, { "cell_type": "code", - "execution_count": 96, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1712,7 +1836,7 @@ }, { "cell_type": "code", - "execution_count": 97, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1721,7 +1845,7 @@ }, { "cell_type": "code", - "execution_count": 98, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1737,7 +1861,7 @@ }, { "cell_type": "code", - "execution_count": 99, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1750,7 +1874,7 @@ }, { "cell_type": "code", - "execution_count": 100, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1770,12 +1894,12 @@ "source": [ "Let's first try the simplest possible model: we feed in the input sequence, which first goes through the encoder (an embedding layer followed by a single LSTM layer), which outputs a vector, then it goes through a decoder (a single LSTM layer, followed by a dense output layer), which outputs a sequence of vectors, each representing the estimated probabilities for all possible output character.\n", "\n", - "Since the decoder expects a sequence as input, we repeat the vector (which is output by the decoder) as many times as the longest possible output sequence." + "Since the decoder expects a sequence as input, we repeat the vector (which is output by the encoder) as many times as the longest possible output sequence." ] }, { "cell_type": "code", - "execution_count": 101, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1819,7 +1943,7 @@ }, { "cell_type": "code", - "execution_count": 102, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1837,7 +1961,7 @@ }, { "cell_type": "code", - "execution_count": 103, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1846,12 +1970,12 @@ }, { "cell_type": "code", - "execution_count": 104, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#ids = model.predict_classes(X_new)\n", - "ids = np.argmax(model.predict(X_new), axis=-1)\n", + "ids = model.predict(X_new).argmax(axis=-1)\n", "for date_str in ids_to_date_strs(ids):\n", " print(date_str)" ] @@ -1872,7 +1996,7 @@ }, { "cell_type": "code", - "execution_count": 105, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1881,12 +2005,12 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#ids = model.predict_classes(X_new)\n", - "ids = np.argmax(model.predict(X_new), axis=-1)\n", + "ids = model.predict(X_new).argmax(axis=-1)\n", "for date_str in ids_to_date_strs(ids):\n", " print(date_str)" ] @@ -1900,7 +2024,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1915,13 +2039,13 @@ "def convert_date_strs(date_strs):\n", " X = prepare_date_strs_padded(date_strs)\n", " #ids = model.predict_classes(X)\n", - " ids = np.argmax(model.predict(X), axis=-1)\n", + " ids = model.predict(X).argmax(axis=-1)\n", " return ids_to_date_strs(ids)" ] }, { "cell_type": "code", - "execution_count": 108, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1966,7 +2090,7 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1990,7 +2114,7 @@ }, { "cell_type": "code", - "execution_count": 110, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2006,7 +2130,7 @@ }, { "cell_type": "code", - "execution_count": 111, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2060,7 +2184,7 @@ }, { "cell_type": "code", - "execution_count": 112, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2080,7 +2204,7 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2110,7 +2234,7 @@ }, { "cell_type": "code", - "execution_count": 114, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2169,7 +2293,7 @@ }, { "cell_type": "code", - "execution_count": 115, + "execution_count": null, "metadata": { "scrolled": true }, @@ -2189,7 +2313,7 @@ }, { "cell_type": "code", - "execution_count": 116, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2230,7 +2354,7 @@ }, { "cell_type": "code", - "execution_count": 117, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2242,7 +2366,7 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": null, "metadata": { "scrolled": true }, @@ -2260,7 +2384,7 @@ }, { "cell_type": "code", - "execution_count": 119, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2269,7 +2393,7 @@ }, { "cell_type": "code", - "execution_count": 120, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2306,7 +2430,7 @@ }, { "cell_type": "code", - "execution_count": 121, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2387,7 +2511,7 @@ }, { "cell_type": "code", - "execution_count": 122, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2413,7 +2537,7 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2426,7 +2550,7 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2444,7 +2568,7 @@ }, { "cell_type": "code", - "execution_count": 125, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2486,7 +2610,7 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2556,7 +2680,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2582,7 +2706,7 @@ }, { "cell_type": "code", - "execution_count": 128, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2596,7 +2720,7 @@ }, { "cell_type": "code", - "execution_count": 129, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2608,7 +2732,7 @@ "metadata": {}, "source": [ "There are still a few interesting features from TF-Addons that you may want to look at:\n", - "* Using a `BeamSearchDecoder` rather than a `BasicDecoder` for inference. Instead of outputing the character with the highest probability, this decoder keeps track of the several candidates, and keeps only the most likely sequences of candidates (see chapter 15 in the book for more details).\n", + "* Using a `BeamSearchDecoder` rather than a `BasicDecoder` for inference. Instead of outputing the character with the highest probability, this decoder keeps track of the several candidates, and keeps only the most likely sequences of candidates (see chapter 16 in the book for more details).\n", "* Setting masks or specifying `sequence_length` if the input or target sequences may have very different lengths.\n", "* Using a `ScheduledOutputTrainingSampler`, which gives you more flexibility than the `ScheduledEmbeddingTrainingSampler` to decide how to feed the output at time _t_ to the cell at time _t_+1. By default it feeds the outputs directly to cell, without computing the argmax ID and passing it through an embedding layer. Alternatively, you specify a `next_inputs_fn` function that will be used to convert the cell outputs to inputs at the next step." ] @@ -2652,7 +2776,7 @@ }, { "cell_type": "code", - "execution_count": 130, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2670,7 +2794,7 @@ }, { "cell_type": "code", - "execution_count": 131, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2688,7 +2812,7 @@ }, { "cell_type": "code", - "execution_count": 132, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2708,7 +2832,7 @@ }, { "cell_type": "code", - "execution_count": 133, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2738,7 +2862,7 @@ }, { "cell_type": "code", - "execution_count": 134, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2765,7 +2889,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, diff --git a/17_autoencoders_and_gans.ipynb b/17_autoencoders_and_gans.ipynb index 596b2ab..dcc4d03 100644 --- a/17_autoencoders_and_gans.ipynb +++ b/17_autoencoders_and_gans.ipynb @@ -4,14 +4,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "**Chapter 16 – Autoencoders and GANs**" + "**Chapter 17 – Autoencoders and GANs**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "_This notebook contains all the sample code and solutions to the exercises in chapter 16._" + "_This notebook contains all the sample code and solutions to the exercises in chapter 17._" ] }, { @@ -20,74 +20,147 @@ "source": [ "\n", " \n", " \n", "
\n", - " \"Open\n", + " \"Open\n", " \n", - " \n", + " \n", "
" ] }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "tags": [] + }, + "source": [ + "# WORK IN PROGRESS\n", + "\n", + "\n", + "**I'm still working on updating this chapter to the 3rd edition. Please come back in a few weeks.**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dFXIv9qNpKzt", + "tags": [] + }, "source": [ "# Setup" ] }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "8IPbJEmZpKzu" + }, "source": [ - "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures." + "This project requires Python 3.8 or above:" ] }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, + "execution_count": null, + "metadata": { + "id": "TFSU3FCOpKzu" + }, "outputs": [], "source": [ - "# Python ≥3.8 is required\n", "import sys\n", - "assert sys.version_info >= (3, 8)\n", "\n", - "# Is this notebook running on Colab or Kaggle?\n", - "IS_COLAB = \"google.colab\" in sys.modules\n", - "IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n", + "assert sys.version_info >= (3, 8)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TAlKky09pKzv" + }, + "source": [ + "It also requires Scikit-Learn ≥ 1.0.1:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YqCwW7cMpKzw" + }, + "outputs": [], + "source": [ + "import sklearn\n", "\n", - "# Common imports\n", - "import numpy as np\n", + "assert sklearn.__version__ >= \"1.0.1\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GJtVEqxfpKzw" + }, + "source": [ + "And TensorFlow ≥ 2.6:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0Piq5se2pKzx" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "\n", + "assert tf.__version__ >= \"2.6.0\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DDaDoLQTpKzx" + }, + "source": [ + "As we did in earlier chapters, let's define the default font sizes to make the figures prettier:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8d4TH3NbpKzx" + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.rc('font', size=14)\n", + "plt.rc('axes', labelsize=14, titlesize=14)\n", + "plt.rc('legend', fontsize=14)\n", + "plt.rc('xtick', labelsize=10)\n", + "plt.rc('ytick', labelsize=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RcoUIRsvpKzy" + }, + "source": [ + "And let's create the `images/autoencoders` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PQFH5Y9PpKzy" + }, + "outputs": [], + "source": [ "from pathlib import Path\n", "\n", - "# Scikit-Learn ≥1.0 is required\n", - "import sklearn\n", - "assert sklearn.__version__ >= \"1.0\"\n", - "\n", - "# TensorFlow ≥2.6 is required\n", - "import tensorflow as tf\n", - "assert tf.__version__ >= \"2.6\"\n", - "\n", - "# to make this notebook's output stable across runs\n", - "np.random.seed(42)\n", - "tf.random.set_seed(42)\n", - "\n", - "if not tf.config.list_physical_devices('GPU'):\n", - " print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n", - " if IS_COLAB:\n", - " print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n", - " if IS_KAGGLE:\n", - " print(\"Go to Settings > Accelerator and select GPU.\")\n", - "\n", - "# To plot pretty figures\n", - "%matplotlib inline\n", - "import matplotlib as mpl\n", - "import matplotlib.pyplot as plt\n", - "mpl.rc('axes', labelsize=14)\n", - "mpl.rc('xtick', labelsize=12)\n", - "mpl.rc('ytick', labelsize=12)\n", - "\n", - "# Where to save the figures\n", "IMAGES_PATH = Path() / \"images\" / \"autoencoders\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n", "\n", @@ -100,20 +173,28 @@ }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "YTsawKlapKzy" + }, "source": [ - "A couple utility functions to plot grayscale 28x28 image:" + "This chapter can be very slow without a GPU, so let's make sure there's one, or else issue a warning:" ] }, { "cell_type": "code", - "execution_count": 2, - "metadata": {}, + "execution_count": null, + "metadata": { + "id": "Ekxzo6pOpKzy" + }, "outputs": [], "source": [ - "def plot_image(image):\n", - " plt.imshow(image, cmap=\"binary\")\n", - " plt.axis(\"off\")" + "if not tf.config.list_physical_devices('GPU'):\n", + " print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n", + " if \"google.colab\" in sys.modules:\n", + " print(\"Go to Runtime > Change runtime and select a GPU hardware \"\n", + " \"accelerator.\")\n", + " if \"kaggle_secrets\" in sys.modules:\n", + " print(\"Go to Settings > Accelerator and select GPU.\")" ] }, { @@ -132,7 +213,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -159,7 +240,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -175,7 +256,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -184,7 +265,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -193,7 +274,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -222,7 +303,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -249,7 +330,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -259,7 +340,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -292,7 +373,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -308,7 +389,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -325,7 +406,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -341,7 +422,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -359,7 +440,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -369,8 +450,8 @@ "plt.scatter(X_valid_2D[:, 0], X_valid_2D[:, 1], c=y_valid, s=10, cmap=cmap)\n", "image_positions = np.array([[1., 1.]])\n", "for index, position in enumerate(X_valid_2D):\n", - " dist = np.sum((position - image_positions) ** 2, axis=1)\n", - " if np.min(dist) > 0.02: # if far enough from other images\n", + " dist = ((position - image_positions) ** 2).sum(axis=1)\n", + " if dist.min() > 0.02: # if far enough from other images\n", " image_positions = np.r_[image_positions, [position]]\n", " imagebox = mpl.offsetbox.AnnotationBbox(\n", " mpl.offsetbox.OffsetImage(X_valid[index], cmap=\"binary\"),\n", @@ -397,7 +478,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -418,7 +499,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -451,7 +532,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": { "scrolled": true }, @@ -470,7 +551,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -492,7 +573,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -513,7 +594,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -526,7 +607,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -536,7 +617,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -548,7 +629,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -572,7 +653,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -605,7 +686,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -615,7 +696,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -632,7 +713,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -652,7 +733,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -661,7 +742,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -685,7 +766,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -712,7 +793,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -733,7 +814,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -760,7 +841,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -788,7 +869,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -814,7 +895,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -831,7 +912,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -842,13 +923,13 @@ " ax.bar(x, counts / len(data), width=widths*0.8)\n", " ax.xaxis.set_ticks(bins)\n", " ax.yaxis.set_major_formatter(mpl.ticker.FuncFormatter(\n", - " lambda y, position: \"{}%\".format(int(np.round(100 * y)))))\n", + " lambda y, position: \"{}%\".format(round(100 * y))))\n", " ax.grid(True)" ] }, { "cell_type": "code", - "execution_count": 38, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -881,7 +962,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -898,7 +979,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -927,7 +1008,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -936,7 +1017,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -953,7 +1034,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -976,7 +1057,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -996,7 +1077,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1023,7 +1104,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1032,7 +1113,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1050,7 +1131,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1062,7 +1143,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1103,7 +1184,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": null, "metadata": { "scrolled": true }, @@ -1122,7 +1203,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1130,7 +1211,7 @@ " n_cols = n_cols or len(images)\n", " n_rows = (len(images) - 1) // n_cols + 1\n", " if images.shape[-1] == 1:\n", - " images = np.squeeze(images, axis=-1)\n", + " images = images.squeeze(axis=-1)\n", " plt.figure(figsize=(n_cols, n_rows))\n", " for index, image in enumerate(images):\n", " plt.subplot(n_rows, n_cols, index + 1)\n", @@ -1147,7 +1228,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1168,7 +1249,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1201,7 +1282,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1227,7 +1308,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1238,7 +1319,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1249,7 +1330,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1276,7 +1357,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1285,7 +1366,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1300,7 +1381,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1316,7 +1397,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1351,7 +1432,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1362,7 +1443,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1371,7 +1452,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1383,7 +1464,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1392,7 +1473,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1428,7 +1509,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1448,7 +1529,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1482,7 +1563,7 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1499,7 +1580,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1516,11 +1597,11 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "hashes = np.round(hashing_encoder.predict(X_valid)).astype(np.int32)\n", + "hashes = hashing_encoder.predict(X_valid).round().astype(np.int32)\n", "hashes *= np.array([[2**bit for bit in range(16)]])\n", "hashes = hashes.sum(axis=1)\n", "for h in hashes[:5]:\n", @@ -1537,7 +1618,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1568,9 +1649,26 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 1. to 8.\n", - "\n", - "See Appendix A." + "## 1. to 8." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. Here are some of the main tasks that autoencoders are used for:\n", + " * Feature extraction\n", + " * Unsupervised pretraining\n", + " * Dimensionality reduction\n", + " * Generative models\n", + " * Anomaly detection (an autoencoder is generally bad at reconstructing outliers)\n", + "2. If you want to train a classifier and you have plenty of unlabeled training data but only a few thousand labeled instances, then you could first train a deep autoencoder on the full dataset (labeled + unlabeled), then reuse its lower half for the classifier (i.e., reuse the layers up to the codings layer, included) and train the classifier using the labeled data. If you have little labeled data, you probably want to freeze the reused layers when training the classifier.\n", + "3. The fact that an autoencoder perfectly reconstructs its inputs does not necessarily mean that it is a good autoencoder; perhaps it is simply an overcomplete autoencoder that learned to copy its inputs to the codings layer and then to the outputs. In fact, even if the codings layer contained a single neuron, it would be possible for a very deep autoencoder to learn to map each training instance to a different coding (e.g., the first instance could be mapped to 0.001, the second to 0.002, the third to 0.003, and so on), and it could learn \"by heart\" to reconstruct the right training instance for each coding. It would perfectly reconstruct its inputs without really learning any useful pattern in the data. In practice such a mapping is unlikely to happen, but it illustrates the fact that perfect reconstructions are not a guarantee that the autoencoder learned anything useful. However, if it produces very bad reconstructions, then it is almost guaranteed to be a bad autoencoder. To evaluate the performance of an autoencoder, one option is to measure the reconstruction loss (e.g., compute the MSE, or the mean square of the outputs minus the inputs). Again, a high reconstruction loss is a good sign that the autoencoder is bad, but a low reconstruction loss is not a guarantee that it is good. You should also evaluate the autoencoder according to what it will be used for. For example, if you are using it for unsupervised pretraining of a classifier, then you should also evaluate the classifier's performance.\n", + "4. An undercomplete autoencoder is one whose codings layer is smaller than the input and output layers. If it is larger, then it is an overcomplete autoencoder. The main risk of an excessively undercomplete autoencoder is that it may fail to reconstruct the inputs. The main risk of an overcomplete autoencoder is that it may just copy the inputs to the outputs, without learning any useful features.\n", + "5. To tie the weights of an encoder layer and its corresponding decoder layer, you simply make the decoder weights equal to the transpose of the encoder weights. This reduces the number of parameters in the model by half, often making training converge faster with less training data and reducing the risk of overfitting the training set.\n", + "6. A generative model is a model capable of randomly generating outputs that resemble the training instances. For example, once trained successfully on the MNIST dataset, a generative model can be used to randomly generate realistic images of digits. The output distribution is typically similar to the training data. For example, since MNIST contains many images of each digit, the generative model would output roughly the same number of images of each digit. Some generative models can be parametrized—for example, to generate only some kinds of outputs. An example of a generative autoencoder is the variational autoencoder.\n", + "7. A generative adversarial network is a neural network architecture composed of two parts, the generator and the discriminator, which have opposing objectives. The generator's goal is to generate instances similar to those in the training set, to fool the discriminator. The discriminator must distinguish the real instances from the generated ones. At each training iteration, the discriminator is trained like a normal binary classifier, then the generator is trained to maximize the discriminator's error. GANs are used for advanced image processing tasks such as super resolution, colorization, image editing (replacing objects with realistic background), turning a simple sketch into a photorealistic image, or predicting the next frames in a video. They are also used to augment a dataset (to train other models), to generate other types of data (such as text, audio, and time series), and to identify the weaknesses in other models and strengthen them.\n", + "8. Training GANs is notoriously difficult, because of the complex dynamics between the generator and the discriminator. The biggest difficulty is mode collapse, where the generator produces outputs with very little diversity. Moreover, training can be terribly unstable: it may start out fine and then suddenly start oscillating or diverging, without any apparent reason. GANs are also very sensitive to the choice of hyperparameters." ] }, { @@ -1586,7 +1684,7 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1597,7 +1695,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1615,7 +1713,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1624,7 +1722,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1638,7 +1736,7 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1647,7 +1745,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1660,7 +1758,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1677,7 +1775,7 @@ " if index == 0:\n", " plt.title(\"Original\")\n", " plt.subplot(n_images, 3, index * 3 + 2)\n", - " plt.imshow(np.clip(new_images_noisy[index], 0., 1.))\n", + " plt.imshow(new_images_noisy[index].clip(0., 1.))\n", " plt.axis('off')\n", " if index == 0:\n", " plt.title(\"Noisy\")\n", @@ -1757,7 +1855,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, diff --git a/18_reinforcement_learning.ipynb b/18_reinforcement_learning.ipynb index 12f185e..220f5b0 100644 --- a/18_reinforcement_learning.ipynb +++ b/18_reinforcement_learning.ipynb @@ -4,14 +4,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "**Chapter 17 – Reinforcement Learning**" + "**Chapter 18 – Reinforcement Learning**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "_This notebook contains all the sample code and solutions to the exercises in chapter 17._" + "_This notebook contains all the sample code and solutions to the exercises in chapter 18._" ] }, { @@ -20,84 +20,147 @@ "source": [ "\n", " \n", " \n", "
\n", - " \"Open\n", + " \"Open\n", " \n", - " \n", + " \n", "
" ] }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "tags": [] + }, + "source": [ + "# WORK IN PROGRESS\n", + "\n", + "\n", + "**I'm still working on updating this chapter to the 3rd edition. Please come back in a few weeks.**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dFXIv9qNpKzt", + "tags": [] + }, "source": [ "# Setup" ] }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "8IPbJEmZpKzu" + }, "source": [ - "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures." + "This project requires Python 3.8 or above:" ] }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, + "execution_count": null, + "metadata": { + "id": "TFSU3FCOpKzu" + }, "outputs": [], "source": [ - "# Python ≥3.8 is required\n", "import sys\n", - "assert sys.version_info >= (3, 8)\n", "\n", - "# Is this notebook running on Colab or Kaggle?\n", - "IS_COLAB = \"google.colab\" in sys.modules\n", - "IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n", + "assert sys.version_info >= (3, 8)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TAlKky09pKzv" + }, + "source": [ + "It also requires Scikit-Learn ≥ 1.0.1:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YqCwW7cMpKzw" + }, + "outputs": [], + "source": [ + "import sklearn\n", "\n", - "if IS_COLAB or IS_KAGGLE:\n", - " !apt update && apt install -y libpq-dev libsdl2-dev swig xorg-dev xvfb\n", - " %pip install -U tf-agents pyvirtualdisplay\n", - " %pip install -U gym>=0.21.0\n", - " %pip install -U gym[box2d,atari,accept-rom-license]\n", + "assert sklearn.__version__ >= \"1.0.1\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GJtVEqxfpKzw" + }, + "source": [ + "And TensorFlow ≥ 2.6:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0Piq5se2pKzx" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", "\n", - "# Common imports\n", - "import numpy as np\n", + "assert tf.__version__ >= \"2.6.0\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DDaDoLQTpKzx" + }, + "source": [ + "As we did in earlier chapters, let's define the default font sizes to make the figures prettier:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8d4TH3NbpKzx" + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.rc('font', size=14)\n", + "plt.rc('axes', labelsize=14, titlesize=14)\n", + "plt.rc('legend', fontsize=14)\n", + "plt.rc('xtick', labelsize=10)\n", + "plt.rc('ytick', labelsize=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RcoUIRsvpKzy" + }, + "source": [ + "And let's create the `images/rl` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PQFH5Y9PpKzy" + }, + "outputs": [], + "source": [ "from pathlib import Path\n", "\n", - "# Scikit-Learn ≥1.0 is required\n", - "import sklearn\n", - "assert sklearn.__version__ >= \"1.0\"\n", - "\n", - "# TensorFlow ≥2.6 is required\n", - "import tensorflow as tf\n", - "assert tf.__version__ >= \"2.6\"\n", - "\n", - "# to make this notebook's output stable across runs\n", - "np.random.seed(42)\n", - "tf.random.set_seed(42)\n", - "\n", - "if not tf.config.list_physical_devices('GPU'):\n", - " print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n", - " if IS_COLAB:\n", - " print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n", - " if IS_KAGGLE:\n", - " print(\"Go to Settings > Accelerator and select GPU.\")\n", - "\n", - "# To plot pretty figures\n", - "%matplotlib inline\n", - "import matplotlib as mpl\n", - "import matplotlib.pyplot as plt\n", - "mpl.rc('axes', labelsize=14)\n", - "mpl.rc('xtick', labelsize=12)\n", - "mpl.rc('ytick', labelsize=12)\n", - "\n", - "# To get smooth animations\n", - "import matplotlib.animation as animation\n", - "mpl.rc('animation', html='jshtml')\n", - "\n", - "# Where to save the figures\n", "IMAGES_PATH = Path() / \"images\" / \"rl\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n", "\n", @@ -108,6 +171,51 @@ " plt.savefig(path, format=fig_extension, dpi=resolution)" ] }, + { + "cell_type": "markdown", + "metadata": { + "id": "YTsawKlapKzy" + }, + "source": [ + "This chapter can be very slow without a GPU, so let's make sure there's one, or else issue a warning:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ekxzo6pOpKzy" + }, + "outputs": [], + "source": [ + "if not tf.config.list_physical_devices('GPU'):\n", + " print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n", + " if \"google.colab\" in sys.modules:\n", + " print(\"Go to Runtime > Change runtime and select a GPU hardware \"\n", + " \"accelerator.\")\n", + " if \"kaggle_secrets\" in sys.modules:\n", + " print(\"Go to Settings > Accelerator and select GPU.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's install the gym library, which provides many environments for Reinforcement Learning. Some of these environments require an X server to plot graphics, so we need to install xvfb on Colab or Kaggle (that's an in-memory X server, since the runtimes are not hooked to a screen). We also need to install pyvirtualdisplay, which provides a Python interface to xvfb. And let's also install the Box2D and Atari environments. By running the following cell, you also accept the Atari ROM license." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if \"google.colab\" in sys.modules or \"kaggle_secrets\" in sys.modules:\n", + " !apt update &> /dev/null && apt install -y xvfb &> /dev/null\n", + " %pip install -q -U gym pyglet pyvirtualdisplay\n", + " %pip install -q -U gym[box2d,atari,accept-rom-license]" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -124,7 +232,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -140,7 +248,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -156,7 +264,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -172,7 +280,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -189,7 +297,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -231,7 +339,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -244,7 +352,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -260,7 +368,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -270,7 +378,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -284,7 +392,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -301,7 +409,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -324,7 +432,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -342,7 +450,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -366,7 +474,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -382,7 +490,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -398,7 +506,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -414,7 +522,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -445,7 +553,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -470,11 +578,11 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "np.mean(totals), np.std(totals), np.min(totals), np.max(totals)" + "np.mean(totals), np.std(totals), min(totals), max(totals)" ] }, { @@ -493,7 +601,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -521,7 +629,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -542,7 +650,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -572,7 +680,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -611,7 +719,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -641,7 +749,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -665,7 +773,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -701,7 +809,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -741,7 +849,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -772,7 +880,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -803,7 +911,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -832,7 +940,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -848,7 +956,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": null, "metadata": { "scrolled": true }, @@ -859,7 +967,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -871,7 +979,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -881,7 +989,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -897,7 +1005,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -926,7 +1034,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -943,7 +1051,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -989,7 +1097,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1013,7 +1121,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1024,7 +1132,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1038,7 +1146,7 @@ " for a in possible_actions[s]:\n", " Q_values[s, a] = np.sum([\n", " transition_probabilities[s][a][sp]\n", - " * (rewards[s][a][sp] + gamma * np.max(Q_prev[sp]))\n", + " * (rewards[s][a][sp] + gamma * Q_prev[sp].max())\n", " for sp in range(3)])\n", "\n", "history1 = np.array(history1) # Not shown" @@ -1046,7 +1154,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1055,11 +1163,11 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "np.argmax(Q_values, axis=1)" + "Q_values.argmax(axis=1)" ] }, { @@ -1078,7 +1186,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1089,7 +1197,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1101,13 +1209,13 @@ " for a in possible_actions[s]:\n", " Q_values[s, a] = np.sum([\n", " transition_probabilities[s][a][sp]\n", - " * (rewards[s][a][sp] + gamma * np.max(Q_prev[sp]))\n", + " * (rewards[s][a][sp] + gamma * Q_prev[sp].max())\n", " for sp in range(3)])" ] }, { "cell_type": "code", - "execution_count": 47, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1116,11 +1224,11 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "np.argmax(Q_values, axis=1)" + "Q_values.argmax(axis=1)" ] }, { @@ -1153,7 +1261,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1173,7 +1281,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1190,7 +1298,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1210,7 +1318,7 @@ " history2.append(Q_values.copy()) # Not shown\n", " action = exploration_policy(state)\n", " next_state, reward = step(state, action)\n", - " next_value = np.max(Q_values[next_state]) # greedy policy at the next step\n", + " next_value = Q_values[next_state].max() # greedy policy at the next step\n", " alpha = alpha0 / (1 + iteration * decay)\n", " Q_values[state, action] *= 1 - alpha\n", " Q_values[state, action] += alpha * (reward + gamma * next_value)\n", @@ -1221,7 +1329,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1230,16 +1338,16 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "np.argmax(Q_values, axis=1) # optimal action for each state" + "Q_values.argmax(axis=1) # optimal action for each state" ] }, { "cell_type": "code", - "execution_count": 54, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1274,7 +1382,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1302,7 +1410,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1311,7 +1419,7 @@ " return np.random.randint(n_outputs)\n", " else:\n", " Q_values = model.predict(state[np.newaxis])\n", - " return np.argmax(Q_values[0])" + " return Q_values[0].argmax()" ] }, { @@ -1323,7 +1431,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1341,7 +1449,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1363,7 +1471,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1387,7 +1495,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1400,7 +1508,7 @@ " experiences = sample_experiences(batch_size)\n", " states, actions, rewards, next_states, dones = experiences\n", " next_Q_values = model.predict(next_states)\n", - " max_next_Q_values = np.max(next_Q_values, axis=1)\n", + " max_next_Q_values = next_Q_values.max(axis=1)\n", " target_Q_values = (rewards +\n", " (1 - dones) * discount_rate * max_next_Q_values)\n", " target_Q_values = target_Q_values.reshape(-1, 1)\n", @@ -1422,7 +1530,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1436,7 +1544,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": null, "metadata": { "scrolled": true }, @@ -1462,7 +1570,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1476,7 +1584,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1512,7 +1620,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1532,7 +1640,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1545,7 +1653,7 @@ " experiences = sample_experiences(batch_size)\n", " states, actions, rewards, next_states, dones = experiences\n", " next_Q_values = model.predict(next_states)\n", - " best_next_actions = np.argmax(next_Q_values, axis=1)\n", + " best_next_actions = next_Q_values.argmax(axis=1)\n", " next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n", " next_best_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n", " target_Q_values = (rewards + \n", @@ -1562,7 +1670,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1571,7 +1679,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1611,7 +1719,7 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1625,7 +1733,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": null, "metadata": { "scrolled": true }, @@ -1656,7 +1764,7 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1680,7 +1788,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1693,7 +1801,7 @@ " experiences = sample_experiences(batch_size)\n", " states, actions, rewards, next_states, dones = experiences\n", " next_Q_values = model.predict(next_states)\n", - " best_next_actions = np.argmax(next_Q_values, axis=1)\n", + " best_next_actions = next_Q_values.argmax(axis=1)\n", " next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n", " next_best_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n", " target_Q_values = (rewards + \n", @@ -1710,7 +1818,7 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1719,7 +1827,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1752,7 +1860,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1764,7 +1872,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": null, "metadata": { "scrolled": true }, @@ -1795,1021 +1903,13 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "env.close()" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Using TF-Agents to Beat Breakout" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's use TF-Agents to create an agent that will learn to play Breakout. We will use the Deep Q-Learning algorithm, so you can easily compare the components with the previous implementation, but TF-Agents implements many other (and more sophisticated) algorithms!" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "First, we need to download and install the Atari ROMs. This can be done very easily using the AutoROM tool, if you accept the license:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "!{sys.prefix}/bin/AutoROM --quiet --accept-license" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## TF-Agents Environments" - ] - }, - { - "cell_type": "code", - "execution_count": 78, - "metadata": {}, - "outputs": [], - "source": [ - "tf.random.set_seed(42)\n", - "np.random.seed(42)" - ] - }, - { - "cell_type": "code", - "execution_count": 79, - "metadata": {}, - "outputs": [], - "source": [ - "from tf_agents.environments import suite_gym\n", - "\n", - "env = suite_gym.load(\"Breakout-v4\")\n", - "env" - ] - }, - { - "cell_type": "code", - "execution_count": 80, - "metadata": {}, - "outputs": [], - "source": [ - "env.gym" - ] - }, - { - "cell_type": "code", - "execution_count": 81, - "metadata": {}, - "outputs": [], - "source": [ - "env.seed(42)\n", - "env.reset()" - ] - }, - { - "cell_type": "code", - "execution_count": 82, - "metadata": {}, - "outputs": [], - "source": [ - "env.step(1) # Fire" - ] - }, - { - "cell_type": "code", - "execution_count": 83, - "metadata": {}, - "outputs": [], - "source": [ - "img = env.render(mode=\"rgb_array\")\n", - "\n", - "plt.figure(figsize=(6, 8))\n", - "plt.imshow(img)\n", - "plt.axis(\"off\")\n", - "save_fig(\"breakout_plot\")\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 84, - "metadata": {}, - "outputs": [], - "source": [ - "env.current_time_step()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Environment Specifications" - ] - }, - { - "cell_type": "code", - "execution_count": 85, - "metadata": {}, - "outputs": [], - "source": [ - "env.observation_spec()" - ] - }, - { - "cell_type": "code", - "execution_count": 86, - "metadata": {}, - "outputs": [], - "source": [ - "env.action_spec()" - ] - }, - { - "cell_type": "code", - "execution_count": 87, - "metadata": {}, - "outputs": [], - "source": [ - "env.time_step_spec()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Environment Wrappers" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can wrap a TF-Agents environments in a TF-Agents wrapper:" - ] - }, - { - "cell_type": "code", - "execution_count": 88, - "metadata": {}, - "outputs": [], - "source": [ - "from tf_agents.environments.wrappers import ActionRepeat\n", - "\n", - "repeating_env = ActionRepeat(env, times=4)\n", - "repeating_env" - ] - }, - { - "cell_type": "code", - "execution_count": 89, - "metadata": {}, - "outputs": [], - "source": [ - "repeating_env.unwrapped" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Here is the list of available wrappers:" - ] - }, - { - "cell_type": "code", - "execution_count": 90, - "metadata": {}, - "outputs": [], - "source": [ - "import tf_agents.environments.wrappers\n", - "\n", - "for name in dir(tf_agents.environments.wrappers):\n", - " obj = getattr(tf_agents.environments.wrappers, name)\n", - " if hasattr(obj, \"__base__\") and issubclass(obj, tf_agents.environments.wrappers.PyEnvironmentBaseWrapper):\n", - " print(\"{:27s} {}\".format(name, obj.__doc__.split(\"\\n\")[0]))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `suite_gym.load()` function can create an env and wrap it for you, both with TF-Agents environment wrappers and Gym environment wrappers (the latter are applied first)." - ] - }, - { - "cell_type": "code", - "execution_count": 91, - "metadata": {}, - "outputs": [], - "source": [ - "from functools import partial\n", - "from gym.wrappers import TimeLimit\n", - "\n", - "limited_repeating_env = suite_gym.load(\n", - " \"Breakout-v4\",\n", - " gym_env_wrappers=[partial(TimeLimit, max_episode_steps=10000)],\n", - " env_wrappers=[partial(ActionRepeat, times=4)],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 92, - "metadata": {}, - "outputs": [], - "source": [ - "limited_repeating_env" - ] - }, - { - "cell_type": "code", - "execution_count": 93, - "metadata": {}, - "outputs": [], - "source": [ - "limited_repeating_env.unwrapped" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Create an Atari Breakout environment, and wrap it to apply the default Atari preprocessing steps:" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Warning**: Breakout requires the player to press the FIRE button at the start of the game and after each life lost. The agent may take a very long time learning this because at first it seems that pressing FIRE just means losing faster. To speed up training considerably, we create and use a subclass of the `AtariPreprocessing` wrapper class called `AtariPreprocessingWithAutoFire` which presses FIRE (i.e., plays action 1) automatically at the start of the game and after each life lost. This is different from the book which uses the regular `AtariPreprocessing` wrapper." - ] - }, - { - "cell_type": "code", - "execution_count": 94, - "metadata": {}, - "outputs": [], - "source": [ - "from tf_agents.environments import suite_atari\n", - "from tf_agents.environments.atari_preprocessing import AtariPreprocessing\n", - "from tf_agents.environments.atari_wrappers import FrameStack4\n", - "\n", - "max_episode_steps = 27000 # <=> 108k ALE frames since 1 step = 4 frames\n", - "environment_name = \"BreakoutNoFrameskip-v4\"\n", - "\n", - "class AtariPreprocessingWithAutoFire(AtariPreprocessing):\n", - " def reset(self, **kwargs):\n", - " obs = super().reset(**kwargs)\n", - " super().step(1) # FIRE to start\n", - " return obs\n", - " def step(self, action):\n", - " lives_before_action = self.ale.lives()\n", - " obs, rewards, done, info = super().step(action)\n", - " if self.ale.lives() < lives_before_action and not done:\n", - " super().step(1) # FIRE to start after life lost\n", - " return obs, rewards, done, info\n", - "\n", - "env = suite_atari.load(\n", - " environment_name,\n", - " max_episode_steps=max_episode_steps,\n", - " gym_env_wrappers=[AtariPreprocessingWithAutoFire, FrameStack4])" - ] - }, - { - "cell_type": "code", - "execution_count": 95, - "metadata": {}, - "outputs": [], - "source": [ - "env" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Play a few steps just to see what happens:" - ] - }, - { - "cell_type": "code", - "execution_count": 96, - "metadata": {}, - "outputs": [], - "source": [ - "env.seed(42)\n", - "env.reset()\n", - "for _ in range(4):\n", - " time_step = env.step(3) # LEFT" - ] - }, - { - "cell_type": "code", - "execution_count": 97, - "metadata": {}, - "outputs": [], - "source": [ - "def plot_observation(obs):\n", - " # Since there are only 3 color channels, you cannot display 4 frames\n", - " # with one primary color per frame. So this code computes the delta between\n", - " # the current frame and the mean of the other frames, and it adds this delta\n", - " # to the red and blue channels to get a pink color for the current frame.\n", - " obs = obs.astype(np.float32)\n", - " img = obs[..., :3]\n", - " current_frame_delta = np.maximum(obs[..., 3] - obs[..., :3].mean(axis=-1), 0.)\n", - " img[..., 0] += current_frame_delta\n", - " img[..., 2] += current_frame_delta\n", - " img = np.clip(img / 150, 0, 1)\n", - " plt.imshow(img)\n", - " plt.axis(\"off\")" - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "metadata": {}, - "outputs": [], - "source": [ - "plt.figure(figsize=(6, 6))\n", - "plot_observation(time_step.observation)\n", - "save_fig(\"preprocessed_breakout_plot\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Convert the Python environment to a TF environment:" - ] - }, - { - "cell_type": "code", - "execution_count": 99, - "metadata": {}, - "outputs": [], - "source": [ - "from tf_agents.environments.tf_py_environment import TFPyEnvironment\n", - "\n", - "tf_env = TFPyEnvironment(env)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Creating the DQN" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Create a small class to normalize the observations. Images are stored using bytes from 0 to 255 to use less RAM, but we want to pass floats from 0.0 to 1.0 to the neural network:" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Create the Q-Network:" - ] - }, - { - "cell_type": "code", - "execution_count": 100, - "metadata": {}, - "outputs": [], - "source": [ - "from tf_agents.networks.q_network import QNetwork\n", - "\n", - "preprocessing_layer = tf.keras.layers.Lambda(\n", - " lambda obs: tf.cast(obs, np.float32) / 255.)\n", - "conv_layer_params=[(32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)]\n", - "fc_layer_params=[512]\n", - "\n", - "q_net = QNetwork(\n", - " tf_env.observation_spec(),\n", - " tf_env.action_spec(),\n", - " preprocessing_layers=preprocessing_layer,\n", - " conv_layer_params=conv_layer_params,\n", - " fc_layer_params=fc_layer_params)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Create the DQN Agent:" - ] - }, - { - "cell_type": "code", - "execution_count": 101, - "metadata": {}, - "outputs": [], - "source": [ - "from tf_agents.agents.dqn.dqn_agent import DqnAgent\n", - "\n", - "train_step = tf.Variable(0)\n", - "update_period = 4 # run a training step every 4 collect steps\n", - "optimizer = tf.keras.optimizers.RMSprop(learning_rate=2.5e-4, rho=0.95, momentum=0.0,\n", - " epsilon=0.00001, centered=True)\n", - "epsilon_fn = tf.keras.optimizers.schedules.PolynomialDecay(\n", - " initial_learning_rate=1.0, # initial ε\n", - " decay_steps=250000 // update_period, # <=> 1,000,000 ALE frames\n", - " end_learning_rate=0.01) # final ε\n", - "agent = DqnAgent(tf_env.time_step_spec(),\n", - " tf_env.action_spec(),\n", - " q_network=q_net,\n", - " optimizer=optimizer,\n", - " target_update_period=2000, # <=> 32,000 ALE frames\n", - " td_errors_loss_fn=tf.keras.losses.Huber(reduction=\"none\"),\n", - " gamma=0.99, # discount factor\n", - " train_step_counter=train_step,\n", - " epsilon_greedy=lambda: epsilon_fn(train_step))\n", - "agent.initialize()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Create the replay buffer (this will use a lot of RAM, so please reduce the buffer size if you get an out-of-memory error):" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Warning**: we use a replay buffer of size 100,000 instead of 1,000,000 (as used in the book) since many people were getting OOM (Out-Of-Memory) errors." - ] - }, - { - "cell_type": "code", - "execution_count": 102, - "metadata": {}, - "outputs": [], - "source": [ - "from tf_agents.replay_buffers import tf_uniform_replay_buffer\n", - "\n", - "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n", - " data_spec=agent.collect_data_spec,\n", - " batch_size=tf_env.batch_size,\n", - " max_length=100000) # reduce if OOM error\n", - "\n", - "replay_buffer_observer = replay_buffer.add_batch" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Create a simple custom observer that counts and displays the number of times it is called (except when it is passed a trajectory that represents the boundary between two episodes, as this does not count as a step):" - ] - }, - { - "cell_type": "code", - "execution_count": 103, - "metadata": {}, - "outputs": [], - "source": [ - "class ShowProgress:\n", - " def __init__(self, total):\n", - " self.counter = 0\n", - " self.total = total\n", - " def __call__(self, trajectory):\n", - " if not trajectory.is_boundary():\n", - " self.counter += 1\n", - " if self.counter % 100 == 0:\n", - " print(\"\\r{}/{}\".format(self.counter, self.total), end=\"\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's add some training metrics:" - ] - }, - { - "cell_type": "code", - "execution_count": 104, - "metadata": {}, - "outputs": [], - "source": [ - "from tf_agents.metrics import tf_metrics\n", - "\n", - "train_metrics = [\n", - " tf_metrics.NumberOfEpisodes(),\n", - " tf_metrics.EnvironmentSteps(),\n", - " tf_metrics.AverageReturnMetric(),\n", - " tf_metrics.AverageEpisodeLengthMetric(),\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": 105, - "metadata": {}, - "outputs": [], - "source": [ - "train_metrics[0].result()" - ] - }, - { - "cell_type": "code", - "execution_count": 106, - "metadata": {}, - "outputs": [], - "source": [ - "from tf_agents.eval.metric_utils import log_metrics\n", - "import logging\n", - "logging.getLogger().setLevel(logging.INFO)\n", - "log_metrics(train_metrics)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Create the collect driver:" - ] - }, - { - "cell_type": "code", - "execution_count": 107, - "metadata": {}, - "outputs": [], - "source": [ - "from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver\n", - "\n", - "collect_driver = DynamicStepDriver(\n", - " tf_env,\n", - " agent.collect_policy,\n", - " observers=[replay_buffer_observer] + train_metrics,\n", - " num_steps=update_period) # collect 4 steps for each training iteration" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Collect the initial experiences, before training:" - ] - }, - { - "cell_type": "code", - "execution_count": 108, - "metadata": {}, - "outputs": [], - "source": [ - "from tf_agents.policies.random_tf_policy import RandomTFPolicy\n", - "\n", - "initial_collect_policy = RandomTFPolicy(tf_env.time_step_spec(),\n", - " tf_env.action_spec())\n", - "init_driver = DynamicStepDriver(\n", - " tf_env,\n", - " initial_collect_policy,\n", - " observers=[replay_buffer.add_batch, ShowProgress(20000)],\n", - " num_steps=20000) # <=> 80,000 ALE frames\n", - "final_time_step, final_policy_state = init_driver.run()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Let's sample 2 sub-episodes, with 3 time steps each and display them:" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Note**: `replay_buffer.get_next()` is deprecated. We must use `replay_buffer.as_dataset(..., single_deterministic_pass=False)` instead." - ] - }, - { - "cell_type": "code", - "execution_count": 109, - "metadata": {}, - "outputs": [], - "source": [ - "tf.random.set_seed(9) # chosen to show an example of trajectory at the end of an episode\n", - "\n", - "#trajectories, buffer_info = replay_buffer.get_next( # get_next() is deprecated\n", - "# sample_batch_size=2, num_steps=3)\n", - "\n", - "trajectories, buffer_info = next(iter(replay_buffer.as_dataset(\n", - " sample_batch_size=2,\n", - " num_steps=3,\n", - " single_deterministic_pass=False)))" - ] - }, - { - "cell_type": "code", - "execution_count": 110, - "metadata": {}, - "outputs": [], - "source": [ - "trajectories._fields" - ] - }, - { - "cell_type": "code", - "execution_count": 111, - "metadata": {}, - "outputs": [], - "source": [ - "trajectories.observation.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 112, - "metadata": {}, - "outputs": [], - "source": [ - "from tf_agents.trajectories.trajectory import to_transition\n", - "\n", - "time_steps, action_steps, next_time_steps = to_transition(trajectories)\n", - "time_steps.observation.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 113, - "metadata": {}, - "outputs": [], - "source": [ - "trajectories.step_type.numpy()" - ] - }, - { - "cell_type": "code", - "execution_count": 114, - "metadata": {}, - "outputs": [], - "source": [ - "plt.figure(figsize=(10, 6.8))\n", - "for row in range(2):\n", - " for col in range(3):\n", - " plt.subplot(2, 3, row * 3 + col + 1)\n", - " plot_observation(trajectories.observation[row, col].numpy())\n", - "plt.subplots_adjust(left=0, right=1, bottom=0, top=1, hspace=0, wspace=0.02)\n", - "save_fig(\"sub_episodes_plot\")\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Now let's create the dataset:" - ] - }, - { - "cell_type": "code", - "execution_count": 115, - "metadata": {}, - "outputs": [], - "source": [ - "dataset = replay_buffer.as_dataset(\n", - " sample_batch_size=64,\n", - " num_steps=2,\n", - " num_parallel_calls=3).prefetch(3)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Convert the main functions to TF Functions for better performance:" - ] - }, - { - "cell_type": "code", - "execution_count": 116, - "metadata": {}, - "outputs": [], - "source": [ - "from tf_agents.utils.common import function\n", - "\n", - "collect_driver.run = function(collect_driver.run)\n", - "agent.train = function(agent.train)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "And now we are ready to run the main loop!" - ] - }, - { - "cell_type": "code", - "execution_count": 117, - "metadata": {}, - "outputs": [], - "source": [ - "def train_agent(n_iterations):\n", - " time_step = None\n", - " policy_state = agent.collect_policy.get_initial_state(tf_env.batch_size)\n", - " iterator = iter(dataset)\n", - " for iteration in range(n_iterations):\n", - " time_step, policy_state = collect_driver.run(time_step, policy_state)\n", - " trajectories, buffer_info = next(iterator)\n", - " train_loss = agent.train(trajectories)\n", - " print(\"\\r{} loss:{:.5f}\".format(\n", - " iteration, train_loss.loss.numpy()), end=\"\")\n", - " if iteration % 1000 == 0:\n", - " log_metrics(train_metrics)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Run the next cell to train the agent for 50,000 steps. Then look at its behavior by running the following cell. You can run these two cells as many times as you wish. The agent will keep improving! It will likely take over 200,000 iterations for the agent to become reasonably good." - ] - }, - { - "cell_type": "code", - "execution_count": 118, - "metadata": {}, - "outputs": [], - "source": [ - "train_agent(n_iterations=50000)" - ] - }, - { - "cell_type": "code", - "execution_count": 119, - "metadata": {}, - "outputs": [], - "source": [ - "frames = []\n", - "def save_frames(trajectory):\n", - " global frames\n", - " frames.append(tf_env.pyenv.envs[0].render(mode=\"rgb_array\"))\n", - "\n", - "watch_driver = DynamicStepDriver(\n", - " tf_env,\n", - " agent.policy,\n", - " observers=[save_frames, ShowProgress(1000)],\n", - " num_steps=1000)\n", - "final_time_step, final_policy_state = watch_driver.run()\n", - "\n", - "plot_animation(frames)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If you want to save an animated GIF to show off your agent to your friends, here's one way to do it:" - ] - }, - { - "cell_type": "code", - "execution_count": 120, - "metadata": {}, - "outputs": [], - "source": [ - "import PIL\n", - "\n", - "image_path = Path() / \"images\" / \"rl\" / \"breakout.gif\"\n", - "frame_images = [PIL.Image.fromarray(frame) for frame in frames[:150]]\n", - "frame_images[0].save(image_path, format='GIF',\n", - " append_images=frame_images[1:],\n", - " save_all=True,\n", - " duration=30,\n", - " loop=0)" - ] - }, - { - "cell_type": "code", - "execution_count": 121, - "metadata": {}, - "outputs": [], - "source": [ - "%%html\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Extra material" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Deque vs Rotating List" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The `deque` class offers fast append, but fairly slow random access (for large replay memories):" - ] - }, - { - "cell_type": "code", - "execution_count": 122, - "metadata": {}, - "outputs": [], - "source": [ - "from collections import deque\n", - "np.random.seed(42)\n", - "\n", - "mem = deque(maxlen=1000000)\n", - "for i in range(1000000):\n", - " mem.append(i)\n", - "[mem[i] for i in np.random.randint(1000000, size=5)]" - ] - }, - { - "cell_type": "code", - "execution_count": 123, - "metadata": {}, - "outputs": [], - "source": [ - "%timeit mem.append(1)" - ] - }, - { - "cell_type": "code", - "execution_count": 124, - "metadata": {}, - "outputs": [], - "source": [ - "%timeit [mem[i] for i in np.random.randint(1000000, size=5)]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Alternatively, you could use a rotating list like this `ReplayMemory` class. This would make random access faster for large replay memories:" - ] - }, - { - "cell_type": "code", - "execution_count": 125, - "metadata": {}, - "outputs": [], - "source": [ - "class ReplayMemory:\n", - " def __init__(self, max_size):\n", - " self.buffer = np.empty(max_size, dtype=np.object)\n", - " self.max_size = max_size\n", - " self.index = 0\n", - " self.size = 0\n", - "\n", - " def append(self, obj):\n", - " self.buffer[self.index] = obj\n", - " self.size = min(self.size + 1, self.max_size)\n", - " self.index = (self.index + 1) % self.max_size\n", - "\n", - " def sample(self, batch_size):\n", - " indices = np.random.randint(self.size, size=batch_size)\n", - " return self.buffer[indices]" - ] - }, - { - "cell_type": "code", - "execution_count": 126, - "metadata": {}, - "outputs": [], - "source": [ - "mem = ReplayMemory(max_size=1000000)\n", - "for i in range(1000000):\n", - " mem.append(i)\n", - "mem.sample(5)" - ] - }, - { - "cell_type": "code", - "execution_count": 127, - "metadata": {}, - "outputs": [], - "source": [ - "%timeit mem.append(1)" - ] - }, - { - "cell_type": "code", - "execution_count": 128, - "metadata": {}, - "outputs": [], - "source": [ - "%timeit mem.sample(5)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Creating a Custom TF-Agents Environment" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "To create a custom TF-Agent environment, you just need to write a class that inherits from the `PyEnvironment` class and implements a few methods. For example, the following minimal environment represents a simple 4x4 grid. The agent starts in one corner (0,0) and must move to the opposite corner (3,3). The episode is done if the agent reaches the goal (it gets a +10 reward) or if the agent goes out of bounds (-1 reward). The actions are up (0), down (1), left (2) and right (3)." - ] - }, - { - "cell_type": "code", - "execution_count": 129, - "metadata": {}, - "outputs": [], - "source": [ - "class MyEnvironment(tf_agents.environments.py_environment.PyEnvironment):\n", - " def __init__(self, discount=1.0):\n", - " super().__init__()\n", - " self._action_spec = tf_agents.specs.BoundedArraySpec(\n", - " shape=(), dtype=np.int32, name=\"action\", minimum=0, maximum=3)\n", - " self._observation_spec = tf_agents.specs.BoundedArraySpec(\n", - " shape=(4, 4), dtype=np.int32, name=\"observation\", minimum=0, maximum=1)\n", - " self.discount = discount\n", - "\n", - " def action_spec(self):\n", - " return self._action_spec\n", - "\n", - " def observation_spec(self):\n", - " return self._observation_spec\n", - "\n", - " def _reset(self):\n", - " self._state = np.zeros(2, dtype=np.int32)\n", - " obs = np.zeros((4, 4), dtype=np.int32)\n", - " obs[self._state[0], self._state[1]] = 1\n", - " return tf_agents.trajectories.time_step.restart(obs)\n", - "\n", - " def _step(self, action):\n", - " self._state += [(-1, 0), (+1, 0), (0, -1), (0, +1)][action]\n", - " reward = 0\n", - " obs = np.zeros((4, 4), dtype=np.int32)\n", - " done = (self._state.min() < 0 or self._state.max() > 3)\n", - " if not done:\n", - " obs[self._state[0], self._state[1]] = 1\n", - " if done or np.all(self._state == np.array([3, 3])):\n", - " reward = -1 if done else +10\n", - " return tf_agents.trajectories.time_step.termination(obs, reward)\n", - " else:\n", - " return tf_agents.trajectories.time_step.transition(obs, reward,\n", - " self.discount)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The action and observation specs will generally be instances of the `ArraySpec` or `BoundedArraySpec` classes from the `tf_agents.specs` package (check out the other specs in this package as well). Optionally, you can also define a `render()` method, a `close()` method to free resources, as well as a `time_step_spec()` method if you don't want the `reward` and `discount` to be 32-bit float scalars. Note that the base class takes care of keeping track of the current time step, which is why we must implement `_reset()` and `_step()` rather than `reset()` and `step()`.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 130, - "metadata": {}, - "outputs": [], - "source": [ - "my_env = MyEnvironment()\n", - "time_step = my_env.reset()\n", - "time_step" - ] - }, - { - "cell_type": "code", - "execution_count": 131, - "metadata": {}, - "outputs": [], - "source": [ - "time_step = my_env.step(1)\n", - "time_step" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -2821,9 +1921,28 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 1. to 7.\n", - "\n", - "See Appendix A." + "## 1. to 7." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. Reinforcement Learning is an area of Machine Learning aimed at creating agents capable of taking actions in an environment in a way that maximizes rewards over time. There are many differences between RL and regular supervised and unsupervised learning. Here are a few:\n", + " * In supervised and unsupervised learning, the goal is generally to find patterns in the data and use them to make predictions. In Reinforcement Learning, the goal is to find a good policy.\n", + " * Unlike in supervised learning, the agent is not explicitly given the \"right\" answer. It must learn by trial and error.\n", + " * Unlike in unsupervised learning, there is a form of supervision, through rewards. We do not tell the agent how to perform the task, but we do tell it when it is making progress or when it is failing.\n", + " * A Reinforcement Learning agent needs to find the right balance between exploring the environment, looking for new ways of getting rewards, and exploiting sources of rewards that it already knows. In contrast, supervised and unsupervised learning systems generally don't need to worry about exploration; they just feed on the training data they are given.\n", + " * In supervised and unsupervised learning, training instances are typically independent (in fact, they are generally shuffled). In Reinforcement Learning, consecutive observations are generally _not_ independent. An agent may remain in the same region of the environment for a while before it moves on, so consecutive observations will be very correlated. In some cases a replay memory (buffer) is used to ensure that the training algorithm gets fairly independent observations.\n", + "2. Here are a few possible applications of Reinforcement Learning, other than those mentioned in Chapter 18:\n", + " * Music personalization: The environment is a user's personalized web radio. The agent is the software deciding what song to play next for that user. Its possible actions are to play any song in the catalog (it must try to choose a song the user will enjoy) or to play an advertisement (it must try to choose an ad that the user will be interested in). It gets a small reward every time the user listens to a song, a larger reward every time the user listens to an ad, a negative reward when the user skips a song or an ad, and a very negative reward if the user leaves.\n", + " * Marketing: The environment is your company's marketing department. The agent is the software that defines which customers a mailing campaign should be sent to, given their profile and purchase history (for each customer it has two possible actions: send or don't send). It gets a negative reward for the cost of the mailing campaign, and a positive reward for estimated revenue generated from this campaign.\n", + " * Product delivery: Let the agent control a fleet of delivery trucks, deciding what they should pick up at the depots, where they should go, what they should drop off, and so on. It will get positive rewards for each product delivered on time, and negative rewards for late deliveries.\n", + "3. When estimating the value of an action, Reinforcement Learning algorithms typically sum all the rewards that this action led to, giving more weight to immediate rewards and less weight to later rewards (considering that an action has more influence on the near future than on the distant future). To model this, a discount factor is typically applied at each time step. For example, with a discount factor of 0.9, a reward of 100 that is received two time steps later is counted as only 0.92 × 100 = 81 when you are estimating the value of the action. You can think of the discount factor as a measure of how much the future is valued relative to the present: if it is very close to 1, then the future is valued almost as much as the present; if it is close to 0, then only immediate rewards matter. Of course, this impacts the optimal policy tremendously: if you value the future, you may be willing to put up with a lot of immediate pain for the prospect of eventual rewards, while if you don't value the future, you will just grab any immediate reward you can find, never investing in the future.\n", + "4. To measure the performance of a Reinforcement Learning agent, you can simply sum up the rewards it gets. In a simulated environment, you can run many episodes and look at the total rewards it gets on average (and possibly look at the min, max, standard deviation, and so on).\n", + "5. The credit assignment problem is the fact that when a Reinforcement Learning agent receives a reward, it has no direct way of knowing which of its previous actions contributed to this reward. It typically occurs when there is a large delay between an action and the resulting reward (e.g., during a game of Atari's _Pong_, there may be a few dozen time steps between the moment the agent hits the ball and the moment it wins the point). One way to alleviate it is to provide the agent with shorter-term rewards, when possible. This usually requires prior knowledge about the task. For example, if we want to build an agent that will learn to play chess, instead of giving it a reward only when it wins the game, we could give it a reward every time it captures one of the opponent's pieces.\n", + "6. An agent can often remain in the same region of its environment for a while, so all of its experiences will be very similar for that period of time. This can introduce some bias in the learning algorithm. It may tune its policy for this region of the environment, but it will not perform well as soon as it moves out of this region. To solve this problem, you can use a replay memory; instead of using only the most immediate experiences for learning, the agent will learn based on a buffer of its past experiences, recent and not so recent (perhaps this is why we dream at night: to replay our experiences of the day and better learn from them?).\n", + "7. An off-policy RL algorithm learns the value of the optimal policy (i.e., the sum of discounted rewards that can be expected for each state if the agent acts optimally) while the agent follows a different policy. Q-Learning is a good example of such an algorithm. In contrast, an on-policy algorithm learns the value of the policy that the agent actually executes, including both exploration and exploitation." ] }, { @@ -2843,7 +1962,7 @@ }, { "cell_type": "code", - "execution_count": 240, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2859,7 +1978,7 @@ }, { "cell_type": "code", - "execution_count": 241, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2868,7 +1987,7 @@ }, { "cell_type": "code", - "execution_count": 242, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2897,7 +2016,7 @@ }, { "cell_type": "code", - "execution_count": 243, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2924,7 +2043,7 @@ }, { "cell_type": "code", - "execution_count": 244, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -2959,7 +2078,7 @@ }, { "cell_type": "code", - "execution_count": 245, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -3000,7 +2119,7 @@ }, { "cell_type": "code", - "execution_count": 246, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -3029,7 +2148,7 @@ }, { "cell_type": "code", - "execution_count": 247, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -3048,7 +2167,7 @@ }, { "cell_type": "code", - "execution_count": 248, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -3065,7 +2184,7 @@ }, { "cell_type": "code", - "execution_count": 249, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -3101,7 +2220,7 @@ }, { "cell_type": "code", - "execution_count": 250, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -3123,7 +2242,7 @@ }, { "cell_type": "code", - "execution_count": 257, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -3148,7 +2267,7 @@ }, { "cell_type": "code", - "execution_count": 264, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -3163,43 +2282,6 @@ "That's pretty good. You can try training it for longer and/or tweaking the hyperparameters to see if you can get it to go over 200." ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 9.\n", - "_Exercise: Use TF-Agents to train an agent that can achieve a superhuman level at SpaceInvaders-v4 using any of the available algorithms._" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Please follow the steps in the [Using TF-Agents to Beat Breakout](http://localhost:8888/notebooks/18_reinforcement_learning.ipynb#Using-TF-Agents-to-Beat-Breakout) section above, replacing `\"Breakout-v4\"` with `\"SpaceInvaders-v4\"`. There will be a few things to tweak, however. For example, the Space Invaders game does not require the user to press FIRE to begin the game. Instead, the player's laser cannon blinks for a few seconds then the game starts automatically. For better performance, you may want to skip this blinking phase (which lasts about 40 steps) at the beginning of each episode and after each life lost. Indeed, it's impossible to do anything at all during this phase, and nothing moves. One way to do this is to use the following custom environment wrapper, instead of the `AtariPreprocessingWithAutoFire` wrapper:" - ] - }, - { - "cell_type": "code", - "execution_count": 132, - "metadata": {}, - "outputs": [], - "source": [ - "class AtariPreprocessingWithSkipStart(AtariPreprocessing):\n", - " def skip_frames(self, num_skip):\n", - " for _ in range(num_skip):\n", - " super().step(0) # NOOP for num_skip steps\n", - " def reset(self, **kwargs):\n", - " obs = super().reset(**kwargs)\n", - " self.skip_frames(40)\n", - " return obs\n", - " def step(self, action):\n", - " lives_before_action = self.ale.lives()\n", - " obs, rewards, done, info = super().step(action)\n", - " if self.ale.lives() < lives_before_action and not done:\n", - " self.skip_frames(40)\n", - " return obs, rewards, done, info" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -3213,7 +2295,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 10.\n", + "## 9.\n", "_Exercise: If you have about $100 to spare, you can purchase a Raspberry Pi 3 plus some cheap robotics components, install TensorFlow on the Pi, and go wild! For an example, check out this [fun post](https://homl.info/2) by Lukas Biewald, or take a look at GoPiGo or BrickPi. Start with simple goals, like making the robot turn around to find the brightest angle (if it has a light sensor) or the closest object (if it has a sonar sensor), and move in that direction. Then you can start using Deep Learning: for example, if the robot has a camera, you can try to implement an object detection algorithm so it detects people and moves toward them. You can also try to use RL to make the agent learn on its own how to use the motors to achieve that goal. Have fun!_" ] }, @@ -3226,10 +2308,13 @@ } ], "metadata": { + "interpreter": { + "hash": "95c485e91159f3a8b550e08492cb4ed2557284663e79130c96242e7ff9e65ae1" + }, "kernelspec": { - "display_name": "homl3", + "display_name": "Python 3", "language": "python", - "name": "homl3" + "name": "python3" }, "language_info": { "codemirror_mode": { diff --git a/19_training_and_deploying_at_scale.ipynb b/19_training_and_deploying_at_scale.ipynb index 3a69957..1866330 100644 --- a/19_training_and_deploying_at_scale.ipynb +++ b/19_training_and_deploying_at_scale.ipynb @@ -4,14 +4,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "**Chapter 18 – Training and Deploying TensorFlow Models at Scale**" + "**Chapter 19 – Training and Deploying TensorFlow Models at Scale**" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "_This notebook contains all the sample code and solutions to the exercises in chapter 18._" + "_This notebook contains all the sample code and solutions to the exercises in chapter 19._" ] }, { @@ -20,81 +20,147 @@ "source": [ "\n", " \n", " \n", "
\n", - " \"Open\n", + " \"Open\n", " \n", - " \n", + " \n", "
" ] }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "tags": [] + }, + "source": [ + "# WORK IN PROGRESS\n", + "\n", + "\n", + "**I'm still working on updating this chapter to the 3rd edition. Please come back in a few weeks.**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dFXIv9qNpKzt", + "tags": [] + }, "source": [ "# Setup" ] }, { "cell_type": "markdown", - "metadata": {}, + "metadata": { + "id": "8IPbJEmZpKzu" + }, "source": [ - "First, let's import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures." + "This project requires Python 3.8 or above:" ] }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, + "execution_count": null, + "metadata": { + "id": "TFSU3FCOpKzu" + }, "outputs": [], "source": [ - "# Python ≥3.8 is required\n", "import sys\n", - "assert sys.version_info >= (3, 8)\n", "\n", - "# Is this notebook running on Colab or Kaggle?\n", - "IS_COLAB = \"google.colab\" in sys.modules\n", - "IS_KAGGLE = \"kaggle_secrets\" in sys.modules\n", + "assert sys.version_info >= (3, 8)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TAlKky09pKzv" + }, + "source": [ + "It also requires Scikit-Learn ≥ 1.0.1:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YqCwW7cMpKzw" + }, + "outputs": [], + "source": [ + "import sklearn\n", "\n", - "if IS_COLAB or IS_KAGGLE:\n", - " !echo \"deb http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal\" > /etc/apt/sources.list.d/tensorflow-serving.list\n", - " !curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | apt-key add -\n", - " !apt update && apt-get install -y tensorflow-model-server\n", - " %pip install -q -U tensorflow-serving-api\n", + "assert sklearn.__version__ >= \"1.0.1\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GJtVEqxfpKzw" + }, + "source": [ + "And TensorFlow ≥ 2.6:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0Piq5se2pKzx" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", "\n", - "# Common imports\n", - "import os\n", - "import numpy as np\n", + "assert tf.__version__ >= \"2.6.0\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DDaDoLQTpKzx" + }, + "source": [ + "As we did in earlier chapters, let's define the default font sizes to make the figures prettier:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8d4TH3NbpKzx" + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "plt.rc('font', size=14)\n", + "plt.rc('axes', labelsize=14, titlesize=14)\n", + "plt.rc('legend', fontsize=14)\n", + "plt.rc('xtick', labelsize=10)\n", + "plt.rc('ytick', labelsize=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RcoUIRsvpKzy" + }, + "source": [ + "And let's create the `images/deploy` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PQFH5Y9PpKzy" + }, + "outputs": [], + "source": [ "from pathlib import Path\n", "\n", - "# Scikit-Learn ≥1.0 is required\n", - "import sklearn\n", - "assert sklearn.__version__ >= \"1.0\"\n", - "\n", - "# TensorFlow ≥2.6 is required\n", - "import tensorflow as tf\n", - "assert tf.__version__ >= \"2.6\"\n", - "\n", - "# to make this notebook's output stable across runs\n", - "np.random.seed(42)\n", - "tf.random.set_seed(42)\n", - "\n", - "if not tf.config.list_physical_devices('GPU'):\n", - " print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n", - " if IS_COLAB:\n", - " print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n", - " if IS_KAGGLE:\n", - " print(\"Go to Settings > Accelerator and select GPU.\")\n", - "\n", - "# To plot pretty figures\n", - "%matplotlib inline\n", - "import matplotlib as mpl\n", - "import matplotlib.pyplot as plt\n", - "mpl.rc('axes', labelsize=14)\n", - "mpl.rc('xtick', labelsize=12)\n", - "mpl.rc('ytick', labelsize=12)\n", - "\n", - "# Where to save the figures\n", "IMAGES_PATH = Path() / \"images\" / \"deploy\"\n", "IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n", "\n", @@ -105,6 +171,52 @@ " plt.savefig(path, format=fig_extension, dpi=resolution)" ] }, + { + "cell_type": "markdown", + "metadata": { + "id": "YTsawKlapKzy" + }, + "source": [ + "This chapter can be very slow without a GPU, so let's make sure there's one, or else issue a warning:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ekxzo6pOpKzy" + }, + "outputs": [], + "source": [ + "if not tf.config.list_physical_devices('GPU'):\n", + " print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n", + " if \"google.colab\" in sys.modules:\n", + " print(\"Go to Runtime > Change runtime and select a GPU hardware \"\n", + " \"accelerator.\")\n", + " if \"kaggle_secrets\" in sys.modules:\n", + " print(\"Go to Settings > Accelerator and select GPU.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you are running this notebook in Colab or Kaggle, let's install TensorFlow Server:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if \"google.colab\" in sys.modules or \"kaggle_secrets\" in sys.modules:\n", + " !echo \"deb http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal\" > /etc/apt/sources.list.d/tensorflow-serving.list\n", + " !curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | apt-key add -\n", + " !apt update && apt-get install -y tensorflow-model-server\n", + " %pip install -q -U tensorflow-serving-api" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -122,7 +234,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -136,7 +248,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -156,16 +268,16 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "np.round(model.predict(X_new), 2)" + "model.predict(X_new).round(2)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -177,7 +289,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -186,7 +298,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -202,7 +314,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -223,7 +335,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -232,7 +344,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -241,7 +353,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -250,7 +362,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -260,7 +372,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -276,7 +388,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -285,7 +397,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -302,7 +414,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -313,7 +425,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -363,7 +475,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -372,7 +484,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -385,7 +497,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -394,7 +506,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -408,7 +520,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -424,7 +536,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -438,7 +550,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -447,7 +559,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -464,7 +576,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -479,7 +591,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -493,7 +605,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -509,7 +621,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": { "scrolled": true }, @@ -530,7 +642,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -550,7 +662,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "metadata": { "scrolled": true }, @@ -573,7 +685,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -585,7 +697,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -594,7 +706,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -610,7 +722,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -625,7 +737,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -634,7 +746,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -658,7 +770,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -667,7 +779,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -682,7 +794,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -698,12 +810,12 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "Y_probas = predict(X_new)\n", - "np.round(Y_probas, 2)" + "Y_probas.round(2)" ] }, { @@ -722,7 +834,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -732,7 +844,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -741,7 +853,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -750,7 +862,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -769,7 +881,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -780,7 +892,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -803,7 +915,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -818,7 +930,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -856,7 +968,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -867,7 +979,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -883,7 +995,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -954,7 +1066,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -978,7 +1090,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1007,7 +1119,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1019,7 +1131,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1028,7 +1140,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1048,7 +1160,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1115,7 +1227,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1131,7 +1243,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1158,7 +1270,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -1175,13 +1287,13 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = tf.keras.models.load_model(\"my_mnist_multiworker_model.h5\")\n", "Y_pred = model.predict(X_new)\n", - "np.argmax(Y_pred, axis=-1)" + "Y_pred.argmax(axis=-1)" ] }, { @@ -1202,9 +1314,26 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 1. to 8.\n", - "\n", - "See Appendix A." + "## 1. to 8." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "1. A SavedModel contains a TensorFlow model, including its architecture (a computation graph) and its weights. It is stored as a directory containing a _saved_model.pb_ file, which defines the computation graph (represented as a serialized protocol buffer), and a _variables_ subdirectory containing the variable values. For models containing a large number of weights, these variable values may be split across multiple files. A SavedModel also includes an _assets_ subdirectory that may contain additional data, such as vocabulary files, class names, or some example instances for this model. To be more accurate, a SavedModel can contain one or more _metagraphs_. A metagraph is a computation graph plus some function signature definitions (including their input and output names, types, and shapes). Each metagraph is identified by a set of tags. To inspect a SavedModel, you can use the command-line tool `saved_model_cli` or just load it using `tf.saved_model.load()` and inspect it in Python.\n", + "2. TF Serving allows you to deploy multiple TensorFlow models (or multiple versions of the same model) and make them accessible to all your applications easily via a REST API or a gRPC API. Using your models directly in your applications would make it harder to deploy a new version of a model across all applications. Implementing your own microservice to wrap a TF model would require extra work, and it would be hard to match TF Serving's features. TF Serving has many features: it can monitor a directory and autodeploy the models that are placed there, and you won't have to change or even restart any of your applications to benefit from the new model versions; it's fast, well tested, and scales very well; and it supports A/B testing of experimental models and deploying a new model version to just a subset of your users (in this case the model is called a _canary_). TF Serving is also capable of grouping individual requests into batches to run them jointly on the GPU. To deploy TF Serving, you can install it from source, but it is much simpler to install it using a Docker image. To deploy a cluster of TF Serving Docker images, you can use an orchestration tool such as Kubernetes, or use a fully hosted solution such as Google Cloud AI Platform.\n", + "3. To deploy a model across multiple TF Serving instances, all you need to do is configure these TF Serving instances to monitor the same _models_ directory, and then export your new model as a SavedModel into a subdirectory.\n", + "4. The gRPC API is more efficient than the REST API. However, its client libraries are not as widely available, and if you activate compression when using the REST API, you can get almost the same performance. So, the gRPC API is most useful when you need the highest possible performance and the clients are not limited to the REST API.\n", + "5. To reduce a model's size so it can run on a mobile or embedded device, TFLite uses several techniques:\n", + " * It provides a converter which can optimize a SavedModel: it shrinks the model and reduces its latency. To do this, it prunes all the operations that are not needed to make predictions (such as training operations), and it optimizes and fuses operations whenever possible.\n", + " * The converter can also perform post-training quantization: this technique dramatically reduces the model’s size, so it’s much faster to download and store.\n", + " * It saves the optimized model using the FlatBuffer format, which can be loaded to RAM directly, without parsing. This reduces the loading time and memory footprint.\n", + "6. Quantization-aware training consists in adding fake quantization operations to the model during training. This allows the model to learn to ignore the quantization noise; the final weights will be more robust to quantization.\n", + "7. Model parallelism means chopping your model into multiple parts and running them in parallel across multiple devices, hopefully speeding up the model during training or inference. Data parallelism means creating multiple exact replicas of your model and deploying them across multiple devices. At each iteration during training, each replica is given a different batch of data, and it computes the gradients of the loss with regard to the model parameters. In synchronous data parallelism, the gradients from all replicas are then aggregated and the optimizer performs a Gradient Descent step. The parameters may be centralized (e.g., on parameter servers) or replicated across all replicas and kept in sync using AllReduce. In asynchronous data parallelism, the parameters are centralized and the replicas run independently from each other, each updating the central parameters directly at the end of each training iteration, without having to wait for the other replicas. To speed up training, data parallelism turns out to work better than model parallelism, in general. This is mostly because it requires less communication across devices. Moreover, it is much easier to implement, and it works the same way for any model, whereas model parallelism requires analyzing the model to determine the best way to chop it into pieces.\n", + "8. When training a model across multiple servers, you can use the following distribution strategies:\n", + " * The `MultiWorkerMirroredStrategy` performs mirrored data parallelism. The model is replicated across all available servers and devices, and each replica gets a different batch of data at each training iteration and computes its own gradients. The mean of the gradients is computed and shared across all replicas using a distributed AllReduce implementation (NCCL by default), and all replicas perform the same Gradient Descent step. This strategy is the simplest to use since all servers and devices are treated in exactly the same way, and it performs fairly well. In general, you should use this strategy. Its main limitation is that it requires the model to fit in RAM on every replica.\n", + " * The `ParameterServerStrategy` performs asynchronous data parallelism. The model is replicated across all devices on all workers, and the parameters are sharded across all parameter servers. Each worker has its own training loop, running asynchronously with the other workers; at each training iteration, each worker gets its own batch of data and fetches the latest version of the model parameters from the parameter servers, then it computes the gradients of the loss with regard to these parameters, and it sends them to the parameter servers. Lastly, the parameter servers perform a Gradient Descent step using these gradients. This strategy is generally slower than the previous strategy, and a bit harder to deploy, since it requires managing parameter servers. However, it can be useful in some situations, especially when you can take advantage of the asynchronous updates, for example to reduce I/O bottlenecks. This depends on many factors, including hardware, network topology, number of servers, model size, and more, so your mileage may vary." ] }, { @@ -1262,7 +1391,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" },