Improve the implementation of the test_set_check() function: faster, supports python 2 and 3, and more fine grain split (32 bits intead of 8)

main
Aurélien Geron 2018-04-03 16:45:53 +02:00
parent a164ffc699
commit 8f6a28e6bc
1 changed files with 147 additions and 122 deletions

View File

@ -219,26 +219,41 @@
"metadata": {},
"outputs": [],
"source": [
"import hashlib\n",
"from zlib import crc32\n",
"\n",
"def test_set_check(identifier, test_ratio, hash):\n",
" return hash(np.int64(identifier)).digest()[-1] < 256 * test_ratio\n",
"def test_set_check(identifier, test_ratio):\n",
" return crc32(np.int64(identifier)) & 0xffffffff < test_ratio * 2**32\n",
"\n",
"def split_train_test_by_id(data, test_ratio, id_column, hash=hashlib.md5):\n",
"def split_train_test_by_id(data, test_ratio, id_column):\n",
" ids = data[id_column]\n",
" in_test_set = ids.apply(lambda id_: test_set_check(id_, test_ratio, hash))\n",
" in_test_set = ids.apply(lambda id_: test_set_check(id_, test_ratio))\n",
" return data.loc[~in_test_set], data.loc[in_test_set]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The implementation of `test_set_check()` above works fine in both Python 2 and Python 3. In earlier releases, the following implementation was proposed, which supported any hash function, but was much slower and did not support Python 2:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# This version supports both Python 2 and Python 3, instead of just Python 3.\n",
"def test_set_check(identifier, test_ratio, hash):\n",
" return bytearray(hash(np.int64(identifier)).digest())[-1] < 256 * test_ratio"
"import hashlib\n",
"\n",
"def test_set_check(identifier, test_ratio, hash=hashlib.md5):\n",
" return hash(np.int64(identifier)).digest()[-1] < 256 * test_ratio"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you want an implementation that supports any hash function and is compatible with both Python 2 and Python 3, here is one:"
]
},
{
@ -246,6 +261,16 @@
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def test_set_check(identifier, test_ratio, hash=hashlib.md5):\n",
" return bytearray(hash(np.int64(identifier)).digest())[-1] < 256 * test_ratio"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"housing_with_id = housing.reset_index() # adds an `index` column\n",
"train_set, test_set = split_train_test_by_id(housing_with_id, 0.2, \"index\")"
@ -253,7 +278,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@ -263,7 +288,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
@ -272,7 +297,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
@ -283,7 +308,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@ -292,7 +317,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
@ -301,7 +326,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
@ -313,7 +338,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
@ -322,7 +347,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
@ -331,7 +356,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
@ -345,7 +370,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
@ -354,7 +379,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
@ -363,7 +388,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
@ -383,7 +408,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
@ -392,7 +417,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
@ -409,7 +434,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
@ -418,7 +443,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
@ -428,7 +453,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
@ -445,7 +470,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
@ -459,7 +484,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
@ -488,7 +513,7 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
@ -497,7 +522,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
@ -506,7 +531,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
@ -521,7 +546,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
@ -533,7 +558,7 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
@ -551,7 +576,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
@ -561,7 +586,7 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
@ -573,7 +598,7 @@
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
@ -589,7 +614,7 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
@ -599,7 +624,7 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
@ -609,7 +634,7 @@
},
{
"cell_type": "code",
"execution_count": 45,
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
@ -618,7 +643,7 @@
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
@ -627,7 +652,7 @@
},
{
"cell_type": "code",
"execution_count": 47,
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
@ -638,7 +663,7 @@
},
{
"cell_type": "code",
"execution_count": 48,
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
@ -656,7 +681,7 @@
},
{
"cell_type": "code",
"execution_count": 49,
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
@ -666,7 +691,7 @@
},
{
"cell_type": "code",
"execution_count": 50,
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
@ -675,7 +700,7 @@
},
{
"cell_type": "code",
"execution_count": 51,
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
@ -691,7 +716,7 @@
},
{
"cell_type": "code",
"execution_count": 52,
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
@ -707,7 +732,7 @@
},
{
"cell_type": "code",
"execution_count": 53,
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
@ -716,7 +741,7 @@
},
{
"cell_type": "code",
"execution_count": 54,
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
@ -726,7 +751,7 @@
},
{
"cell_type": "code",
"execution_count": 55,
"execution_count": 56,
"metadata": {},
"outputs": [],
"source": [
@ -735,7 +760,7 @@
},
{
"cell_type": "code",
"execution_count": 56,
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
@ -744,7 +769,7 @@
},
{
"cell_type": "code",
"execution_count": 57,
"execution_count": 58,
"metadata": {},
"outputs": [],
"source": [
@ -761,7 +786,7 @@
},
{
"cell_type": "code",
"execution_count": 58,
"execution_count": 59,
"metadata": {},
"outputs": [],
"source": [
@ -778,7 +803,7 @@
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
@ -788,7 +813,7 @@
},
{
"cell_type": "code",
"execution_count": 60,
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
@ -811,7 +836,7 @@
},
{
"cell_type": "code",
"execution_count": 61,
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
@ -831,7 +856,7 @@
},
{
"cell_type": "code",
"execution_count": 62,
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
@ -847,7 +872,7 @@
},
{
"cell_type": "code",
"execution_count": 63,
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
@ -1048,7 +1073,7 @@
},
{
"cell_type": "code",
"execution_count": 64,
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
@ -1069,7 +1094,7 @@
},
{
"cell_type": "code",
"execution_count": 65,
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
@ -1085,7 +1110,7 @@
},
{
"cell_type": "code",
"execution_count": 66,
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
@ -1096,7 +1121,7 @@
},
{
"cell_type": "code",
"execution_count": 67,
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
@ -1112,7 +1137,7 @@
},
{
"cell_type": "code",
"execution_count": 68,
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
@ -1142,7 +1167,7 @@
},
{
"cell_type": "code",
"execution_count": 69,
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
@ -1159,7 +1184,7 @@
},
{
"cell_type": "code",
"execution_count": 70,
"execution_count": 71,
"metadata": {},
"outputs": [],
"source": [
@ -1177,7 +1202,7 @@
},
{
"cell_type": "code",
"execution_count": 71,
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
@ -1193,7 +1218,7 @@
},
{
"cell_type": "code",
"execution_count": 72,
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
@ -1219,7 +1244,7 @@
},
{
"cell_type": "code",
"execution_count": 73,
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
@ -1241,7 +1266,7 @@
},
{
"cell_type": "code",
"execution_count": 74,
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
@ -1255,7 +1280,7 @@
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
@ -1265,7 +1290,7 @@
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
@ -1281,7 +1306,7 @@
},
{
"cell_type": "code",
"execution_count": 77,
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
@ -1293,7 +1318,7 @@
},
{
"cell_type": "code",
"execution_count": 78,
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
@ -1314,7 +1339,7 @@
},
{
"cell_type": "code",
"execution_count": 79,
"execution_count": 80,
"metadata": {},
"outputs": [],
"source": [
@ -1323,7 +1348,7 @@
},
{
"cell_type": "code",
"execution_count": 80,
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
@ -1332,7 +1357,7 @@
},
{
"cell_type": "code",
"execution_count": 81,
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
@ -1346,7 +1371,7 @@
},
{
"cell_type": "code",
"execution_count": 82,
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
@ -1358,7 +1383,7 @@
},
{
"cell_type": "code",
"execution_count": 83,
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
@ -1370,7 +1395,7 @@
},
{
"cell_type": "code",
"execution_count": 84,
"execution_count": 85,
"metadata": {},
"outputs": [],
"source": [
@ -1389,7 +1414,7 @@
},
{
"cell_type": "code",
"execution_count": 85,
"execution_count": 86,
"metadata": {},
"outputs": [],
"source": [
@ -1402,7 +1427,7 @@
},
{
"cell_type": "code",
"execution_count": 86,
"execution_count": 87,
"metadata": {},
"outputs": [],
"source": [
@ -1416,7 +1441,7 @@
},
{
"cell_type": "code",
"execution_count": 87,
"execution_count": 88,
"metadata": {},
"outputs": [],
"source": [
@ -1428,7 +1453,7 @@
},
{
"cell_type": "code",
"execution_count": 88,
"execution_count": 89,
"metadata": {},
"outputs": [],
"source": [
@ -1440,7 +1465,7 @@
},
{
"cell_type": "code",
"execution_count": 89,
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
@ -1452,7 +1477,7 @@
},
{
"cell_type": "code",
"execution_count": 90,
"execution_count": 91,
"metadata": {},
"outputs": [],
"source": [
@ -1466,7 +1491,7 @@
},
{
"cell_type": "code",
"execution_count": 91,
"execution_count": 92,
"metadata": {},
"outputs": [],
"source": [
@ -1476,7 +1501,7 @@
},
{
"cell_type": "code",
"execution_count": 92,
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
@ -1492,7 +1517,7 @@
},
{
"cell_type": "code",
"execution_count": 93,
"execution_count": 94,
"metadata": {},
"outputs": [],
"source": [
@ -1521,7 +1546,7 @@
},
{
"cell_type": "code",
"execution_count": 94,
"execution_count": 95,
"metadata": {},
"outputs": [],
"source": [
@ -1530,7 +1555,7 @@
},
{
"cell_type": "code",
"execution_count": 95,
"execution_count": 96,
"metadata": {},
"outputs": [],
"source": [
@ -1546,7 +1571,7 @@
},
{
"cell_type": "code",
"execution_count": 96,
"execution_count": 97,
"metadata": {},
"outputs": [],
"source": [
@ -1557,7 +1582,7 @@
},
{
"cell_type": "code",
"execution_count": 97,
"execution_count": 98,
"metadata": {},
"outputs": [],
"source": [
@ -1566,7 +1591,7 @@
},
{
"cell_type": "code",
"execution_count": 98,
"execution_count": 99,
"metadata": {},
"outputs": [],
"source": [
@ -1586,7 +1611,7 @@
},
{
"cell_type": "code",
"execution_count": 99,
"execution_count": 100,
"metadata": {},
"outputs": [],
"source": [
@ -1597,7 +1622,7 @@
},
{
"cell_type": "code",
"execution_count": 100,
"execution_count": 101,
"metadata": {},
"outputs": [],
"source": [
@ -1607,7 +1632,7 @@
},
{
"cell_type": "code",
"execution_count": 101,
"execution_count": 102,
"metadata": {},
"outputs": [],
"source": [
@ -1620,7 +1645,7 @@
},
{
"cell_type": "code",
"execution_count": 102,
"execution_count": 103,
"metadata": {},
"outputs": [],
"source": [
@ -1638,7 +1663,7 @@
},
{
"cell_type": "code",
"execution_count": 103,
"execution_count": 104,
"metadata": {},
"outputs": [],
"source": [
@ -1661,7 +1686,7 @@
},
{
"cell_type": "code",
"execution_count": 104,
"execution_count": 105,
"metadata": {},
"outputs": [],
"source": [
@ -1683,7 +1708,7 @@
},
{
"cell_type": "code",
"execution_count": 105,
"execution_count": 106,
"metadata": {},
"outputs": [],
"source": [
@ -1692,7 +1717,7 @@
},
{
"cell_type": "code",
"execution_count": 106,
"execution_count": 107,
"metadata": {},
"outputs": [],
"source": [
@ -1711,7 +1736,7 @@
},
{
"cell_type": "code",
"execution_count": 107,
"execution_count": 108,
"metadata": {},
"outputs": [],
"source": [
@ -1749,7 +1774,7 @@
},
{
"cell_type": "code",
"execution_count": 108,
"execution_count": 109,
"metadata": {},
"outputs": [],
"source": [
@ -1775,7 +1800,7 @@
},
{
"cell_type": "code",
"execution_count": 109,
"execution_count": 110,
"metadata": {},
"outputs": [],
"source": [
@ -1793,7 +1818,7 @@
},
{
"cell_type": "code",
"execution_count": 110,
"execution_count": 111,
"metadata": {},
"outputs": [],
"source": [
@ -1823,7 +1848,7 @@
},
{
"cell_type": "code",
"execution_count": 111,
"execution_count": 112,
"metadata": {},
"outputs": [],
"source": [
@ -1856,7 +1881,7 @@
},
{
"cell_type": "code",
"execution_count": 112,
"execution_count": 113,
"metadata": {},
"outputs": [],
"source": [
@ -1874,7 +1899,7 @@
},
{
"cell_type": "code",
"execution_count": 113,
"execution_count": 114,
"metadata": {},
"outputs": [],
"source": [
@ -1897,7 +1922,7 @@
},
{
"cell_type": "code",
"execution_count": 114,
"execution_count": 115,
"metadata": {},
"outputs": [],
"source": [
@ -1922,7 +1947,7 @@
},
{
"cell_type": "code",
"execution_count": 115,
"execution_count": 116,
"metadata": {},
"outputs": [],
"source": [
@ -1961,7 +1986,7 @@
},
{
"cell_type": "code",
"execution_count": 116,
"execution_count": 117,
"metadata": {},
"outputs": [],
"source": [
@ -1997,7 +2022,7 @@
},
{
"cell_type": "code",
"execution_count": 117,
"execution_count": 118,
"metadata": {},
"outputs": [],
"source": [
@ -2013,7 +2038,7 @@
},
{
"cell_type": "code",
"execution_count": 118,
"execution_count": 119,
"metadata": {},
"outputs": [],
"source": [
@ -2023,7 +2048,7 @@
},
{
"cell_type": "code",
"execution_count": 119,
"execution_count": 120,
"metadata": {},
"outputs": [],
"source": [
@ -2039,7 +2064,7 @@
},
{
"cell_type": "code",
"execution_count": 120,
"execution_count": 121,
"metadata": {},
"outputs": [],
"source": [
@ -2055,7 +2080,7 @@
},
{
"cell_type": "code",
"execution_count": 121,
"execution_count": 122,
"metadata": {},
"outputs": [],
"source": [
@ -2067,7 +2092,7 @@
},
{
"cell_type": "code",
"execution_count": 122,
"execution_count": 123,
"metadata": {},
"outputs": [],
"source": [
@ -2083,7 +2108,7 @@
},
{
"cell_type": "code",
"execution_count": 123,
"execution_count": 124,
"metadata": {},
"outputs": [],
"source": [
@ -2099,7 +2124,7 @@
},
{
"cell_type": "code",
"execution_count": 124,
"execution_count": 125,
"metadata": {},
"outputs": [],
"source": [
@ -2129,7 +2154,7 @@
},
{
"cell_type": "code",
"execution_count": 125,
"execution_count": 126,
"metadata": {},
"outputs": [],
"source": [
@ -2142,7 +2167,7 @@
},
{
"cell_type": "code",
"execution_count": 126,
"execution_count": 127,
"metadata": {},
"outputs": [],
"source": [
@ -2158,7 +2183,7 @@
},
{
"cell_type": "code",
"execution_count": 127,
"execution_count": 128,
"metadata": {},
"outputs": [],
"source": [
@ -2192,7 +2217,7 @@
},
{
"cell_type": "code",
"execution_count": 128,
"execution_count": 129,
"metadata": {},
"outputs": [],
"source": [
@ -2208,7 +2233,7 @@
},
{
"cell_type": "code",
"execution_count": 129,
"execution_count": 130,
"metadata": {},
"outputs": [],
"source": [