diff --git a/.gitignore b/.gitignore index 374ffcd..e49febb 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,9 @@ tf_logs/* images/**/*.png images/**/*.dot my_* +person.proto +person.desc +person_pb2.py datasets/flowers datasets/lifesat/lifesat.csv datasets/spam diff --git a/13_loading_and_preprocessing_data.ipynb b/13_loading_and_preprocessing_data.ipynb index f66c58a..ed70265 100644 --- a/13_loading_and_preprocessing_data.ipynb +++ b/13_loading_and_preprocessing_data.ipynb @@ -770,16 +770,40 @@ "### A Brief Intro to Protocol Buffers" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For this section you need to [install protobuf](https://developers.google.com/protocol-buffers/docs/downloads). In general you will not have to do so when using TensorFlow, as it comes with functions to create and parse protocol buffers of type `tf.train.Example`, which are generally sufficient. However, in this section we will learn about protocol buffers by creating our own simple protobuf definition, so we need the protobuf compiler (`protoc`): we will use it to compile the protobuf definition to a Python module that we can then use in our code." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First let's write a simple protobuf definition:" + ] + }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ - "from homl.person_pb2 import Person\n", - "\n", - "person = Person(name=\"Al\", id=123, email=[\"a@b.com\"]) # create a Person\n", - "print(person) # display the Person" + "%%writefile person.proto\n", + "syntax = \"proto3\";\n", + "message Person {\n", + " string name = 1;\n", + " int32 id = 2;\n", + " repeated string email = 3;\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "And let's compile it (the `--descriptor_set_out` and `--include_imports` options are only required for the `tf.io.decode_proto()` example below):" ] }, { @@ -788,7 +812,7 @@ "metadata": {}, "outputs": [], "source": [ - "person.name # read a field" + "!protoc person.proto --python_out=. --descriptor_set_out=person.desc --include_imports" ] }, { @@ -797,7 +821,7 @@ "metadata": {}, "outputs": [], "source": [ - "person.name = \"Alice\" # modify a field" + "!ls person*" ] }, { @@ -806,7 +830,10 @@ "metadata": {}, "outputs": [], "source": [ - "person.email[0] # repeated fields can be accessed like arrays" + "from person_pb2 import Person\n", + "\n", + "person = Person(name=\"Al\", id=123, email=[\"a@b.com\"]) # create a Person\n", + "print(person) # display the Person" ] }, { @@ -815,7 +842,7 @@ "metadata": {}, "outputs": [], "source": [ - "person.email.append(\"c@d.com\") # add an email address" + "person.name # read a field" ] }, { @@ -823,6 +850,33 @@ "execution_count": 50, "metadata": {}, "outputs": [], + "source": [ + "person.name = \"Alice\" # modify a field" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [], + "source": [ + "person.email[0] # repeated fields can be accessed like arrays" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "person.email.append(\"c@d.com\") # add an email address" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], "source": [ "s = person.SerializeToString() # serialize to a byte string\n", "s" @@ -830,7 +884,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 54, "metadata": {}, "outputs": [], "source": [ @@ -840,13 +894,50 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "person == person2 # now they are equal" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Custom protobuf" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In rare cases, you may want to parse a custom protobuf (like the one we just created) in TensorFlow. For this you can use the `tf.io.decode_proto()` function:" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [], + "source": [ + "person_tf = tf.io.decode_proto(\n", + " bytes=s,\n", + " message_type=\"Person\",\n", + " field_names=[\"name\", \"id\", \"email\"],\n", + " output_types=[tf.string, tf.int32, tf.string],\n", + " descriptor_source=\"person.desc\")\n", + "\n", + "person_tf.values" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "For more details, see the [`tf.io.decode_proto()`](https://www.tensorflow.org/api_docs/python/tf/io/decode_proto) documentation." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -885,7 +976,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 57, "metadata": {}, "outputs": [], "source": [ @@ -914,7 +1005,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 58, "metadata": {}, "outputs": [], "source": [ @@ -930,7 +1021,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 59, "metadata": {}, "outputs": [], "source": [ @@ -939,7 +1030,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 60, "metadata": { "scrolled": true }, @@ -950,7 +1041,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 61, "metadata": {}, "outputs": [], "source": [ @@ -959,7 +1050,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 62, "metadata": {}, "outputs": [], "source": [ @@ -968,7 +1059,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 63, "metadata": {}, "outputs": [], "source": [ @@ -984,7 +1075,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 64, "metadata": {}, "outputs": [], "source": [ @@ -999,7 +1090,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 65, "metadata": {}, "outputs": [], "source": [ @@ -1012,7 +1103,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 66, "metadata": {}, "outputs": [], "source": [ @@ -1030,7 +1121,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 67, "metadata": {}, "outputs": [], "source": [ @@ -1039,7 +1130,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 68, "metadata": {}, "outputs": [], "source": [ @@ -1065,7 +1156,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": 69, "metadata": {}, "outputs": [], "source": [ @@ -1076,7 +1167,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": 70, "metadata": {}, "outputs": [], "source": [ @@ -1085,7 +1176,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": 71, "metadata": {}, "outputs": [], "source": [ @@ -1095,7 +1186,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 72, "metadata": {}, "outputs": [], "source": [ @@ -1104,7 +1195,7 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 73, "metadata": {}, "outputs": [], "source": [ @@ -1116,7 +1207,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 74, "metadata": {}, "outputs": [], "source": [ @@ -1148,7 +1239,7 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": 75, "metadata": {}, "outputs": [], "source": [ @@ -1187,7 +1278,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 76, "metadata": {}, "outputs": [], "source": [ @@ -1196,7 +1287,7 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 77, "metadata": {}, "outputs": [], "source": [ @@ -1205,7 +1296,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 78, "metadata": {}, "outputs": [], "source": [ @@ -1225,7 +1316,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 79, "metadata": {}, "outputs": [], "source": [ @@ -1234,7 +1325,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 80, "metadata": {}, "outputs": [], "source": [ @@ -1243,7 +1334,7 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 81, "metadata": {}, "outputs": [], "source": [ @@ -1252,7 +1343,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 82, "metadata": {}, "outputs": [], "source": [ @@ -1275,7 +1366,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 83, "metadata": {}, "outputs": [], "source": [ @@ -1299,7 +1390,7 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 84, "metadata": {}, "outputs": [], "source": [ @@ -1308,7 +1399,7 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 85, "metadata": {}, "outputs": [], "source": [ @@ -1321,7 +1412,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 86, "metadata": {}, "outputs": [], "source": [ @@ -1331,7 +1422,7 @@ }, { "cell_type": "code", - "execution_count": 83, + "execution_count": 87, "metadata": {}, "outputs": [], "source": [ @@ -1340,7 +1431,7 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 88, "metadata": {}, "outputs": [], "source": [ @@ -1351,7 +1442,7 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 89, "metadata": {}, "outputs": [], "source": [ @@ -1362,7 +1453,7 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": 90, "metadata": {}, "outputs": [], "source": [ @@ -1371,7 +1462,7 @@ }, { "cell_type": "code", - "execution_count": 87, + "execution_count": 91, "metadata": {}, "outputs": [], "source": [ @@ -1382,7 +1473,7 @@ }, { "cell_type": "code", - "execution_count": 88, + "execution_count": 92, "metadata": {}, "outputs": [], "source": [ @@ -1391,7 +1482,7 @@ }, { "cell_type": "code", - "execution_count": 89, + "execution_count": 93, "metadata": {}, "outputs": [], "source": [ @@ -1403,7 +1494,7 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 94, "metadata": {}, "outputs": [], "source": [ @@ -1415,7 +1506,7 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": 95, "metadata": {}, "outputs": [], "source": [ @@ -1431,7 +1522,7 @@ }, { "cell_type": "code", - "execution_count": 92, + "execution_count": 96, "metadata": {}, "outputs": [], "source": [ @@ -1440,7 +1531,7 @@ }, { "cell_type": "code", - "execution_count": 93, + "execution_count": 97, "metadata": {}, "outputs": [], "source": [ @@ -1457,7 +1548,7 @@ }, { "cell_type": "code", - "execution_count": 94, + "execution_count": 98, "metadata": {}, "outputs": [], "source": [ @@ -1466,7 +1557,7 @@ }, { "cell_type": "code", - "execution_count": 95, + "execution_count": 99, "metadata": {}, "outputs": [], "source": [ @@ -1477,7 +1568,7 @@ }, { "cell_type": "code", - "execution_count": 96, + "execution_count": 100, "metadata": {}, "outputs": [], "source": [ @@ -1492,7 +1583,7 @@ }, { "cell_type": "code", - "execution_count": 97, + "execution_count": 101, "metadata": {}, "outputs": [], "source": [ @@ -1515,7 +1606,7 @@ }, { "cell_type": "code", - "execution_count": 98, + "execution_count": 102, "metadata": {}, "outputs": [], "source": [ @@ -1532,7 +1623,7 @@ }, { "cell_type": "code", - "execution_count": 99, + "execution_count": 103, "metadata": {}, "outputs": [], "source": [ @@ -1553,7 +1644,7 @@ }, { "cell_type": "code", - "execution_count": 100, + "execution_count": 104, "metadata": {}, "outputs": [], "source": [ @@ -1582,7 +1673,7 @@ }, { "cell_type": "code", - "execution_count": 101, + "execution_count": 105, "metadata": {}, "outputs": [], "source": [ @@ -1594,7 +1685,7 @@ }, { "cell_type": "code", - "execution_count": 102, + "execution_count": 106, "metadata": {}, "outputs": [], "source": [ @@ -1603,7 +1694,7 @@ }, { "cell_type": "code", - "execution_count": 103, + "execution_count": 107, "metadata": {}, "outputs": [], "source": [ @@ -1624,7 +1715,7 @@ }, { "cell_type": "code", - "execution_count": 104, + "execution_count": 108, "metadata": {}, "outputs": [], "source": [ @@ -1640,7 +1731,7 @@ }, { "cell_type": "code", - "execution_count": 105, + "execution_count": 109, "metadata": {}, "outputs": [], "source": [ @@ -1658,7 +1749,7 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 110, "metadata": {}, "outputs": [], "source": [ @@ -1677,7 +1768,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 111, "metadata": {}, "outputs": [], "source": [ @@ -1696,7 +1787,7 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 112, "metadata": {}, "outputs": [], "source": [ @@ -1706,7 +1797,7 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 113, "metadata": {}, "outputs": [], "source": [