Add a basic implementation of MultiHeadAttention
parent
d24fee1576
commit
61bb73074b
|
@ -2109,6 +2109,13 @@
|
|||
"decoder_in = positional_encoding(decoder_embeddings)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here is a (very) simplified Transformer (the actual architecture has skip connections, layer norm, dense nets, and most importantly it uses Multi-Head Attention instead of regular Attention):"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 131,
|
||||
|
@ -2128,6 +2135,69 @@
|
|||
"outputs = output_layer(final_enc)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Here's a basic implementation of the `MultiHeadAttention` layer. One will likely be added to `keras.layers` in the near future. Note that `Conv1D` layers with `kernel_size=1` (and the default `padding=\"valid\"` and `strides=1`) is equivalent to a `TimeDistributed(Dense(...))` layer."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 132,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"K = keras.backend\n",
|
||||
"\n",
|
||||
"class MultiHeadAttention(keras.layers.Layer):\n",
|
||||
" def __init__(self, n_heads, causal=False, use_scale=False, **kwargs):\n",
|
||||
" self.n_heads = n_heads\n",
|
||||
" self.causal = causal\n",
|
||||
" self.use_scale = use_scale\n",
|
||||
" super().__init__(**kwargs)\n",
|
||||
" def build(self, batch_input_shape):\n",
|
||||
" self.dims = batch_input_shape[0][-1]\n",
|
||||
" self.q_dims, self.v_dims, self.k_dims = [self.dims // self.n_heads] * 3 # could be hyperparameters instead\n",
|
||||
" self.q_linear = keras.layers.Conv1D(self.n_heads * self.q_dims, kernel_size=1, use_bias=False)\n",
|
||||
" self.v_linear = keras.layers.Conv1D(self.n_heads * self.v_dims, kernel_size=1, use_bias=False)\n",
|
||||
" self.k_linear = keras.layers.Conv1D(self.n_heads * self.k_dims, kernel_size=1, use_bias=False)\n",
|
||||
" self.attention = keras.layers.Attention(causal=self.causal, use_scale=self.use_scale)\n",
|
||||
" self.out_linear = keras.layers.Conv1D(self.dims, kernel_size=1, use_bias=False)\n",
|
||||
" super().build(batch_input_shape)\n",
|
||||
" def _multi_head_linear(self, inputs, linear):\n",
|
||||
" shape = K.concatenate([K.shape(inputs)[:-1], [self.n_heads, -1]])\n",
|
||||
" projected = K.reshape(linear(inputs), shape)\n",
|
||||
" perm = K.permute_dimensions(projected, [0, 2, 1, 3])\n",
|
||||
" return K.reshape(perm, [shape[0] * self.n_heads, shape[1], -1])\n",
|
||||
" def call(self, inputs):\n",
|
||||
" q = inputs[0]\n",
|
||||
" v = inputs[1]\n",
|
||||
" k = inputs[2] if len(inputs) > 2 else v\n",
|
||||
" shape = K.shape(q)\n",
|
||||
" q_proj = self._multi_head_linear(q, self.q_linear)\n",
|
||||
" v_proj = self._multi_head_linear(v, self.v_linear)\n",
|
||||
" k_proj = self._multi_head_linear(k, self.k_linear)\n",
|
||||
" multi_attended = self.attention([q_proj, v_proj, k_proj])\n",
|
||||
" shape_attended = K.shape(multi_attended)\n",
|
||||
" reshaped_attended = K.reshape(multi_attended, [shape[0], self.n_heads, shape_attended[1], shape_attended[2]])\n",
|
||||
" perm = K.permute_dimensions(reshaped_attended, [0, 2, 1, 3])\n",
|
||||
" concat = K.reshape(perm, [shape[0], shape_attended[1], -1])\n",
|
||||
" return self.out_linear(concat)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 133,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Q = np.random.rand(2, 50, 512)\n",
|
||||
"V = np.random.rand(2, 80, 512)\n",
|
||||
"multi_attn = MultiHeadAttention(8)\n",
|
||||
"multi_attn([Q, V]).shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
|
@ -2167,7 +2237,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 132,
|
||||
"execution_count": 134,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -2212,7 +2282,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 133,
|
||||
"execution_count": 135,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -2229,7 +2299,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 134,
|
||||
"execution_count": 136,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -2246,7 +2316,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 135,
|
||||
"execution_count": 137,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -2267,7 +2337,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 136,
|
||||
"execution_count": 138,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
|
Loading…
Reference in New Issue