Replace n_inputs with n_outputs, fixes #125
parent
53c2133bc7
commit
63c1523528
|
@ -1909,7 +1909,7 @@
|
|||
" error = Y_proba - Y_train_one_hot\n",
|
||||
" if iteration % 500 == 0:\n",
|
||||
" print(iteration, loss)\n",
|
||||
" gradients = 1/m * X_train.T.dot(error) + np.r_[np.zeros([1, n_inputs]), alpha * Theta[1:]]\n",
|
||||
" gradients = 1/m * X_train.T.dot(error) + np.r_[np.zeros([1, n_outputs]), alpha * Theta[1:]]\n",
|
||||
" Theta = Theta - eta * gradients"
|
||||
]
|
||||
},
|
||||
|
@ -1987,7 +1987,7 @@
|
|||
" l2_loss = 1/2 * np.sum(np.square(Theta[1:]))\n",
|
||||
" loss = xentropy_loss + alpha * l2_loss\n",
|
||||
" error = Y_proba - Y_train_one_hot\n",
|
||||
" gradients = 1/m * X_train.T.dot(error) + np.r_[np.zeros([1, n_inputs]), alpha * Theta[1:]]\n",
|
||||
" gradients = 1/m * X_train.T.dot(error) + np.r_[np.zeros([1, n_outputs]), alpha * Theta[1:]]\n",
|
||||
" Theta = Theta - eta * gradients\n",
|
||||
"\n",
|
||||
" logits = X_valid.dot(Theta)\n",
|
||||
|
|
Loading…
Reference in New Issue