From ca35dddc38298ce909efbf16ff47df68fa2751ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Fri, 15 Sep 2017 17:58:29 +0200 Subject: [PATCH] Add workaround for a bug introduced by Scikit-Learn 0.19.0, in chapter 03 --- 03_classification.ipynb | 171 ++++++++++++++++++++++++++-------------- 1 file changed, 110 insertions(+), 61 deletions(-) diff --git a/03_classification.ipynb b/03_classification.ipynb index 74f2b16..be6a2dc 100644 --- a/03_classification.ipynb +++ b/03_classification.ipynb @@ -26,7 +26,9 @@ { "cell_type": "code", "execution_count": 1, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# To support both python 2 and python 3\n", @@ -143,7 +145,9 @@ { "cell_type": "code", "execution_count": 8, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "# EXTRA\n", @@ -313,7 +317,9 @@ { "cell_type": "code", "execution_count": 20, - "metadata": {}, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "from sklearn.model_selection import cross_val_predict\n", @@ -463,9 +469,38 @@ " method=\"decision_function\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note: there is an [issue](https://github.com/scikit-learn/scikit-learn/issues/9589) introduced in Scikit-Learn 0.19.0 where the result of `cross_val_predict()` is incorrect in the binary classification case when using `method=\"decision_function\"`, as in the code above. The resulting array has an extra first dimension full of 0s. We need to add this small hack for now to work around this issue:" + ] + }, { "cell_type": "code", "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "y_scores.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# hack to work around issue #9589 introduced in Scikit-Learn 0.19.0\n", + "if y_scores.ndim == 2:\n", + " y_scores = y_scores[:, 1]" + ] + }, + { + "cell_type": "code", + "execution_count": 37, "metadata": { "collapsed": true }, @@ -478,7 +513,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -498,7 +533,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ @@ -507,8 +542,10 @@ }, { "cell_type": "code", - "execution_count": 38, - "metadata": {}, + "execution_count": 40, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "y_train_pred_90 = (y_scores > 70000)" @@ -516,7 +553,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 41, "metadata": {}, "outputs": [], "source": [ @@ -525,7 +562,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 42, "metadata": {}, "outputs": [], "source": [ @@ -534,7 +571,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 43, "metadata": {}, "outputs": [], "source": [ @@ -559,8 +596,10 @@ }, { "cell_type": "code", - "execution_count": 42, - "metadata": {}, + "execution_count": 44, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "from sklearn.metrics import roc_curve\n", @@ -570,7 +609,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ @@ -589,7 +628,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 46, "metadata": {}, "outputs": [], "source": [ @@ -600,7 +639,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 47, "metadata": { "collapsed": true }, @@ -614,7 +653,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 48, "metadata": { "collapsed": true }, @@ -626,7 +665,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 49, "metadata": {}, "outputs": [], "source": [ @@ -640,7 +679,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 50, "metadata": {}, "outputs": [], "source": [ @@ -649,7 +688,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ @@ -659,7 +698,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 52, "metadata": {}, "outputs": [], "source": [ @@ -675,7 +714,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 53, "metadata": {}, "outputs": [], "source": [ @@ -685,7 +724,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 54, "metadata": {}, "outputs": [], "source": [ @@ -695,7 +734,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 55, "metadata": {}, "outputs": [], "source": [ @@ -704,7 +743,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 56, "metadata": {}, "outputs": [], "source": [ @@ -713,7 +752,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 57, "metadata": {}, "outputs": [], "source": [ @@ -722,7 +761,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 58, "metadata": {}, "outputs": [], "source": [ @@ -734,7 +773,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 59, "metadata": {}, "outputs": [], "source": [ @@ -743,7 +782,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 60, "metadata": {}, "outputs": [], "source": [ @@ -753,7 +792,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 61, "metadata": {}, "outputs": [], "source": [ @@ -762,7 +801,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 62, "metadata": {}, "outputs": [], "source": [ @@ -771,7 +810,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 63, "metadata": {}, "outputs": [], "source": [ @@ -783,7 +822,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 64, "metadata": {}, "outputs": [], "source": [ @@ -794,8 +833,10 @@ }, { "cell_type": "code", - "execution_count": 63, - "metadata": {}, + "execution_count": 65, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "def plot_confusion_matrix(matrix):\n", @@ -808,7 +849,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 66, "metadata": {}, "outputs": [], "source": [ @@ -819,8 +860,10 @@ }, { "cell_type": "code", - "execution_count": 65, - "metadata": {}, + "execution_count": 67, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "row_sums = conf_mx.sum(axis=1, keepdims=True)\n", @@ -829,7 +872,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": 68, "metadata": {}, "outputs": [], "source": [ @@ -841,7 +884,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": 69, "metadata": {}, "outputs": [], "source": [ @@ -869,7 +912,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 70, "metadata": {}, "outputs": [], "source": [ @@ -885,7 +928,7 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 71, "metadata": {}, "outputs": [], "source": [ @@ -894,7 +937,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 72, "metadata": {}, "outputs": [], "source": [ @@ -911,8 +954,10 @@ }, { "cell_type": "code", - "execution_count": 71, - "metadata": {}, + "execution_count": 73, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "noise = np.random.randint(0, 100, (len(X_train), 784))\n", @@ -925,7 +970,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 74, "metadata": {}, "outputs": [], "source": [ @@ -938,7 +983,7 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 75, "metadata": {}, "outputs": [], "source": [ @@ -964,8 +1009,10 @@ }, { "cell_type": "code", - "execution_count": 74, - "metadata": {}, + "execution_count": 76, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "from sklearn.dummy import DummyClassifier\n", @@ -976,7 +1023,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 77, "metadata": { "scrolled": true }, @@ -995,7 +1042,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 78, "metadata": {}, "outputs": [], "source": [ @@ -1006,8 +1053,10 @@ }, { "cell_type": "code", - "execution_count": 77, - "metadata": {}, + "execution_count": 79, + "metadata": { + "collapsed": true + }, "outputs": [], "source": [ "y_knn_pred = knn_clf.predict(X_test)" @@ -1015,7 +1064,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 80, "metadata": {}, "outputs": [], "source": [ @@ -1025,7 +1074,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 81, "metadata": {}, "outputs": [], "source": [ @@ -1038,7 +1087,7 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 82, "metadata": {}, "outputs": [], "source": [ @@ -1056,7 +1105,7 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 83, "metadata": {}, "outputs": [], "source": [ @@ -1065,7 +1114,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 84, "metadata": { "collapsed": true }, @@ -1076,7 +1125,7 @@ }, { "cell_type": "code", - "execution_count": 83, + "execution_count": 85, "metadata": {}, "outputs": [], "source": [ @@ -1085,7 +1134,7 @@ }, { "cell_type": "code", - "execution_count": 84, + "execution_count": 86, "metadata": {}, "outputs": [], "source": [ @@ -1095,7 +1144,7 @@ }, { "cell_type": "code", - "execution_count": 85, + "execution_count": 87, "metadata": {}, "outputs": [], "source": [ @@ -1144,7 +1193,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.5.3" + "version": "3.5.2" }, "nav_menu": {}, "toc": {