From a40b278df5e681785b9c64abb0f80b67caecbc43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Sat, 5 Nov 2016 18:13:54 +0100 Subject: [PATCH] Move to sklearn 0.18 --- 03_classification.ipynb | 148 ++++++++++++++++------------------------ 1 file changed, 58 insertions(+), 90 deletions(-) diff --git a/03_classification.ipynb b/03_classification.ipynb index ea7f79b..1140701 100644 --- a/03_classification.ipynb +++ b/03_classification.ipynb @@ -264,7 +264,7 @@ }, "outputs": [], "source": [ - "from sklearn.cross_validation import cross_val_score\n", + "from sklearn.model_selection import cross_val_score\n", "cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")" ] }, @@ -276,18 +276,18 @@ }, "outputs": [], "source": [ - "from sklearn.cross_validation import StratifiedKFold\n", + "from sklearn.model_selection import StratifiedKFold\n", "from sklearn.base import clone\n", "\n", - "skfolds = StratifiedKFold(y_train_5, n_folds=3, random_state=42)\n", + "skfolds = StratifiedKFold(n_splits=3, random_state=42)\n", "\n", - "for train_index, test_index in skfolds:\n", + "for train_index, test_index in skfolds.split(X_train, y_train_5):\n", " clone_clf = clone(sgd_clf)\n", " X_train_folds = X_train[train_index]\n", " y_train_folds = (y_train_5[train_index])\n", " X_test_fold = X_train[test_index]\n", " y_test_fold = (y_train_5[test_index])\n", - " \n", + "\n", " clone_clf.fit(X_train_folds, y_train_folds)\n", " y_pred = clone_clf.predict(X_test_fold)\n", " n_correct = sum(y_pred == y_test_fold)\n", @@ -330,7 +330,7 @@ }, "outputs": [], "source": [ - "from sklearn.cross_validation import cross_val_predict\n", + "from sklearn.model_selection import cross_val_predict\n", "\n", "y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)" ] @@ -459,32 +459,11 @@ "cell_type": "code", "execution_count": 30, "metadata": { - "collapsed": false + "collapsed": true }, "outputs": [], "source": [ - "# Implemented in https://github.com/scikit-learn/scikit-learn/pull/6671\n", - "# Pushed to master but not yet in pip module.\n", - "from sklearn.cross_validation import StratifiedKFold\n", - "from sklearn.base import clone\n", - "\n", - "def cross_val_predict_future(clf, X, y, cv, method=None):\n", - " clf_clone = clone(clf) # keep original intact\n", - " if method is None:\n", - " return cross_val_predict(clf, X, y, cv=cv)\n", - " else:\n", - " method_f = getattr(clf_clone, method)\n", - " scores = []\n", - " skfolds = StratifiedKFold(y, n_folds=cv)\n", - " for train_indices, test_indices in skfolds:\n", - " clf_clone.fit(X[train_indices], y[train_indices])\n", - " scores.append((method_f(X[test_indices]), test_indices))\n", - " res_shape = list(scores[0][0].shape)\n", - " res_shape[0] = len(X)\n", - " res = np.empty(tuple(res_shape))\n", - " for sc, test_indices in scores:\n", - " res[test_indices] = sc\n", - " return res" + "y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method=\"decision_function\")" ] }, { @@ -494,17 +473,6 @@ "collapsed": true }, "outputs": [], - "source": [ - "y_scores = cross_val_predict_future(sgd_clf, X_train, y_train_5, cv=3, method=\"decision_function\")" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "metadata": { - "collapsed": true - }, - "outputs": [], "source": [ "from sklearn.metrics import precision_recall_curve\n", "\n", @@ -513,7 +481,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 32, "metadata": { "collapsed": false }, @@ -535,7 +503,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 33, "metadata": { "collapsed": false }, @@ -546,7 +514,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 34, "metadata": { "collapsed": false }, @@ -557,7 +525,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 35, "metadata": { "collapsed": false }, @@ -568,7 +536,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 36, "metadata": { "collapsed": false }, @@ -579,7 +547,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 37, "metadata": { "collapsed": false }, @@ -606,7 +574,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 38, "metadata": { "collapsed": false }, @@ -619,7 +587,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 39, "metadata": { "collapsed": false }, @@ -640,7 +608,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 40, "metadata": { "collapsed": false }, @@ -653,7 +621,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 41, "metadata": { "collapsed": false }, @@ -661,7 +629,7 @@ "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "forest_clf = RandomForestClassifier(random_state=42)\n", - "y_probas_forest = cross_val_predict_future(forest_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n", + "y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n", "y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n", "fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)\n", "\n", @@ -675,7 +643,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 42, "metadata": { "collapsed": false }, @@ -686,7 +654,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 43, "metadata": { "collapsed": false }, @@ -698,7 +666,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 44, "metadata": { "collapsed": false }, @@ -716,7 +684,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 45, "metadata": { "collapsed": false }, @@ -728,7 +696,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 46, "metadata": { "collapsed": false }, @@ -740,7 +708,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 47, "metadata": { "collapsed": false }, @@ -751,7 +719,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 48, "metadata": { "collapsed": false }, @@ -762,7 +730,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 49, "metadata": { "collapsed": false }, @@ -776,7 +744,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 50, "metadata": { "collapsed": false }, @@ -787,7 +755,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 51, "metadata": { "collapsed": false }, @@ -799,7 +767,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 52, "metadata": { "collapsed": false }, @@ -810,7 +778,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 53, "metadata": { "collapsed": false }, @@ -821,7 +789,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 54, "metadata": { "collapsed": false }, @@ -835,7 +803,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 55, "metadata": { "collapsed": false }, @@ -848,7 +816,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 56, "metadata": { "collapsed": false }, @@ -868,7 +836,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 57, "metadata": { "collapsed": false }, @@ -884,7 +852,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 58, "metadata": { "collapsed": false }, @@ -918,7 +886,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 59, "metadata": { "collapsed": false }, @@ -936,7 +904,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 60, "metadata": { "collapsed": false }, @@ -947,7 +915,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 61, "metadata": { "collapsed": true }, @@ -958,7 +926,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 62, "metadata": { "collapsed": false }, @@ -976,7 +944,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 63, "metadata": { "collapsed": false }, @@ -992,7 +960,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": 64, "metadata": { "collapsed": false }, @@ -1007,7 +975,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": 65, "metadata": { "collapsed": false }, @@ -1018,7 +986,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": 66, "metadata": { "collapsed": false }, @@ -1046,7 +1014,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 67, "metadata": { "collapsed": false }, @@ -1054,13 +1022,13 @@ "source": [ "from sklearn.dummy import DummyClassifier\n", "dmy_clf = DummyClassifier()\n", - "y_probas_dmy = cross_val_predict_future(dmy_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n", + "y_probas_dmy = cross_val_predict(dmy_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n", "y_scores_dmy = y_probas_dmy[:, 1]" ] }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 68, "metadata": { "collapsed": false, "scrolled": true @@ -1080,7 +1048,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 69, "metadata": { "collapsed": false }, @@ -1093,7 +1061,7 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": 70, "metadata": { "collapsed": false }, @@ -1104,7 +1072,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 71, "metadata": { "collapsed": false }, @@ -1116,7 +1084,7 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 72, "metadata": { "collapsed": false }, @@ -1131,7 +1099,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 73, "metadata": { "collapsed": false }, @@ -1151,7 +1119,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 74, "metadata": { "collapsed": false }, @@ -1162,7 +1130,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 75, "metadata": { "collapsed": true }, @@ -1173,7 +1141,7 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 76, "metadata": { "collapsed": false }, @@ -1184,7 +1152,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 77, "metadata": { "collapsed": false }, @@ -1196,7 +1164,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 78, "metadata": { "collapsed": false }, @@ -1247,7 +1215,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.5.1" + "version": "3.5.2" }, "nav_menu": {}, "toc": {