Set n_init explicitly when creating KMeans or MiniBatchKMeans, to avoid warning

main
Aurélien Geron 2023-11-14 15:09:10 +13:00
parent 9b2c0e81c8
commit 1cf75d217b
1 changed files with 33 additions and 20 deletions

View File

@ -230,7 +230,7 @@
"mapping = {}\n",
"for class_id in np.unique(y):\n",
" mode, _ = stats.mode(y_pred[y==class_id])\n",
" mapping[mode[0]] = class_id\n",
" mapping[mode] = class_id\n",
"\n",
"y_pred = np.array([mapping[cluster_id] for cluster_id in y_pred])\n",
"\n",
@ -309,10 +309,17 @@
" random_state=7)\n",
"\n",
"k = 5\n",
"kmeans = KMeans(n_clusters=k, random_state=42)\n",
"kmeans = KMeans(n_clusters=k, n_init=10, random_state=42)\n",
"y_pred = kmeans.fit_predict(X)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: Throughout this notebook, when `n_init` was not set when creating a `KMeans` estimator, I explicitly set it to `n_init=10` to avoid a warning about the fact that the default value for this hyperparameter will change from 10 to `\"auto\"` in Scikit-Learn 1.4."
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -1169,10 +1176,17 @@
"source": [
"from sklearn.cluster import MiniBatchKMeans\n",
"\n",
"minibatch_kmeans = MiniBatchKMeans(n_clusters=5, random_state=42)\n",
"minibatch_kmeans = MiniBatchKMeans(n_clusters=5, n_init=3, random_state=42)\n",
"minibatch_kmeans.fit(X)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: Throughout this notebook, when `n_init` was not set when creating a `MiniBatchKMeans` estimator, I explicitly set it to `n_init=3` to avoid a warning about the fact that the default value for this hyperparameter will change from 3 to `\"auto\"` in Scikit-Learn 1.4."
]
},
{
"cell_type": "code",
"execution_count": 31,
@ -1215,7 +1229,7 @@
"source": [
"from sklearn.datasets import fetch_openml\n",
"\n",
"mnist = fetch_openml('mnist_784', as_frame=False)"
"mnist = fetch_openml('mnist_784', as_frame=False, parser=\"auto\")"
]
},
{
@ -1275,7 +1289,7 @@
"from sklearn.cluster import MiniBatchKMeans\n",
"\n",
"minibatch_kmeans = MiniBatchKMeans(n_clusters=10, batch_size=10,\n",
" random_state=42)\n",
" n_init=3, random_state=42)\n",
"minibatch_kmeans.fit(X_memmap)"
]
},
@ -1320,8 +1334,8 @@
"times = np.empty((max_k, 2))\n",
"inertias = np.empty((max_k, 2))\n",
"for k in range(1, max_k + 1):\n",
" kmeans_ = KMeans(n_clusters=k, algorithm=\"full\", random_state=42)\n",
" minibatch_kmeans = MiniBatchKMeans(n_clusters=k, random_state=42)\n",
" kmeans_ = KMeans(n_clusters=k, algorithm=\"lloyd\", n_init=10, random_state=42)\n",
" minibatch_kmeans = MiniBatchKMeans(n_clusters=k, n_init=10, random_state=42)\n",
" print(f\"\\r{k}/{max_k}\", end=\"\") # \\r returns to the start of line\n",
" times[k - 1, 0] = timeit(\"kmeans_.fit(X)\", number=10, globals=globals())\n",
" times[k - 1, 1] = timeit(\"minibatch_kmeans.fit(X)\", number=10,\n",
@ -1387,8 +1401,8 @@
"source": [
"# extra code this cell generates and saves Figure 97\n",
"\n",
"kmeans_k3 = KMeans(n_clusters=3, random_state=42)\n",
"kmeans_k8 = KMeans(n_clusters=8, random_state=42)\n",
"kmeans_k3 = KMeans(n_clusters=3, n_init=10, random_state=42)\n",
"kmeans_k8 = KMeans(n_clusters=8, n_init=10, random_state=42)\n",
"\n",
"plot_clusterer_comparison(kmeans_k3, kmeans_k8, X, \"$k=3$\", \"$k=8$\")\n",
"save_fig(\"bad_n_clusters_plot\")\n",
@ -1470,7 +1484,7 @@
"source": [
"# extra code this cell generates and saves Figure 98\n",
"\n",
"kmeans_per_k = [KMeans(n_clusters=k, random_state=42).fit(X)\n",
"kmeans_per_k = [KMeans(n_clusters=k, n_init=10, random_state=42).fit(X)\n",
" for k in range(1, 10)]\n",
"inertias = [model.inertia_ for model in kmeans_per_k]\n",
"\n",
@ -1724,7 +1738,7 @@
"kmeans_good = KMeans(n_clusters=3,\n",
" init=np.array([[-1.5, 2.5], [0.5, 0], [4, 0]]),\n",
" n_init=1, random_state=42)\n",
"kmeans_bad = KMeans(n_clusters=3, random_state=42)\n",
"kmeans_bad = KMeans(n_clusters=3, n_init=10, random_state=42)\n",
"kmeans_good.fit(X)\n",
"kmeans_bad.fit(X)\n",
"\n",
@ -1805,7 +1819,7 @@
"outputs": [],
"source": [
"X = image.reshape(-1, 3)\n",
"kmeans = KMeans(n_clusters=8, random_state=42).fit(X)\n",
"kmeans = KMeans(n_clusters=8, n_init=10, random_state=42).fit(X)\n",
"segmented_img = kmeans.cluster_centers_[kmeans.labels_]\n",
"segmented_img = segmented_img.reshape(image.shape)"
]
@ -1834,7 +1848,7 @@
"segmented_imgs = []\n",
"n_colors = (10, 8, 6, 4, 2)\n",
"for n_clusters in n_colors:\n",
" kmeans = KMeans(n_clusters=n_clusters, random_state=42).fit(X)\n",
" kmeans = KMeans(n_clusters=n_clusters, n_init=10, random_state=42).fit(X)\n",
" segmented_img = kmeans.cluster_centers_[kmeans.labels_]\n",
" segmented_imgs.append(segmented_img.reshape(image.shape))\n",
"\n",
@ -1978,7 +1992,7 @@
"outputs": [],
"source": [
"k = 50\n",
"kmeans = KMeans(n_clusters=k, random_state=42)\n",
"kmeans = KMeans(n_clusters=k, n_init=10, random_state=42)\n",
"X_digits_dist = kmeans.fit_transform(X_train)\n",
"representative_digit_idx = X_digits_dist.argmin(axis=0)\n",
"X_representative_digits = X_train[representative_digit_idx]"
@ -2623,8 +2637,7 @@
"source": [
"def plot_spectral_clustering(sc, X, size, alpha, show_xlabels=True,\n",
" show_ylabels=True):\n",
" plt.scatter(X[:, 0], X[:, 1], marker='o', s=size, c='gray', cmap=\"Paired\",\n",
" alpha=alpha)\n",
" plt.scatter(X[:, 0], X[:, 1], marker='o', s=size, c='gray', alpha=alpha)\n",
" plt.scatter(X[:, 0], X[:, 1], marker='o', s=30, c='w')\n",
" plt.scatter(X[:, 0], X[:, 1], marker='.', s=10, c=sc.labels_, cmap=\"Paired\")\n",
" \n",
@ -4005,7 +4018,7 @@
"kmeans_per_k = []\n",
"for k in k_range:\n",
" print(f\"k={k}\")\n",
" kmeans = KMeans(n_clusters=k, random_state=42)\n",
" kmeans = KMeans(n_clusters=k, n_init=10, random_state=42)\n",
" kmeans.fit(X_train_pca)\n",
" kmeans_per_k.append(kmeans)"
]
@ -6581,7 +6594,7 @@
"\n",
"for n_clusters in k_range:\n",
" pipeline = make_pipeline(\n",
" KMeans(n_clusters=n_clusters, random_state=42),\n",
" KMeans(n_clusters=n_clusters, n_init=10, random_state=42),\n",
" RandomForestClassifier(n_estimators=150, random_state=42)\n",
" )\n",
" pipeline.fit(X_train_pca, y_train)\n",
@ -6971,7 +6984,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@ -6985,7 +6998,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.13"
}
},
"nbformat": 4,