Move to sklearn 0.18
parent
9e414b6d64
commit
a40b278df5
|
@ -264,7 +264,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.cross_validation import cross_val_score\n",
|
||||
"from sklearn.model_selection import cross_val_score\n",
|
||||
"cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring=\"accuracy\")"
|
||||
]
|
||||
},
|
||||
|
@ -276,12 +276,12 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.cross_validation import StratifiedKFold\n",
|
||||
"from sklearn.model_selection import StratifiedKFold\n",
|
||||
"from sklearn.base import clone\n",
|
||||
"\n",
|
||||
"skfolds = StratifiedKFold(y_train_5, n_folds=3, random_state=42)\n",
|
||||
"skfolds = StratifiedKFold(n_splits=3, random_state=42)\n",
|
||||
"\n",
|
||||
"for train_index, test_index in skfolds:\n",
|
||||
"for train_index, test_index in skfolds.split(X_train, y_train_5):\n",
|
||||
" clone_clf = clone(sgd_clf)\n",
|
||||
" X_train_folds = X_train[train_index]\n",
|
||||
" y_train_folds = (y_train_5[train_index])\n",
|
||||
|
@ -330,7 +330,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.cross_validation import cross_val_predict\n",
|
||||
"from sklearn.model_selection import cross_val_predict\n",
|
||||
"\n",
|
||||
"y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)"
|
||||
]
|
||||
|
@ -459,32 +459,11 @@
|
|||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Implemented in https://github.com/scikit-learn/scikit-learn/pull/6671\n",
|
||||
"# Pushed to master but not yet in pip module.\n",
|
||||
"from sklearn.cross_validation import StratifiedKFold\n",
|
||||
"from sklearn.base import clone\n",
|
||||
"\n",
|
||||
"def cross_val_predict_future(clf, X, y, cv, method=None):\n",
|
||||
" clf_clone = clone(clf) # keep original intact\n",
|
||||
" if method is None:\n",
|
||||
" return cross_val_predict(clf, X, y, cv=cv)\n",
|
||||
" else:\n",
|
||||
" method_f = getattr(clf_clone, method)\n",
|
||||
" scores = []\n",
|
||||
" skfolds = StratifiedKFold(y, n_folds=cv)\n",
|
||||
" for train_indices, test_indices in skfolds:\n",
|
||||
" clf_clone.fit(X[train_indices], y[train_indices])\n",
|
||||
" scores.append((method_f(X[test_indices]), test_indices))\n",
|
||||
" res_shape = list(scores[0][0].shape)\n",
|
||||
" res_shape[0] = len(X)\n",
|
||||
" res = np.empty(tuple(res_shape))\n",
|
||||
" for sc, test_indices in scores:\n",
|
||||
" res[test_indices] = sc\n",
|
||||
" return res"
|
||||
"y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method=\"decision_function\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -494,17 +473,6 @@
|
|||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"y_scores = cross_val_predict_future(sgd_clf, X_train, y_train_5, cv=3, method=\"decision_function\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.metrics import precision_recall_curve\n",
|
||||
"\n",
|
||||
|
@ -513,7 +481,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -535,7 +503,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -546,7 +514,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 34,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -557,7 +525,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 35,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -568,7 +536,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 36,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -579,7 +547,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"execution_count": 37,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -606,7 +574,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"execution_count": 38,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -619,7 +587,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"execution_count": 39,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -640,7 +608,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"execution_count": 40,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -653,7 +621,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"execution_count": 41,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -661,7 +629,7 @@
|
|||
"source": [
|
||||
"from sklearn.ensemble import RandomForestClassifier\n",
|
||||
"forest_clf = RandomForestClassifier(random_state=42)\n",
|
||||
"y_probas_forest = cross_val_predict_future(forest_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n",
|
||||
"y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n",
|
||||
"y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class\n",
|
||||
"fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)\n",
|
||||
"\n",
|
||||
|
@ -675,7 +643,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"execution_count": 42,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -686,7 +654,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 44,
|
||||
"execution_count": 43,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -698,7 +666,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"execution_count": 44,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -716,7 +684,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 46,
|
||||
"execution_count": 45,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -728,7 +696,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
"execution_count": 46,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -740,7 +708,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"execution_count": 47,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -751,7 +719,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"execution_count": 48,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -762,7 +730,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"execution_count": 49,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -776,7 +744,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 51,
|
||||
"execution_count": 50,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -787,7 +755,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 52,
|
||||
"execution_count": 51,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -799,7 +767,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 53,
|
||||
"execution_count": 52,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -810,7 +778,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 54,
|
||||
"execution_count": 53,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -821,7 +789,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 55,
|
||||
"execution_count": 54,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -835,7 +803,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 56,
|
||||
"execution_count": 55,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -848,7 +816,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 57,
|
||||
"execution_count": 56,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -868,7 +836,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 58,
|
||||
"execution_count": 57,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -884,7 +852,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 59,
|
||||
"execution_count": 58,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -918,7 +886,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 60,
|
||||
"execution_count": 59,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -936,7 +904,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 61,
|
||||
"execution_count": 60,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -947,7 +915,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 62,
|
||||
"execution_count": 61,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
|
@ -958,7 +926,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 63,
|
||||
"execution_count": 62,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -976,7 +944,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 64,
|
||||
"execution_count": 63,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -992,7 +960,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 65,
|
||||
"execution_count": 64,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -1007,7 +975,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 66,
|
||||
"execution_count": 65,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -1018,7 +986,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 67,
|
||||
"execution_count": 66,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -1046,7 +1014,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 68,
|
||||
"execution_count": 67,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -1054,13 +1022,13 @@
|
|||
"source": [
|
||||
"from sklearn.dummy import DummyClassifier\n",
|
||||
"dmy_clf = DummyClassifier()\n",
|
||||
"y_probas_dmy = cross_val_predict_future(dmy_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n",
|
||||
"y_probas_dmy = cross_val_predict(dmy_clf, X_train, y_train_5, cv=3, method=\"predict_proba\")\n",
|
||||
"y_scores_dmy = y_probas_dmy[:, 1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 69,
|
||||
"execution_count": 68,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"scrolled": true
|
||||
|
@ -1080,7 +1048,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 70,
|
||||
"execution_count": 69,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -1093,7 +1061,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 71,
|
||||
"execution_count": 70,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -1104,7 +1072,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 72,
|
||||
"execution_count": 71,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -1116,7 +1084,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 73,
|
||||
"execution_count": 72,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -1131,7 +1099,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 74,
|
||||
"execution_count": 73,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -1151,7 +1119,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 75,
|
||||
"execution_count": 74,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -1162,7 +1130,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 76,
|
||||
"execution_count": 75,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
|
@ -1173,7 +1141,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 77,
|
||||
"execution_count": 76,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -1184,7 +1152,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 78,
|
||||
"execution_count": 77,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -1196,7 +1164,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 79,
|
||||
"execution_count": 78,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -1247,7 +1215,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.5.1"
|
||||
"version": "3.5.2"
|
||||
},
|
||||
"nav_menu": {},
|
||||
"toc": {
|
||||
|
|
Loading…
Reference in New Issue