Add workaround for a bug introduced by Scikit-Learn 0.19.0, in chapter 03

main
Aurélien Geron 2017-09-15 17:58:29 +02:00
parent d016b56672
commit ca35dddc38
1 changed files with 110 additions and 61 deletions

View File

@ -26,7 +26,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 1,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"# To support both python 2 and python 3\n", "# To support both python 2 and python 3\n",
@ -143,7 +145,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 8,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"# EXTRA\n", "# EXTRA\n",
@ -313,7 +317,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 20,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"from sklearn.model_selection import cross_val_predict\n", "from sklearn.model_selection import cross_val_predict\n",
@ -463,9 +469,38 @@
" method=\"decision_function\")" " method=\"decision_function\")"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: there is an [issue](https://github.com/scikit-learn/scikit-learn/issues/9589) introduced in Scikit-Learn 0.19.0 where the result of `cross_val_predict()` is incorrect in the binary classification case when using `method=\"decision_function\"`, as in the code above. The resulting array has an extra first dimension full of 0s. We need to add this small hack for now to work around this issue:"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 35, "execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"y_scores.shape"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# hack to work around issue #9589 introduced in Scikit-Learn 0.19.0\n",
"if y_scores.ndim == 2:\n",
" y_scores = y_scores[:, 1]"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": { "metadata": {
"collapsed": true "collapsed": true
}, },
@ -478,7 +513,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 36, "execution_count": 38,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -498,7 +533,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 37, "execution_count": 39,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -507,8 +542,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 38, "execution_count": 40,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"y_train_pred_90 = (y_scores > 70000)" "y_train_pred_90 = (y_scores > 70000)"
@ -516,7 +553,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 39, "execution_count": 41,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -525,7 +562,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 40, "execution_count": 42,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -534,7 +571,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 41, "execution_count": 43,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -559,8 +596,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 42, "execution_count": 44,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"from sklearn.metrics import roc_curve\n", "from sklearn.metrics import roc_curve\n",
@ -570,7 +609,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 43, "execution_count": 45,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -589,7 +628,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 44, "execution_count": 46,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -600,7 +639,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 45, "execution_count": 47,
"metadata": { "metadata": {
"collapsed": true "collapsed": true
}, },
@ -614,7 +653,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 46, "execution_count": 48,
"metadata": { "metadata": {
"collapsed": true "collapsed": true
}, },
@ -626,7 +665,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 47, "execution_count": 49,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -640,7 +679,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 48, "execution_count": 50,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -649,7 +688,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 49, "execution_count": 51,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -659,7 +698,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 50, "execution_count": 52,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -675,7 +714,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 51, "execution_count": 53,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -685,7 +724,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 52, "execution_count": 54,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -695,7 +734,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 53, "execution_count": 55,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -704,7 +743,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 54, "execution_count": 56,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -713,7 +752,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 55, "execution_count": 57,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -722,7 +761,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 56, "execution_count": 58,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -734,7 +773,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 57, "execution_count": 59,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -743,7 +782,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 58, "execution_count": 60,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -753,7 +792,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 59, "execution_count": 61,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -762,7 +801,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 60, "execution_count": 62,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -771,7 +810,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 61, "execution_count": 63,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -783,7 +822,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 62, "execution_count": 64,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -794,8 +833,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 63, "execution_count": 65,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"def plot_confusion_matrix(matrix):\n", "def plot_confusion_matrix(matrix):\n",
@ -808,7 +849,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 64, "execution_count": 66,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -819,8 +860,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 65, "execution_count": 67,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"row_sums = conf_mx.sum(axis=1, keepdims=True)\n", "row_sums = conf_mx.sum(axis=1, keepdims=True)\n",
@ -829,7 +872,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 66, "execution_count": 68,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -841,7 +884,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 67, "execution_count": 69,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -869,7 +912,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 68, "execution_count": 70,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -885,7 +928,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 69, "execution_count": 71,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -894,7 +937,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 70, "execution_count": 72,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -911,8 +954,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 71, "execution_count": 73,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"noise = np.random.randint(0, 100, (len(X_train), 784))\n", "noise = np.random.randint(0, 100, (len(X_train), 784))\n",
@ -925,7 +970,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 72, "execution_count": 74,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -938,7 +983,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 73, "execution_count": 75,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -964,8 +1009,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 74, "execution_count": 76,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"from sklearn.dummy import DummyClassifier\n", "from sklearn.dummy import DummyClassifier\n",
@ -976,7 +1023,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 75, "execution_count": 77,
"metadata": { "metadata": {
"scrolled": true "scrolled": true
}, },
@ -995,7 +1042,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 76, "execution_count": 78,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1006,8 +1053,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 77, "execution_count": 79,
"metadata": {}, "metadata": {
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"y_knn_pred = knn_clf.predict(X_test)" "y_knn_pred = knn_clf.predict(X_test)"
@ -1015,7 +1064,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 78, "execution_count": 80,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1025,7 +1074,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 79, "execution_count": 81,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1038,7 +1087,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 80, "execution_count": 82,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1056,7 +1105,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 81, "execution_count": 83,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1065,7 +1114,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 82, "execution_count": 84,
"metadata": { "metadata": {
"collapsed": true "collapsed": true
}, },
@ -1076,7 +1125,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 83, "execution_count": 85,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1085,7 +1134,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 84, "execution_count": 86,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1095,7 +1144,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 85, "execution_count": 87,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -1144,7 +1193,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.5.3" "version": "3.5.2"
}, },
"nav_menu": {}, "nav_menu": {},
"toc": { "toc": {