Replace n_inputs with n_outputs, fixes #125

This commit is contained in:
Aurelien Geron
2017-12-07 18:57:30 -08:00
parent 93792c131b
commit 49d67ecc14

View File

@@ -2749,7 +2749,7 @@
" error = Y_proba - Y_train_one_hot\n", " error = Y_proba - Y_train_one_hot\n",
" if iteration % 500 == 0:\n", " if iteration % 500 == 0:\n",
" print(iteration, loss)\n", " print(iteration, loss)\n",
" gradients = 1/m * X_train.T.dot(error) + np.r_[np.zeros([1, n_inputs]), alpha * Theta[1:]]\n", " gradients = 1/m * X_train.T.dot(error) + np.r_[np.zeros([1, n_outputs]), alpha * Theta[1:]]\n",
" Theta = Theta - eta * gradients" " Theta = Theta - eta * gradients"
] ]
}, },
@@ -2853,7 +2853,7 @@
" l2_loss = 1/2 * np.sum(np.square(Theta[1:]))\n", " l2_loss = 1/2 * np.sum(np.square(Theta[1:]))\n",
" loss = xentropy_loss + alpha * l2_loss\n", " loss = xentropy_loss + alpha * l2_loss\n",
" error = Y_proba - Y_train_one_hot\n", " error = Y_proba - Y_train_one_hot\n",
" gradients = 1/m * X_train.T.dot(error) + np.r_[np.zeros([1, n_inputs]), alpha * Theta[1:]]\n", " gradients = 1/m * X_train.T.dot(error) + np.r_[np.zeros([1, n_outputs]), alpha * Theta[1:]]\n",
" Theta = Theta - eta * gradients\n", " Theta = Theta - eta * gradients\n",
"\n", "\n",
" logits = X_valid.dot(Theta)\n", " logits = X_valid.dot(Theta)\n",