From e9c97ff3b8117ae584d8c52e57ab77535e031e8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Thu, 15 Mar 2018 19:17:51 +0100 Subject: [PATCH] mean_squared_error(y_true, y_pred) instead of (y_pred, y_true), for clarity (result unchanged). Fixes #158 --- 04_training_linear_models.ipynb | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/04_training_linear_models.ipynb b/04_training_linear_models.ipynb index 8acdbfc..1845e8e 100644 --- a/04_training_linear_models.ipynb +++ b/04_training_linear_models.ipynb @@ -617,8 +617,8 @@ " model.fit(X_train[:m], y_train[:m])\n", " y_train_predict = model.predict(X_train[:m])\n", " y_val_predict = model.predict(X_val)\n", - " train_errors.append(mean_squared_error(y_train_predict, y_train[:m]))\n", - " val_errors.append(mean_squared_error(y_val_predict, y_val))\n", + " train_errors.append(mean_squared_error(y_train[:m], y_train_predict))\n", + " val_errors.append(mean_squared_error(y_val, y_val_predict))\n", "\n", " plt.plot(np.sqrt(train_errors), \"r-+\", linewidth=2, label=\"train\")\n", " plt.plot(np.sqrt(val_errors), \"b-\", linewidth=3, label=\"val\")\n", @@ -822,8 +822,8 @@ " sgd_reg.fit(X_train_poly_scaled, y_train)\n", " y_train_predict = sgd_reg.predict(X_train_poly_scaled)\n", " y_val_predict = sgd_reg.predict(X_val_poly_scaled)\n", - " train_errors.append(mean_squared_error(y_train_predict, y_train))\n", - " val_errors.append(mean_squared_error(y_val_predict, y_val))\n", + " train_errors.append(mean_squared_error(y_train, y_train_predict))\n", + " val_errors.append(mean_squared_error(y_val, y_val_predict))\n", "\n", "best_epoch = np.argmin(val_errors)\n", "best_val_rmse = np.sqrt(val_errors[best_epoch])\n", @@ -863,7 +863,7 @@ "for epoch in range(1000):\n", " sgd_reg.fit(X_train_poly_scaled, y_train) # continues where it left off\n", " y_val_predict = sgd_reg.predict(X_val_poly_scaled)\n", - " val_error = mean_squared_error(y_val_predict, y_val)\n", + " val_error = mean_squared_error(y_val, y_val_predict)\n", " if val_error < minimum_val_error:\n", " minimum_val_error = val_error\n", " best_epoch = epoch\n",