mean_squared_error(y_true, y_pred) instead of (y_pred, y_true), for clarity (result unchanged). Fixes #158
parent
eefe262dca
commit
e9c97ff3b8
|
@ -617,8 +617,8 @@
|
||||||
" model.fit(X_train[:m], y_train[:m])\n",
|
" model.fit(X_train[:m], y_train[:m])\n",
|
||||||
" y_train_predict = model.predict(X_train[:m])\n",
|
" y_train_predict = model.predict(X_train[:m])\n",
|
||||||
" y_val_predict = model.predict(X_val)\n",
|
" y_val_predict = model.predict(X_val)\n",
|
||||||
" train_errors.append(mean_squared_error(y_train_predict, y_train[:m]))\n",
|
" train_errors.append(mean_squared_error(y_train[:m], y_train_predict))\n",
|
||||||
" val_errors.append(mean_squared_error(y_val_predict, y_val))\n",
|
" val_errors.append(mean_squared_error(y_val, y_val_predict))\n",
|
||||||
"\n",
|
"\n",
|
||||||
" plt.plot(np.sqrt(train_errors), \"r-+\", linewidth=2, label=\"train\")\n",
|
" plt.plot(np.sqrt(train_errors), \"r-+\", linewidth=2, label=\"train\")\n",
|
||||||
" plt.plot(np.sqrt(val_errors), \"b-\", linewidth=3, label=\"val\")\n",
|
" plt.plot(np.sqrt(val_errors), \"b-\", linewidth=3, label=\"val\")\n",
|
||||||
|
@ -822,8 +822,8 @@
|
||||||
" sgd_reg.fit(X_train_poly_scaled, y_train)\n",
|
" sgd_reg.fit(X_train_poly_scaled, y_train)\n",
|
||||||
" y_train_predict = sgd_reg.predict(X_train_poly_scaled)\n",
|
" y_train_predict = sgd_reg.predict(X_train_poly_scaled)\n",
|
||||||
" y_val_predict = sgd_reg.predict(X_val_poly_scaled)\n",
|
" y_val_predict = sgd_reg.predict(X_val_poly_scaled)\n",
|
||||||
" train_errors.append(mean_squared_error(y_train_predict, y_train))\n",
|
" train_errors.append(mean_squared_error(y_train, y_train_predict))\n",
|
||||||
" val_errors.append(mean_squared_error(y_val_predict, y_val))\n",
|
" val_errors.append(mean_squared_error(y_val, y_val_predict))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"best_epoch = np.argmin(val_errors)\n",
|
"best_epoch = np.argmin(val_errors)\n",
|
||||||
"best_val_rmse = np.sqrt(val_errors[best_epoch])\n",
|
"best_val_rmse = np.sqrt(val_errors[best_epoch])\n",
|
||||||
|
@ -863,7 +863,7 @@
|
||||||
"for epoch in range(1000):\n",
|
"for epoch in range(1000):\n",
|
||||||
" sgd_reg.fit(X_train_poly_scaled, y_train) # continues where it left off\n",
|
" sgd_reg.fit(X_train_poly_scaled, y_train) # continues where it left off\n",
|
||||||
" y_val_predict = sgd_reg.predict(X_val_poly_scaled)\n",
|
" y_val_predict = sgd_reg.predict(X_val_poly_scaled)\n",
|
||||||
" val_error = mean_squared_error(y_val_predict, y_val)\n",
|
" val_error = mean_squared_error(y_val, y_val_predict)\n",
|
||||||
" if val_error < minimum_val_error:\n",
|
" if val_error < minimum_val_error:\n",
|
||||||
" minimum_val_error = val_error\n",
|
" minimum_val_error = val_error\n",
|
||||||
" best_epoch = epoch\n",
|
" best_epoch = epoch\n",
|
||||||
|
|
Loading…
Reference in New Issue