diff --git a/15_recurrent_neural_networks.ipynb b/15_recurrent_neural_networks.ipynb index 3c5d541..e884fc3 100644 --- a/15_recurrent_neural_networks.ipynb +++ b/15_recurrent_neural_networks.ipynb @@ -2109,6 +2109,13 @@ "decoder_in = positional_encoding(decoder_embeddings)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here is a (very) simplified Transformer (the actual architecture has skip connections, layer norm, dense nets, and most importantly it uses Multi-Head Attention instead of regular Attention):" + ] + }, { "cell_type": "code", "execution_count": 131, @@ -2128,6 +2135,69 @@ "outputs = output_layer(final_enc)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here's a basic implementation of the `MultiHeadAttention` layer. One will likely be added to `keras.layers` in the near future. Note that `Conv1D` layers with `kernel_size=1` (and the default `padding=\"valid\"` and `strides=1`) is equivalent to a `TimeDistributed(Dense(...))` layer." + ] + }, + { + "cell_type": "code", + "execution_count": 132, + "metadata": {}, + "outputs": [], + "source": [ + "K = keras.backend\n", + "\n", + "class MultiHeadAttention(keras.layers.Layer):\n", + " def __init__(self, n_heads, causal=False, use_scale=False, **kwargs):\n", + " self.n_heads = n_heads\n", + " self.causal = causal\n", + " self.use_scale = use_scale\n", + " super().__init__(**kwargs)\n", + " def build(self, batch_input_shape):\n", + " self.dims = batch_input_shape[0][-1]\n", + " self.q_dims, self.v_dims, self.k_dims = [self.dims // self.n_heads] * 3 # could be hyperparameters instead\n", + " self.q_linear = keras.layers.Conv1D(self.n_heads * self.q_dims, kernel_size=1, use_bias=False)\n", + " self.v_linear = keras.layers.Conv1D(self.n_heads * self.v_dims, kernel_size=1, use_bias=False)\n", + " self.k_linear = keras.layers.Conv1D(self.n_heads * self.k_dims, kernel_size=1, use_bias=False)\n", + " self.attention = keras.layers.Attention(causal=self.causal, use_scale=self.use_scale)\n", + " self.out_linear = keras.layers.Conv1D(self.dims, kernel_size=1, use_bias=False)\n", + " super().build(batch_input_shape)\n", + " def _multi_head_linear(self, inputs, linear):\n", + " shape = K.concatenate([K.shape(inputs)[:-1], [self.n_heads, -1]])\n", + " projected = K.reshape(linear(inputs), shape)\n", + " perm = K.permute_dimensions(projected, [0, 2, 1, 3])\n", + " return K.reshape(perm, [shape[0] * self.n_heads, shape[1], -1])\n", + " def call(self, inputs):\n", + " q = inputs[0]\n", + " v = inputs[1]\n", + " k = inputs[2] if len(inputs) > 2 else v\n", + " shape = K.shape(q)\n", + " q_proj = self._multi_head_linear(q, self.q_linear)\n", + " v_proj = self._multi_head_linear(v, self.v_linear)\n", + " k_proj = self._multi_head_linear(k, self.k_linear)\n", + " multi_attended = self.attention([q_proj, v_proj, k_proj])\n", + " shape_attended = K.shape(multi_attended)\n", + " reshaped_attended = K.reshape(multi_attended, [shape[0], self.n_heads, shape_attended[1], shape_attended[2]])\n", + " perm = K.permute_dimensions(reshaped_attended, [0, 2, 1, 3])\n", + " concat = K.reshape(perm, [shape[0], shape_attended[1], -1])\n", + " return self.out_linear(concat)" + ] + }, + { + "cell_type": "code", + "execution_count": 133, + "metadata": {}, + "outputs": [], + "source": [ + "Q = np.random.rand(2, 50, 512)\n", + "V = np.random.rand(2, 80, 512)\n", + "multi_attn = MultiHeadAttention(8)\n", + "multi_attn([Q, V]).shape" + ] + }, { "cell_type": "markdown", "metadata": { @@ -2167,7 +2237,7 @@ }, { "cell_type": "code", - "execution_count": 132, + "execution_count": 134, "metadata": {}, "outputs": [], "source": [ @@ -2212,7 +2282,7 @@ }, { "cell_type": "code", - "execution_count": 133, + "execution_count": 135, "metadata": {}, "outputs": [], "source": [ @@ -2229,7 +2299,7 @@ }, { "cell_type": "code", - "execution_count": 134, + "execution_count": 136, "metadata": {}, "outputs": [], "source": [ @@ -2246,7 +2316,7 @@ }, { "cell_type": "code", - "execution_count": 135, + "execution_count": 137, "metadata": {}, "outputs": [], "source": [ @@ -2267,7 +2337,7 @@ }, { "cell_type": "code", - "execution_count": 136, + "execution_count": 138, "metadata": {}, "outputs": [], "source": [