Make notebook code match book examples more closely in chapter 3
parent
291ae3e39d
commit
c226f328d3
|
@ -93,47 +93,14 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from six.moves import urllib\n",
|
||||
"from sklearn.datasets import fetch_mldata\n",
|
||||
"try:\n",
|
||||
"mnist = fetch_mldata('MNIST original')\n",
|
||||
"except urllib.error.HTTPError as ex:\n",
|
||||
" print(\"Could not download MNIST data from mldata.org, trying alternative...\")\n",
|
||||
"\n",
|
||||
" # Alternative method to load MNIST, if mldata.org is down\n",
|
||||
" from scipy.io import loadmat\n",
|
||||
" mnist_alternative_url = \"https://github.com/amplab/datascience-sp14/raw/master/lab7/mldata/mnist-original.mat\"\n",
|
||||
" mnist_path = \"./mnist-original.mat\"\n",
|
||||
" response = urllib.request.urlopen(mnist_alternative_url)\n",
|
||||
" with open(mnist_path, \"wb\") as f:\n",
|
||||
" content = response.read()\n",
|
||||
" f.write(content)\n",
|
||||
" mnist_raw = loadmat(mnist_path)\n",
|
||||
" mnist = {\n",
|
||||
" \"data\": mnist_raw[\"data\"].T,\n",
|
||||
" \"target\": mnist_raw[\"label\"][0],\n",
|
||||
" \"COL_NAMES\": [\"label\", \"data\"],\n",
|
||||
" \"DESCR\": \"mldata.org dataset: mnist-original\",\n",
|
||||
" }\n",
|
||||
" print(\"Success!\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
"editable": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mnist"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -147,7 +114,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -160,7 +127,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -173,25 +140,43 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
"editable": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"import matplotlib\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"some_digit = X[36000]\n",
|
||||
"some_digit_image = some_digit.reshape(28, 28)\n",
|
||||
"plt.imshow(some_digit_image, cmap = matplotlib.cm.binary,\n",
|
||||
" interpolation=\"nearest\")\n",
|
||||
"plt.axis(\"off\")\n",
|
||||
"\n",
|
||||
"save_fig(\"some_digit_plot\")\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"deletable": true,
|
||||
"editable": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def plot_digit(data):\n",
|
||||
" image = data.reshape(28, 28)\n",
|
||||
" plt.imshow(image, cmap = matplotlib.cm.binary,\n",
|
||||
" interpolation=\"nearest\")\n",
|
||||
" plt.axis(\"off\")\n",
|
||||
"\n",
|
||||
"some_digit_index = 36000\n",
|
||||
"some_digit = X[some_digit_index]\n",
|
||||
"plot_digit(some_digit)\n",
|
||||
"save_fig(\"some_digit_plot\")\n",
|
||||
"plt.show()"
|
||||
" plt.axis(\"off\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -218,13 +203,7 @@
|
|||
" row_images.append(np.concatenate(rimages, axis=1))\n",
|
||||
" image = np.concatenate(row_images, axis=0)\n",
|
||||
" plt.imshow(image, cmap = matplotlib.cm.binary, **options)\n",
|
||||
" plt.axis(\"off\")\n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(9,9))\n",
|
||||
"example_images = np.r_[X[:12000:600], X[13000:30600:600], X[30600:60000:590]]\n",
|
||||
"plot_digits(example_images, images_per_row=10)\n",
|
||||
"save_fig(\"more_digits_plot\")\n",
|
||||
"plt.show()"
|
||||
" plt.axis(\"off\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -237,20 +216,24 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y[some_digit_index]"
|
||||
"plt.figure(figsize=(9,9))\n",
|
||||
"example_images = np.r_[X[:12000:600], X[13000:30600:600], X[30600:60000:590]]\n",
|
||||
"plot_digits(example_images, images_per_row=10)\n",
|
||||
"save_fig(\"more_digits_plot\")\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
"editable": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]"
|
||||
"y[36000]"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -263,7 +246,22 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"shuffle_index = rnd.permutation(60000)\n",
|
||||
"X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"deletable": true,
|
||||
"editable": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"shuffle_index = np.random.permutation(60000)\n",
|
||||
"X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]"
|
||||
]
|
||||
},
|
||||
|
@ -279,7 +277,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"deletable": true,
|
||||
|
@ -293,7 +291,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -309,7 +307,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -322,7 +320,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -336,7 +334,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -364,7 +362,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 18,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"deletable": true,
|
||||
|
@ -382,7 +380,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 19,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -396,7 +394,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 20,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -411,7 +409,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 21,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -426,7 +424,33 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 22,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"deletable": true,
|
||||
"editable": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y_train_perfect_predictions = y_train_5"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
"editable": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"confusion_matrix(y_train_5, y_train_perfect_predictions)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -441,7 +465,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 25,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -454,7 +478,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 26,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -467,7 +491,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 27,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -480,7 +504,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 28,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -494,7 +518,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 29,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -507,7 +531,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 30,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -521,7 +545,21 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 31,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"deletable": true,
|
||||
"editable": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"threshold = 0\n",
|
||||
"y_some_digit_pred = (y_scores > threshold)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -529,14 +567,12 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"threshold = 0\n",
|
||||
"y_some_digit_pred = (y_scores > threshold)\n",
|
||||
"y_some_digit_pred"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -551,7 +587,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 34,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"deletable": true,
|
||||
|
@ -559,12 +595,13 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method=\"decision_function\")"
|
||||
"y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,\n",
|
||||
" method=\"decision_function\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 35,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"deletable": true,
|
||||
|
@ -579,7 +616,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 36,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -591,7 +628,7 @@
|
|||
" plt.plot(thresholds, precisions[:-1], \"b--\", label=\"Precision\", linewidth=2)\n",
|
||||
" plt.plot(thresholds, recalls[:-1], \"g-\", label=\"Recall\", linewidth=2)\n",
|
||||
" plt.xlabel(\"Threshold\", fontsize=16)\n",
|
||||
" plt.legend(loc=\"center left\", fontsize=16)\n",
|
||||
" plt.legend(loc=\"upper left\", fontsize=16)\n",
|
||||
" plt.ylim([0, 1])\n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(8, 4))\n",
|
||||
|
@ -603,7 +640,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 37,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -616,7 +653,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 38,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -629,7 +666,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 39,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -642,7 +679,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 40,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -655,7 +692,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 41,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -687,7 +724,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"execution_count": 42,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -702,7 +739,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"execution_count": 43,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -710,8 +747,8 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def plot_roc_curve(fpr, tpr, **options):\n",
|
||||
" plt.plot(fpr, tpr, linewidth=2, **options)\n",
|
||||
"def plot_roc_curve(fpr, tpr, label=None):\n",
|
||||
" plt.plot(fpr, tpr, linewidth=2, label=label)\n",
|
||||
" plt.plot([0, 1], [0, 1], 'k--')\n",
|
||||
" plt.axis([0, 1, 0, 1])\n",
|
||||
" plt.xlabel('False Positive Rate', fontsize=16)\n",
|
||||
|
@ -725,7 +762,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"execution_count": 44,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -740,7 +777,33 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"execution_count": 45,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.ensemble import RandomForestClassifier\n",
|
||||
"forest_clf = RandomForestClassifier(random_state=42)\n",
|
||||
"y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,\n",
|
||||
" method=\"predict_proba\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n",
|
||||
"fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,y_scores_forest)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 51,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -748,15 +811,9 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.ensemble import RandomForestClassifier\n",
|
||||
"forest_clf = RandomForestClassifier(random_state=42)\n",
|
||||
"y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n",
|
||||
"y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n",
|
||||
"fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)\n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(8, 6))\n",
|
||||
"plt.plot(fpr, tpr, \"b:\", linewidth=2, label=\"SGD\")\n",
|
||||
"plot_roc_curve(fpr_forest, tpr_forest, label=\"Random Forest\")\n",
|
||||
"plot_roc_curve(fpr_forest, tpr_forest, \"Random Forest\")\n",
|
||||
"plt.legend(loc=\"lower right\", fontsize=16)\n",
|
||||
"save_fig(\"roc_curve_comparison_plot\")\n",
|
||||
"plt.show()"
|
||||
|
@ -764,7 +821,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"execution_count": 53,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -777,7 +834,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"execution_count": 54,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -791,7 +848,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"execution_count": 55,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -814,7 +871,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"execution_count": 56,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -828,7 +885,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"execution_count": 57,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -842,7 +899,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
"execution_count": 58,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -855,7 +912,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"execution_count": 59,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -868,7 +925,18 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"execution_count": 60,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sgd_clf.classes_[5]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 61,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -884,7 +952,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"execution_count": 62,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -897,7 +965,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 51,
|
||||
"execution_count": 63,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -911,7 +979,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 52,
|
||||
"execution_count": 64,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -924,7 +992,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 53,
|
||||
"execution_count": 65,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -937,7 +1005,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 54,
|
||||
"execution_count": 66,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -953,7 +1021,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 55,
|
||||
"execution_count": 67,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -968,7 +1036,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 56,
|
||||
"execution_count": 68,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -981,8 +1049,17 @@
|
|||
" fig = plt.figure(figsize=(8,8))\n",
|
||||
" ax = fig.add_subplot(111)\n",
|
||||
" cax = ax.matshow(conf_mx)\n",
|
||||
" fig.colorbar(cax)\n",
|
||||
"\n",
|
||||
" fig.colorbar(cax)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 69,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.matshow(conf_mx, cmap=plt.cm.gray)\n",
|
||||
"save_fig(\"confusion_matrix_plot\", tight_layout=False)\n",
|
||||
"plt.show()"
|
||||
|
@ -990,7 +1067,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 57,
|
||||
"execution_count": 70,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -999,7 +1076,17 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"row_sums = conf_mx.sum(axis=1, keepdims=True)\n",
|
||||
"norm_conf_mx = conf_mx / row_sums\n",
|
||||
"norm_conf_mx = conf_mx / row_sums"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 71,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"np.fill_diagonal(norm_conf_mx, 0)\n",
|
||||
"plt.matshow(norm_conf_mx, cmap=plt.cm.gray)\n",
|
||||
"save_fig(\"confusion_matrix_errors_plot\", tight_layout=False)\n",
|
||||
|
@ -1008,7 +1095,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 58,
|
||||
"execution_count": 72,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -1023,14 +1110,10 @@
|
|||
"X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]\n",
|
||||
"\n",
|
||||
"plt.figure(figsize=(8,8))\n",
|
||||
"plt.subplot(221)\n",
|
||||
"plot_digits(X_aa[:25], images_per_row=5)\n",
|
||||
"plt.subplot(222)\n",
|
||||
"plot_digits(X_ab[:25], images_per_row=5)\n",
|
||||
"plt.subplot(223)\n",
|
||||
"plot_digits(X_ba[:25], images_per_row=5)\n",
|
||||
"plt.subplot(224)\n",
|
||||
"plot_digits(X_bb[:25], images_per_row=5)\n",
|
||||
"plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)\n",
|
||||
"plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)\n",
|
||||
"plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)\n",
|
||||
"plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)\n",
|
||||
"save_fig(\"error_analysis_digits_plot\")\n",
|
||||
"plt.show()"
|
||||
]
|
||||
|
@ -1047,7 +1130,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 59,
|
||||
"execution_count": 73,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -1067,7 +1150,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 60,
|
||||
"execution_count": 74,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -1080,20 +1163,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 61,
|
||||
"metadata": {
|
||||
"collapsed": true,
|
||||
"deletable": true,
|
||||
"editable": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_train, cv=3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 62,
|
||||
"execution_count": 77,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -1101,6 +1171,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_train, cv=3)\n",
|
||||
"f1_score(y_train, y_train_knn_pred, average=\"macro\")"
|
||||
]
|
||||
},
|
||||
|
@ -1116,7 +1187,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 63,
|
||||
"execution_count": 78,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -1134,7 +1205,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 64,
|
||||
"execution_count": 79,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -1151,20 +1222,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 65,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
"editable": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"knn_clf.fit(X_train_mod, y_train_mod)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 66,
|
||||
"execution_count": 82,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"deletable": true,
|
||||
|
@ -1172,10 +1230,10 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"knn_clf.fit(X_train_mod, y_train_mod)\n",
|
||||
"clean_digit = knn_clf.predict([X_test_mod[some_index]])\n",
|
||||
"plot_digit(clean_digit)\n",
|
||||
"save_fig(\"cleaned_digit_example_plot\")\n",
|
||||
"plt.show()"
|
||||
"save_fig(\"cleaned_digit_example_plot\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -1435,7 +1493,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.2+"
|
||||
"version": "3.5.3"
|
||||
},
|
||||
"nav_menu": {},
|
||||
"toc": {
|
||||
|
|
Loading…
Reference in New Issue