Move to sklearn 0.18

main
Aurélien Geron 2016-11-05 18:13:54 +01:00
parent 9e414b6d64
commit a40b278df5
1 changed files with 58 additions and 90 deletions

View File

@ -264,7 +264,7 @@
},
"outputs": [],
"source": [
"from sklearn.cross_validation import cross_val_score\n",
"from sklearn.model_selection import cross_val_score\n",
"cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")"
]
},
@ -276,12 +276,12 @@
},
"outputs": [],
"source": [
"from sklearn.cross_validation import StratifiedKFold\n",
"from sklearn.model_selection import StratifiedKFold\n",
"from sklearn.base import clone\n",
"\n",
"skfolds = StratifiedKFold(y_train_5, n_folds=3, random_state=42)\n",
"skfolds = StratifiedKFold(n_splits=3, random_state=42)\n",
"\n",
"for train_index, test_index in skfolds:\n",
"for train_index, test_index in skfolds.split(X_train, y_train_5):\n",
" clone_clf = clone(sgd_clf)\n",
" X_train_folds = X_train[train_index]\n",
" y_train_folds = (y_train_5[train_index])\n",
@ -330,7 +330,7 @@
},
"outputs": [],
"source": [
"from sklearn.cross_validation import cross_val_predict\n",
"from sklearn.model_selection import cross_val_predict\n",
"\n",
"y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)"
]
@ -459,32 +459,11 @@
"cell_type": "code",
"execution_count": 30,
"metadata": {
"collapsed": false
"collapsed": true
},
"outputs": [],
"source": [
"# Implemented in https://github.com/scikit-learn/scikit-learn/pull/6671\n",
"# Pushed to master but not yet in pip module.\n",
"from sklearn.cross_validation import StratifiedKFold\n",
"from sklearn.base import clone\n",
"\n",
"def cross_val_predict_future(clf, X, y, cv, method=None):\n",
" clf_clone = clone(clf) # keep original intact\n",
" if method is None:\n",
" return cross_val_predict(clf, X, y, cv=cv)\n",
" else:\n",
" method_f = getattr(clf_clone, method)\n",
" scores = []\n",
" skfolds = StratifiedKFold(y, n_folds=cv)\n",
" for train_indices, test_indices in skfolds:\n",
" clf_clone.fit(X[train_indices], y[train_indices])\n",
" scores.append((method_f(X[test_indices]), test_indices))\n",
" res_shape = list(scores[0][0].shape)\n",
" res_shape[0] = len(X)\n",
" res = np.empty(tuple(res_shape))\n",
" for sc, test_indices in scores:\n",
" res[test_indices] = sc\n",
" return res"
"y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method=\"decision_function\")"
]
},
{
@ -494,17 +473,6 @@
"collapsed": true
},
"outputs": [],
"source": [
"y_scores = cross_val_predict_future(sgd_clf, X_train, y_train_5, cv=3, method=\"decision_function\")"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from sklearn.metrics import precision_recall_curve\n",
"\n",
@ -513,7 +481,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 32,
"metadata": {
"collapsed": false
},
@ -535,7 +503,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 33,
"metadata": {
"collapsed": false
},
@ -546,7 +514,7 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 34,
"metadata": {
"collapsed": false
},
@ -557,7 +525,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 35,
"metadata": {
"collapsed": false
},
@ -568,7 +536,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 36,
"metadata": {
"collapsed": false
},
@ -579,7 +547,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 37,
"metadata": {
"collapsed": false
},
@ -606,7 +574,7 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 38,
"metadata": {
"collapsed": false
},
@ -619,7 +587,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 39,
"metadata": {
"collapsed": false
},
@ -640,7 +608,7 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 40,
"metadata": {
"collapsed": false
},
@ -653,7 +621,7 @@
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": 41,
"metadata": {
"collapsed": false
},
@ -661,7 +629,7 @@
"source": [
"from sklearn.ensemble import RandomForestClassifier\n",
"forest_clf = RandomForestClassifier(random_state=42)\n",
"y_probas_forest = cross_val_predict_future(forest_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n",
"y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n",
"y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n",
"fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)\n",
"\n",
@ -675,7 +643,7 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 42,
"metadata": {
"collapsed": false
},
@ -686,7 +654,7 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 43,
"metadata": {
"collapsed": false
},
@ -698,7 +666,7 @@
},
{
"cell_type": "code",
"execution_count": 45,
"execution_count": 44,
"metadata": {
"collapsed": false
},
@ -716,7 +684,7 @@
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": 45,
"metadata": {
"collapsed": false
},
@ -728,7 +696,7 @@
},
{
"cell_type": "code",
"execution_count": 47,
"execution_count": 46,
"metadata": {
"collapsed": false
},
@ -740,7 +708,7 @@
},
{
"cell_type": "code",
"execution_count": 48,
"execution_count": 47,
"metadata": {
"collapsed": false
},
@ -751,7 +719,7 @@
},
{
"cell_type": "code",
"execution_count": 49,
"execution_count": 48,
"metadata": {
"collapsed": false
},
@ -762,7 +730,7 @@
},
{
"cell_type": "code",
"execution_count": 50,
"execution_count": 49,
"metadata": {
"collapsed": false
},
@ -776,7 +744,7 @@
},
{
"cell_type": "code",
"execution_count": 51,
"execution_count": 50,
"metadata": {
"collapsed": false
},
@ -787,7 +755,7 @@
},
{
"cell_type": "code",
"execution_count": 52,
"execution_count": 51,
"metadata": {
"collapsed": false
},
@ -799,7 +767,7 @@
},
{
"cell_type": "code",
"execution_count": 53,
"execution_count": 52,
"metadata": {
"collapsed": false
},
@ -810,7 +778,7 @@
},
{
"cell_type": "code",
"execution_count": 54,
"execution_count": 53,
"metadata": {
"collapsed": false
},
@ -821,7 +789,7 @@
},
{
"cell_type": "code",
"execution_count": 55,
"execution_count": 54,
"metadata": {
"collapsed": false
},
@ -835,7 +803,7 @@
},
{
"cell_type": "code",
"execution_count": 56,
"execution_count": 55,
"metadata": {
"collapsed": false
},
@ -848,7 +816,7 @@
},
{
"cell_type": "code",
"execution_count": 57,
"execution_count": 56,
"metadata": {
"collapsed": false
},
@ -868,7 +836,7 @@
},
{
"cell_type": "code",
"execution_count": 58,
"execution_count": 57,
"metadata": {
"collapsed": false
},
@ -884,7 +852,7 @@
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": 58,
"metadata": {
"collapsed": false
},
@ -918,7 +886,7 @@
},
{
"cell_type": "code",
"execution_count": 60,
"execution_count": 59,
"metadata": {
"collapsed": false
},
@ -936,7 +904,7 @@
},
{
"cell_type": "code",
"execution_count": 61,
"execution_count": 60,
"metadata": {
"collapsed": false
},
@ -947,7 +915,7 @@
},
{
"cell_type": "code",
"execution_count": 62,
"execution_count": 61,
"metadata": {
"collapsed": true
},
@ -958,7 +926,7 @@
},
{
"cell_type": "code",
"execution_count": 63,
"execution_count": 62,
"metadata": {
"collapsed": false
},
@ -976,7 +944,7 @@
},
{
"cell_type": "code",
"execution_count": 64,
"execution_count": 63,
"metadata": {
"collapsed": false
},
@ -992,7 +960,7 @@
},
{
"cell_type": "code",
"execution_count": 65,
"execution_count": 64,
"metadata": {
"collapsed": false
},
@ -1007,7 +975,7 @@
},
{
"cell_type": "code",
"execution_count": 66,
"execution_count": 65,
"metadata": {
"collapsed": false
},
@ -1018,7 +986,7 @@
},
{
"cell_type": "code",
"execution_count": 67,
"execution_count": 66,
"metadata": {
"collapsed": false
},
@ -1046,7 +1014,7 @@
},
{
"cell_type": "code",
"execution_count": 68,
"execution_count": 67,
"metadata": {
"collapsed": false
},
@ -1054,13 +1022,13 @@
"source": [
"from sklearn.dummy import DummyClassifier\n",
"dmy_clf = DummyClassifier()\n",
"y_probas_dmy = cross_val_predict_future(dmy_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n",
"y_probas_dmy = cross_val_predict(dmy_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n",
"y_scores_dmy = y_probas_dmy[:, 1]"
]
},
{
"cell_type": "code",
"execution_count": 69,
"execution_count": 68,
"metadata": {
"collapsed": false,
"scrolled": true
@ -1080,7 +1048,7 @@
},
{
"cell_type": "code",
"execution_count": 70,
"execution_count": 69,
"metadata": {
"collapsed": false
},
@ -1093,7 +1061,7 @@
},
{
"cell_type": "code",
"execution_count": 71,
"execution_count": 70,
"metadata": {
"collapsed": false
},
@ -1104,7 +1072,7 @@
},
{
"cell_type": "code",
"execution_count": 72,
"execution_count": 71,
"metadata": {
"collapsed": false
},
@ -1116,7 +1084,7 @@
},
{
"cell_type": "code",
"execution_count": 73,
"execution_count": 72,
"metadata": {
"collapsed": false
},
@ -1131,7 +1099,7 @@
},
{
"cell_type": "code",
"execution_count": 74,
"execution_count": 73,
"metadata": {
"collapsed": false
},
@ -1151,7 +1119,7 @@
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 74,
"metadata": {
"collapsed": false
},
@ -1162,7 +1130,7 @@
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 75,
"metadata": {
"collapsed": true
},
@ -1173,7 +1141,7 @@
},
{
"cell_type": "code",
"execution_count": 77,
"execution_count": 76,
"metadata": {
"collapsed": false
},
@ -1184,7 +1152,7 @@
},
{
"cell_type": "code",
"execution_count": 78,
"execution_count": 77,
"metadata": {
"collapsed": false
},
@ -1196,7 +1164,7 @@
},
{
"cell_type": "code",
"execution_count": 79,
"execution_count": 78,
"metadata": {
"collapsed": false
},
@ -1247,7 +1215,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
"version": "3.5.2"
},
"nav_menu": {},
"toc": {