2019-03-14 02:15:09 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"**Chapter 13 – Loading and Preprocessing Data with TensorFlow**"
2021-10-15 10:46:27 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"_This notebook contains all the sample code and solutions to the exercises in chapter 13._"
2019-03-14 02:15:09 +01:00
]
},
2019-11-06 05:16:20 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
2022-02-19 06:19:26 +01:00
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/13_loading_and_preprocessing_data.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
2019-11-06 05:16:20 +01:00
" </td>\n",
2021-05-25 05:14:52 +02:00
" <td>\n",
2022-02-19 06:19:26 +01:00
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/13_loading_and_preprocessing_data.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
2021-05-25 05:14:52 +02:00
" </td>\n",
2019-11-06 05:16:20 +01:00
"</table>"
]
},
2019-03-14 02:15:09 +01:00
{
"cell_type": "markdown",
2022-02-19 06:19:26 +01:00
"metadata": {
"tags": []
},
2019-03-14 02:15:09 +01:00
"source": [
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 11:03:20 +01:00
"This project requires Python 3.7 or above:"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"\n",
2022-02-19 11:03:20 +01:00
"assert sys.version_info >= (3, 7)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"It also requires Scikit-Learn ≥ 1.0.1:"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"import sklearn\n",
"\n",
"assert sklearn.__version__ >= \"1.0.1\""
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"And TensorFlow ≥ 2.6:"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"import tensorflow as tf\n",
"\n",
"assert tf.__version__ >= \"2.7.0\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# The tf.data API"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<TensorSliceDataset shapes: (), types: tf.int32>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"import tensorflow as tf\n",
"\n",
"X = tf.range(10) # any data tensor\n",
"dataset = tf.data.Dataset.from_tensor_slices(X)\n",
"dataset"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
"execution_count": 5,
2022-02-19 06:19:26 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(0, shape=(), dtype=int32)\n",
"tf.Tensor(1, shape=(), dtype=int32)\n",
"tf.Tensor(2, shape=(), dtype=int32)\n",
"tf.Tensor(3, shape=(), dtype=int32)\n",
"tf.Tensor(4, shape=(), dtype=int32)\n",
"tf.Tensor(5, shape=(), dtype=int32)\n",
"tf.Tensor(6, shape=(), dtype=int32)\n",
"tf.Tensor(7, shape=(), dtype=int32)\n",
"tf.Tensor(8, shape=(), dtype=int32)\n",
"tf.Tensor(9, shape=(), dtype=int32)\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
"for item in dataset:\n",
" print(item)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'a': (<tf.Tensor: shape=(), dtype=int32, numpy=1>, <tf.Tensor: shape=(), dtype=int32, numpy=4>), 'b': <tf.Tensor: shape=(), dtype=int32, numpy=7>}\n",
"{'a': (<tf.Tensor: shape=(), dtype=int32, numpy=2>, <tf.Tensor: shape=(), dtype=int32, numpy=5>), 'b': <tf.Tensor: shape=(), dtype=int32, numpy=8>}\n",
"{'a': (<tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(), dtype=int32, numpy=6>), 'b': <tf.Tensor: shape=(), dtype=int32, numpy=9>}\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"X_nested = {\"a\": ([1, 2, 3], [4, 5, 6]), \"b\": [7, 8, 9]}\n",
"dataset = tf.data.Dataset.from_tensor_slices(X_nested)\n",
"for item in dataset:\n",
" print(item)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chaining Transformations"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
"execution_count": 7,
2022-02-19 06:19:26 +01:00
"metadata": {
"tags": [
"raises-exception"
]
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([0 1 2 3 4 5 6], shape=(7,), dtype=int32)\n",
"tf.Tensor([7 8 9 0 1 2 3], shape=(7,), dtype=int32)\n",
"tf.Tensor([4 5 6 7 8 9 0], shape=(7,), dtype=int32)\n",
"tf.Tensor([1 2 3 4 5 6 7], shape=(7,), dtype=int32)\n",
"tf.Tensor([8 9], shape=(2,), dtype=int32)\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"dataset = tf.data.Dataset.from_tensor_slices(tf.range(10))\n",
"dataset = dataset.repeat(3).batch(7)\n",
2019-03-14 02:15:09 +01:00
"for item in dataset:\n",
" print(item)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([ 0 2 4 6 8 10 12], shape=(7,), dtype=int32)\n",
"tf.Tensor([14 16 18 0 2 4 6], shape=(7,), dtype=int32)\n",
"tf.Tensor([ 8 10 12 14 16 18 0], shape=(7,), dtype=int32)\n",
"tf.Tensor([ 2 4 6 8 10 12 14], shape=(7,), dtype=int32)\n",
"tf.Tensor([16 18], shape=(2,), dtype=int32)\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"dataset = dataset.map(lambda x: x * 2) # x is a batch\n",
"for item in dataset:\n",
" print(item)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([14 16 18 0 2 4 6], shape=(7,), dtype=int32)\n",
"tf.Tensor([ 8 10 12 14 16 18 0], shape=(7,), dtype=int32)\n",
"tf.Tensor([ 2 4 6 8 10 12 14], shape=(7,), dtype=int32)\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"dataset = dataset.filter(lambda x: tf.reduce_sum(x) > 50)\n",
"for item in dataset:\n",
" print(item)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([14 16 18 0 2 4 6], shape=(7,), dtype=int32)\n",
"tf.Tensor([ 8 10 12 14 16 18 0], shape=(7,), dtype=int32)\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"for item in dataset.take(2):\n",
2019-03-14 02:15:09 +01:00
" print(item)"
]
},
2022-02-19 06:19:26 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Shuffling the Data"
]
},
2019-03-14 02:15:09 +01:00
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor([1 4 2 3 5 0 6], shape=(7,), dtype=int64)\n",
"tf.Tensor([9 8 2 0 3 1 4], shape=(7,), dtype=int64)\n",
"tf.Tensor([5 7 9 6 7 8], shape=(6,), dtype=int64)\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"dataset = tf.data.Dataset.range(10).repeat(2)\n",
"dataset = dataset.shuffle(buffer_size=4, seed=42).batch(7)\n",
2019-03-14 02:15:09 +01:00
"for item in dataset:\n",
" print(item)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"### Interleaving lines from multiple files"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"Let's start by loading and preparing the California housing dataset. We first load it, then split it into a training set, a validation set and a test set:"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – fetches, splits and normalizes the California housing dataset\n",
"\n",
2019-03-14 02:15:09 +01:00
"from sklearn.datasets import fetch_california_housing\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"housing = fetch_california_housing()\n",
"X_train_full, X_test, y_train_full, y_test = train_test_split(\n",
" housing.data, housing.target.reshape(-1, 1), random_state=42)\n",
"X_train, X_valid, y_train, y_valid = train_test_split(\n",
2022-02-19 06:19:26 +01:00
" X_train_full, y_train_full, random_state=42)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"For a very large dataset that does not fit in memory, you will typically want to split it into many files first, then have TensorFlow read these files in parallel. To demonstrate this, let's start by splitting the housing dataset and saving it to 20 CSV files:"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – split the dataset into 20 parts and save it to CSV files\n",
"\n",
"import numpy as np\n",
"from pathlib import Path\n",
"\n",
"def save_to_csv_files(data, name_prefix, header=None, n_parts=10):\n",
2021-10-15 10:46:27 +02:00
" housing_dir = Path() / \"datasets\" / \"housing\"\n",
" housing_dir.mkdir(parents=True, exist_ok=True)\n",
2022-02-19 06:19:26 +01:00
" filename_format = \"my_{}_{:02d}.csv\"\n",
2019-03-14 02:15:09 +01:00
"\n",
" filepaths = []\n",
" m = len(data)\n",
2022-02-19 06:19:26 +01:00
" chunks = np.array_split(np.arange(m), n_parts)\n",
" for file_idx, row_indices in enumerate(chunks):\n",
" part_csv = housing_dir / filename_format.format(name_prefix, file_idx)\n",
" filepaths.append(str(part_csv))\n",
" with open(part_csv, \"w\") as f:\n",
2019-03-14 02:15:09 +01:00
" if header is not None:\n",
" f.write(header)\n",
" f.write(\"\\n\")\n",
" for row_idx in row_indices:\n",
" f.write(\",\".join([repr(col) for col in data[row_idx]]))\n",
" f.write(\"\\n\")\n",
2022-02-19 06:19:26 +01:00
" return filepaths\n",
"\n",
2019-03-14 02:15:09 +01:00
"train_data = np.c_[X_train, y_train]\n",
"valid_data = np.c_[X_valid, y_valid]\n",
"test_data = np.c_[X_test, y_test]\n",
"header_cols = housing.feature_names + [\"MedianHouseValue\"]\n",
"header = \",\".join(header_cols)\n",
"\n",
2022-02-19 06:19:26 +01:00
"train_filepaths = save_to_csv_files(train_data, \"train\", header, n_parts=20)\n",
"valid_filepaths = save_to_csv_files(valid_data, \"valid\", header, n_parts=10)\n",
"test_filepaths = save_to_csv_files(test_data, \"test\", header, n_parts=10)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Okay, now let's take a peek at the first few lines of one of these CSV files:"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 14,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MedInc,HouseAge,AveRooms,AveBedrms,Population,AveOccup,Latitude,Longitude,MedianHouseValue\n",
"3.5214,15.0,3.0499445061043287,1.106548279689234,1447.0,1.6059933407325193,37.63,-122.43,1.442\n",
"5.3275,5.0,6.490059642147117,0.9910536779324056,3464.0,3.4433399602385686,33.69,-117.39,1.687\n",
"3.1,29.0,7.5423728813559325,1.5915254237288134,1328.0,2.2508474576271187,38.44,-122.98,1.621\n",
"\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"print(\"\".join(open(train_filepaths[0]).readlines()[:4]))"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 15,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"['datasets/housing/my_train_00.csv',\n",
" 'datasets/housing/my_train_01.csv',\n",
" 'datasets/housing/my_train_02.csv',\n",
" 'datasets/housing/my_train_03.csv',\n",
" 'datasets/housing/my_train_04.csv',\n",
" 'datasets/housing/my_train_05.csv',\n",
" 'datasets/housing/my_train_06.csv',\n",
" 'datasets/housing/my_train_07.csv',\n",
" 'datasets/housing/my_train_08.csv',\n",
" 'datasets/housing/my_train_09.csv',\n",
" 'datasets/housing/my_train_10.csv',\n",
" 'datasets/housing/my_train_11.csv',\n",
" 'datasets/housing/my_train_12.csv',\n",
" 'datasets/housing/my_train_13.csv',\n",
" 'datasets/housing/my_train_14.csv',\n",
" 'datasets/housing/my_train_15.csv',\n",
" 'datasets/housing/my_train_16.csv',\n",
" 'datasets/housing/my_train_17.csv',\n",
" 'datasets/housing/my_train_18.csv',\n",
" 'datasets/housing/my_train_19.csv']"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
"train_filepaths"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"**Building an Input Pipeline**"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 16,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
"filepath_dataset = tf.data.Dataset.list_files(train_filepaths, seed=42)"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 17,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(b'datasets/housing/my_train_05.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_16.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_01.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_17.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_00.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_14.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_10.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_02.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_12.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_19.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_07.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_09.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_13.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_15.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_11.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_18.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_04.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_06.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_03.csv', shape=(), dtype=string)\n",
"tf.Tensor(b'datasets/housing/my_train_08.csv', shape=(), dtype=string)\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – shows that the file paths are shuffled\n",
2019-03-14 02:15:09 +01:00
"for filepath in filepath_dataset:\n",
" print(filepath)"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 18,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
"n_readers = 5\n",
"dataset = filepath_dataset.interleave(\n",
" lambda filepath: tf.data.TextLineDataset(filepath).skip(1),\n",
" cycle_length=n_readers)"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 19,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(b'4.5909,16.0,5.475877192982456,1.0964912280701755,1357.0,2.9758771929824563,33.63,-117.71,2.418', shape=(), dtype=string)\n",
"tf.Tensor(b'2.4792,24.0,3.4547038327526134,1.1341463414634145,2251.0,3.921602787456446,34.18,-118.38,2.0', shape=(), dtype=string)\n",
"tf.Tensor(b'4.2708,45.0,5.121387283236994,0.953757225433526,492.0,2.8439306358381504,37.48,-122.19,2.67', shape=(), dtype=string)\n",
"tf.Tensor(b'2.1856,41.0,3.7189873417721517,1.0658227848101265,803.0,2.0329113924050635,32.76,-117.12,1.205', shape=(), dtype=string)\n",
"tf.Tensor(b'4.1812,52.0,5.701388888888889,0.9965277777777778,692.0,2.4027777777777777,33.73,-118.31,3.215', shape=(), dtype=string)\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
"for line in dataset.take(5):\n",
2022-02-19 06:19:26 +01:00
" print(line)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"## Preprocessing the Data"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 20,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"StandardScaler()"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – compute the mean and standard deviation of each feature\n",
"\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"scaler = StandardScaler()\n",
"scaler.fit(X_train)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 21,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"X_mean, X_std = scaler.mean_, scaler.scale_ # extra code\n",
"n_inputs = 8\n",
"\n",
"def parse_csv_line(line):\n",
" defs = [0.] * n_inputs + [tf.constant([], dtype=tf.float32)]\n",
" fields = tf.io.decode_csv(line, record_defaults=defs)\n",
" return tf.stack(fields[:-1]), tf.stack(fields[-1:])\n",
"\n",
"def preprocess(line):\n",
" x, y = parse_csv_line(line)\n",
" return (x - X_mean) / X_std, y"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 22,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(<tf.Tensor: shape=(8,), dtype=float32, numpy=\n",
" array([ 0.16579159, 1.216324 , -0.05204564, -0.39215982, -0.5277444 ,\n",
" -0.2633488 , 0.8543046 , -1.3072058 ], dtype=float32)>,\n",
" <tf.Tensor: shape=(1,), dtype=float32, numpy=array([2.782], dtype=float32)>)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"preprocess(b'4.2083,44.0,5.3232,0.9171,846.0,2.3370,37.47,-122.2,2.782')"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"## Putting Everything Together + Prefetching"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 23,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"def csv_reader_dataset(filepaths, n_readers=5, n_read_threads=None,\n",
" n_parse_threads=5, shuffle_buffer_size=10_000, seed=42,\n",
" batch_size=32):\n",
" dataset = tf.data.Dataset.list_files(filepaths, seed=seed)\n",
2019-03-14 02:15:09 +01:00
" dataset = dataset.interleave(\n",
" lambda filepath: tf.data.TextLineDataset(filepath).skip(1),\n",
" cycle_length=n_readers, num_parallel_calls=n_read_threads)\n",
" dataset = dataset.map(preprocess, num_parallel_calls=n_parse_threads)\n",
2022-02-19 06:19:26 +01:00
" dataset = dataset.shuffle(shuffle_buffer_size, seed=seed)\n",
" return dataset.batch(batch_size).prefetch(1)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 24,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"X = tf.Tensor(\n",
"[[-1.3957452 -0.04940685 -0.22830808 0.22648273 2.2593622 0.35200632\n",
" 0.9667386 -1.4121602 ]\n",
" [ 2.7112627 -1.0778131 0.69413143 -0.14870553 0.51810503 0.3507294\n",
" -0.82285154 0.80680597]\n",
" [-0.13484643 -1.868895 0.01032507 -0.13787179 -0.12893449 0.03143518\n",
" 0.2687057 0.13212144]], shape=(3, 8), dtype=float32)\n",
"y = tf.Tensor(\n",
"[[1.819]\n",
" [3.674]\n",
" [0.954]], shape=(3, 1), dtype=float32)\n",
"\n",
"X = tf.Tensor(\n",
"[[ 0.09031774 0.9789995 0.1327582 -0.13753782 -0.23388447 0.10211545\n",
" 0.97610843 -1.4121602 ]\n",
" [ 0.05218809 -2.0271113 0.2940109 -0.02403445 0.16218767 -0.02844518\n",
" 1.4117942 -0.93737936]\n",
" [-0.672276 0.02970133 -0.76922584 -0.15086786 0.4962024 -0.02741998\n",
" -0.7853724 0.77182245]], shape=(3, 8), dtype=float32)\n",
"y = tf.Tensor(\n",
"[[2.725]\n",
" [1.205]\n",
" [1.625]], shape=(3, 1), dtype=float32)\n",
"\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – show the first couple of batches produced by the dataset\n",
2020-03-20 11:22:47 +01:00
"\n",
2022-02-19 06:19:26 +01:00
"example_set = csv_reader_dataset(train_filepaths, batch_size=3)\n",
"for X_batch, y_batch in example_set.take(2):\n",
2019-03-14 02:15:09 +01:00
" print(\"X =\", X_batch)\n",
" print(\"y =\", y_batch)\n",
" print()"
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2019-03-14 02:15:09 +01:00
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"Here is a short description of each method in the `Dataset` class:"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 25,
"metadata": {
"scrolled": true
},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"● apply() Applies a transformation function to this dataset.\n",
"● as_numpy_iterator() Returns an iterator which converts all elements of the dataset to numpy.\n",
"● batch() Combines consecutive elements of this dataset into batches.\n",
"● bucket_by_sequence_length()A transformation that buckets elements in a `Dataset` by length.\n",
"● cache() Caches the elements in this dataset.\n",
"● cardinality() Returns the cardinality of the dataset, if known.\n",
"● choose_from_datasets()Creates a dataset that deterministically chooses elements from `datasets`.\n",
"● concatenate() Creates a `Dataset` by concatenating the given dataset with this dataset.\n",
"● element_spec() The type specification of an element of this dataset.\n",
"● enumerate() Enumerates the elements of this dataset.\n",
"● filter() Filters this dataset according to `predicate`.\n",
"● flat_map() Maps `map_func` across this dataset and flattens the result.\n",
"● from_generator() Creates a `Dataset` whose elements are generated by `generator`. (deprecated arguments)\n",
"● from_tensor_slices() Creates a `Dataset` whose elements are slices of the given tensors.\n",
"● from_tensors() Creates a `Dataset` with a single element, comprising the given tensors.\n",
"● get_single_element() Returns the single element of the `dataset`.\n",
"● group_by_window() Groups windows of elements by key and reduces them.\n",
"● interleave() Maps `map_func` across this dataset, and interleaves the results.\n",
"● list_files() A dataset of all files matching one or more glob patterns.\n",
"● map() Maps `map_func` across the elements of this dataset.\n",
"● options() Returns the options for this dataset and its inputs.\n",
"● padded_batch() Combines consecutive elements of this dataset into padded batches.\n",
"● prefetch() Creates a `Dataset` that prefetches elements from this dataset.\n",
"● random() Creates a `Dataset` of pseudorandom values.\n",
"● range() Creates a `Dataset` of a step-separated range of values.\n",
"● reduce() Reduces the input dataset to a single element.\n",
"● rejection_resample() A transformation that resamples a dataset to a target distribution.\n",
"● repeat() Repeats this dataset so each original value is seen `count` times.\n",
"● sample_from_datasets()Samples elements at random from the datasets in `datasets`.\n",
"● scan() A transformation that scans a function across an input dataset.\n",
"● shard() Creates a `Dataset` that includes only 1/`num_shards` of this dataset.\n",
"● shuffle() Randomly shuffles the elements of this dataset.\n",
"● skip() Creates a `Dataset` that skips `count` elements from this dataset.\n",
"● snapshot() API to persist the output of the input dataset.\n",
"● take() Creates a `Dataset` with at most `count` elements from this dataset.\n",
"● take_while() A transformation that stops dataset iteration based on a `predicate`.\n",
"● unbatch() Splits elements of a dataset into multiple elements.\n",
"● unique() A transformation that discards duplicate elements of a `Dataset`.\n",
"● window() Returns a dataset of \"windows\".\n",
"● with_options() Returns a new `tf.data.Dataset` with the given options set.\n",
"● zip() Creates a `Dataset` by zipping together the given datasets.\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – list all methods of the tf.data.Dataset class\n",
"for m in dir(tf.data.Dataset):\n",
" if not (m.startswith(\"_\") or m.endswith(\"_\")):\n",
" func = getattr(tf.data.Dataset, m)\n",
" if hasattr(func, \"__doc__\"):\n",
" print(\"● {:21s}{}\".format(m + \"()\", func.__doc__.split(\"\\n\")[0]))"
2019-03-14 02:15:09 +01:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2019-03-14 02:15:09 +01:00
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"## Using the Dataset with Keras"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 26,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"train_set = csv_reader_dataset(train_filepaths)\n",
"valid_set = csv_reader_dataset(valid_filepaths)\n",
"test_set = csv_reader_dataset(test_filepaths)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 27,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – for reproducibility\n",
"tf.keras.backend.clear_session()\n",
"tf.random.set_seed(42)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 28,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"363/363 [==============================] - 1s 2ms/step - loss: 1.3569 - val_loss: 0.5272\n",
"Epoch 2/5\n",
"363/363 [==============================] - 0s 965us/step - loss: 0.5132 - val_loss: 63.7862\n",
"Epoch 3/5\n",
"363/363 [==============================] - 0s 902us/step - loss: 0.5916 - val_loss: 20.3634\n",
"Epoch 4/5\n",
"363/363 [==============================] - 1s 944us/step - loss: 0.5089 - val_loss: 0.3993\n",
"Epoch 5/5\n",
"363/363 [==============================] - 1s 905us/step - loss: 0.4200 - val_loss: 0.3639\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f82912a32b0>"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(30, activation=\"relu\", kernel_initializer=\"he_normal\",\n",
" input_shape=X_train.shape[1:]),\n",
" tf.keras.layers.Dense(1),\n",
"])\n",
"model.compile(loss=\"mse\", optimizer=\"sgd\")\n",
"model.fit(train_set, validation_data=valid_set, epochs=5)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 29,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"162/162 [==============================] - 0s 594us/step - loss: 0.3868\n"
]
}
],
2020-03-20 11:22:47 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"test_mse = model.evaluate(test_set)\n",
"new_set = test_set.take(3) # pretend we have 3 new samples\n",
"y_pred = model.predict(new_set) # or you could just pass a NumPy array"
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 30,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/5"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – defines the optimizer and loss function for training\n",
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)\n",
2021-10-17 04:04:08 +02:00
"loss_fn = tf.keras.losses.mean_squared_error\n",
2019-03-14 02:15:09 +01:00
"\n",
2022-02-19 06:19:26 +01:00
"n_epochs = 5\n",
"for epoch in range(n_epochs):\n",
2019-03-14 02:15:09 +01:00
" for X_batch, y_batch in train_set:\n",
2022-02-19 06:19:26 +01:00
" # extra code – perform one Gradient Descent step\n",
" # as explained in Chapter 12\n",
" print(\"\\rEpoch {}/{}\".format(epoch + 1, n_epochs), end=\"\")\n",
2019-03-14 02:15:09 +01:00
" with tf.GradientTape() as tape:\n",
" y_pred = model(X_batch)\n",
" main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))\n",
" loss = tf.add_n([main_loss] + model.losses)\n",
" gradients = tape.gradient(loss, model.trainable_variables)\n",
2022-02-19 06:19:26 +01:00
" optimizer.apply_gradients(zip(gradients, model.trainable_variables))"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 31,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/5"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
"@tf.function\n",
2022-02-19 06:19:26 +01:00
"def train_one_epoch(model, optimizer, loss_fn, train_set):\n",
" for X_batch, y_batch in train_set:\n",
2019-03-14 02:15:09 +01:00
" with tf.GradientTape() as tape:\n",
" y_pred = model(X_batch)\n",
" main_loss = tf.reduce_mean(loss_fn(y_batch, y_pred))\n",
" loss = tf.add_n([main_loss] + model.losses)\n",
" gradients = tape.gradient(loss, model.trainable_variables)\n",
" optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
"\n",
2022-02-19 06:19:26 +01:00
"optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)\n",
"loss_fn = tf.keras.losses.mean_squared_error\n",
"for epoch in range(n_epochs):\n",
" print(\"\\rEpoch {}/{}\".format(epoch + 1, n_epochs), end=\"\")\n",
" train_one_epoch(model, optimizer, loss_fn, train_set)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"# The TFRecord Format"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A TFRecord file is just a list of binary records. You can create one using a `tf.io.TFRecordWriter`:"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 32,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
"with tf.io.TFRecordWriter(\"my_data.tfrecord\") as f:\n",
" f.write(b\"This is the first record\")\n",
" f.write(b\"And this is the second record\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And you can read it using a `tf.data.TFRecordDataset`:"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 33,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(b'This is the first record', shape=(), dtype=string)\n",
"tf.Tensor(b'And this is the second record', shape=(), dtype=string)\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
"filepaths = [\"my_data.tfrecord\"]\n",
"dataset = tf.data.TFRecordDataset(filepaths)\n",
"for item in dataset:\n",
" print(item)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can read multiple TFRecord files with just one `TFRecordDataset`. By default it will read them one at a time, but if you set `num_parallel_reads=3`, it will read 3 at a time in parallel and interleave their records:"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 34,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(b'File 0 record 0', shape=(), dtype=string)\n",
"tf.Tensor(b'File 1 record 0', shape=(), dtype=string)\n",
"tf.Tensor(b'File 2 record 0', shape=(), dtype=string)\n",
"tf.Tensor(b'File 0 record 1', shape=(), dtype=string)\n",
"tf.Tensor(b'File 1 record 1', shape=(), dtype=string)\n",
"tf.Tensor(b'File 2 record 1', shape=(), dtype=string)\n",
"tf.Tensor(b'File 0 record 2', shape=(), dtype=string)\n",
"tf.Tensor(b'File 1 record 2', shape=(), dtype=string)\n",
"tf.Tensor(b'File 2 record 2', shape=(), dtype=string)\n",
"tf.Tensor(b'File 3 record 0', shape=(), dtype=string)\n",
"tf.Tensor(b'File 4 record 0', shape=(), dtype=string)\n",
"tf.Tensor(b'File 3 record 1', shape=(), dtype=string)\n",
"tf.Tensor(b'File 4 record 1', shape=(), dtype=string)\n",
"tf.Tensor(b'File 3 record 2', shape=(), dtype=string)\n",
"tf.Tensor(b'File 4 record 2', shape=(), dtype=string)\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – shows how to read multiple files in parallel and interleave them\n",
"\n",
2019-03-14 02:15:09 +01:00
"filepaths = [\"my_test_{}.tfrecord\".format(i) for i in range(5)]\n",
"for i, filepath in enumerate(filepaths):\n",
" with tf.io.TFRecordWriter(filepath) as f:\n",
" for j in range(3):\n",
" f.write(\"File {} record {}\".format(i, j).encode(\"utf-8\"))\n",
"\n",
"dataset = tf.data.TFRecordDataset(filepaths, num_parallel_reads=3)\n",
"for item in dataset:\n",
" print(item)"
]
},
2022-02-19 06:19:26 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Compressed TFRecord Files"
]
},
2019-03-14 02:15:09 +01:00
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 35,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
"options = tf.io.TFRecordOptions(compression_type=\"GZIP\")\n",
"with tf.io.TFRecordWriter(\"my_compressed.tfrecord\", options) as f:\n",
2022-02-19 06:19:26 +01:00
" f.write(b\"Compress, compress, compress!\")"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 36,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
"dataset = tf.data.TFRecordDataset([\"my_compressed.tfrecord\"],\n",
2022-02-19 06:19:26 +01:00
" compression_type=\"GZIP\")"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(b'Compress, compress, compress!', shape=(), dtype=string)\n"
]
}
],
2022-02-19 06:19:26 +01:00
"source": [
"# extra code – shows that the data is decompressed correctly\n",
2019-03-14 02:15:09 +01:00
"for item in dataset:\n",
" print(item)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"## A Brief Introduction to Protocol Buffers"
2019-03-14 02:15:09 +01:00
]
},
2019-10-21 03:01:15 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For this section you need to [install protobuf](https://developers.google.com/protocol-buffers/docs/downloads). In general you will not have to do so when using TensorFlow, as it comes with functions to create and parse protocol buffers of type `tf.train.Example`, which are generally sufficient. However, in this section we will learn about protocol buffers by creating our own simple protobuf definition, so we need the protobuf compiler (`protoc`): we will use it to compile the protobuf definition to a Python module that we can then use in our code."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First let's write a simple protobuf definition:"
]
},
2019-03-14 02:15:09 +01:00
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 38,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Overwriting person.proto\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
2019-10-21 03:01:15 +02:00
"%%writefile person.proto\n",
"syntax = \"proto3\";\n",
"message Person {\n",
2022-02-19 06:19:26 +01:00
" string name = 1;\n",
" int32 id = 2;\n",
" repeated string email = 3;\n",
2019-10-21 03:01:15 +02:00
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And let's compile it (the `--descriptor_set_out` and `--include_imports` options are only required for the `tf.io.decode_proto()` example below):"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 39,
2019-10-21 03:01:15 +02:00
"metadata": {},
"outputs": [],
"source": [
"!protoc person.proto --python_out=. --descriptor_set_out=person.desc --include_imports"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 40,
2019-10-21 03:01:15 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"person.desc person.proto person_pb2.py\n"
]
}
],
2019-10-21 03:01:15 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"%ls person*"
2019-10-21 03:01:15 +02:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 41,
2019-10-21 03:01:15 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"name: \"Al\"\n",
"id: 123\n",
"email: \"a@b.com\"\n",
"\n"
]
}
],
2019-10-21 03:01:15 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"from person_pb2 import Person # import the generated access class\n",
2019-03-14 02:15:09 +01:00
"\n",
"person = Person(name=\"Al\", id=123, email=[\"a@b.com\"]) # create a Person\n",
"print(person) # display the Person"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 42,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"'Al'"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
"person.name # read a field"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 43,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
"person.name = \"Alice\" # modify a field"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 44,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"'a@b.com'"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
"person.email[0] # repeated fields can be accessed like arrays"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 45,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
"person.email.append(\"c@d.com\") # add an email address"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 46,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"b'\\n\\x05Alice\\x10{\\x1a\\x07a@b.com\\x1a\\x07c@d.com'"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"serialized = person.SerializeToString() # serialize person to a byte string\n",
"serialized"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 47,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"27"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
"person2 = Person() # create a new Person\n",
2022-02-19 06:19:26 +01:00
"person2.ParseFromString(serialized) # parse the byte string (27 bytes long)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 48,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
"person == person2 # now they are equal"
]
},
2019-10-21 03:01:15 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"### Custom protobuf"
2019-10-21 03:01:15 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In rare cases, you may want to parse a custom protobuf (like the one we just created) in TensorFlow. For this you can use the `tf.io.decode_proto()` function:"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 49,
2019-10-21 03:01:15 +02:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[<tf.Tensor: shape=(1,), dtype=string, numpy=array([b'Alice'], dtype=object)>,\n",
" <tf.Tensor: shape=(1,), dtype=int32, numpy=array([123], dtype=int32)>,\n",
" <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'a@b.com', b'c@d.com'], dtype=object)>]"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
2019-10-21 03:01:15 +02:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – shows how to use the tf.io.decode_proto() function\n",
"\n",
2019-10-21 03:01:15 +02:00
"person_tf = tf.io.decode_proto(\n",
2022-02-19 06:19:26 +01:00
" bytes=serialized,\n",
2019-10-21 03:01:15 +02:00
" message_type=\"Person\",\n",
" field_names=[\"name\", \"id\", \"email\"],\n",
" output_types=[tf.string, tf.int32, tf.string],\n",
" descriptor_source=\"person.desc\")\n",
"\n",
"person_tf.values"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For more details, see the [`tf.io.decode_proto()`](https://www.tensorflow.org/api_docs/python/tf/io/decode_proto) documentation."
]
},
2019-03-14 02:15:09 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"## TensorFlow Protobufs"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here is the definition of the tf.train.Example protobuf:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```proto\n",
"syntax = \"proto3\";\n",
"\n",
"message BytesList { repeated bytes value = 1; }\n",
"message FloatList { repeated float value = 1 [packed = true]; }\n",
"message Int64List { repeated int64 value = 1 [packed = true]; }\n",
"message Feature {\n",
" oneof kind {\n",
" BytesList bytes_list = 1;\n",
" FloatList float_list = 2;\n",
" Int64List int64_list = 3;\n",
" }\n",
"};\n",
"message Features { map<string, Feature> feature = 1; };\n",
"message Example { Features features = 1; };\n",
"```"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 50,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"from tensorflow.train import BytesList, FloatList, Int64List\n",
"from tensorflow.train import Feature, Features, Example\n",
2019-03-14 02:15:09 +01:00
"\n",
"person_example = Example(\n",
" features=Features(\n",
" feature={\n",
" \"name\": Feature(bytes_list=BytesList(value=[b\"Alice\"])),\n",
" \"id\": Feature(int64_list=Int64List(value=[123])),\n",
2022-02-19 06:19:26 +01:00
" \"emails\": Feature(bytes_list=BytesList(value=[b\"a@b.com\",\n",
" b\"c@d.com\"]))\n",
" }))"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
2019-03-14 02:15:09 +01:00
"with tf.io.TFRecordWriter(\"my_contacts.tfrecord\") as f:\n",
2022-02-19 06:19:26 +01:00
" for _ in range(5):\n",
" f.write(person_example.SerializeToString())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading and Parsing Examples"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 52,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'emails': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f829147c040>, 'id': <tf.Tensor: shape=(), dtype=int64, numpy=123>, 'name': <tf.Tensor: shape=(), dtype=string, numpy=b'Alice'>}\n",
"{'emails': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f82390756a0>, 'id': <tf.Tensor: shape=(), dtype=int64, numpy=123>, 'name': <tf.Tensor: shape=(), dtype=string, numpy=b'Alice'>}\n",
"{'emails': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f8239068a60>, 'id': <tf.Tensor: shape=(), dtype=int64, numpy=123>, 'name': <tf.Tensor: shape=(), dtype=string, numpy=b'Alice'>}\n",
"{'emails': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f829147b310>, 'id': <tf.Tensor: shape=(), dtype=int64, numpy=123>, 'name': <tf.Tensor: shape=(), dtype=string, numpy=b'Alice'>}\n",
"{'emails': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f829155d850>, 'id': <tf.Tensor: shape=(), dtype=int64, numpy=123>, 'name': <tf.Tensor: shape=(), dtype=string, numpy=b'Alice'>}\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
"feature_description = {\n",
" \"name\": tf.io.FixedLenFeature([], tf.string, default_value=\"\"),\n",
" \"id\": tf.io.FixedLenFeature([], tf.int64, default_value=0),\n",
" \"emails\": tf.io.VarLenFeature(tf.string),\n",
"}\n",
2022-02-19 06:19:26 +01:00
"\n",
"def parse(serialized_example):\n",
" return tf.io.parse_single_example(serialized_example, feature_description)\n",
"\n",
"dataset = tf.data.TFRecordDataset([\"my_contacts.tfrecord\"]).map(parse)\n",
"for parsed_example in dataset:\n",
" print(parsed_example)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 53,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'a@b.com', b'c@d.com'], dtype=object)>"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"tf.sparse.to_dense(parsed_example[\"emails\"], default_value=b\"\")"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 54,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(2,), dtype=string, numpy=array([b'a@b.com', b'c@d.com'], dtype=object)>"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"parsed_example[\"emails\"].values"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 55,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'emails': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f8281729dc0>, 'id': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([123, 123])>, 'name': <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'Alice', b'Alice'], dtype=object)>}\n",
"{'emails': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f8239068f40>, 'id': <tf.Tensor: shape=(2,), dtype=int64, numpy=array([123, 123])>, 'name': <tf.Tensor: shape=(2,), dtype=string, numpy=array([b'Alice', b'Alice'], dtype=object)>}\n",
"{'emails': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x7f8239075b50>, 'id': <tf.Tensor: shape=(1,), dtype=int64, numpy=array([123])>, 'name': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'Alice'], dtype=object)>}\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"def parse(serialized_examples):\n",
" return tf.io.parse_example(serialized_examples, feature_description)\n",
"\n",
"dataset = tf.data.TFRecordDataset([\"my_contacts.tfrecord\"]).batch(2).map(parse)\n",
"for parsed_examples in dataset:\n",
" print(parsed_examples) # two examples at a time"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 56,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"{'emails': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f8239075b50>,\n",
" 'id': <tf.Tensor: shape=(1,), dtype=int64, numpy=array([123])>,\n",
" 'name': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'Alice'], dtype=object)>}"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"parsed_examples"
2019-03-14 02:15:09 +01:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2019-03-14 02:15:09 +01:00
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"## Extra Material – Storing Images and Tensors in TFRecords"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"Let's load and display an example image:"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 57,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAAD3CAYAAABCbaxBAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9eZAlSX7fB35+HhHvzPuoyqy7q6uq75meezAYYAYDDAACi4M3eIiiKGmpYymZTLtaI9eWtlozUUat1rSkKIiiKLMVAYokSBA8AALg4BgAA8zV03d3ddd9ZuV9v3xHRLjvHxEe4REvXmZWY5agiPLuV/lehIf7z91//v0d/nMPMcbwJD1JT9KT9CT97pP6vSbgSXqSnqQn6d+U9ARQn6Qn6Ul6kr5N6QmgPklP0pP0JH2b0hNAfZKepCfpSfo2pSeA+iQ9SU/Sk/RtSk8A9Ul6kp6kJ+nblJ4A6pP02ElE/pKI/O1vd95jlGVE5NK3o6wn6Un6/0eSJ3Gov7+TiPxZ4D8HngZ2gZ8D/qIxZvv3kKzKJCIGuGyMuVFx78vATxtjvi3g/SQ9SR8kPdFQfx8nEfnPgb8K/F+ASeDTwHngSyJSG/GM/6+OwifpSfrfV3oCqL9Pk4hMAP8l8BeMMb9kjAmNMXeAP0YCqn86zff/EJF/JCI/LSK7wJ9Nr/20U9afEZG7IrIhIv93EbkjIt/nPP/T6fcLqdn+b4vIPRFZF5H/m1POJ0XkqyKyLSKPRORvjAL2I9r2eRF5ICL/hYispmX9uIj8kIhcE5FNEflLx61XRL5fRN4XkR0R+UkR+Q0R+fec+39ORK6KyJaI/LKInH9cmp+kfzPSE0D9/Zs+AzSAf+xeNMbsA78IfNG5/GPAPwKmgL/r5heR54GfBP4UsEii6Z4+ou7PAs8A3wv8ZRF5Lr0eA/8ZMAd8R3r/P3q8ZmVpgaR9p4G/DPzPJELiY8B3pfVePKpeEZkjaftfBGaB90n6jvT+jwN/CfhDwDzwW8Df+4A0P0n/O09PAPX3b5oD1o0xUcW9R+l9m75qjPknxhhtjOmW8v4R4J8bY75ijBmQgNdRjvn/0hjTNca8AbwBfBjAGPMtY8zXjDFRqi3/T8DnHr9pAITAf2WMCYG/n7bnrxlj9owx7wDvAB86Rr0/BLxjjPnHaV/9dWDZqefPA/+1MeZqev+vAC8/0VJ/f6YngPr7N60DcyN8oovpfZvuH1LOKfe+MeYA2DiibheQDoAxABG5IiI/LyLLqXvhr1AE9sdJG8aYOP1uhcCKc797zHrL7TPAA6ec88BfS90F28AmIBytpT9J/wamJ4D6+zd9FeiTmKpZEpE28AeAX3UuH6ZxPgLOOM83SUzjD5L+R+A9kpX8CRJTWj5gWd+uesvtE/c3Cdj+eWPMlPNpGmN+518B3U/Sv2bpCaD+Pk3GmB2SRan/XkR+UEQCEbkA/EMSDeynjlnUPwJ+REQ+ky7k/Jd8cBAcJwnd2heRZ4H/8AOW8+2s9xeAl9JFLR/4j0n8szb9TeAvisgLACIyKSJ/9F8R3U/Sv2bpCaD+Pk7GmP+GRBv7b0kA5eskGtf3GmP6xyzjHeAvkPgpHwF7wCqJ9vu46f8M/Mm0jP8Z+AcfoIwPkkbWa4xZB/4o8N+QuDKeB14hbZ8x5udIQs/+fuoueJtEw3+Sfh+mJ4H9T9K3NYnIGLBNYj7f/j0m59ueRESRaPB/yhjz67/X9DxJ/3qlJxrqk/S7TiLyIyLSSv2v/y3wFnDn95aqb18SkR8QkSkRqZP7V7/2e0zWk/SvYXoCqE/StyP9GLCUfi4DP2H+zTJ9vgO4SRL58CPAj1eEjz1JT9ITk/9JepKepCfp25WeaKhP0pP0JD1J36Z06EEX//jrjwxAEnqXBCNafVZEEBGMMRkq23zl70oprCZsn3HvVyV7360bEueVsfcYjs+pKr9cVp7MUB7I21WVN7tnQCjWYZKLR7YJQMwwPUbKfWgq23ec8kf/jrFEZv1ihp8RGe4bdwxHtsvpf601Sqnhdpqqfs+TV9VGlfSvAKIkGY0huk3SmPS6V+o9yf7J+9XlDcGgJKFPKSEn03J+wnEiJuszEUmZsjj0bhPc70qS+WJ5V0nSFiO6WB7gaylGAEvCIy5VWZOOkZyiK2kDEJP0cyygTJIhH3dbm3H6RvIyDBgDWoFJ+zLjtbR/huqzj6YDKqU5YMc5mS953iRfMlbGuLSVG+h0ofulwjBPWMggKX8oEby0DKUkqyPwR3f5oYA6NHHsxwXLCuY/bAId5WKoAj9Jei6vM7k4VFZhIEqgelhe9351GRXPPqanpNCelAmHwEQcQMrgeriiUcBWVVfpyeG8DLcvI4ajBd8o2kYJzQ/uYkpBsNT3kuKcmFQgOWKuMHRSbOsQbSYBzWrqSqLbOAVbAtzcUvxbvCaFejSgXIC3NwzEJZROZcGIJIU/VcloMGIKisCw2mCyeizIAQXh4g6hVLXf8mfpThUrlcs9it2KWCKUoKGav1yB53apOGIyBdOETxSeAiX5vZyu4fa66fCj2IypRH7BtjwZnLLWOrJhTr7jXBu677bFDGtvRwHNcesqCwA3+6HadbWABLEsnP9b2QdIOqGsKMaqCccC8FG0Fa+VBAmmxPRS0f7jAXf52cfRbLPfFXwzZDFYRndMJiXGzgf3yZRX04fK5Za+5xNHqsjIkknBV9RoPirOG5NNfG2sEmLSdqWgaiSzCiyoGmdukTXVFHjJZZcjUwEU8q+VipMuyg0YHpphtjD52DgfW+NhgPmYcrsA9lW/iyQVAT4TzGlSDtKKJL/tR1K+smic1OONpOtIQMVR6ctaljEGUZbUx0/HAT6Ti8r0d05XSkgliI+a3EfWSYXUrig3MRGlcB1IzFE3P3mBrlZapT1nWoorOp3nK+k9pJ3Hca1UJWPshC/Sd5x0mFBNzDODUsOue8mZDFftKLslBJOYpULBVZFVJU4/pkI3G1OTWzfumNgcVdpUcaJKds+OkQsSw5ppPogFLYq0j13BJZIgmMknvZaKfnc0uqwayVUbKc3FoRIyIvNOKfCfKpcwnFzBU67LmNw8NsYkAltGlOm2pQJjhrI74J943QxWtzMmf77ooigCux07a70IFjjdsU2AVMTSls/MozSbw01+5RaaFJjJxXTiKIE4q8ORnYdoSlUS0aTXE7Yw6bUKbcX5Ny0sAw7rj4KiVlQAlrIroKrdTtmjVPx8gJwBpMr3ejSY2zIKjHAIV1eCcansahA8GhjzvhIS3ekwBh8B2MYZuZJaMqoPjGNxZGaYfd5ByUqNzCSTIgPZ9G8BIG1GyckxQ4XlJmTZf1qlPedCINVySpM7yWOcupw2GIN2aNNYAMIhsDiBE8BU6fwyTmH5V13QYIt9J8Ygmszay+dcXpCkQy4CWhIBdHx5KqW+rYSBymTzV/OUSVuSYoMpcrIrTF3MMMbkZrtDn5LUzZIgZ0EQppfS+ViA4oQOc3iDDgdUVAqhkvb+sPmkjUbhQ9bYHFySrhgG2DLA2RLtE5kbwRhEVA6uQjprspJAp2Vpky1WlSeciDWlU47XxhWzWR6XnpxkyZqcMUneEidfSrsjKYvScvh7lUk89P24zGzy/ssJMuQFJL14XI01941ZbdIty6m2gnZDrglmbaTImulDthC3YpAEGNxLdoHJCh5xJgKAkZzvxCJrhfZoHKtLCv2U8JjReZmZtm2bbpz84rQ949e8OVVNy8rVRbO9WBaZQMoVo6KGZ+cGjstBnIoybTUjIstlleCMsKQeKwmS79b1VO6fQn8b93pZmJNfd9lvqCNsEcNa5VBWlWukhfKM89dYQLQEmATBUl+rpKuBStzwpmEFy6SdL0YK144CUzjK5C8oywnA5tyVq/UZYeSday+7K31ZkU5vuM+AnZhpX2cdlNPiSqDCBHW0wHK7c750ZHeFQ0dwBL+Qa5xizXt3kpe10eLMOWxxpgpMh79z7JQLjIQMO0rp7KzMX66vmCz8Fcc/v3dMukZ8L1+rgGlcCFGpQHdBSMr5U+skkfnlhROrdRT7fhggKlKqroqyGgq5BeF
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"import matplotlib.pyplot as plt\n",
2019-03-14 02:15:09 +01:00
"from sklearn.datasets import load_sample_images\n",
"\n",
"img = load_sample_images()[\"images\"][0]\n",
"plt.imshow(img)\n",
"plt.axis(\"off\")\n",
"plt.title(\"Original Image\")\n",
"plt.show()"
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2019-03-14 02:15:09 +01:00
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"Now let's create an `Example` protobuf containing the image encoded as JPEG:"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 58,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"data = tf.io.encode_jpeg(img)\n",
"example_with_image = Example(features=Features(feature={\n",
" \"image\": Feature(bytes_list=BytesList(value=[data.numpy()]))}))\n",
"serialized_example = example_with_image.SerializeToString()\n",
"with tf.io.TFRecordWriter(\"my_image.tfrecord\") as f:\n",
" f.write(serialized_example)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"Finally, let's create a tf.data pipeline that will read this TFRecord file, parse each `Example` protobuf (in this case just one), and parse and display the image that the example contains:"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 59,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAADnCAYAAABBu67aAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9+bNtSXbfh30y995nvPN9983zq3pV1dXVE6obDaAbjUkABBEQxMG0KBNy0Aw5wrKpsH/Rf+CwFXY4RJm2whGiKFKkIZE0KYDEDGJiA2h0d3V1dQ2v6tWbxzuPZ9470z/kvM+59zWHCEc4anffeuecvXeOK7/ru1auzBRaaz65Prk+uT65Prn+zS/5/+sCfHJ9cn1yfXL9/8v1CaB+cn1yfXJ9cv1buj4B1E+uT65Prk+uf0vXJ4D6yfXJ9cn1yfVv6foEUD+5Prk+uT65/i1d+Uk3/+GfPtVCCACEEGg02KAAIQRIgdYaSXgGQKARgBYCRPhda20++3eEv+cSds8IIdDapGE+h2gEd/+kSwqBABQagTDfXTrCVwNh89P2tzhdX944XfeiBoF5L75MOtrWR0OtngIQURua5lDTz9l0BWkZptryhLLG7aS1QkpbcPu8lDKqC2gtkFLYZ2W4UbuELWucj47aMS6Tz+OYK+5/oUFrzDuRrCBAS9CoqE5GPgAkwraX9vlrrZFCmP7SUT5CEzdTLLNojZBE9QeXiUAYWbJyJbHt5MqjXbuENg552DRkeF9o87t0+Wv/45SMxFeSstZoIazc2tdr7etKb9oJxKz+FPGzvspJGtoMZeSM+77tgSq+JwVCK/eFrDZehB8j2L437S+kky/b/64PffkCRiA0lap8O0kEaJNyLKahzNr9P7SVb2tAK4SVM2nHp8vLjddmdjz4nAioQtqBZmvjPkt3wxZC1LrRAemx6dafd0IbFd7kH1pkFqgcd7n0dA1w4n8dqBpBSZrXPxfKFaUTPekVTAQK0qoTLeIXbGdoPFD4Ompt368rGDcQp+uqlEpA7Lh28WCKxmPTVHKCaQA4uX1nXdpVQbv8jk9jVruidarQXIFr5TL3pW9LXwUPHaYQBrhsgVw71sAUwkAOA68eRihsH2lfTuHTEjMatF43pwQIYE8EYoki8iU2chnXMVQUZT9qVx6nYISTWRE9HspXaW3qEivD6aKntbf94KFvCkyP72ehg7Kduhdn69p0qkChWPGY87Lmn42JQwBpI4tB2bpLJs+adPz4kAKJREqBtHiU+boeW1V/nQio0mpNxyZN50lXBVvR0MnC5Wo71jFX96+rq/CAK3D6ywiVhWzhOoJIYCL/xKwOSmnHzPr4gRyX140ld09aobadoa0wJXkF7YL74IR6VrkEwo9Hz0BcGyZjchrUZwnsLNaiARU9LyxzcUzNq8NYGxN+n9lO/wq4KqOyCzvafb1FyEHXyugqIJAetLSVChFSs+UWEdLENRcILTyLwubr1Ml0VR0opnUIcmm/R83ilJGvj3aDUEeDOTDooMAiBQEoy2pNMtIWR0V9GgqgpiQvfBcEiyuW0YSs1BAwBpBZ6VpVZds7EJColWcrw+jyGEBQXpkQiYjFNU3UhfAaOSi4+kta+HTjsuhIW1lbOgxwjVeKTiZdX9e6B4RpAykFGQIhNKF1bdk8xE5fJwIqqKDxEZYxRoPNm4gmi0CzXeWcVMUNZu4rrb0GSEeu8Am5gRBDlRACrcIgiVmNb2ApE9FL8rbg58prTEnb4KGo8ZsWXLVhjCKYyQ4cHQOtv+faStvR58G1lscUYB+j1d2VmDzRO8lnoZLf48FgrtmmuBAyGkyzQDVGJhF3F86cnS5wylqPH4zSs2n7mi+lB1Xh3C6ObUXAqYNCFD6FFLhdakJotB2cMgxBD7SJeozr5T/UkNXmM6tuWmtQjqFKjEwp647yo6tWRpNZYiJrbdniDEWrdSJLnpD4ckaVqj0bCJ22dbdtmJCU8GysNGqFAG0YoBbGZRRLWT25aTA9+QogSQ3YRTQmUj0iXNtE9ZIiegCQEv+eFEYKhbVmhAVok8y0i69+nWzy1wTEa2Nji7ji2Xv+pfCuf0+lg0k42q2DDzMaDDGsRcoqMFUHoKSdJKW05oMpkLLvGFPIvFEHPsfiRK2LY5+nRiEw6gWl0MhIoMwT3nSL2F0MfEEzzmjnuP1szSKFmbS980vWATe0EQhUEKgEAKa5dvJNOEGN2sL9E6XnB56IpVYkyc8sn+sbVW8Ik6e2LCLkFdJ0jDMMdD1zQAYgVJZhWt+sJPLpaptOZGpGBRU2v5OvoEzBym7EWj1jQkTWgvKEIPh/seDu2l2ENERoFyEEQkXmfcz4ovHgbtQlyKTp3Beh9Zx8xvmb3+sKV8fciCir5EcPxsex4fidoJONkpitjhNlN9Nq0xaF3FyEFv55N+alHeNCWD9tJC/SWyw6yDRWhiMXgtbGCkT+azLUWEv7wSxSQXb+PN8BzolsK+qYS1JITJ1lAmI6YlGhArMmNRL/pAXi2OdphDHVsgH8Uu3mOkiagiZA4NiKA3BZYyNaBzY17b+UlsUHFlWfWEufN/mZ9pThx9rzMxVcVC6tVNLmcRkhtGcCeJHWCgKbKoq6ETHb7NJTvznFqmqmokjeCR1atzaUZayZyPzvrk5SphJkH7AMpKbQhGEXpn0J1hEOBa04KeXbc8q1UqtDzPAEASRjd4Az+7WXaRV+8/0WA2ScjiDyXaClSOrqTOnE7RC6KHSXb3ONm30NfvXZmsPShlBBhGFyalp2rYmHUx5KOW0hUjtIM91frmTatI2cMXPmyZlO5czJviCa70AjfcmV72NpZVYIgbAzlU5hSxHZQYmWDpPFyvabplb32nWyyS9klIDpfE+npR2oNQ7swKdGLWvpGqAys6SmIm6g1xvMgVrc8Q5ypZM37diqSJ7xAyb61Q9nrx1nN1Awwozbo86GjJ/QuQFChcNACuDiWLkp92z/nREq5aMNnB/6ODN6lrnv6pveC0PMfHZCdrJgzMrnOP/ZcSkZRRUN8qmBWB9aIgFBl3is+NLXPXIZAMTVNx0YjiG64TxLLly/mL6eMeSdUvEDV0SWjSXdBAB15asrXvtU8pyTm9g/7HzfbuJTCvx8g687ad/MdgM58GB2njOe9QpG6GguTvm6Huvbx+O1H5+K4LZJOLDTJTW9ftKlBRh/s30/kh/p+z9MIhoMIICpLZNO2klHStrVz/SD0KZfjZvFKr4Z9Y6vF/hQw0A1vjUVfnaZCo1vMh0ELUnFKkZXYOXAh6gTa/m59KgLTvR0fMcxTNMgCfyZMjrcE07wHaim6bvyuh6v+99M2nH4TgT0OszopqWL0j62QzQI6cs4y2I6yc80ayLJhWO58iRKSYcJxjiTadAy78aDVUS/1zNNlWFIPvGlhYS9Ao7BIoCQTgfeDNeML7MQIRWtZwJMKKr2A11EaYio/DHLjOsdUfOZdU5vieSzAzbzzU7YJY0PYWbN9Zctj7Nmo/IqC7ovYtI41hsRolTJ2YmXmGDURM255epKMB4D2o2rGSQglhfhhbz2jH/J+rfjseUSiQE4kiEffoYO4WnSsE9zywCx0AZ/nEJy7agRbtrBuFq0qLl97AT6C1D/BSa/FQIPY9r/arS+8IWanYDVskJ7YRUiEmTPlKY1rcQ6gH1DRUJt83f+DQcOAjOL6phR0Fg2bbDT+65eIHQqQH4G0BdIepYQT9B588Om6dmhDrGuJn/JlHTVWtl/smXV2k2USCucGi2UZU5OEEMdDBAbd4xAoJWLTTRtpB3giMyY3olyAeLnQ8W9Ao3Zou8PHaRBhOn1Y+VtVgibjH4THrisr9A8aNqkCu0uhOlfHx3iK+MLaBS2pSbOp+5ZjZ10k84vEHFXpY1Sc3KqtJ1gcQzTUZ56z1lWJMTUnRljw8q77Wsh8Yx4tpgEVaS0IhOORxu1oK0sOIYb7oJWFTAdC+wn/qzikTKN2AlgH9XGkiLt3oOEIbrS61gJKVduk0CscIRzHaYiR2Rk48LjwvjD/husPvNscPNkEjJpmL30nY4lesLKlP2u3Ay+kZcQwxz
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"feature_description = { \"image\": tf.io.VarLenFeature(tf.string) }\n",
"\n",
"def parse(serialized_example):\n",
" example_with_image = tf.io.parse_single_example(serialized_example,\n",
" feature_description)\n",
" return tf.io.decode_jpeg(example_with_image[\"image\"].values[0])\n",
" # or you can use tf.io.decode_image() instead\n",
"\n",
"dataset = tf.data.TFRecordDataset(\"my_image.tfrecord\").map(parse)\n",
"for image in dataset:\n",
" plt.imshow(image)\n",
" plt.axis(\"off\")\n",
" plt.show()"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"Or use `decode_image()` which supports BMP, GIF, JPEG and PNG formats:"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Tensors can be serialized and parsed easily using `tf.io.serialize_tensor()` and `tf.io.parse_tensor()`:"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 60,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(), dtype=string, numpy=b'\\x08\\x01\\x12\\x08\\x12\\x02\\x08\\x03\\x12\\x02\\x08\\x02\"\\x18\\x00\\x00\\x00\\x00\\x00\\x00\\x80?\\x00\\x00\\x00@\\x00\\x00@@\\x00\\x00\\x80@\\x00\\x00\\xa0@'>"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"tensor = tf.constant([[0., 1.], [2., 3.], [4., 5.]])\n",
"serialized = tf.io.serialize_tensor(tensor)\n",
"serialized"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 61,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(3, 2), dtype=float32, numpy=\n",
"array([[0., 1.],\n",
" [2., 3.],\n",
" [4., 5.]], dtype=float32)>"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"tf.io.parse_tensor(serialized, out_type=tf.float32)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 62,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(3,), dtype=string, numpy=\n",
"array([b'\\x08\\t\\x12\\x08\\x12\\x02\\x08\\x02\\x12\\x02\\x08\\x01\"\\x10\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x00\\x01\\x00\\x00\\x00\\x00\\x00\\x00\\x00',\n",
" b'\\x08\\x07\\x12\\x04\\x12\\x02\\x08\\x02\"\\x10\\x07\\x07a@b.comc@d.com',\n",
" b'\\x08\\t\\x12\\x04\\x12\\x02\\x08\\x01\"\\x08\\x02\\x00\\x00\\x00\\x00\\x00\\x00\\x00'],\n",
" dtype=object)>"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"sparse_tensor = parsed_example[\"emails\"]\n",
"serialized_sparse = tf.io.serialize_sparse(sparse_tensor)\n",
2019-03-14 02:15:09 +01:00
"serialized_sparse"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 63,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"value: \"\\010\\t\\022\\010\\022\\002\\010\\002\\022\\002\\010\\001\\\"\\020\\000\\000\\000\\000\\000\\000\\000\\000\\001\\000\\000\\000\\000\\000\\000\\000\"\n",
"value: \"\\010\\007\\022\\004\\022\\002\\010\\002\\\"\\020\\007\\007a@b.comc@d.com\"\n",
"value: \"\\010\\t\\022\\004\\022\\002\\010\\001\\\"\\010\\002\\000\\000\\000\\000\\000\\000\\000\""
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
"BytesList(value=serialized_sparse.numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"## Handling Lists of Lists Using the `SequenceExample` Protobuf"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```proto\n",
"syntax = \"proto3\";\n",
"\n",
"message FeatureList { repeated Feature feature = 1; };\n",
"message FeatureLists { map<string, FeatureList> feature_list = 1; };\n",
"message SequenceExample {\n",
2022-02-19 06:19:26 +01:00
" Features context = 1;\n",
" FeatureLists feature_lists = 2;\n",
2019-03-14 02:15:09 +01:00
"};\n",
"```"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 64,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"from tensorflow.train import FeatureList, FeatureLists, SequenceExample\n",
2019-03-14 02:15:09 +01:00
"\n",
"context = Features(feature={\n",
" \"author_id\": Feature(int64_list=Int64List(value=[123])),\n",
" \"title\": Feature(bytes_list=BytesList(value=[b\"A\", b\"desert\", b\"place\", b\".\"])),\n",
" \"pub_date\": Feature(int64_list=Int64List(value=[1623, 12, 25]))\n",
"})\n",
"\n",
"content = [[\"When\", \"shall\", \"we\", \"three\", \"meet\", \"again\", \"?\"],\n",
" [\"In\", \"thunder\", \",\", \"lightning\", \",\", \"or\", \"in\", \"rain\", \"?\"]]\n",
"comments = [[\"When\", \"the\", \"hurlyburly\", \"'s\", \"done\", \".\"],\n",
" [\"When\", \"the\", \"battle\", \"'s\", \"lost\", \"and\", \"won\", \".\"]]\n",
"\n",
"def words_to_feature(words):\n",
" return Feature(bytes_list=BytesList(value=[word.encode(\"utf-8\")\n",
" for word in words]))\n",
"\n",
"content_features = [words_to_feature(sentence) for sentence in content]\n",
"comments_features = [words_to_feature(comment) for comment in comments]\n",
" \n",
"sequence_example = SequenceExample(\n",
" context=context,\n",
" feature_lists=FeatureLists(feature_list={\n",
" \"content\": FeatureList(feature=content_features),\n",
" \"comments\": FeatureList(feature=comments_features)\n",
" }))"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 65,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"context {\n",
" feature {\n",
" key: \"author_id\"\n",
" value {\n",
" int64_list {\n",
" value: 123\n",
" }\n",
" }\n",
" }\n",
" feature {\n",
" key: \"pub_date\"\n",
" value {\n",
" int64_list {\n",
" value: 1623\n",
" value: 12\n",
" value: 25\n",
" }\n",
" }\n",
" }\n",
" feature {\n",
" key: \"title\"\n",
" value {\n",
" bytes_list {\n",
" value: \"A\"\n",
" value: \"desert\"\n",
" value: \"place\"\n",
" value: \".\"\n",
" }\n",
" }\n",
" }\n",
"}\n",
"feature_lists {\n",
" feature_list {\n",
" key: \"comments\"\n",
" value {\n",
" feature {\n",
" bytes_list {\n",
" value: \"When\"\n",
" value: \"the\"\n",
" value: \"hurlyburly\"\n",
" value: \"\\'s\"\n",
" value: \"done\"\n",
" value: \".\"\n",
" }\n",
" }\n",
" feature {\n",
" bytes_list {\n",
" value: \"When\"\n",
" value: \"the\"\n",
" value: \"battle\"\n",
" value: \"\\'s\"\n",
" value: \"lost\"\n",
" value: \"and\"\n",
" value: \"won\"\n",
" value: \".\"\n",
" }\n",
" }\n",
" }\n",
" }\n",
" feature_list {\n",
" key: \"content\"\n",
" value {\n",
" feature {\n",
" bytes_list {\n",
" value: \"When\"\n",
" value: \"shall\"\n",
" value: \"we\"\n",
" value: \"three\"\n",
" value: \"meet\"\n",
" value: \"again\"\n",
" value: \"?\"\n",
" }\n",
" }\n",
" feature {\n",
" bytes_list {\n",
" value: \"In\"\n",
" value: \"thunder\"\n",
" value: \",\"\n",
" value: \"lightning\"\n",
" value: \",\"\n",
" value: \"or\"\n",
" value: \"in\"\n",
" value: \"rain\"\n",
" value: \"?\"\n",
" }\n",
" }\n",
" }\n",
" }\n",
"}"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
"sequence_example"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 66,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
"serialized_sequence_example = sequence_example.SerializeToString()"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 67,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
"context_feature_descriptions = {\n",
" \"author_id\": tf.io.FixedLenFeature([], tf.int64, default_value=0),\n",
" \"title\": tf.io.VarLenFeature(tf.string),\n",
" \"pub_date\": tf.io.FixedLenFeature([3], tf.int64, default_value=[0, 0, 0]),\n",
"}\n",
"sequence_feature_descriptions = {\n",
" \"content\": tf.io.VarLenFeature(tf.string),\n",
" \"comments\": tf.io.VarLenFeature(tf.string),\n",
2022-02-19 06:19:26 +01:00
"}"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
2019-03-14 02:15:09 +01:00
"parsed_context, parsed_feature_lists = tf.io.parse_single_sequence_example(\n",
" serialized_sequence_example, context_feature_descriptions,\n",
2022-02-19 06:19:26 +01:00
" sequence_feature_descriptions)\n",
"parsed_content = tf.RaggedTensor.from_sparse(parsed_feature_lists[\"content\"])"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 69,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"{'title': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f8281d310d0>,\n",
" 'author_id': <tf.Tensor: shape=(), dtype=int64, numpy=123>,\n",
" 'pub_date': <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1623, 12, 25])>}"
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
"parsed_context"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 70,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(4,), dtype=string, numpy=array([b'A', b'desert', b'place', b'.'], dtype=object)>"
]
},
"execution_count": 70,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
"parsed_context[\"title\"].values"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 71,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"{'comments': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f8281d31be0>,\n",
" 'content': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f8281d31280>}"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
"parsed_feature_lists"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 72,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<tf.RaggedTensor [[b'When', b'shall', b'we', b'three', b'meet', b'again', b'?'], [b'In', b'thunder', b',', b'lightning', b',', b'or', b'in', b'rain', b'?']]>\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
"print(tf.RaggedTensor.from_sparse(parsed_feature_lists[\"content\"]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"# Keras Preprocessing Layers"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"## The `Normalization` Layer"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 73,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"363/363 [==============================] - 0s 863us/step - loss: 2.6287 - val_loss: 1.2771\n",
"Epoch 2/5\n",
"363/363 [==============================] - 0s 691us/step - loss: 0.8460 - val_loss: 1.3751\n",
"Epoch 3/5\n",
"363/363 [==============================] - 0s 729us/step - loss: 0.6995 - val_loss: 1.2119\n",
"Epoch 4/5\n",
"363/363 [==============================] - 0s 716us/step - loss: 0.6606 - val_loss: 0.8703\n",
"Epoch 5/5\n",
"363/363 [==============================] - 0s 696us/step - loss: 0.6374 - val_loss: 0.6106\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f8241cba1f0>"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"tf.random.set_seed(42) # extra code – ensures reproducibility\n",
"norm_layer = tf.keras.layers.Normalization()\n",
"model = tf.keras.models.Sequential([\n",
" norm_layer,\n",
" tf.keras.layers.Dense(1)\n",
"])\n",
"model.compile(loss=\"mse\", optimizer=tf.keras.optimizers.SGD(learning_rate=2e-3))\n",
"norm_layer.adapt(X_train) # computes the mean and variance of every feature\n",
"model.fit(X_train, y_train, validation_data=(X_valid, y_valid), epochs=5)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 74,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"norm_layer = tf.keras.layers.Normalization()\n",
"norm_layer.adapt(X_train)\n",
"X_train_scaled = norm_layer(X_train)\n",
"X_valid_scaled = norm_layer(X_valid)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 75,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"363/363 [==============================] - 0s 806us/step - loss: 2.6287 - val_loss: 1.2771\n",
"Epoch 2/5\n",
"363/363 [==============================] - 0s 642us/step - loss: 0.8460 - val_loss: 1.3751\n",
"Epoch 3/5\n",
"363/363 [==============================] - 0s 647us/step - loss: 0.6995 - val_loss: 1.2119\n",
"Epoch 4/5\n",
"363/363 [==============================] - 0s 669us/step - loss: 0.6606 - val_loss: 0.8703\n",
"Epoch 5/5\n",
"363/363 [==============================] - 0s 651us/step - loss: 0.6374 - val_loss: 0.6106\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7f8272695400>"
]
},
"execution_count": 75,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"tf.random.set_seed(42) # extra code – ensures reproducibility\n",
"model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])\n",
"model.compile(loss=\"mse\", optimizer=tf.keras.optimizers.SGD(learning_rate=2e-3))\n",
"model.fit(X_train_scaled, y_train, epochs=5,\n",
" validation_data=(X_valid_scaled, y_valid))"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 76,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"final_model = tf.keras.Sequential([norm_layer, model])\n",
"X_new = X_test[:3] # pretend we have a few new instances (unscaled)\n",
"y_pred = final_model(X_new) # preprocesses the data and makes predictions"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 77,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(3, 1), dtype=float32, numpy=\n",
"array([[1.0205517],\n",
" [1.5699625],\n",
" [2.460654 ]], dtype=float32)>"
]
},
"execution_count": 77,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"y_pred"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 78,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – creates a dataset to demo applying the norm_layer using map()\n",
"dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).batch(5)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 79,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"dataset = dataset.map(lambda X, y: (norm_layer(X), y))"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 80,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"[(<tf.Tensor: shape=(5, 8), dtype=float32, numpy=\n",
" array([[-0.1939791 , -1.0778134 , -0.9433871 , 0.0148516 , 0.02073434,\n",
" -0.572917 , 0.92925584, -1.4221287 ],\n",
" [ 0.7519827 , -1.8688954 , 0.40547717, -0.23327832, 1.8614666 ,\n",
" 0.20516507, -0.9165531 , 1.0966995 ],\n",
" [-0.41469136, 0.02970134, 0.8180875 , 1.0567819 , -0.08786613,\n",
" -0.29983336, 1.3087229 , -1.6970023 ],\n",
" [ 1.7188951 , -1.315138 , 0.32664284, -0.21955258, -0.337921 ,\n",
" -0.11146677, -0.9821399 , 0.9417729 ],\n",
" [-0.96207225, -1.2360299 , -0.05625898, -0.03124549, 1.709061 ,\n",
" -0.30257043, -0.8041173 , 1.3265921 ]], dtype=float32)>,\n",
" <tf.Tensor: shape=(5, 1), dtype=float64, numpy=\n",
" array([[1.442],\n",
" [1.687],\n",
" [1.621],\n",
" [2.621],\n",
" [0.956]])>)]"
]
},
"execution_count": 80,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"list(dataset.take(1)) # extra code – shows the first batch"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 81,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"class MyNormalization(tf.keras.layers.Layer):\n",
" def adapt(self, X):\n",
" self.mean_ = np.mean(X, axis=0, keepdims=True)\n",
" self.std_ = np.std(X, axis=0, keepdims=True)\n",
"\n",
" def call(self, inputs):\n",
" eps = tf.keras.backend.epsilon() # a small smoothing term\n",
" return (inputs - self.mean_) / (self.std_ + eps)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 82,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"my_norm_layer = MyNormalization()\n",
"my_norm_layer.adapt(X_train)\n",
"X_train_scaled = my_norm_layer(X_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The `Discretization` Layer"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 83,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(6, 1), dtype=int64, numpy=\n",
"array([[0],\n",
" [2],\n",
" [2],\n",
" [1],\n",
" [1],\n",
" [0]])>"
]
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"age = tf.constant([[10.], [93.], [57.], [18.], [37.], [5.]])\n",
"discretize_layer = tf.keras.layers.Discretization(bin_boundaries=[18., 50.])\n",
"age_categories = discretize_layer(age)\n",
"age_categories"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 84,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(6, 1), dtype=int64, numpy=\n",
"array([[1],\n",
" [2],\n",
" [2],\n",
" [1],\n",
" [2],\n",
" [0]])>"
]
},
"execution_count": 84,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"discretize_layer = tf.keras.layers.Discretization(num_bins=3)\n",
"discretize_layer.adapt(age)\n",
"age_categories = discretize_layer(age)\n",
"age_categories"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The `CategoryEncoding` Layer"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 85,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(6, 3), dtype=float32, numpy=\n",
"array([[0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [0., 0., 1.],\n",
" [0., 1., 0.],\n",
" [0., 0., 1.],\n",
" [1., 0., 0.]], dtype=float32)>"
]
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
],
2022-02-19 06:19:26 +01:00
"source": [
"onehot_layer = tf.keras.layers.CategoryEncoding(num_tokens=3)\n",
"onehot_layer(age_categories)"
]
},
{
"cell_type": "code",
"execution_count": 86,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(3, 3), dtype=float32, numpy=\n",
"array([[1., 1., 0.],\n",
" [0., 0., 1.],\n",
" [1., 0., 1.]], dtype=float32)>"
]
},
"execution_count": 86,
"metadata": {},
"output_type": "execute_result"
}
],
2022-02-19 06:19:26 +01:00
"source": [
"two_age_categories = np.array([[1, 0], [2, 2], [2, 0]])\n",
"onehot_layer(two_age_categories)"
]
},
{
"cell_type": "code",
"execution_count": 87,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(3, 3), dtype=float32, numpy=\n",
"array([[1., 1., 0.],\n",
" [0., 0., 2.],\n",
" [1., 0., 1.]], dtype=float32)>"
]
},
"execution_count": 87,
"metadata": {},
"output_type": "execute_result"
}
],
2022-02-19 06:19:26 +01:00
"source": [
"onehot_layer = tf.keras.layers.CategoryEncoding(num_tokens=3, output_mode=\"count\")\n",
"onehot_layer(two_age_categories)"
]
},
{
"cell_type": "code",
"execution_count": 88,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(3, 6), dtype=float32, numpy=\n",
"array([[0., 1., 0., 1., 0., 0.],\n",
" [0., 0., 1., 0., 0., 1.],\n",
" [0., 0., 1., 1., 0., 0.]], dtype=float32)>"
]
},
"execution_count": 88,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"onehot_layer = tf.keras.layers.CategoryEncoding(num_tokens=3 + 3)\n",
"onehot_layer(two_age_categories + [0, 3]) # adds 3 to the second feature"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(3, 6), dtype=float32, numpy=\n",
"array([[0., 1., 0., 1., 0., 0.],\n",
" [0., 0., 1., 0., 0., 1.],\n",
" [0., 0., 1., 1., 0., 0.]], dtype=float32)>"
]
},
"execution_count": 89,
"metadata": {},
"output_type": "execute_result"
}
],
2022-02-19 06:19:26 +01:00
"source": [
"# extra code – shows another way to one-hot encode each feature separately\n",
"onehot_layer = tf.keras.layers.CategoryEncoding(num_tokens=3,\n",
" output_mode=\"one_hot\")\n",
"tf.keras.layers.concatenate([onehot_layer(cat)\n",
" for cat in tf.transpose(two_age_categories)])"
]
},
{
"cell_type": "code",
"execution_count": 90,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(3, 6), dtype=float32, numpy=\n",
"array([[0., 1., 0., 1., 0., 0.],\n",
" [0., 0., 1., 0., 0., 1.],\n",
" [0., 0., 1., 1., 0., 0.]], dtype=float32)>"
]
},
"execution_count": 90,
"metadata": {},
"output_type": "execute_result"
}
],
2022-02-19 06:19:26 +01:00
"source": [
"# extra code – shows another way to do this, using tf.one_hot() and Flatten\n",
"tf.keras.layers.Flatten()(tf.one_hot(two_age_categories, depth=3))"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"## The `StringLookup` Layer"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 91,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(4, 1), dtype=int64, numpy=\n",
"array([[1],\n",
" [3],\n",
" [3],\n",
" [0]])>"
]
},
"execution_count": 91,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"cities = [\"Auckland\", \"Paris\", \"Paris\", \"San Francisco\"]\n",
"str_lookup_layer = tf.keras.layers.StringLookup()\n",
"str_lookup_layer.adapt(cities)\n",
"str_lookup_layer([[\"Paris\"], [\"Auckland\"], [\"Auckland\"], [\"Montreal\"]])"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 92,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(5, 1), dtype=int64, numpy=\n",
"array([[5],\n",
" [7],\n",
" [4],\n",
" [3],\n",
" [4]])>"
]
},
"execution_count": 92,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"str_lookup_layer = tf.keras.layers.StringLookup(num_oov_indices=5)\n",
"str_lookup_layer.adapt(cities)\n",
"str_lookup_layer([[\"Paris\"], [\"Auckland\"], [\"Foo\"], [\"Bar\"], [\"Baz\"]])"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 93,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:5 out of the last 367 calls to <function PreprocessingLayer.make_adapt_function.<locals>.adapt_step at 0x7f8239426dc0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n"
]
},
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(4, 4), dtype=float32, numpy=\n",
"array([[0., 1., 0., 0.],\n",
" [0., 0., 0., 1.],\n",
" [0., 0., 0., 1.],\n",
" [1., 0., 0., 0.]], dtype=float32)>"
]
},
"execution_count": 93,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"str_lookup_layer = tf.keras.layers.StringLookup(output_mode=\"one_hot\")\n",
"str_lookup_layer.adapt(cities)\n",
"str_lookup_layer([[\"Paris\"], [\"Auckland\"], [\"Auckland\"], [\"Montreal\"]])"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 94,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:6 out of the last 368 calls to <function PreprocessingLayer.make_adapt_function.<locals>.adapt_step at 0x7f8239426160> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n"
]
},
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(4, 1), dtype=int64, numpy=\n",
"array([[3],\n",
" [2],\n",
" [3],\n",
" [0]])>"
]
},
"execution_count": 94,
"metadata": {},
"output_type": "execute_result"
}
],
2020-03-20 11:22:47 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – an example using the IntegerLookup layer\n",
"ids = [123, 456, 789]\n",
"int_lookup_layer = tf.keras.layers.IntegerLookup()\n",
"int_lookup_layer.adapt(ids)\n",
"int_lookup_layer([[123], [456], [123], [111]])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## The `Hashing` Layer"
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 95,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(4, 1), dtype=int64, numpy=\n",
"array([[0],\n",
" [1],\n",
" [9],\n",
" [1]])>"
]
},
"execution_count": 95,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"hashing_layer = tf.keras.layers.Hashing(num_bins=10)\n",
"hashing_layer([[\"Paris\"], [\"Tokyo\"], [\"Auckland\"], [\"Montreal\"]])"
2019-03-14 02:15:09 +01:00
]
},
2019-05-15 14:23:24 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"## Encoding Categorical Features Using Embeddings"
2019-05-15 14:23:24 +02:00
]
},
2019-03-14 02:15:09 +01:00
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 96,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(3, 2), dtype=float32, numpy=\n",
"array([[-0.04663396, 0.01846724],\n",
" [-0.02736737, -0.02768031],\n",
" [-0.04663396, 0.01846724]], dtype=float32)>"
]
},
"execution_count": 96,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"tf.random.set_seed(42)\n",
"embedding_layer = tf.keras.layers.Embedding(input_dim=5, output_dim=2)\n",
"embedding_layer(np.array([2, 4, 2]))"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(3, 2), dtype=float32, numpy=\n",
"array([[-0.01896119, 0.02223358],\n",
" [ 0.02401174, 0.03724445],\n",
" [-0.01896119, 0.02223358]], dtype=float32)>"
]
},
"execution_count": 97,
"metadata": {},
"output_type": "execute_result"
}
],
2022-02-19 06:19:26 +01:00
"source": [
"tf.random.set_seed(42)\n",
"ocean_prox = [\"<1H OCEAN\", \"INLAND\", \"NEAR OCEAN\", \"NEAR BAY\", \"ISLAND\"]\n",
"str_lookup_layer = tf.keras.layers.StringLookup()\n",
"str_lookup_layer.adapt(ocean_prox)\n",
"lookup_and_embed = tf.keras.Sequential([\n",
" str_lookup_layer,\n",
" tf.keras.layers.Embedding(input_dim=str_lookup_layer.vocabulary_size(),\n",
" output_dim=2)\n",
2019-03-14 02:15:09 +01:00
"])\n",
2022-02-19 06:19:26 +01:00
"lookup_and_embed(np.array([[\"<1H OCEAN\"], [\"ISLAND\"], [\"<1H OCEAN\"]]))"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 98,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"313/313 [==============================] - 0s 903us/step - loss: 0.1491 - val_loss: 0.1188\n",
"Epoch 2/5\n",
"313/313 [==============================] - 0s 723us/step - loss: 0.1069 - val_loss: 0.0967\n",
"Epoch 3/5\n",
"313/313 [==============================] - 0s 667us/step - loss: 0.0924 - val_loss: 0.0886\n",
"Epoch 4/5\n",
"313/313 [==============================] - 0s 677us/step - loss: 0.0870 - val_loss: 0.0856\n",
"Epoch 5/5\n",
"313/313 [==============================] - 0s 671us/step - loss: 0.0849 - val_loss: 0.0843\n"
]
}
],
2022-02-19 06:19:26 +01:00
"source": [
"# extra code – set seeds and generates fake random data\n",
"# (feel free to load the real dataset if you prefer)\n",
"tf.random.set_seed(42)\n",
"np.random.seed(42)\n",
"X_train_num = np.random.rand(10_000, 8)\n",
"X_train_cat = np.random.choice(ocean_prox, size=10_000)\n",
"y_train = np.random.rand(10_000, 1)\n",
"X_valid_num = np.random.rand(2_000, 8)\n",
"X_valid_cat = np.random.choice(ocean_prox, size=2_000)\n",
"y_valid = np.random.rand(2_000, 1)\n",
"\n",
"num_input = tf.keras.layers.Input(shape=[8], name=\"num\")\n",
"cat_input = tf.keras.layers.Input(shape=[], dtype=tf.string, name=\"cat\")\n",
"cat_embeddings = lookup_and_embed(cat_input) \n",
"encoded_inputs = tf.keras.layers.concatenate([num_input, cat_embeddings])\n",
"outputs = tf.keras.layers.Dense(1)(encoded_inputs)\n",
"model = tf.keras.models.Model(inputs=[num_input, cat_input], outputs=[outputs])\n",
"model.compile(loss=\"mse\", optimizer=\"sgd\")\n",
"history = model.fit((X_train_num, X_train_cat), y_train, epochs=5,\n",
" validation_data=((X_valid_num, X_valid_cat), y_valid))"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"313/313 [==============================] - 1s 1ms/step - loss: 0.0839 - val_loss: 0.0838\n",
"Epoch 2/5\n",
"313/313 [==============================] - 0s 1ms/step - loss: 0.0835 - val_loss: 0.0835\n",
"Epoch 3/5\n",
"313/313 [==============================] - 0s 1ms/step - loss: 0.0832 - val_loss: 0.0833\n",
"Epoch 4/5\n",
"313/313 [==============================] - 0s 1ms/step - loss: 0.0831 - val_loss: 0.0832\n",
"Epoch 5/5\n",
"313/313 [==============================] - 0s 1ms/step - loss: 0.0830 - val_loss: 0.0831\n"
]
}
],
2022-02-19 06:19:26 +01:00
"source": [
"# extra code – shows that the model can also be trained using a tf.data.Dataset\n",
"train_set = tf.data.Dataset.from_tensor_slices(\n",
" ((X_train_num, X_train_cat), y_train)).batch(32)\n",
"valid_set = tf.data.Dataset.from_tensor_slices(\n",
" ((X_valid_num, X_valid_cat), y_valid)).batch(32)\n",
"history = model.fit(train_set, epochs=5,\n",
" validation_data=valid_set)"
]
},
{
"cell_type": "code",
"execution_count": 100,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"313/313 [==============================] - 1s 1ms/step - loss: 0.0829 - val_loss: 0.0830\n",
"Epoch 2/5\n",
"313/313 [==============================] - 0s 1ms/step - loss: 0.0829 - val_loss: 0.0830\n",
"Epoch 3/5\n",
"313/313 [==============================] - 0s 1ms/step - loss: 0.0828 - val_loss: 0.0830\n",
"Epoch 4/5\n",
"313/313 [==============================] - 0s 1ms/step - loss: 0.0828 - val_loss: 0.0829\n",
"Epoch 5/5\n",
"313/313 [==============================] - 0s 1ms/step - loss: 0.0828 - val_loss: 0.0829\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"# extra code – shows that the dataset can contain dictionaries\n",
"train_set = tf.data.Dataset.from_tensor_slices(\n",
" ({\"num\": X_train_num, \"cat\": X_train_cat}, y_train)).batch(32)\n",
"valid_set = tf.data.Dataset.from_tensor_slices(\n",
" ({\"num\": X_valid_num, \"cat\": X_valid_cat}, y_valid)).batch(32)\n",
"history = model.fit(train_set, epochs=5, validation_data=valid_set)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"## Text Preprocessing"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(2, 4), dtype=int64, numpy=\n",
"array([[2, 1, 0, 0],\n",
" [6, 2, 1, 2]])>"
]
},
"execution_count": 101,
"metadata": {},
"output_type": "execute_result"
}
],
2022-02-19 06:19:26 +01:00
"source": [
"train_data = [\"To be\", \"!(to be)\", \"That's the question\", \"Be, be, be.\"]\n",
"text_vec_layer = tf.keras.layers.TextVectorization()\n",
"text_vec_layer.adapt(train_data)\n",
"text_vec_layer([\"Be good!\", \"Question: be or be?\"])"
]
},
{
"cell_type": "code",
"execution_count": 102,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.RaggedTensor [[2, 1], [6, 2, 1, 2]]>"
]
},
"execution_count": 102,
"metadata": {},
"output_type": "execute_result"
}
],
2022-02-19 06:19:26 +01:00
"source": [
"text_vec_layer = tf.keras.layers.TextVectorization(ragged=True)\n",
"text_vec_layer.adapt(train_data)\n",
"text_vec_layer([\"Be good!\", \"Question: be or be?\"])"
]
},
{
"cell_type": "code",
"execution_count": 103,
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(2, 6), dtype=float32, numpy=\n",
"array([[0.96725637, 0.6931472 , 0. , 0. , 0. ,\n",
" 0. ],\n",
" [0.96725637, 1.3862944 , 0. , 0. , 0. ,\n",
" 1.0986123 ]], dtype=float32)>"
]
},
"execution_count": 103,
"metadata": {},
"output_type": "execute_result"
}
],
2022-02-19 06:19:26 +01:00
"source": [
"text_vec_layer = tf.keras.layers.TextVectorization(output_mode=\"tf_idf\")\n",
"text_vec_layer.adapt(train_data)\n",
"text_vec_layer([\"Be good!\", \"Question: be or be?\"])"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 104,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"1.3862943611198906"
]
},
"execution_count": 104,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"2 * np.log(1 + 4 / (1 + 3))"
2019-03-14 02:15:09 +01:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "code",
"execution_count": 105,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"1.0986122886681098"
]
},
"execution_count": 105,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"1 * np.log(1 + 4 / (1 + 1))"
2019-03-14 02:15:09 +01:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
"metadata": {
"tags": []
},
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"# Using Pretrained Language Model Components"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 106,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"text/plain": [
"array([[-0.25, 0.28, 0.01, 0.1 , 0.14, 0.16, 0.25, 0.02, 0.07,\n",
" 0.13, -0.19, 0.06, -0.04, -0.07, 0. , -0.08, -0.14, -0.16,\n",
" 0.02, -0.24, 0.16, -0.16, -0.03, 0.03, -0.14, 0.03, -0.09,\n",
" -0.04, -0.14, -0.19, 0.07, 0.15, 0.18, -0.23, -0.07, -0.08,\n",
" 0.01, -0.01, 0.09, 0.14, -0.03, 0.03, 0.08, 0.1 , -0.01,\n",
" -0.03, -0.07, -0.1 , 0.05, 0.31],\n",
" [-0.2 , 0.2 , -0.08, 0.02, 0.19, 0.05, 0.22, -0.09, 0.02,\n",
" 0.19, -0.02, -0.14, -0.2 , -0.04, 0.01, -0.07, -0.22, -0.1 ,\n",
" 0.16, -0.44, 0.31, -0.1 , 0.23, 0.15, -0.05, 0.15, -0.13,\n",
" -0.04, -0.08, -0.16, -0.1 , 0.13, 0.13, -0.18, -0.04, 0.03,\n",
" -0.1 , -0.07, 0.07, 0.03, -0.08, 0.02, 0.05, 0.07, -0.14,\n",
" -0.1 , -0.18, -0.13, -0.04, 0.15]], dtype=float32)"
]
},
"execution_count": 106,
"metadata": {},
"output_type": "execute_result"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"import tensorflow_hub as hub\n",
"\n",
"hub_layer = hub.KerasLayer(\"https://tfhub.dev/google/nnlm-en-dim50/2\")\n",
"sentence_embeddings = hub_layer(tf.constant([\"To be\", \"Not to be\"]))\n",
"sentence_embeddings.numpy().round(2)"
2019-03-14 02:15:09 +01:00
]
},
{
2022-02-19 06:19:26 +01:00
"cell_type": "markdown",
2019-03-14 02:15:09 +01:00
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"## Image Preprocessing Layers"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 107,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"from sklearn.datasets import load_sample_images\n",
"\n",
"images = load_sample_images()[\"images\"]\n",
"crop_image_layer = tf.keras.layers.CenterCrop(height=100, width=100)\n",
"cropped_images = crop_image_layer(images)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 108,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVQAAADnCAYAAABBu67aAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz9V7AtS3rfif0yyyy/vT3+nHvctX3bNxoNoNFAEyAggp4EOUaaGY00MpwJvUhvelCEpGAoFCPNcCiKwYnQEOSQxAgEPQg2QLgG2t2+/h7vzT7bu+VXVWXqoSqrsmrV2nufJkKKUNy8d529VlVW5peZX/4/k19mCa01n6ZP06fp0/Rp+ndP8v/XBHyaPk2fpk/T/7+kTwH10/Rp+jR9mv6Y0qeA+mn6NH2aPk1/TOlTQP00fZo+TZ+mP6b0KaB+mj5Nn6ZP0x9Tco+6+Y+/90IDCCEA0MnHXBNCoLVOUdnkK36XUmKiCcwz9v2yZO7bdQMIQJt7yaf4XLH8YllZ0mN5IGtXWd70ngZBvg4dXzy2TQBCj9OjRbEPdWn7TlL+5N8Rhsi0X/T4M0KM9409hhPbZfW/Ugop5Xg7dVm/Z8kpa6OM+1cAQop4NMbo1nFjkutOofdE+k/WrzZvCDRSxPRJKcjINJwfc5wQOu0zIUTClPmht5tgf5cini+Gd6WI26KFypcHuErk2oiIecSmKm3SCZJVdCltAELH/RwJkDrOkI27qU1bfSOyMjRoDUqCTvoy5bWkf8bqM48mAyoKc8CMczxfsrxxvnistLZpKzbQ6kL7S0lwU8xCGpHwhxQCJylDSpHW4bmTu/xIQB2bOOZjg2UJ8x81gY4L0yoDPxH3XFZnfHGsrNxAFED1qLz2/fIySp59yWizXHsSJhwDE2EBUgrX4xVNArayugpPjudlvH0pMRwv+CbRNklo/uhhegkIFvpeJDgndCKQLDGXGzqRb+sYbToGzXLqCqJbWwUbAuzcIv83f03k6lGAtAHe3NAQFVA6kQUTksj9KUtagRY6pwiMqw06rceAHJATLvYQirL2G/4s3CljpWK5x7FbHksEBWgo5y9b4NldKiwxmYBpzCcSR4IU2b2MrvH22ulIQEXrUuQXmJbHg1PUWic2zMp3kmtj9+226HHt7TigOWldRQFgZz9Suy4XkCAMC2f/lvYBIplQRhRj1IQTAfgk2vLXCoIEXWB6UdL+kwF38dmX0WzT3yV8M2YxGEa3TCYptJkP9pMJryYPFcstfM8mjigjI006AV8hJ/NRft7odOIrbZQQnbQrAVUtUqvAgKq25hZpU3WOl2x2OTblQCH7Wqo4qbzcgPGhGWcLnY2N9TE1HgWYLym3c2Bf9jtPUh7gU8GcJGkhrRDxb/MRCV8ZNI7rcSbSdSygYqn0RS1La42QhtSXTycBPp2JyuR3RldCSCmIT5rcx9ZJidQuKTc2EUXuOhCbo3Z+sgJtrbRMe061FFt0Ws+X0ntEO0/iWilLWpsJn6fvJOkooRqbZxopx133ImMybLWj6JYQ6NgsFeRcFWlVwurHROimY6oz68YeE5OjTJvKT1SR3jNjZIPEuGaaDWJOiyLpY1twCREjmM4mvRIl/W5pdGk1IlNtRGEujpWQEpl1So7/ZLGE8WQLnmJdWmfmsdY6FthiQpl2W0owZiy7Bf6x101jdDuts+fzLoo8sJuxM9aLwACnPbYxkAphaMtm5nGazdEmv7QLjQtM5WIycaSAKK3Dkp1HaEplElEn12O20Mm1Em3F+jcpLAUO44+CvFaUA5aiK6Cs3VbZk1T8bICsAaTM93o8mJsycoxwBFeXgnGh7HIQPB4Ys74SxLrTUQw+AbC1NXIFtWRSH2jL4kjNMPO8hZKlGpmOJ0UKssnfHECajCIjR48VlpmQRf9pmfacCYFEyylM7jiPtuqy2qA1yqJNYQAIi8D8BI4BUybzS1uFZV9VToPN953QGqFIrb1szmUFiWTIhQAlYgF0cnkqCn1bCgOlyeQv5ymdtCTBBp3nZFuY2pihtc7Mdos+KRI3S4ycOUGYXErmYw6KYzr00Q06GlCRCYSKpPfHzSelFRIX0sZm4BJ3xTjAFgHOlGieSN0IWiOEzMBVkMyatCRQSVlKp4tVxQknhDGlE45X2hazaR6bnoxkkTY5ZZKsJVa+hHZLUual5fj3MpN47PtJmVln/ZcRpMkKiHvxpBpr5hsz2qRdllVtCe2aTBNM20ieNZOHTCF2xSBiYLAvmQUmI3iENREAtMj4ThhkLdEetWV1iVw/xTymVVZmqm2bpmsrv7DanvJr1pyypqXlqrzZni+LVCBlilFewzNzA8vlIKyKUm01JSLNZZTglLC4HiMJ4u/G9VTsn1x/a/t6UZiTXbfZb6wjTBHjWuVYVplppLnytPVXG0A0BOgYwRJfq0hWA6Www5vGFSyddL7QInftODCF40z+nLIcA2zGXZlanxJG1rnmsr3SlxZp9Yb9DJiJmfR12kEZLbYEyk1QSwsstjvjS0t2lzh0BJbgF2QapzDmvT3Ji9pofuYctThTBqbj3zlxygRGTIYZpWR2luYv1pdPBv7y45/dOyFdE74Xr5XANDaEyESg2yAkivkT6ySW+cWFE6N15Pt+HCBKUqKuCmk0FDILwtKqjODNaaJlNABSFXl0fAnSrDQbl5fx15b6CrVdjqlPWPPGptMiCAtMk6y6AN5l3WH3p11ufu4bwLbGa7x6S1HJz6+sRTFNGfyP85LBiJRbRVazEBpHWsqEMHkz7jPjl2+obe3m21nOzXE6AaAm34RMibR7I75e0khR8l2Q65jxjHmiTQfY3TwOZPl7xZRCgxnRMaltPavzToYy4MsBpMieLaPpZaIbxpOt+ZUJnaJkzUDIFtpjPrUT0pETlML2eY738Y8SCTCpz+JJUSLoirVbc8CAzdFUxHWV+W/Ns6U9o637Ik/HJF6ZSEHJvLXZ0YBCCtQYbXzCZM4VmDe38y6Lo8ZIjOUZ639hXFJWtaXtK7uhc20yddr6TIk3Jfd8mhLrwTb5pQXwWbhTjFfp9cSyTheYdKHTj2iT6d+TzptjwqbM3zw6pivDSaxKOgmErRaPD2QRGBGkK7V5nBPjDCDyjRpzGSTmmSrkyXxfmjHVtTAtMvAVY/XFVeV/p76oYwB+0sJR8Z4dr2toM7SkE9vuz9wkFiV9bTQqYReZ0ygmQ4A9jvZY5oXMScA0pw1YtKc0FMFVk6NLQOYbTbWL2AdZVn062YoLhAn9mUZUfNAwpN1niaAt6f+jtPW8DzET1pkGNV5eqXvAmvDxgp69WFjefkNbtrho+HC8nmIZZrQdDVGJ/7esfdkibVHIZGWKwpdMUZhEfyGp+HkjDrPY9xgojRkvBUjHErDCxEqItHGarD/HMWGspdZ3cSy9x2io5QVnWkIMpTqnRmYiXdh2aG7CZ12vtTGV8uJ7zCxNOt+GP5HMsBxAUegCXbiQY448Z9lwdCKBdEKls2zhqFSbHqtUoHV+YSjD26QzFKk5WnjStgSz5xPMsIXN2CKhMGFD5Q08LqzKCKX8vRJttDDGZlKKkjwyvWEJkJSHJsW8jl2xKDw6pVq0MblLyh+nP72TE4xjgnmMxqMFcv53prXmbxcXTArKwiSGLokk0MQLUkeQVmphGA3+xHNHlBc/yXKxQVSKWDybFXoT5uSMYUSm5GltFKC4MBPHmtV7BKFH5snSMYtSKiXLSFqtdLyaLgCtMkAT5NA+I9RmNgtghRnwcg1LqahwTSeai6WxWUxlL1ak+GI000IyCkpu8qq8phrTXMbQJf1UwlwnfdZ+PtdFCSXC8uWk1AkZdz1Jf6R12+vGsUNea527nNahMzXQhLakNVjmfppDH9EGS0Mpa7OJELLb8TJ7noucYmgSyQCKRCjbFeQnyvjT5al8IosJboKjUmYm2xoyFs5p+0e86GIpJWlY4hgfGaryRGpIFqYybbRszDLfLKnmb0DQeHbME1KX9UhCoREiUfJE4rgfW7/IS0eMf9veQGBjQP56on0aK1hkgkSIZMlcgJQ2DmRtU9r0Y6a4maRUlt/wd353nOn
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2020-03-20 11:22:47 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"plt.imshow(images[0])\n",
"plt.axis(\"off\")\n",
"plt.show()"
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 109,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAADnCAYAAADl9EEgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAEAAElEQVR4nOz96Y8lWZreif3OamZ38z081lwqszJr6a4qsovdbKo55IwADSVAGmIgQB/09wkYaCDow0gAwRYJDtnsYa9VzaruWjIr94zF97vZdjZ9OHY9PCIjPCKXyi7O1El4hvtdzMz9ntfe7XmfR6SU+O367frt+s1b8u/7An67frt+u569fmucv12/Xb+h67fG+dv12/Ubun5rnL9dv12/oeu3xvnb9dv1G7r0dU/e/c7v5lKuESAFAALBzZ09bu8e8Prbb/IH/80/o3OODx98wsnxCf/6//0/sTi/4L//v//f+M4PvsffvfsL3v3oA8oAYw87u7u88da3WK/X/Plf/gVd1/HG299kOpvy8OED5vML3nrrTd56+028c3RtQ9vUHD96iDGKt7/5TSbjMTIkREwE8lc5HjOezYi+xzUrtIAda1BC4DyEAI9Ol8yXLUdHRxwfH0N0EHqKwrC9N6OqRtx+5RtUownTrQNsOaKyBYW1lFpTGYsUAolECIlSir7v+fkv3+H45IT/4X/8H/nrH/2IrVs3mN3YY3trmxsHh7z+6mv8H//Fv2BUlkTvECkxkRYjJM47fAj5D54gxcimgi6EQAiBlI/voSklSAmeU2QXKr82DsfZHGNzvCeOMywp5eVzKaXL5z5TyU95B+SvzT+J517M1et6xjW88D1s9twXWU9dkxCXj3zRDoVEPPP6v0zHQwjB4c3dZ/6K1xonCRCPN4QUComAlD987zxd09K5nq5pcV2HAJSUSCGRQPCBvu2wQiGFQSDyBtx8pYQSEi0VcrNBhq8YE855nA+EGJFB0DtP1zt0ApkYjBO8DzjniSEQYkIKCDGRBIQkCAkQCqk12lpsUZK8IKSAkAohJAmB94HeOeqmoQ+JFCMxRrAWhUAJiRIKKRNSKRLi8no3djP81Qkx0vc9bdeyXq9JIRC9RwJFKVDKPN7vwwE2315+BIOxCCGuHPzKh/PEJ/2CnXDlmC967PoN94xzX72Mpy/1Gcd9OSNNV/5/efRrX/vrXl9n6/Fa47RJkFKia3oCiel4SlVWSJ9YXSz5+L2PcCHRuY77R49o2xaNYHd7m0lZUSnD6f37/N1f/RXfvPsqr771LUba0J2vaFYrmrMlIUZ2RjNu7B4yP5szZ0WKCu8l5/M1H3/0Mc71tG2NlJLW/YrCWmbViMIYpLFIbeC8Ick5RkvGlUVLQaMjUgoQFoREznbZ2jZM9m9yt+9ZnJ7w6OMPUUqgbIFH8P7HD3AxcXK+oukct24esre7zc5si4PtHUpbMK2mlGXJjcNDgtZ4BC5BHxNdiIQkQRrOLhZ8ev8Rn95/wGqxpiwsRilKa/n93/ket/YPqKoKay0xBGKKlzfClBIhBBACpfKN69LuyV4lkV8vhURrDQJCjCT4rLfl+RsrxvjCjSKEGG7Ul+/67PNP/Jyea6Cba3mRgT799heZc7ryDvGcV38R47q8zq8ZEnCtccohekohklJEITBSQUj0oWM5n5M+FrSu59HJMT4EFILCWpQQECLNcsX50THd7gGlsWgh8W2Hq1tC54hAqQsm1RitDCRx+dX3gYv5khA8MQWkEJzPlxil8GNPYQ22HKFtSUjgk6AsLAiFkRJ0RAqB0AohBVVhMUWFLUcoIPrA+dERCBBKE2JkvlxRdz0ffPKQxaqmcz1119K2HSQYFRUxCHxMbIdASgkXIz5GYiJ7UiEQUtL2Pafn57jeMS5+hbWG0hhGVcU379xjazTGGENRFE9u1MFAN0YjhSBJ+djrSEmOgYcvkRCC7LKuLCHEE5vx6vfPC3Gffv9nje5lNvjzPfPLnvf5R305C7k0VLH539VU4Opt7jN/tqev5onzi2vP/zLH++yVPm9da5wdHgBpCgRQNw1N06AAlQTyQqFOjogk+hgQQKU1PkZ+/Kf/iff+9m958M67TJNgIhXjwtA0Le99fJ+u7xmNC8qy4mB/h/3dbWQM1PM5lbbcvXmH5fmCi6MzbGG5e/cOhS0YjypiiPz4L/6KR/cfcHDzFrt7+0y2d9k+OMR3juXqjEQg+oaUElEbkpRIUyGUQUuBFgLVOwwGKSFGQec8D84uWK1rjs7mrOoGl+Dh2Tm7sxmfbD+isgWz8YzCFuy88y6QOHp0xHq1AqO5ee8uo+0ZWhtG1Zi93YREcHx8jEgJvMdqwwTJBwc3+MEPfsA3vvGNnApoRQiB4MMTud/GSNJVg5VyiITzY865/FEPHvbxe582iLwhUhJDrglKKUCQBs/9Io/3vMcf3wyuvuaLZYyfbz3/gj8bgL8gNbj2CF+v67zWOF0KufChNVIounqB61pEBJUgDsUYqRR2VKKUoqDCx8T7v/gFMXi6vqVKklJICq1ZuxWPHt4nkhjPpsxmY2bTCdPpBJESXV1jlWZ/e5dSW1bzJdPplO3JNqPRiOl4Qt/1fPrhfX76Nz/h9Tdqbt/tObyjmG4d4qJn7Rpc6Fk3c0IMOK2IUhLQJKEwUmCkZKcouTueosibtvOJ8+WaxXLFxWpN3XQ0vUOdn3MxWTA/X1IYy3Q0RivNaDQCoKlrvHOgNNv7+6jCIKWiKEqkkPRty+LsnOAcbr1GCclYKk4fHXH79m3u3buHsAYtNTHGHJoOXnFjmJuNf/VrY0VxkxcDQqvLx1JKl8Wez3qszaYTl6+JcRPiJmJ8ftj5vLzxcfj8+LGnD/Fcb3n1xelJz/aZRPzJIz7jQM8zwWcd5EkD/ez15ZO/zC3m8eW/2Ls+ef5nr2uNM7Y1IEg65qKJBFMZdAQVs3G6lNBaU40rlFSYJJApkUIEH7HaYIxmcTHnx3/xV6ybhodHx7mYIiIxeP7sT/8j48mED9/7FV1dc/+Tj/nJj3/Ee+++y8XZOSlElvMFru04PzmjazvazqFtwXgyZWdnh9u3b/HWG2/Q9i2n81NC8vg0I5JoUsKnRB8VIQl81xLaDhEj63ZNYTXVaERMhtIaOqNJwdN3DSkZdNKs6xUieowyrBdLrC3Y3dkBITg/P6PrOtZtQ+d6CldiCpu3fgIRE8SIkYq9gxuMqorv/+D73L55k5u3b6GMRihJGkJTqVT2iCLmaq2SSKmwSpENZzBGIfONhVzlZQinAUIIxBivNbCu6xBC4L1/IgTO78mbMqXHodrnSdc+X2i3uajNm5/z+OdYV+35y/nuTTD9vCz2ixzv2rvN5brWOEOzBgRBJ4TUmLFGFQYbBSZCSAmVEtoYppMxCoFoevCR4CPCB+yowJSW+ekZD97/AJ8iPRFbltjC0NY1n378cW4nKIWQik8//ABi5KOPP+H85BRCZHF+QWMM5+dzmqah7XtMWTGdzdjd3eXenTt8++23WK2XjB4oEhFTCaJILLqePkRaBy7A8uychTtHRs+qbUmyxBRTEIaqsLTWEIOjb5shNjCsfE/XrJBJopPOhbHB4zw4OqJpGlwMhBQZOcfIVRilKIzJxhki2hhu37zJ7vYOP/zhD3nttdcYjSq00ZdhKlIgdTbOFMgGpxRKKbTOr2vbltD3SClQcmPIeetsDBseh7xyyFevGmqMka7rcrVcqcuWzeZfpdRnjPHJdsh11dovuY2/Gqt63On5SqzqqzPPl73bXGuch7vbxJg4XbT5LmsKhDLkLEohU0KHiBKB0LQgJDoJhBQEKXAy39GjcyQEhSkwJIxMKK1xbUPse7reEUJAaIOQitPjY3rnuJgvUFrhg+fh0SOkUtR1S+8cGEUxHdG4lqPTI+4tLgh9Q3Ityvf46KjbOufDCTwCISxaaEalRW9NkdGjY09VWCZa08VAKQWVFEytxZUlqlBIo4b+Zv5SSYKM1O2KTa4mpEAJhUBCTLiuJwqB73uSD2ilmE0mfO93fofDGwfs7e2hrUFphdSKtmmzgfc9ddNQWMvBwQFKa1K
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"plt.imshow(cropped_images[0])\n",
"plt.axis(\"off\")\n",
"plt.show()"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 06:19:26 +01:00
"# TensorFlow Datasets"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 110,
2020-03-20 11:22:47 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"import tensorflow_datasets as tfds\n",
"\n",
"datasets = tfds.load(name=\"mnist\")\n",
"mnist_train, mnist_test = datasets[\"train\"], datasets[\"test\"]"
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 111,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"for batch in mnist_train.shuffle(10_000, seed=42).batch(32).prefetch(1):\n",
" images = batch[\"image\"]\n",
" labels = batch[\"label\"]\n",
" # [...] do something with the images and labels"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 112,
2019-03-14 02:15:09 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 06:19:26 +01:00
"mnist_train = mnist_train.shuffle(10_000, seed=42).batch(32)\n",
"mnist_train = mnist_train.map(lambda items: (items[\"image\"], items[\"label\"]))\n",
"mnist_train = mnist_train.prefetch(1)"
2019-03-14 02:15:09 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 113,
2019-03-14 02:15:09 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"1688/1688 [==============================] - 2s 1ms/step - loss: 9.6765 - accuracy: 0.8348 - val_loss: 5.8894 - val_accuracy: 0.8835\n",
"Epoch 2/5\n",
"1688/1688 [==============================] - 1s 796us/step - loss: 5.6335 - accuracy: 0.8785 - val_loss: 5.1325 - val_accuracy: 0.8800\n",
"Epoch 3/5\n",
"1688/1688 [==============================] - 1s 793us/step - loss: 5.0494 - accuracy: 0.8832 - val_loss: 5.3470 - val_accuracy: 0.8938\n",
"Epoch 4/5\n",
"1688/1688 [==============================] - 1s 767us/step - loss: 4.8245 - accuracy: 0.8867 - val_loss: 5.2491 - val_accuracy: 0.8870\n",
"Epoch 5/5\n",
"1688/1688 [==============================] - 1s 765us/step - loss: 4.6808 - accuracy: 0.8871 - val_loss: 5.1136 - val_accuracy: 0.8960\n",
"313/313 [==============================] - 0s 769us/step - loss: 4.6993 - accuracy: 0.8975\n"
]
}
],
2019-03-14 02:15:09 +01:00
"source": [
2022-02-19 06:19:26 +01:00
"train_set, valid_set, test_set = tfds.load(\n",
" name=\"mnist\",\n",
" split=[\"train[:90%]\", \"train[90%:]\", \"test\"],\n",
" as_supervised=True\n",
")\n",
"train_set = train_set.shuffle(10_000, seed=42).batch(32).prefetch(1)\n",
"valid_set = valid_set.batch(32).cache()\n",
"test_set = test_set.batch(32).cache()\n",
"tf.random.set_seed(42)\n",
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(input_shape=(28, 28)),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
"])\n",
"model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"nadam\",\n",
" metrics=[\"accuracy\"])\n",
"history = model.fit(train_set, validation_data=valid_set, epochs=5)\n",
"test_loss, test_accuracy = model.evaluate(test_set)"
2019-03-14 02:15:09 +01:00
]
},
2020-03-20 11:22:47 +01:00
{
"cell_type": "markdown",
2022-02-19 06:19:26 +01:00
"metadata": {
"tags": []
},
2020-03-20 11:22:47 +01:00
"source": [
"# Exercises\n",
"\n",
"## 1. to 8.\n",
2022-02-19 06:19:26 +01:00
"1. Ingesting a large dataset and preprocessing it efficiently can be a complex engineering challenge. The Data API makes it fairly simple. It offers many features, including loading data from various sources (such as text or binary files), reading data in parallel from multiple sources, transforming it, interleaving the records, shuffling the data, batching it, and prefetching it.\n",
"2. Splitting a large dataset into multiple files makes it possible to shuffle it at a coarse level before shuffling it at a finer level using a shuffling buffer. It also makes it possible to handle huge datasets that do not fit on a single machine. It's also simpler to manipulate thousands of small files rather than one huge file; for example, it's easier to split the data into multiple subsets. Lastly, if the data is split across multiple files spread across multiple servers, it is possible to download several files from different servers simultaneously, which improves the bandwidth usage.\n",
"3. You can use TensorBoard to visualize profiling data: if the GPU is not fully utilized then your input pipeline is likely to be the bottleneck. You can fix it by making sure it reads and preprocesses the data in multiple threads in parallel, and ensuring it prefetches a few batches. If this is insufficient to get your GPU to 100% usage during training, make sure your preprocessing code is optimized. You can also try saving the dataset into multiple TFRecord files, and if necessary perform some of the preprocessing ahead of time so that it does not need to be done on the fly during training (TF Transform can help with this). If necessary, use a machine with more CPU and RAM, and ensure that the GPU bandwidth is large enough.\n",
"4. A TFRecord file is composed of a sequence of arbitrary binary records: you can store absolutely any binary data you want in each record. However, in practice most TFRecord files contain sequences of serialized protocol buffers. This makes it possible to benefit from the advantages of protocol buffers, such as the fact that they can be read easily across multiple platforms and languages and their definition can be updated later in a backward-compatible way.\n",
"5. The `Example` protobuf format has the advantage that TensorFlow provides some operations to parse it (the `tf.io.parse`*`example()` functions) without you having to define your own format. It is sufficiently flexible to represent instances in most datasets. However, if it does not cover your use case, you can define your own protocol buffer, compile it using `protoc` (setting the `--descriptor_set_out` and `--include_imports` arguments to export the protobuf descriptor), and use the `tf.io.decode_proto()` function to parse the serialized protobufs (see the \"Custom protobuf\" section of the notebook for an example). It's more complicated, and it requires deploying the descriptor along with the model, but it can be done.\n",
"6. When using TFRecords, you will generally want to activate compression if the TFRecord files will need to be downloaded by the training script, as compression will make files smaller and thus reduce download time. But if the files are located on the same machine as the training script, it's usually preferable to leave compression off, to avoid wasting CPU for decompression.\n",
"7. Let's look at the pros and cons of each preprocessing option:\n",
" * If you preprocess the data when creating the data files, the training script will run faster, since it will not have to perform preprocessing on the fly. In some cases, the preprocessed data will also be much smaller than the original data, so you can save some space and speed up downloads. It may also be helpful to materialize the preprocessed data, for example to inspect it or archive it. However, this approach has a few cons. First, it's not easy to experiment with various preprocessing logics if you need to generate a preprocessed dataset for each variant. Second, if you want to perform data augmentation, you have to materialize many variants of your dataset, which will use a large amount of disk space and take a lot of time to generate. Lastly, the trained model will expect preprocessed data, so you will have to add preprocessing code in your application before it calls the model. There's a risk of code duplication and preprocessing mismatch in this case.\n",
" * If the data is preprocessed with the tf.data pipeline, it's much easier to tweak the preprocessing logic and apply data augmentation. Also, tf.data makes it easy to build highly efficient preprocessing pipelines (e.g., with multithreading and prefetching). However, preprocessing the data this way will slow down training. Moreover, each training instance will be preprocessed once per epoch rather than just once if the data was preprocessed when creating the data files. Well, unless the dataset fits in RAM and you can cache it using the dataset's `cache()` method. Lastly, the trained model will still expect preprocessed data. But if you use preprocessing layers in your tf.data pipeline to handle the preprocessing step, then you can just reuse these layers in your final model (adding them after training), to avoid code duplication and preprocessing mismatch.\n",
" * If you add preprocessing layers to your model, you will only have to write the preprocessing code once for both training and inference. If your model needs to be deployed to many different platforms, you will not need to write the preprocessing code multiple times. Plus, you will not run the risk of using the wrong preprocessing logic for your model, since it will be part of the model. On the downside, preprocessing the data on the fly during training will slow things down, and each instance will be preprocessed once per epoch.\n",
"8. Let's look at how to encode categorical text features and text:\n",
" * To encode a categorical feature that has a natural order, such as a movie rating (e.g., \"bad,\" \"average,\" \"good\"), the simplest option is to use ordinal encoding: sort the categories in their natural order and map each category to its rank (e.g., \"bad\" maps to 0, \"average\" maps to 1, and \"good\" maps to 2). However, most categorical features don't have such a natural order. For example, there's no natural order for professions or countries. In this case, you can use one-hot encoding, or embeddings if there are many categories. With Keras, the `StringLookup` layer can be used for ordinal encoding (using the default `output_mode=\"int\"`), or one-hot encoding (using `output_mode=\"one_hot\"`). It can also perform multi-hot encoding (using `output_mode=\"multi_hot\"`) if you want to encode multiple categorical text features together, assuming they share the same categories and it doesn't matter which feature contributed which category. For trainable embeddings, you must first use the `StringLookup` layer to produce an ordinal encoding, then use the `Embedding` layer.\n",
" * For text, the `TextVectorization` layer is easy to use and it can work well for simple tasks, or you can use TF Text for more advanced features. However, you'll often want to use pretrained language models, which you can obtain using tools like TF Hub or Hugging Face's Transformers library. These last two options are discussed in Chapter 16."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-03-20 11:22:47 +01:00
"## 9.\n",
"### a.\n",
2022-02-19 06:19:26 +01:00
"_Exercise: Load the Fashion MNIST dataset (introduced in Chapter 10); split it into a training set, a validation set, and a test set; shuffle the training set; and save each dataset to multiple TFRecord files. Each record should be a serialized `Example` protobuf with two features: the serialized image (use `tf.io.serialize_tensor()` to serialize each image), and the label. Note: for large images, you could use `tf.io.encode_jpeg()` instead. This would save a lot of space, but it would lose a bit of image quality._"
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 114,
2020-03-20 11:22:47 +01:00
"metadata": {},
"outputs": [],
"source": [
2021-10-17 04:04:08 +02:00
"(X_train_full, y_train_full), (X_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()\n",
2020-03-20 11:22:47 +01:00
"X_valid, X_train = X_train_full[:5000], X_train_full[5000:]\n",
"y_valid, y_train = y_train_full[:5000], y_train_full[5000:]"
]
},
{
"cell_type": "code",
2022-02-19 06:19:26 +01:00
"execution_count": 115,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-20 05:38:01 +01:00
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-20 15:27:32.431462: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
],
2020-03-20 11:22:47 +01:00
"source": [
2022-02-20 05:38:01 +01:00
"tf.random.set_seed(42)\n",
"train_set = tf.data.Dataset.from_tensor_slices((X_train, y_train))\n",
"train_set = train_set.shuffle(len(X_train), seed=42)\n",
2020-03-20 11:22:47 +01:00
"valid_set = tf.data.Dataset.from_tensor_slices((X_valid, y_valid))\n",
"test_set = tf.data.Dataset.from_tensor_slices((X_test, y_test))"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 116,
2020-03-20 11:22:47 +01:00
"metadata": {},
"outputs": [],
"source": [
"def create_example(image, label):\n",
" image_data = tf.io.serialize_tensor(image)\n",
" #image_data = tf.io.encode_jpeg(image[..., np.newaxis])\n",
" return Example(\n",
" features=Features(\n",
" feature={\n",
" \"image\": Feature(bytes_list=BytesList(value=[image_data.numpy()])),\n",
" \"label\": Feature(int64_list=Int64List(value=[label])),\n",
" }))"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 117,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"features {\n",
" feature {\n",
" key: \"image\"\n",
" value {\n",
" bytes_list {\n",
" value: \"\\010\\004\\022\\010\\022\\002\\010\\034\\022\\002\\010\\034\\\"\\220\\006\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\001\\000\\000\\rI\\000\\000\\001\\004\\000\\000\\000\\000\\001\\001\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\003\\000$\\210\\177>6\\000\\000\\000\\001\\003\\004\\000\\000\\003\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\006\\000f\\314\\260\\206\\220{\\027\\000\\000\\000\\000\\014\\n\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\233\\354\\317\\262k\\234\\241m@\\027M\\202H\\017\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\001\\000E\\317\\337\\332\\330\\330\\243\\177yz\\222\\215X\\254B\\000\\000\\000\\000\\000\\000\\000\\000\\000\\001\\001\\001\\000\\310\\350\\350\\351\\345\\337\\337\\327\\325\\244\\177{\\304\\345\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\267\\341\\330\\337\\344\\353\\343\\340\\336\\340\\335\\337\\365\\255\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\301\\344\\332\\325\\306\\264\\324\\322\\323\\325\\337\\334\\363\\312\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\001\\003\\000\\014\\333\\334\\324\\332\\300\\251\\343\\320\\332\\340\\324\\342\\305\\3214\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\006\\000c\\364\\336\\334\\332\\313\\306\\335\\327\\325\\336\\334\\365w\\2478\\000\\000\\000\\000\\000\\000\\000\\000\\000\\004\\000\\0007\\354\\344\\346\\344\\360\\350\\325\\332\\337\\352\\331\\331\\321\\\\\\000\\000\\000\\001\\004\\006\\007\\002\\000\\000\\000\\000\\000\\355\\342\\331\\337\\336\\333\\336\\335\\330\\337\\345\\327\\332\\377M\\000\\000\\003\\000\\000\\000\\000\\000\\000\\000>\\221\\314\\344\\317\\325\\335\\332\\320\\323\\332\\340\\337\\333\\327\\340\\364\\237\\000\\000\\000\\000\\000\\022,Rk\\275\\344\\334\\336\\331\\342\\310\\315\\323\\346\\340\\352\\260\\274\\372\\370\\351\\356\\327\\000\\0009\\273\\320\\340\\335\\340\\320\\314\\326\\320\\321\\310\\237\\365\\301\\316\\337\\377\\377\\335\\352\\335\\323\\334\\350\\366\\000\\003\\312\\344\\340\\335\\323\\323\\326\\315\\315\\315\\334\\360P\\226\\377\\345\\335\\274\\232\\277\\322\\314\\321\\336\\344\\341\\000b\\351\\306\\322\\336\\345\\345\\352\\371\\334\\302\\327\\331\\361AIju\\250\\333\\335\\327\\331\\337\\337\\340\\345\\035K\\314\\324\\314\\301\\315\\323\\341\\330\\271\\305\\316\\306\\325\\360\\303\\343\\365\\357\\337\\332\\324\\321\\336\\334\\335\\346C0\\313\\267\\302\\325\\305\\271\\276\\302\\300\\312\\326\\333\\335\\334\\354\\341\\330\\307\\316\\272\\265\\261\\254\\265\\315\\316s\\000z\\333\\301\\263\\253\\267\\304\\314\\322\\325\\317\\323\\322\\310\\304\\302\\277\\303\\277\\306\\300\\260\\234\\247\\261\\322\\\\\\000\\000J\\275\\324\\277\\257\\254\\257\\265\\271\\274\\275\\274\\301\\306\\314\\321\\322\\322\\323\\274\\274\\302\\300\\330\\252\\000\\002\\000\\000\\000B\\310\\336\\355\\357\\362\\366\\363\\364\\335\\334\\301\\277\\263\\266\\266\\265\\260\\246\\250c:\\000\\000\\000\\000\\000\\000\\000\\000\\000(=,H)#\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\\000\"\n",
" }\n",
" }\n",
" }\n",
" feature {\n",
" key: \"label\"\n",
" value {\n",
" int64_list {\n",
" value: 9\n",
" }\n",
" }\n",
" }\n",
"}\n",
"\n"
]
}
],
2020-03-20 11:22:47 +01:00
"source": [
"for image, label in valid_set.take(1):\n",
" print(create_example(image, label))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The following function saves a given dataset to a set of TFRecord files. The examples are written to the files in a round-robin fashion. To do this, we enumerate all the examples using the `dataset.enumerate()` method, and we compute `index % n_shards` to decide which file to write to. We use the standard `contextlib.ExitStack` class to make sure that all writers are properly closed whether or not an I/O error occurs while writing."
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 118,
2020-03-20 11:22:47 +01:00
"metadata": {},
"outputs": [],
"source": [
"from contextlib import ExitStack\n",
"\n",
"def write_tfrecords(name, dataset, n_shards=10):\n",
" paths = [\"{}.tfrecord-{:05d}-of-{:05d}\".format(name, index, n_shards)\n",
" for index in range(n_shards)]\n",
" with ExitStack() as stack:\n",
" writers = [stack.enter_context(tf.io.TFRecordWriter(path))\n",
" for path in paths]\n",
" for index, (image, label) in dataset.enumerate():\n",
" shard = index % n_shards\n",
" example = create_example(image, label)\n",
" writers[shard].write(example.SerializeToString())\n",
" return paths"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 119,
2020-03-20 11:22:47 +01:00
"metadata": {},
"outputs": [],
"source": [
"train_filepaths = write_tfrecords(\"my_fashion_mnist.train\", train_set)\n",
"valid_filepaths = write_tfrecords(\"my_fashion_mnist.valid\", valid_set)\n",
"test_filepaths = write_tfrecords(\"my_fashion_mnist.test\", test_set)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### b.\n",
"_Exercise: Then use tf.data to create an efficient dataset for each set. Finally, use a Keras model to train these datasets, including a preprocessing layer to standardize each input feature. Try to make the input pipeline as efficient as possible, using TensorBoard to visualize profiling data._"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 120,
2020-03-20 11:22:47 +01:00
"metadata": {},
"outputs": [],
"source": [
"def preprocess(tfrecord):\n",
" feature_descriptions = {\n",
" \"image\": tf.io.FixedLenFeature([], tf.string, default_value=\"\"),\n",
" \"label\": tf.io.FixedLenFeature([], tf.int64, default_value=-1)\n",
" }\n",
" example = tf.io.parse_single_example(tfrecord, feature_descriptions)\n",
" image = tf.io.parse_tensor(example[\"image\"], out_type=tf.uint8)\n",
" #image = tf.io.decode_jpeg(example[\"image\"])\n",
" image = tf.reshape(image, shape=[28, 28])\n",
" return image, example[\"label\"]\n",
"\n",
"def mnist_dataset(filepaths, n_read_threads=5, shuffle_buffer_size=None,\n",
" n_parse_threads=5, batch_size=32, cache=True):\n",
" dataset = tf.data.TFRecordDataset(filepaths,\n",
" num_parallel_reads=n_read_threads)\n",
" if cache:\n",
" dataset = dataset.cache()\n",
" if shuffle_buffer_size:\n",
" dataset = dataset.shuffle(shuffle_buffer_size)\n",
" dataset = dataset.map(preprocess, num_parallel_calls=n_parse_threads)\n",
" dataset = dataset.batch(batch_size)\n",
" return dataset.prefetch(1)"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 121,
2020-03-20 11:22:47 +01:00
"metadata": {},
"outputs": [],
"source": [
"train_set = mnist_dataset(train_filepaths, shuffle_buffer_size=60000)\n",
2020-10-06 23:02:03 +02:00
"valid_set = mnist_dataset(valid_filepaths)\n",
"test_set = mnist_dataset(test_filepaths)"
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 122,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
"data": {
2022-02-20 05:38:01 +01:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAABYCAYAAABWMiSwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA79UlEQVR4nO19WWxk2Xned4q170VWcSe72Y3u0fQMe3o22DMae9STUSArhpRIDzYiC34JEsTJS6AkbwaCJC8BAhgx7CgQYAhGHCQQDEWOZUESoNFiSW5A0mhmeqY1pJpNNncWydr35d48cL7Dvw5vsVdWlTz1AQTJWs8995x/+f7lKNu2McQQQwwxRG/g6vcAhhhiiCE+TBgK3SGGGGKIHmIodIcYYogheoih0B1iiCGG6CGGQneIIYYYoocYCt0hhhhiiB5iKHSHGGKIIXqIgRG6SqknlVJvKKXySqnbSql/0u8xDQqUUpeUUjWl1F/0eyyDAKXUqFLq/yqlykqpu0qpf9rvMfUbSqnzSqlvKKWySqldpdSfKKXc/R5Xv6CU8iml/uyD9VFUSv1cKfVb/R4XMCBC94PF8VcAvg5gFMA/B/AXSqnLfR3Y4OBPAfyk34MYIPwpgAaACQCfA/BFpdRT/R1S3/HfAaQBTAG4BuBVAH/QzwH1GW4AGziahxiAPwTwFaXU+X4OChgQoQvgIwCmAfyRbdtt27bfAPAjAJ/v77D6D6XU7wLIAfhOn4cyEFBKhQB8FsAf2rZdsm37hwD+H4ZrZQHAV2zbrtm2vQvgmwA+tIrItu2ybdv/wbbtNdu2Ldu2vw5gFcDz/R7boAhd1eWxp3s9kEGCUioK4D8C+EK/xzJAuAygbdv2snjsbXyIBcwH+G8AflcpFVRKzQD4LRwJ3iEAKKUmcLR23uv3WAZF6L6PI9fo3ymlPEqpf4gjtyDY32H1Hf8JwJ/Ztr3R74EMEMIA8sZjeQCRPoxlkPB9HCmeAoBNAD8F8LV+DmhQoJTyAPhfAP7ctu33+z2egRC6tm03AfxjAP8IwC6OLLuv4GjxfCihlLoG4HUAf9TnoQwaSgCixmNRAMU+jGUgoJRyAfgWgK8CCAFIAkgA+C/9HNcg4IO5+Z84igH86z4PB8AR2TwQsG37HRxZtwAApdSPAfx5/0bUd3wMwHkA60op4MjCG1FKXbFt+7k+jqvfWAbgVkpdsm37lx889gwGwG3sI0YBzAH4E9u26wDqSqkvA/jPAP59X0fWR6ijjfNnOAq4fvID467vUIPS2lEpdRVHG8qFo6jrvwLwkQ8W0YcOSqkgOi26f4sjIfwvbdve78ugBgRKqf8DwAbwz3AUqf8GgJdt2/7QCl6l1B0AXwLwX3GkoL8MoGLb9uf6OrA+Qin1P3C0Pl63bbvU5+FoDAS98AE+D2AHR9zuPwDw8Q+rwAUA27Yrtm3v8gdHbnXtwy5wP8AfAAjgaK38bxwpog+twP0AnwHwCQD7AG4DaAH4N30dUR+hlDoH4F/gSOjuKqVKH/z0XQkNjKU7xBBDDPFhwCBZukMMMcQQf+8xFLpDDDHEED3EUOgOMcQQQ/QQQ6E7xBBDDNFDDIXuEEMMMUQPca/iiDNLbbAsCwDgcrlgWRZqtRpqtRq2t7dRr9dRr9fRbDaxu7uLarWKRqOBdruNXC6HarWKcrmMZrOJRqOBZrOJdruNVquFkZERuN1uTE9P4+LFi2i1WqhWq4hGozh//jwSiQSuXLkCt7vj0p16P3TDmaZ7ZDIZ/PSnP0W73UYoFEIgEMDMzAzcbjfq9TparRby+Twsy8LFixcRi8XOaigPMifAQ87L7u4udnZ29L1899138b3vfU/f76eeegq///u/j1gshvHxcbhcznaCbdvI5/MoFov4yle+guXlZVQqFbTbbbz44ou4ePEikskk4vE4xsfHkUwmH2a4QB/Xim3bYLaRnIdms4mDgwO0220opeB2uzE6OgqPx/M4v/40nNmcWJYFy7IwMjICpRRKpRLy+Ty++c1v4ktf+pLe/0opuFwuNBoN5HI5AIDP54NSCh6PBy6XC16vFy6XC0opKKVg2zaUUnjxxRcxNzeH3/7t38a1a9ceZHinoeuc9KUijYvng0orWJaFer2OarWKXC6n/261Wshms1roNptNFItF1Go1VKtVNJtNtFotLXCbzSY8Hg+8Xi/K5TJKpRJarRYqlQoAIJfLwev1aoE/SGg2myiXyzg8PEQ6nUa73UYwGITP54Nt23C73Wg0Gmi1WigWi2i324hEImg2m4hEIvD5fP2+hAdCu91Gu93WCtS2bb25QqEQXC6Xvtf7+/toNBoAjpU0Xw8AIyMjsCwLhUIBpVIJhUIBtVpNK2CXywXbtlGv11Eul/XaGRkZ6SrEBxncN61WC+VyGfV6Hfl8Xhsd/PF6vQiFQhgZGenziB8dlBftdhuNRkP/UNFQ6Lrdbi1spZDlb/l57XZbC2p+Vi/QtzJgy7L0ZNRqNSwtLSGdTuN73/secrkcyuUyGo0G9vb2UK1WUSqVtOCltrdtG16vFyMjI2g2m6jX6wiFQohEIkgkEtjc3ES1WkU2m0UgEEAqlcKVK1fw9NNPw+v19uW6pcKRi2BjYwN/8zd/g9XVVXz7299GtVrVAsPr9WrNbNu2FhjXrl3D9PQ0Pve5z+GZZ57p+A45v4MA3jNe88HBATKZDJaWlrC8vIxYLIbR0VHE43F84hOfwNLSEr7zne/g5s2b2NzchM/nQyQSgW3b2svJ54/63sTjcbjdbq2wafX9xm/8BhYWFhCLxdBoNLC0tIR8Po8rV67gypUriMfjGBsb6xDgHOOgzJuEuWb29vbwjW98A6VSCeVyGQC0wA2Hw4hEIvj4xz+OycnJfg35ocH1YlmWFqwAkM/ncefOHayuruLOnTsIBoOYnp6Gx+OB2+3WXu7IyAj8fj8AoFKpoNFoIJvNotVqaSVcKBTQbrextraGarWq1xNBK9vlcj3W9TAQvRcsy9KCdX9/H4eHh1roptNp1Ot1FItFbe1I+Hw+eDwe1Ot1NBoNKKXg9/u11Uih6/f7tfU0CAUh0tIHgHK5jNu3b2N1dRUrKyuo1WonNhlvfLvdhtvtRiAQ0Jad+dmDjna7jWaziVqthlKpBJ/Ph1arhWAwiFgshr29Pfj9ftRqNayvr8PtdiMcDqPdbqNQKKDRaCCTyQAAxsbG4PF4UCgUtLVHKy+RSGjLp16vo1QqoVaraUrqVxHtdhu1Wg35fB57e3uoVCp6PdF6r9frqNVqKBQKiEQi8Pv9vzIW72nrt9lsolKpoFKpaMOEwpGgAKZXa1q59JToHdPbarVaZ3pdRN+ErhQobrcbY2Nj2lVutVrw+/1otVpQSqHZbGJ+fh5ut1svnGq1qnnfZrMJy7LQarUwNzeH8+fPa83ldrvh9/uRSqVw9epVXLx4sa+LzxSkRDqdxre+9S1ks1mtOLrB5/PB5XJhc3MT+Xwe2Wy2w10eVCtNIplMIhqNasokEAjA5/MhGAwiFArh0qVL+PSnP4319XX86Ec/gmVZaDQa2oKngAGON2kymYTH48HVq1cxPj6OJ598EqlUSnPhCwsLmJmZwaVLlzA7O9vh7cj74nR/Bgnb29v47ne/i3K5DLfbjfn5eXz0ox9FMBiEUgq1Wg03b95EsVjED37wA3i9Xrz22muYnZ3t99DvCdMjouUuqchm86hvjdfrRavVws7ODlwul6biKDzdbremlZRSSCaTcLvdqFQq2jOm4G232yeEveR/HycGwtJ1uVzw+/3w+/1aSwHQLsLIyAgikUgHR1MqlbR1bFkWPB4P/H4/wuEwRkdHkc1mkc/nNc8TCAQwOjqKWCzWV6FEV5Y8NCmDbDaLnZ0dVCoVbZmZ7yEoXCuVCizLQi6XQy6X0/NHwTuoAhg42jBerxeRSASxWEzzr1SsoVAI8/PzaLVaCIVCeqMB6ODuyHdzjYTDYZw7dw7T09NIJpOIRI7b7Lrdbk0/hUIh/bic60EWuFwrpVIJ6+vrsCwLsVgMiUQCCwsLCIfDAI7c6a2tLbTbbWxsbMCyLG3JSQFmelu
2022-02-19 10:24:54 +01:00
"text/plain": [
"<Figure size 432x288 with 5 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2020-03-20 11:22:47 +01:00
"source": [
"for X, y in train_set.take(1):\n",
" for i in range(5):\n",
" plt.subplot(1, 5, i + 1)\n",
" plt.imshow(X[i].numpy(), cmap=\"binary\")\n",
" plt.axis(\"off\")\n",
" plt.title(str(y[i].numpy()))"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 123,
2020-03-20 11:22:47 +01:00
"metadata": {},
"outputs": [],
"source": [
"tf.random.set_seed(42)\n",
"\n",
2022-02-20 05:38:01 +01:00
"standardization = tf.keras.layers.Normalization(input_shape=[28, 28])\n",
2020-03-20 11:22:47 +01:00
"\n",
"sample_image_batches = train_set.take(100).map(lambda image, label: image)\n",
"sample_images = np.concatenate(list(sample_image_batches.as_numpy_iterator()),\n",
" axis=0).astype(np.float32)\n",
"standardization.adapt(sample_images)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
2020-03-20 11:22:47 +01:00
" standardization,\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.Flatten(),\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(10, activation=\"softmax\")\n",
2020-03-20 11:22:47 +01:00
"])\n",
"model.compile(loss=\"sparse_categorical_crossentropy\",\n",
" optimizer=\"nadam\", metrics=[\"accuracy\"])"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 124,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-19 10:24:54 +01:00
"outputs": [
{
2022-02-20 05:38:01 +01:00
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-20 15:30:49.689831: I tensorflow/core/profiler/lib/profiler_session.cc:110] Profiler session initializing.\n",
"2022-02-20 15:30:49.689858: I tensorflow/core/profiler/lib/profiler_session.cc:125] Profiler session started.\n",
"2022-02-20 15:30:49.691427: I tensorflow/core/profiler/lib/profiler_session.cc:143] Profiler session tear down.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
" 59/Unknown - 1s 3ms/step - loss: 0.9230 - accuracy: 0.6817"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-02-20 15:30:50.428921: I tensorflow/core/profiler/lib/profiler_session.cc:110] Profiler session initializing.\n",
"2022-02-20 15:30:50.428945: I tensorflow/core/profiler/lib/profiler_session.cc:125] Profiler session started.\n",
"2022-02-20 15:30:50.433359: I tensorflow/core/profiler/lib/profiler_session.cc:67] Profiler session collecting data.\n",
"2022-02-20 15:30:50.446608: I tensorflow/core/profiler/lib/profiler_session.cc:143] Profiler session tear down.\n",
"2022-02-20 15:30:50.461272: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: my_logs/run_/20220220_153049/plugins/profile/2022_02_20_15_30_50\n",
"\n",
"2022-02-20 15:30:50.465450: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for trace.json.gz to my_logs/run_/20220220_153049/plugins/profile/2022_02_20_15_30_50/kiwimac.trace.json.gz\n",
"2022-02-20 15:30:50.480245: I tensorflow/core/profiler/rpc/client/save_profile.cc:136] Creating directory: my_logs/run_/20220220_153049/plugins/profile/2022_02_20_15_30_50\n",
"\n",
"2022-02-20 15:30:50.480582: I tensorflow/core/profiler/rpc/client/save_profile.cc:142] Dumped gzipped tool data for memory_profile.json.gz to my_logs/run_/20220220_153049/plugins/profile/2022_02_20_15_30_50/kiwimac.memory_profile.json.gz\n",
"2022-02-20 15:30:50.482034: I tensorflow/core/profiler/rpc/client/capture_profile.cc:251] Creating directory: my_logs/run_/20220220_153049/plugins/profile/2022_02_20_15_30_50\n",
"Dumped tool data for xplane.pb to my_logs/run_/20220220_153049/plugins/profile/2022_02_20_15_30_50/kiwimac.xplane.pb\n",
"Dumped tool data for overview_page.pb to my_logs/run_/20220220_153049/plugins/profile/2022_02_20_15_30_50/kiwimac.overview_page.pb\n",
"Dumped tool data for input_pipeline.pb to my_logs/run_/20220220_153049/plugins/profile/2022_02_20_15_30_50/kiwimac.input_pipeline.pb\n",
"Dumped tool data for tensorflow_stats.pb to my_logs/run_/20220220_153049/plugins/profile/2022_02_20_15_30_50/kiwimac.tensorflow_stats.pb\n",
"Dumped tool data for kernel_stats.pb to my_logs/run_/20220220_153049/plugins/profile/2022_02_20_15_30_50/kiwimac.kernel_stats.pb\n",
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1719/1719 [==============================] - 5s 2ms/step - loss: 0.4437 - accuracy: 0.8402 - val_loss: 0.3649 - val_accuracy: 0.8682\n",
"Epoch 2/5\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.3333 - accuracy: 0.8775 - val_loss: 0.3346 - val_accuracy: 0.8790\n",
"Epoch 3/5\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.2970 - accuracy: 0.8905 - val_loss: 0.3235 - val_accuracy: 0.8866\n",
"Epoch 4/5\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.2723 - accuracy: 0.8995 - val_loss: 0.3308 - val_accuracy: 0.8888\n",
"Epoch 5/5\n",
"1719/1719 [==============================] - 4s 2ms/step - loss: 0.2534 - accuracy: 0.9047 - val_loss: 0.3174 - val_accuracy: 0.8916\n"
2022-02-19 10:24:54 +01:00
]
2022-02-20 05:38:01 +01:00
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fa3e08af370>"
]
},
"execution_count": 124,
"metadata": {},
"output_type": "execute_result"
2022-02-19 10:24:54 +01:00
}
],
2020-03-20 11:22:47 +01:00
"source": [
"from datetime import datetime\n",
2021-10-15 10:46:27 +02:00
"\n",
2022-02-20 05:38:01 +01:00
"logs = Path() / \"my_logs\" / \"run_\" / datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n",
2020-03-20 11:22:47 +01:00
"\n",
"tensorboard_cb = tf.keras.callbacks.TensorBoard(\n",
" log_dir=logs, histogram_freq=1, profile_batch=10)\n",
"\n",
"model.fit(train_set, epochs=5, validation_data=valid_set,\n",
" callbacks=[tensorboard_cb])"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 125,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-20 05:38:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The tensorboard extension is already loaded. To reload it, use:\n",
" %reload_ext tensorboard\n"
]
},
{
"data": {
"text/html": [
"\n",
" <iframe id=\"tensorboard-frame-a8e8524a8e4cf37d\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
" </iframe>\n",
" <script>\n",
" (function() {\n",
" const frame = document.getElementById(\"tensorboard-frame-a8e8524a8e4cf37d\");\n",
" const url = new URL(\"/\", window.location);\n",
" const port = 6007;\n",
" if (port) {\n",
" url.port = port;\n",
" }\n",
" frame.src = url;\n",
" })();\n",
" </script>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2020-03-20 11:22:47 +01:00
"source": [
"%load_ext tensorboard\n",
2022-02-20 05:38:01 +01:00
"%tensorboard --logdir=./my_logs"
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 10.\n",
"_Exercise: In this exercise you will download a dataset, split it, create a `tf.data.Dataset` to load it and preprocess it efficiently, then build and train a binary classification model containing an `Embedding` layer._\n",
"\n",
"### a.\n",
"_Exercise: Download the [Large Movie Review Dataset](https://homl.info/imdb), which contains 50,000 movies reviews from the [Internet Movie Database](https://imdb.com/). The data is organized in two directories, `train` and `test`, each containing a `pos` subdirectory with 12,500 positive reviews and a `neg` subdirectory with 12,500 negative reviews. Each review is stored in a separate text file. There are other files and folders (including preprocessed bag-of-words), but we will ignore them in this exercise._"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 126,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-20 05:38:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"PosixPath('/Users/ageron/.keras/datasets/aclImdb')"
]
},
"execution_count": 126,
"metadata": {},
"output_type": "execute_result"
}
],
2020-03-20 11:22:47 +01:00
"source": [
"from pathlib import Path\n",
"\n",
2022-02-20 05:38:01 +01:00
"root = \"https://ai.stanford.edu/~amaas/data/sentiment/\"\n",
2021-10-15 10:46:27 +02:00
"filename = \"aclImdb_v1.tar.gz\"\n",
2021-10-17 04:04:08 +02:00
"filepath = tf.keras.utils.get_file(filename, root + filename, extract=True)\n",
2021-10-15 10:46:27 +02:00
"path = Path(filepath).with_name(\"aclImdb\")\n",
2020-03-20 11:22:47 +01:00
"path"
]
},
2021-10-15 10:46:27 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's define a `tree()` function to view the structure of the `aclImdb` directory:"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 127,
2021-10-15 10:46:27 +02:00
"metadata": {},
"outputs": [],
"source": [
"def tree(path, level=0, indent=4, max_files=3):\n",
" if level == 0:\n",
" print(f\"{path}/\")\n",
" level += 1\n",
" sub_paths = sorted(path.iterdir())\n",
" sub_dirs = [sub_path for sub_path in sub_paths if sub_path.is_dir()]\n",
" filepaths = [sub_path for sub_path in sub_paths if not sub_path in sub_dirs]\n",
" indent_str = \" \" * indent * level\n",
" for sub_dir in sub_dirs:\n",
" print(f\"{indent_str}{sub_dir.name}/\")\n",
" tree(sub_dir, level + 1, indent)\n",
" for filepath in filepaths[:max_files]:\n",
" print(f\"{indent_str}{filepath.name}\")\n",
" if len(filepaths) > max_files:\n",
" print(f\"{indent_str}...\")"
]
},
2020-03-20 11:22:47 +01:00
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 128,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-20 05:38:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/Users/ageron/.keras/datasets/aclImdb/\n",
" test/\n",
" neg/\n",
" 0_2.txt\n",
" 10000_4.txt\n",
" 10001_1.txt\n",
" ...\n",
" pos/\n",
" 0_10.txt\n",
" 10000_7.txt\n",
" 10001_9.txt\n",
" ...\n",
" labeledBow.feat\n",
" urls_neg.txt\n",
" urls_pos.txt\n",
" train/\n",
" neg/\n",
" 0_3.txt\n",
" 10000_4.txt\n",
" 10001_4.txt\n",
" ...\n",
" pos/\n",
" 0_9.txt\n",
" 10000_8.txt\n",
" 10001_10.txt\n",
" ...\n",
" unsup/\n",
" 0_0.txt\n",
" 10000_0.txt\n",
" 10001_0.txt\n",
" ...\n",
" labeledBow.feat\n",
" unsupBow.feat\n",
" urls_neg.txt\n",
" ...\n",
" README\n",
" imdb.vocab\n",
" imdbEr.txt\n"
]
}
],
2020-03-20 11:22:47 +01:00
"source": [
2021-10-15 10:46:27 +02:00
"tree(path)"
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 129,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-20 05:38:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"(12500, 12500, 12500, 12500)"
]
},
"execution_count": 129,
"metadata": {},
"output_type": "execute_result"
}
],
2020-03-20 11:22:47 +01:00
"source": [
"def review_paths(dirpath):\n",
" return [str(path) for path in dirpath.glob(\"*.txt\")]\n",
"\n",
"train_pos = review_paths(path / \"train\" / \"pos\")\n",
"train_neg = review_paths(path / \"train\" / \"neg\")\n",
"test_valid_pos = review_paths(path / \"test\" / \"pos\")\n",
"test_valid_neg = review_paths(path / \"test\" / \"neg\")\n",
"\n",
"len(train_pos), len(train_neg), len(test_valid_pos), len(test_valid_neg)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### b.\n",
"_Exercise: Split the test set into a validation set (15,000) and a test set (10,000)._"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 130,
2020-03-20 11:22:47 +01:00
"metadata": {},
"outputs": [],
"source": [
"np.random.shuffle(test_valid_pos)\n",
"\n",
"test_pos = test_valid_pos[:5000]\n",
"test_neg = test_valid_neg[:5000]\n",
"valid_pos = test_valid_pos[5000:]\n",
"valid_neg = test_valid_neg[5000:]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### c.\n",
"_Exercise: Use tf.data to create an efficient dataset for each set._"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Since the dataset fits in memory, we can just load all the data using pure Python code and use `tf.data.Dataset.from_tensor_slices()`:"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 131,
2020-03-20 11:22:47 +01:00
"metadata": {},
"outputs": [],
"source": [
"def imdb_dataset(filepaths_positive, filepaths_negative):\n",
" reviews = []\n",
" labels = []\n",
" for filepaths, label in ((filepaths_negative, 0), (filepaths_positive, 1)):\n",
" for filepath in filepaths:\n",
" with open(filepath) as review_file:\n",
" reviews.append(review_file.read())\n",
" labels.append(label)\n",
" return tf.data.Dataset.from_tensor_slices(\n",
" (tf.constant(reviews), tf.constant(labels)))"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 132,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-20 05:38:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(b\"Working with one of the best Shakespeare sources, this film manages to be creditable to it's source, whilst still appealing to a wider audience.<br /><br />Branagh steals the film from under Fishburne's nose, and there's a talented cast on good form.\", shape=(), dtype=string)\n",
"tf.Tensor(0, shape=(), dtype=int32)\n",
"\n",
"tf.Tensor(b'Well...tremors I, the original started off in 1990 and i found the movie quite enjoyable to watch. however, they proceeded to make tremors II and III. Trust me, those movies started going downhill right after they finished the first one, i mean, ass blasters??? Now, only God himself is capable of answering the question \"why in Gods name would they create another one of these dumpster dives of a movie?\" Tremors IV cannot be considered a bad movie, in fact it cannot be even considered an epitome of a bad movie, for it lives up to more than that. As i attempted to sit though it, i noticed that my eyes started to bleed, and i hoped profusely that the little girl from the ring would crawl through the TV and kill me. did they really think that dressing the people who had stared in the other movies up as though they we\\'re from the wild west would make the movie (with the exact same occurrences) any better? honestly, i would never suggest buying this movie, i mean, there are cheaper ways to find things that burn well.', shape=(), dtype=string)\n",
"tf.Tensor(0, shape=(), dtype=int32)\n",
"\n",
"tf.Tensor(b\"Ouch! This one was a bit painful to sit through. It has a cute and amusing premise, but it all goes to hell from there. Matthew Modine is almost always pedestrian and annoying, and he does not disappoint in this one. Deborah Kara Unger and John Neville turned in surprisingly decent performances. Alan Bates and Jennifer Tilly, among others, played it way over the top. I know that's the way the parts were written, and it's hard to blame actors, when the script and director have them do such schlock. If you're going to have outrageous characters, that's OK, but you gotta have good material to make it work. It didn't here. Run away screaming from this movie if at all possible.\", shape=(), dtype=string)\n",
"tf.Tensor(0, shape=(), dtype=int32)\n",
"\n"
]
}
],
2020-03-20 11:22:47 +01:00
"source": [
"for X, y in imdb_dataset(train_pos, train_neg).take(3):\n",
" print(X)\n",
" print(y)\n",
" print()"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 133,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-20 05:38:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"29.7 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n"
]
}
],
2020-03-20 11:22:47 +01:00
"source": [
"%timeit -r1 for X, y in imdb_dataset(train_pos, train_neg).repeat(10): pass"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-02-14 03:02:09 +01:00
"It takes about 17 seconds to load the dataset and go through it 10 times."
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2020-10-06 23:51:42 +02:00
"But let's pretend the dataset does not fit in memory, just to make things more interesting. Luckily, each review fits on just one line (they use `<br />` to indicate line breaks), so we can read the reviews using a `TextLineDataset`. If they didn't we would have to preprocess the input files (e.g., converting them to TFRecords). For very large datasets, it would make sense to use a tool like Apache Beam for that."
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 134,
2020-03-20 11:22:47 +01:00
"metadata": {},
"outputs": [],
"source": [
"def imdb_dataset(filepaths_positive, filepaths_negative, n_read_threads=5):\n",
" dataset_neg = tf.data.TextLineDataset(filepaths_negative,\n",
" num_parallel_reads=n_read_threads)\n",
" dataset_neg = dataset_neg.map(lambda review: (review, 0))\n",
" dataset_pos = tf.data.TextLineDataset(filepaths_positive,\n",
" num_parallel_reads=n_read_threads)\n",
" dataset_pos = dataset_pos.map(lambda review: (review, 1))\n",
" return tf.data.Dataset.concatenate(dataset_pos, dataset_neg)"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 135,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-20 05:38:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"27.7 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n"
]
}
],
2020-03-20 11:22:47 +01:00
"source": [
"%timeit -r1 for X, y in imdb_dataset(train_pos, train_neg).repeat(10): pass"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-02-14 03:02:09 +01:00
"Now it takes about 33 seconds to go through the dataset 10 times. That's much slower, essentially because the dataset is not cached in RAM, so it must be reloaded at each epoch. If you add `.cache()` just before `.repeat(10)`, you will see that this implementation will be about as fast as the previous one."
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 136,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-20 05:38:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"20.6 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)\n"
]
}
],
2020-03-20 11:22:47 +01:00
"source": [
"%timeit -r1 for X, y in imdb_dataset(train_pos, train_neg).cache().repeat(10): pass"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 137,
2020-03-20 11:22:47 +01:00
"metadata": {},
"outputs": [],
"source": [
"batch_size = 32\n",
"\n",
2022-02-20 05:38:01 +01:00
"train_set = imdb_dataset(train_pos, train_neg).shuffle(25000, seed=42)\n",
"train_set = train_set.batch(batch_size).prefetch(1)\n",
2020-03-20 11:22:47 +01:00
"valid_set = imdb_dataset(valid_pos, valid_neg).batch(batch_size).prefetch(1)\n",
"test_set = imdb_dataset(test_pos, test_neg).batch(batch_size).prefetch(1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### d.\n",
2022-02-20 05:38:01 +01:00
"_Exercise: Create a binary classification model, using a `TextVectorization` layer to preprocess each review._"
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-20 05:38:01 +01:00
"Let's create a `TextVectorization` layer and adapt it to the full IMDB training set (if the training set did not fit in RAM, we could just use a smaller sample of the training set by calling `train_set.take(500)`). Let's use TF-IDF for now."
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 138,
2020-03-20 11:22:47 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-20 05:38:01 +01:00
"max_tokens = 1000\n",
"sample_reviews = train_set.map(lambda review, label: review)\n",
"text_vectorization = tf.keras.layers.TextVectorization(\n",
" max_tokens=max_tokens, output_mode=\"tf_idf\")\n",
2020-03-20 11:22:47 +01:00
"text_vectorization.adapt(sample_reviews)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Good! Now let's take a look at the first 10 words in the vocabulary:"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 139,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-20 05:38:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"['[UNK]', 'the', 'and', 'a', 'of', 'to', 'is', 'in', 'it', 'i']"
]
},
"execution_count": 139,
"metadata": {},
"output_type": "execute_result"
}
],
2020-03-20 11:22:47 +01:00
"source": [
2022-02-20 05:38:01 +01:00
"text_vectorization.get_vocabulary()[:10]"
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"These are the most common words in the reviews."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We're ready to train the model!"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 140,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-20 05:38:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"782/782 [==============================] - 4s 4ms/step - loss: 0.4521 - accuracy: 0.8189 - val_loss: 0.3894 - val_accuracy: 0.8419\n",
"Epoch 2/5\n",
"782/782 [==============================] - 4s 4ms/step - loss: 0.3608 - accuracy: 0.8537 - val_loss: 0.7081 - val_accuracy: 0.7643\n",
"Epoch 3/5\n",
"782/782 [==============================] - 4s 4ms/step - loss: 0.3123 - accuracy: 0.8742 - val_loss: 0.3367 - val_accuracy: 0.8569\n",
"Epoch 4/5\n",
"782/782 [==============================] - 4s 4ms/step - loss: 0.2535 - accuracy: 0.8968 - val_loss: 0.5343 - val_accuracy: 0.8040\n",
"Epoch 5/5\n",
"782/782 [==============================] - 4s 4ms/step - loss: 0.1879 - accuracy: 0.9274 - val_loss: 0.3888 - val_accuracy: 0.8439\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fa401b8f9d0>"
]
},
"execution_count": 140,
"metadata": {},
"output_type": "execute_result"
}
],
2020-03-20 11:22:47 +01:00
"source": [
2022-02-20 05:38:01 +01:00
"tf.random.set_seed(42)\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
2020-03-20 11:22:47 +01:00
" text_vectorization,\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(1, activation=\"sigmoid\"),\n",
2020-03-20 11:22:47 +01:00
"])\n",
"model.compile(loss=\"binary_crossentropy\", optimizer=\"nadam\",\n",
" metrics=[\"accuracy\"])\n",
"model.fit(train_set, epochs=5, validation_data=valid_set)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-20 05:38:01 +01:00
"We get about 84.2% accuracy on the validation set after just the first epoch, but after that the model makes no significant progress. We will do better in Chapter 16. For now the point is just to perform efficient preprocessing using `tf.data` and Keras preprocessing layers."
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### e.\n",
2022-02-19 06:19:26 +01:00
"_Exercise: Add an `Embedding` layer and compute the mean embedding for each review, multiplied by the square root of the number of words (see Chapter 16). This rescaled mean embedding can then be passed to the rest of your model._"
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-05-25 11:54:03 +02:00
"To compute the mean embedding for each review, and multiply it by the square root of the number of words in that review, we will need a little function. For each sentence, this function needs to compute $M \\times \\sqrt N$, where $M$ is the mean of all the word embeddings in the sentence (excluding padding tokens), and $N$ is the number of words in the sentence (also excluding padding tokens). We can rewrite $M$ as $\\dfrac{S}{N}$, where $S$ is the sum of all word embeddings (it does not matter whether or not we include the padding tokens in this sum, since their representation is a zero vector). So the function must return $M \\times \\sqrt N = \\dfrac{S}{N} \\times \\sqrt N = \\dfrac{S}{\\sqrt N \\times \\sqrt N} \\times \\sqrt N= \\dfrac{S}{\\sqrt N}$."
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 141,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-20 05:38:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(2, 3), dtype=float32, numpy=\n",
"array([[3.535534 , 4.9497476, 2.1213205],\n",
" [6. , 0. , 0. ]], dtype=float32)>"
]
},
"execution_count": 141,
"metadata": {},
"output_type": "execute_result"
}
],
2020-03-20 11:22:47 +01:00
"source": [
"def compute_mean_embedding(inputs):\n",
" not_pad = tf.math.count_nonzero(inputs, axis=-1)\n",
" n_words = tf.math.count_nonzero(not_pad, axis=-1, keepdims=True) \n",
" sqrt_n_words = tf.math.sqrt(tf.cast(n_words, tf.float32))\n",
2021-05-25 11:54:03 +02:00
" return tf.reduce_sum(inputs, axis=1) / sqrt_n_words\n",
2020-03-20 11:22:47 +01:00
"\n",
"another_example = tf.constant([[[1., 2., 3.], [4., 5., 0.], [0., 0., 0.]],\n",
" [[6., 0., 0.], [0., 0., 0.], [0., 0., 0.]]])\n",
"compute_mean_embedding(another_example)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-05-25 11:54:03 +02:00
"Let's check that this is correct. The first review contains 2 words (the last token is a zero vector, which represents the `<pad>` token). Let's compute the mean embedding for these 2 words, and multiply the result by the square root of 2:"
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 142,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-20 05:38:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[3.535534 , 4.9497476, 2.1213202]], dtype=float32)>"
]
},
"execution_count": 142,
"metadata": {},
"output_type": "execute_result"
}
],
2020-03-20 11:22:47 +01:00
"source": [
2021-05-25 11:54:03 +02:00
"tf.reduce_mean(another_example[0:1, :2], axis=1) * tf.sqrt(2.)"
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-05-25 11:54:03 +02:00
"Looks good! Now let's check the second review, which contains just one word (we ignore the two padding tokens):"
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 143,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-20 05:38:01 +01:00
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[6., 0., 0.]], dtype=float32)>"
]
},
"execution_count": 143,
"metadata": {},
"output_type": "execute_result"
}
],
2021-05-25 11:54:03 +02:00
"source": [
"tf.reduce_mean(another_example[1:2, :1], axis=1) * tf.sqrt(1.)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-20 05:38:01 +01:00
"Perfect. Now we're ready to train our final model. It's the same as before, except we replaced TF-IDF with ordinal encoding (`output_mode=\"int\"`) followed by an `Embedding` layer, followed by a `Lambda` layer that calls the `compute_mean_embedding` layer:"
2021-05-25 11:54:03 +02:00
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 144,
2021-05-25 11:54:03 +02:00
"metadata": {},
"outputs": [],
2020-03-20 11:22:47 +01:00
"source": [
"embedding_size = 20\n",
2022-02-20 05:38:01 +01:00
"tf.random.set_seed(42)\n",
"\n",
"text_vectorization = tf.keras.layers.TextVectorization(\n",
" max_tokens=max_tokens, output_mode=\"int\")\n",
"text_vectorization.adapt(sample_reviews)\n",
2020-03-20 11:22:47 +01:00
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
2020-03-20 11:22:47 +01:00
" text_vectorization,\n",
2022-02-20 05:38:01 +01:00
" tf.keras.layers.Embedding(input_dim=max_tokens,\n",
" output_dim=embedding_size,\n",
" mask_zero=True), # <pad> tokens => zero vectors\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.Lambda(compute_mean_embedding),\n",
" tf.keras.layers.Dense(100, activation=\"relu\"),\n",
" tf.keras.layers.Dense(1, activation=\"sigmoid\"),\n",
2020-03-20 11:22:47 +01:00
"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### f.\n",
"_Exercise: Train the model and see what accuracy you get. Try to optimize your pipelines to make training as fast as possible._"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 145,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-20 05:38:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"782/782 [==============================] - 9s 10ms/step - loss: 0.4758 - accuracy: 0.7675 - val_loss: 0.4153 - val_accuracy: 0.8009\n",
"Epoch 2/5\n",
"782/782 [==============================] - 8s 9ms/step - loss: 0.3438 - accuracy: 0.8537 - val_loss: 0.3814 - val_accuracy: 0.8245\n",
"Epoch 3/5\n",
"782/782 [==============================] - 8s 10ms/step - loss: 0.3244 - accuracy: 0.8618 - val_loss: 0.3341 - val_accuracy: 0.8520\n",
"Epoch 4/5\n",
"782/782 [==============================] - 10s 11ms/step - loss: 0.3153 - accuracy: 0.8666 - val_loss: 0.3122 - val_accuracy: 0.8655\n",
"Epoch 5/5\n",
"782/782 [==============================] - 11s 12ms/step - loss: 0.3135 - accuracy: 0.8676 - val_loss: 0.3119 - val_accuracy: 0.8625\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fa3a0bf9460>"
]
},
"execution_count": 145,
"metadata": {},
"output_type": "execute_result"
}
],
2020-03-20 11:22:47 +01:00
"source": [
2022-02-20 05:38:01 +01:00
"model.compile(loss=\"binary_crossentropy\", optimizer=\"nadam\",\n",
" metrics=[\"accuracy\"])\n",
2020-03-20 11:22:47 +01:00
"model.fit(train_set, epochs=5, validation_data=valid_set)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-20 05:38:01 +01:00
"The model is just marginally better using embeddings (but we will do better in Chapter 16). The pipeline looks fast enough (we optimized it earlier)."
2020-03-20 11:22:47 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### g.\n",
"_Exercise: Use TFDS to load the same dataset more easily: `tfds.load(\"imdb_reviews\")`._"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 146,
2020-03-20 11:22:47 +01:00
"metadata": {},
"outputs": [],
"source": [
"import tensorflow_datasets as tfds\n",
"\n",
"datasets = tfds.load(name=\"imdb_reviews\")\n",
"train_set, test_set = datasets[\"train\"], datasets[\"test\"]"
]
},
{
"cell_type": "code",
2022-02-20 05:38:01 +01:00
"execution_count": 147,
2020-03-20 11:22:47 +01:00
"metadata": {},
2022-02-20 05:38:01 +01:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tf.Tensor(b\"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it.\", shape=(), dtype=string)\n",
"tf.Tensor(0, shape=(), dtype=int64)\n"
]
}
],
2020-03-20 11:22:47 +01:00
"source": [
"for example in train_set.take(1):\n",
" print(example[\"text\"])\n",
" print(example[\"label\"])"
]
},
2019-03-14 02:15:09 +01:00
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
2022-02-19 06:19:26 +01:00
"display_name": "Python 3",
2019-03-14 02:15:09 +01:00
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-10-17 03:27:34 +02:00
"version": "3.8.12"
2019-03-14 02:15:09 +01:00
},
"nav_menu": {
"height": "264px",
"width": "369px"
},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
2020-04-06 09:13:12 +02:00
"nbformat_minor": 4
2019-03-14 02:15:09 +01:00
}