Use as_frame=False for fetch_open_ml(), and svd_solver=full for PCA, fixes #358

main
Aurélien Geron 2021-03-02 09:19:21 +13:00
parent 9fede98b42
commit 5663779ae8
1 changed files with 18 additions and 11 deletions

View File

@ -761,6 +761,13 @@
"# MNIST compression"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Warning:** since Scikit-Learn 0.24, `fetch_openml()` returns a Pandas `DataFrame` by default. To avoid this and keep the same code as in the book, we set `as_frame=True`."
]
},
{
"cell_type": "code",
"execution_count": 31,
@ -769,7 +776,7 @@
"source": [
"from sklearn.datasets import fetch_openml\n",
"\n",
"mnist = fetch_openml('mnist_784', version=1)\n",
"mnist = fetch_openml('mnist_784', version=1, as_frame=False)\n",
"mnist.target = mnist.target.astype(np.uint8)"
]
},
@ -1101,15 +1108,15 @@
"\n",
"for n_components in (2, 10, 154):\n",
" print(\"n_components =\", n_components)\n",
" regular_pca = PCA(n_components=n_components)\n",
" regular_pca = PCA(n_components=n_components, svd_solver=\"full\")\n",
" inc_pca = IncrementalPCA(n_components=n_components, batch_size=500)\n",
" rnd_pca = PCA(n_components=n_components, random_state=42, svd_solver=\"randomized\")\n",
"\n",
" for pca in (regular_pca, inc_pca, rnd_pca):\n",
" for name, pca in ((\"PCA\", regular_pca), (\"Inc PCA\", inc_pca), (\"Rnd PCA\", rnd_pca)):\n",
" t1 = time.time()\n",
" pca.fit(X_train)\n",
" t2 = time.time()\n",
" print(\" {}: {:.1f} seconds\".format(pca.__class__.__name__, t2 - t1))"
" print(\" {}: {:.1f} seconds\".format(name, t2 - t1))"
]
},
{
@ -1135,7 +1142,7 @@
" pca.fit(X)\n",
" t2 = time.time()\n",
" times_rpca.append(t2 - t1)\n",
" pca = PCA(n_components = 2)\n",
" pca = PCA(n_components=2, svd_solver=\"full\")\n",
" t1 = time.time()\n",
" pca.fit(X)\n",
" t2 = time.time()\n",
@ -1174,7 +1181,7 @@
" pca.fit(X)\n",
" t2 = time.time()\n",
" times_rpca.append(t2 - t1)\n",
" pca = PCA(n_components = 2)\n",
" pca = PCA(n_components=2, svd_solver=\"full\")\n",
" t1 = time.time()\n",
" pca.fit(X)\n",
" t2 = time.time()\n",
@ -2252,7 +2259,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.8"
"version": "3.7.9"
}
},
"nbformat": 4,