From 33f9ff10b47b1c8088325d48c1207b87de87e9ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Thu, 4 Mar 2021 15:17:19 +1300 Subject: [PATCH] Use default splitter="best" instead of splitter="random", fixes #340 --- 07_ensemble_learning_and_random_forests.ipynb | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/07_ensemble_learning_and_random_forests.ipynb b/07_ensemble_learning_and_random_forests.ipynb index 089f502..0aabfed 100644 --- a/07_ensemble_learning_and_random_forests.ipynb +++ b/07_ensemble_learning_and_random_forests.ipynb @@ -242,7 +242,7 @@ "from sklearn.tree import DecisionTreeClassifier\n", "\n", "bag_clf = BaggingClassifier(\n", - " DecisionTreeClassifier(random_state=42), n_estimators=500,\n", + " DecisionTreeClassifier(), n_estimators=500,\n", " max_samples=100, bootstrap=True, random_state=42)\n", "bag_clf.fit(X_train, y_train)\n", "y_pred = bag_clf.predict(X_test)" @@ -327,9 +327,11 @@ "metadata": {}, "outputs": [], "source": [ + "from math import ceil, sqrt\n", + "\n", "bag_clf = BaggingClassifier(\n", - " DecisionTreeClassifier(splitter=\"random\", max_leaf_nodes=16, random_state=42),\n", - " n_estimators=500, max_samples=1.0, bootstrap=True, random_state=42)" + " DecisionTreeClassifier(max_leaf_nodes=16),\n", + " n_estimators=500, max_features=ceil(sqrt(X_train.shape[1])), random_state=42)" ] }, { @@ -362,7 +364,7 @@ "metadata": {}, "outputs": [], "source": [ - "np.sum(y_pred == y_pred_rf) / len(y_pred) # almost identical predictions" + "np.sum(y_pred == y_pred_rf) / len(y_pred) # very similar predictions" ] }, { @@ -419,7 +421,7 @@ "outputs": [], "source": [ "bag_clf = BaggingClassifier(\n", - " DecisionTreeClassifier(random_state=42), n_estimators=500,\n", + " DecisionTreeClassifier(), n_estimators=500,\n", " bootstrap=True, oob_score=True, random_state=40)\n", "bag_clf.fit(X_train, y_train)\n", "bag_clf.oob_score_"