Clarify the mini-batch K-Means performance comparison diagram

main
Aurélien Geron 2018-05-07 16:07:15 +02:00
parent 71c40c7aec
commit 8fa23cceaf
1 changed files with 141 additions and 139 deletions

View File

@ -1476,7 +1476,7 @@
},
{
"cell_type": "code",
"execution_count": 73,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
@ -1485,7 +1485,7 @@
},
{
"cell_type": "code",
"execution_count": 74,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@ -1497,7 +1497,7 @@
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@ -1529,7 +1529,7 @@
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@ -1538,7 +1538,7 @@
},
{
"cell_type": "code",
"execution_count": 77,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@ -1549,7 +1549,7 @@
},
{
"cell_type": "code",
"execution_count": 78,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@ -1564,7 +1564,7 @@
},
{
"cell_type": "code",
"execution_count": 79,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@ -1573,7 +1573,7 @@
},
{
"cell_type": "code",
"execution_count": 80,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@ -1596,7 +1596,7 @@
},
{
"cell_type": "code",
"execution_count": 81,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@ -1605,7 +1605,7 @@
},
{
"cell_type": "code",
"execution_count": 82,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@ -1620,7 +1620,7 @@
},
{
"cell_type": "code",
"execution_count": 83,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@ -1637,7 +1637,7 @@
},
{
"cell_type": "code",
"execution_count": 84,
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
@ -1649,7 +1649,7 @@
},
{
"cell_type": "code",
"execution_count": 85,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@ -1675,7 +1675,7 @@
},
{
"cell_type": "code",
"execution_count": 86,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
@ -1684,7 +1684,7 @@
},
{
"cell_type": "code",
"execution_count": 87,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
@ -1702,7 +1702,7 @@
},
{
"cell_type": "code",
"execution_count": 88,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@ -1711,7 +1711,7 @@
},
{
"cell_type": "code",
"execution_count": 89,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
@ -1727,7 +1727,7 @@
},
{
"cell_type": "code",
"execution_count": 90,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
@ -1743,7 +1743,7 @@
},
{
"cell_type": "code",
"execution_count": 91,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
@ -1759,7 +1759,7 @@
},
{
"cell_type": "code",
"execution_count": 92,
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
@ -1783,7 +1783,7 @@
},
{
"cell_type": "code",
"execution_count": 93,
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
@ -1829,7 +1829,7 @@
},
{
"cell_type": "code",
"execution_count": 94,
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
@ -1862,7 +1862,7 @@
},
{
"cell_type": "code",
"execution_count": 95,
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
@ -1878,7 +1878,7 @@
},
{
"cell_type": "code",
"execution_count": 96,
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
@ -1919,7 +1919,7 @@
},
{
"cell_type": "code",
"execution_count": 97,
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
@ -1943,7 +1943,7 @@
},
{
"cell_type": "code",
"execution_count": 98,
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
@ -1996,7 +1996,7 @@
},
{
"cell_type": "code",
"execution_count": 99,
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
@ -2019,7 +2019,7 @@
},
{
"cell_type": "code",
"execution_count": 100,
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
@ -2051,7 +2051,7 @@
},
{
"cell_type": "code",
"execution_count": 101,
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
@ -2067,7 +2067,7 @@
},
{
"cell_type": "code",
"execution_count": 102,
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
@ -2084,7 +2084,7 @@
},
{
"cell_type": "code",
"execution_count": 103,
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
@ -2107,7 +2107,7 @@
},
{
"cell_type": "code",
"execution_count": 104,
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
@ -2116,7 +2116,7 @@
},
{
"cell_type": "code",
"execution_count": 105,
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
@ -2139,7 +2139,7 @@
},
{
"cell_type": "code",
"execution_count": 106,
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
@ -2157,7 +2157,7 @@
},
{
"cell_type": "code",
"execution_count": 107,
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
@ -2199,7 +2199,7 @@
},
{
"cell_type": "code",
"execution_count": 108,
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
@ -2208,7 +2208,7 @@
},
{
"cell_type": "code",
"execution_count": 109,
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
@ -2241,7 +2241,7 @@
},
{
"cell_type": "code",
"execution_count": 110,
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
@ -2250,7 +2250,7 @@
},
{
"cell_type": "code",
"execution_count": 111,
"execution_count": 43,
"metadata": {
"scrolled": true
},
@ -2275,7 +2275,7 @@
},
{
"cell_type": "code",
"execution_count": 112,
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
@ -2284,7 +2284,7 @@
},
{
"cell_type": "code",
"execution_count": 113,
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
@ -2294,7 +2294,7 @@
},
{
"cell_type": "code",
"execution_count": 114,
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
@ -2310,7 +2310,7 @@
},
{
"cell_type": "code",
"execution_count": 115,
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
@ -2321,7 +2321,7 @@
},
{
"cell_type": "code",
"execution_count": 116,
"execution_count": 48,
"metadata": {
"scrolled": false
},
@ -2340,7 +2340,7 @@
},
{
"cell_type": "code",
"execution_count": 117,
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
@ -2357,7 +2357,7 @@
},
{
"cell_type": "code",
"execution_count": 118,
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
@ -2366,7 +2366,7 @@
},
{
"cell_type": "code",
"execution_count": 119,
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
@ -2398,7 +2398,7 @@
},
{
"cell_type": "code",
"execution_count": 120,
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
@ -2414,7 +2414,7 @@
},
{
"cell_type": "code",
"execution_count": 121,
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
@ -2423,7 +2423,7 @@
},
{
"cell_type": "code",
"execution_count": 122,
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
@ -2439,7 +2439,7 @@
},
{
"cell_type": "code",
"execution_count": 123,
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
@ -2448,62 +2448,47 @@
},
{
"cell_type": "code",
"execution_count": 124,
"execution_count": 80,
"metadata": {},
"outputs": [],
"source": [
"inertia_ratios = []\n",
"time_ratios = []\n",
"for k in range(1, 100):\n",
"times = np.empty((100, 2))\n",
"inertias = np.empty((100, 2))\n",
"for k in range(1, 101):\n",
" kmeans = KMeans(n_clusters=k, random_state=42)\n",
" minibatch_kmeans = MiniBatchKMeans(n_clusters=k, random_state=42)\n",
" print(\"\\r{}/{}\".format(k + 1, 100), end=\"\")\n",
" time_kmeans = timeit(\"kmeans.fit(X)\", number=10, globals=globals())\n",
" time_minibatch_kmeans = timeit(\"minibatch_kmeans.fit(X)\", number=10, globals=globals())\n",
" inertia_ratios.append(minibatch_kmeans.inertia_ / kmeans.inertia_)\n",
" time_ratios.append(time_minibatch_kmeans / time_kmeans)"
" print(\"\\r{}/{}\".format(k, 100), end=\"\")\n",
" times[k-1, 0] = timeit(\"kmeans.fit(X)\", number=10, globals=globals())\n",
" times[k-1, 1] = timeit(\"minibatch_kmeans.fit(X)\", number=10, globals=globals())\n",
" inertias[k-1, 0] = kmeans.inertia_\n",
" inertias[k-1, 1] = minibatch_kmeans.inertia_"
]
},
{
"cell_type": "code",
"execution_count": 125,
"execution_count": 103,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.linear_model import LinearRegression"
]
},
{
"cell_type": "code",
"execution_count": 126,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"plt.figure(figsize=(10,4))\n",
"\n",
"lin_reg_inertia = LinearRegression()\n",
"lin_reg_time = LinearRegression()\n",
"lin_reg_inertia.fit(np.arange(1, 100).reshape(-1, 1), inertia_ratios)\n",
"lin_reg_time.fit(np.arange(1, 100).reshape(-1, 1), time_ratios)\n",
"\n",
"plt.subplot(121)\n",
"plt.plot(range(1, 100), inertia_ratios, \"bo\")\n",
"plt.plot([0, 100], [1.0, 1.0], \"r--\")\n",
"plt.plot([0, 100], lin_reg_inertia.predict([[0], [100]]), \"k-\")\n",
"plt.plot(range(1, 101), inertias[:, 0], \"r--\", label=\"K-Means\")\n",
"plt.plot(range(1, 101), inertias[:, 1], \"b.-\", label=\"Mini-batch K-Means\")\n",
"plt.xlabel(\"$k$\", fontsize=16)\n",
"plt.ylabel(\"Minibatch / K-Means Ratio\", fontsize=14)\n",
"plt.title(\"Inertia Ratio\", fontsize=14)\n",
"plt.axis([1, 99, np.min(inertia_ratios) * 0.98, np.max(inertia_ratios) * 1.02])\n",
"#plt.ylabel(\"Inertia\", fontsize=14)\n",
"plt.title(\"Inertia\", fontsize=14)\n",
"plt.legend(fontsize=14)\n",
"plt.axis([1, 100, 0, 100])\n",
"\n",
"plt.subplot(122)\n",
"plt.plot(range(1, 100), time_ratios, \"bo\")\n",
"plt.plot([0, 100], [1.0, 1.0], \"r--\")\n",
"plt.plot([0, 100], lin_reg_time.predict([[0], [100]]), \"k-\")\n",
"plt.plot(range(1, 101), times[:, 0], \"r--\", label=\"K-Means\")\n",
"plt.plot(range(1, 101), times[:, 1], \"b.-\", label=\"Mini-batch K-Means\")\n",
"plt.xlabel(\"$k$\", fontsize=16)\n",
"plt.title(\"Training Time Ratio\", fontsize=14)\n",
"plt.axis([2, 99, np.min(time_ratios) * 0.98, np.max(time_ratios) * 1.02])\n",
"#plt.ylabel(\"Training time (seconds)\", fontsize=14)\n",
"plt.title(\"Training time (seconds)\", fontsize=14)\n",
"plt.axis([1, 100, 0, 6])\n",
"#plt.legend(fontsize=14)\n",
"\n",
"save_fig(\"minibatch_kmeans_vs_kmeans\")\n",
"plt.show()"
@ -2747,7 +2732,7 @@
},
{
"cell_type": "code",
"execution_count": 138,
"execution_count": 118,
"metadata": {},
"outputs": [],
"source": [
@ -2761,7 +2746,7 @@
},
{
"cell_type": "code",
"execution_count": 139,
"execution_count": 119,
"metadata": {},
"outputs": [],
"source": [
@ -3324,7 +3309,7 @@
},
{
"cell_type": "code",
"execution_count": 176,
"execution_count": 106,
"metadata": {},
"outputs": [],
"source": [
@ -3333,7 +3318,7 @@
},
{
"cell_type": "code",
"execution_count": 177,
"execution_count": 107,
"metadata": {},
"outputs": [],
"source": [
@ -3342,7 +3327,7 @@
},
{
"cell_type": "code",
"execution_count": 178,
"execution_count": 108,
"metadata": {},
"outputs": [],
"source": [
@ -3351,17 +3336,17 @@
},
{
"cell_type": "code",
"execution_count": 179,
"execution_count": 112,
"metadata": {},
"outputs": [],
"source": [
"dbscan = DBSCAN(eps=0.05)\n",
"dbscan = DBSCAN(eps=0.05, min_samples=5)\n",
"dbscan.fit(X)"
]
},
{
"cell_type": "code",
"execution_count": 180,
"execution_count": 113,
"metadata": {},
"outputs": [],
"source": [
@ -3370,7 +3355,7 @@
},
{
"cell_type": "code",
"execution_count": 181,
"execution_count": 114,
"metadata": {},
"outputs": [],
"source": [
@ -3379,7 +3364,7 @@
},
{
"cell_type": "code",
"execution_count": 182,
"execution_count": 115,
"metadata": {},
"outputs": [],
"source": [
@ -3388,7 +3373,7 @@
},
{
"cell_type": "code",
"execution_count": 183,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
@ -3397,7 +3382,7 @@
},
{
"cell_type": "code",
"execution_count": 184,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@ -3406,7 +3391,7 @@
},
{
"cell_type": "code",
"execution_count": 185,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@ -3416,7 +3401,7 @@
},
{
"cell_type": "code",
"execution_count": 186,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@ -3444,22 +3429,22 @@
" plt.ylabel(\"$x_2$\", fontsize=14, rotation=0)\n",
" else:\n",
" plt.tick_params(labelleft='off')\n",
" plt.title(\"eps={:.2f}\".format(dbscan.eps), fontsize=14)"
" plt.title(\"eps={:.2f}, min_samples={}\".format(dbscan.eps, dbscan.min_samples), fontsize=14)"
]
},
{
"cell_type": "code",
"execution_count": 187,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(9, 3.2))\n",
"\n",
"plt.subplot(121)\n",
"plot_dbscan(dbscan, X, size=600)\n",
"plot_dbscan(dbscan, X, size=100)\n",
"\n",
"plt.subplot(122)\n",
"plot_dbscan(dbscan2, X, size=100, show_ylabels=False)\n",
"plot_dbscan(dbscan2, X, size=600, show_ylabels=False)\n",
"\n",
"save_fig(\"dbscan_diagram\")\n",
"plt.show()\n"
@ -3683,7 +3668,7 @@
},
{
"cell_type": "code",
"execution_count": 205,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
@ -3704,7 +3689,7 @@
},
{
"cell_type": "code",
"execution_count": 206,
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
@ -3713,7 +3698,7 @@
},
{
"cell_type": "code",
"execution_count": 207,
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
@ -3730,7 +3715,7 @@
},
{
"cell_type": "code",
"execution_count": 208,
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
@ -3739,7 +3724,7 @@
},
{
"cell_type": "code",
"execution_count": 209,
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
@ -3748,7 +3733,7 @@
},
{
"cell_type": "code",
"execution_count": 210,
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
@ -3764,7 +3749,7 @@
},
{
"cell_type": "code",
"execution_count": 211,
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
@ -3780,7 +3765,7 @@
},
{
"cell_type": "code",
"execution_count": 212,
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
@ -3796,7 +3781,7 @@
},
{
"cell_type": "code",
"execution_count": 213,
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
@ -3805,7 +3790,7 @@
},
{
"cell_type": "code",
"execution_count": 214,
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
@ -3821,7 +3806,7 @@
},
{
"cell_type": "code",
"execution_count": 215,
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
@ -3831,7 +3816,7 @@
},
{
"cell_type": "code",
"execution_count": 216,
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
@ -3854,7 +3839,7 @@
},
{
"cell_type": "code",
"execution_count": 217,
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
@ -3870,7 +3855,7 @@
},
{
"cell_type": "code",
"execution_count": 218,
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
@ -3893,7 +3878,7 @@
},
{
"cell_type": "code",
"execution_count": 219,
"execution_count": 178,
"metadata": {},
"outputs": [],
"source": [
@ -3932,7 +3917,7 @@
},
{
"cell_type": "code",
"execution_count": 220,
"execution_count": 155,
"metadata": {},
"outputs": [],
"source": [
@ -4028,7 +4013,7 @@
},
{
"cell_type": "code",
"execution_count": 225,
"execution_count": 94,
"metadata": {},
"outputs": [],
"source": [
@ -4039,7 +4024,7 @@
},
{
"cell_type": "code",
"execution_count": 226,
"execution_count": 100,
"metadata": {},
"outputs": [],
"source": [
@ -4047,8 +4032,9 @@
"\n",
"plot_gaussian_mixture(gm, X)\n",
"plt.scatter(anomalies[:, 0], anomalies[:, 1], color='r', marker='*')\n",
"plt.ylim(ymax=5.1)\n",
"\n",
"save_fig(\"anomaly_detection_diagram\")\n",
"save_fig(\"mixture_anomaly_detection_diagram\")\n",
"plt.show()"
]
},
@ -4254,7 +4240,7 @@
},
{
"cell_type": "code",
"execution_count": 238,
"execution_count": 157,
"metadata": {},
"outputs": [],
"source": [
@ -4263,7 +4249,7 @@
},
{
"cell_type": "code",
"execution_count": 239,
"execution_count": 158,
"metadata": {},
"outputs": [],
"source": [
@ -4280,7 +4266,7 @@
},
{
"cell_type": "code",
"execution_count": 240,
"execution_count": 159,
"metadata": {},
"outputs": [],
"source": [
@ -4289,7 +4275,7 @@
},
{
"cell_type": "code",
"execution_count": 241,
"execution_count": 160,
"metadata": {},
"outputs": [],
"source": [
@ -4300,7 +4286,7 @@
},
{
"cell_type": "code",
"execution_count": 242,
"execution_count": 161,
"metadata": {},
"outputs": [],
"source": [
@ -4315,7 +4301,7 @@
},
{
"cell_type": "code",
"execution_count": 243,
"execution_count": 162,
"metadata": {},
"outputs": [],
"source": [
@ -4324,7 +4310,7 @@
},
{
"cell_type": "code",
"execution_count": 244,
"execution_count": 163,
"metadata": {},
"outputs": [],
"source": [
@ -4333,7 +4319,7 @@
},
{
"cell_type": "code",
"execution_count": 245,
"execution_count": 179,
"metadata": {},
"outputs": [],
"source": [
@ -4351,33 +4337,49 @@
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: the fact that you see only 3 regions in the right plot although there are 4 centroids is not a bug: the weight of the top-right cluster is much larger than the weight of the lower-right cluster, so the probability that any given point in this region belongs to the top right cluster is greater than the probability that it belongs to the lower-right cluster."
]
},
{
"cell_type": "code",
"execution_count": 246,
"execution_count": 167,
"metadata": {},
"outputs": [],
"source": [
"X_moons, y_moons = make_moons(n_samples=1000, noise=0.05, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 168,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"bgm = BayesianGaussianMixture(n_components=10, n_init=10, random_state=42)\n",
"bgm.fit(X)"
"bgm.fit(X_moons)"
]
},
{
"cell_type": "code",
"execution_count": 247,
"execution_count": 169,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(9, 3.2))\n",
"\n",
"plt.subplot(121)\n",
"plot_data(X)\n",
"plot_data(X_moons)\n",
"plt.xlabel(\"$x_1$\", fontsize=14)\n",
"plt.ylabel(\"$x_2$\", fontsize=14, rotation=0)\n",
"\n",
"plt.subplot(122)\n",
"plot_gaussian_mixture(bgm, X, show_ylabels=False)\n",
"plot_gaussian_mixture(bgm, X_moons, show_ylabels=False)\n",
"\n",
"save_fig(\"moons_vs_bgm_diagram\")\n",
"plt.show()"
@ -5284,7 +5286,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
"version": "3.6.5"
}
},
"nbformat": 4,