mean_squared_error 버그 수정(5786ce3047)

This commit is contained in:
Haesun Park
2018-03-26 17:07:03 +09:00
parent 7a0ce2129d
commit 08d8d790bd

View File

@@ -863,8 +863,8 @@
" model.fit(X_train[:m], y_train[:m])\n",
" y_train_predict = model.predict(X_train[:m])\n",
" y_val_predict = model.predict(X_val)\n",
" train_errors.append(mean_squared_error(y_train_predict, y_train[:m]))\n",
" val_errors.append(mean_squared_error(y_val_predict, y_val))\n",
" train_errors.append(mean_squared_error(y_train[:m], y_train_predict))\n",
" val_errors.append(mean_squared_error(y_val, y_val_predict))\n",
"\n",
" plt.plot(np.sqrt(train_errors), \"r-+\", linewidth=2, label=\"훈련\")\n",
" plt.plot(np.sqrt(val_errors), \"b-\", linewidth=3, label=\"검증\")\n",
@@ -1178,8 +1178,8 @@
" sgd_reg.fit(X_train_poly_scaled, y_train)\n",
" y_train_predict = sgd_reg.predict(X_train_poly_scaled)\n",
" y_val_predict = sgd_reg.predict(X_val_poly_scaled)\n",
" train_errors.append(mean_squared_error(y_train_predict, y_train))\n",
" val_errors.append(mean_squared_error(y_val_predict, y_val))\n",
" train_errors.append(mean_squared_error(y_train, y_train_predict))\n",
" val_errors.append(mean_squared_error(y_val, y_val_predict))\n",
"\n",
"best_epoch = np.argmin(val_errors)\n",
"best_val_rmse = np.sqrt(val_errors[best_epoch])\n",