mean_squared_error(y_true, y_pred) instead of (y_pred, y_true), for clarity (result unchanged). Fixes #158

main
Aurélien Geron 2018-03-15 19:17:51 +01:00
parent eefe262dca
commit e9c97ff3b8
1 changed files with 5 additions and 5 deletions

View File

@ -617,8 +617,8 @@
" model.fit(X_train[:m], y_train[:m])\n", " model.fit(X_train[:m], y_train[:m])\n",
" y_train_predict = model.predict(X_train[:m])\n", " y_train_predict = model.predict(X_train[:m])\n",
" y_val_predict = model.predict(X_val)\n", " y_val_predict = model.predict(X_val)\n",
" train_errors.append(mean_squared_error(y_train_predict, y_train[:m]))\n", " train_errors.append(mean_squared_error(y_train[:m], y_train_predict))\n",
" val_errors.append(mean_squared_error(y_val_predict, y_val))\n", " val_errors.append(mean_squared_error(y_val, y_val_predict))\n",
"\n", "\n",
" plt.plot(np.sqrt(train_errors), \"r-+\", linewidth=2, label=\"train\")\n", " plt.plot(np.sqrt(train_errors), \"r-+\", linewidth=2, label=\"train\")\n",
" plt.plot(np.sqrt(val_errors), \"b-\", linewidth=3, label=\"val\")\n", " plt.plot(np.sqrt(val_errors), \"b-\", linewidth=3, label=\"val\")\n",
@ -822,8 +822,8 @@
" sgd_reg.fit(X_train_poly_scaled, y_train)\n", " sgd_reg.fit(X_train_poly_scaled, y_train)\n",
" y_train_predict = sgd_reg.predict(X_train_poly_scaled)\n", " y_train_predict = sgd_reg.predict(X_train_poly_scaled)\n",
" y_val_predict = sgd_reg.predict(X_val_poly_scaled)\n", " y_val_predict = sgd_reg.predict(X_val_poly_scaled)\n",
" train_errors.append(mean_squared_error(y_train_predict, y_train))\n", " train_errors.append(mean_squared_error(y_train, y_train_predict))\n",
" val_errors.append(mean_squared_error(y_val_predict, y_val))\n", " val_errors.append(mean_squared_error(y_val, y_val_predict))\n",
"\n", "\n",
"best_epoch = np.argmin(val_errors)\n", "best_epoch = np.argmin(val_errors)\n",
"best_val_rmse = np.sqrt(val_errors[best_epoch])\n", "best_val_rmse = np.sqrt(val_errors[best_epoch])\n",
@ -863,7 +863,7 @@
"for epoch in range(1000):\n", "for epoch in range(1000):\n",
" sgd_reg.fit(X_train_poly_scaled, y_train) # continues where it left off\n", " sgd_reg.fit(X_train_poly_scaled, y_train) # continues where it left off\n",
" y_val_predict = sgd_reg.predict(X_val_poly_scaled)\n", " y_val_predict = sgd_reg.predict(X_val_poly_scaled)\n",
" val_error = mean_squared_error(y_val_predict, y_val)\n", " val_error = mean_squared_error(y_val, y_val_predict)\n",
" if val_error < minimum_val_error:\n", " if val_error < minimum_val_error:\n",
" minimum_val_error = val_error\n", " minimum_val_error = val_error\n",
" best_epoch = epoch\n", " best_epoch = epoch\n",