Add a basic implementation of MultiHeadAttention
parent
d24fee1576
commit
61bb73074b
|
@ -2109,6 +2109,13 @@
|
||||||
"decoder_in = positional_encoding(decoder_embeddings)"
|
"decoder_in = positional_encoding(decoder_embeddings)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Here is a (very) simplified Transformer (the actual architecture has skip connections, layer norm, dense nets, and most importantly it uses Multi-Head Attention instead of regular Attention):"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 131,
|
"execution_count": 131,
|
||||||
|
@ -2128,6 +2135,69 @@
|
||||||
"outputs = output_layer(final_enc)"
|
"outputs = output_layer(final_enc)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Here's a basic implementation of the `MultiHeadAttention` layer. One will likely be added to `keras.layers` in the near future. Note that `Conv1D` layers with `kernel_size=1` (and the default `padding=\"valid\"` and `strides=1`) is equivalent to a `TimeDistributed(Dense(...))` layer."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 132,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"K = keras.backend\n",
|
||||||
|
"\n",
|
||||||
|
"class MultiHeadAttention(keras.layers.Layer):\n",
|
||||||
|
" def __init__(self, n_heads, causal=False, use_scale=False, **kwargs):\n",
|
||||||
|
" self.n_heads = n_heads\n",
|
||||||
|
" self.causal = causal\n",
|
||||||
|
" self.use_scale = use_scale\n",
|
||||||
|
" super().__init__(**kwargs)\n",
|
||||||
|
" def build(self, batch_input_shape):\n",
|
||||||
|
" self.dims = batch_input_shape[0][-1]\n",
|
||||||
|
" self.q_dims, self.v_dims, self.k_dims = [self.dims // self.n_heads] * 3 # could be hyperparameters instead\n",
|
||||||
|
" self.q_linear = keras.layers.Conv1D(self.n_heads * self.q_dims, kernel_size=1, use_bias=False)\n",
|
||||||
|
" self.v_linear = keras.layers.Conv1D(self.n_heads * self.v_dims, kernel_size=1, use_bias=False)\n",
|
||||||
|
" self.k_linear = keras.layers.Conv1D(self.n_heads * self.k_dims, kernel_size=1, use_bias=False)\n",
|
||||||
|
" self.attention = keras.layers.Attention(causal=self.causal, use_scale=self.use_scale)\n",
|
||||||
|
" self.out_linear = keras.layers.Conv1D(self.dims, kernel_size=1, use_bias=False)\n",
|
||||||
|
" super().build(batch_input_shape)\n",
|
||||||
|
" def _multi_head_linear(self, inputs, linear):\n",
|
||||||
|
" shape = K.concatenate([K.shape(inputs)[:-1], [self.n_heads, -1]])\n",
|
||||||
|
" projected = K.reshape(linear(inputs), shape)\n",
|
||||||
|
" perm = K.permute_dimensions(projected, [0, 2, 1, 3])\n",
|
||||||
|
" return K.reshape(perm, [shape[0] * self.n_heads, shape[1], -1])\n",
|
||||||
|
" def call(self, inputs):\n",
|
||||||
|
" q = inputs[0]\n",
|
||||||
|
" v = inputs[1]\n",
|
||||||
|
" k = inputs[2] if len(inputs) > 2 else v\n",
|
||||||
|
" shape = K.shape(q)\n",
|
||||||
|
" q_proj = self._multi_head_linear(q, self.q_linear)\n",
|
||||||
|
" v_proj = self._multi_head_linear(v, self.v_linear)\n",
|
||||||
|
" k_proj = self._multi_head_linear(k, self.k_linear)\n",
|
||||||
|
" multi_attended = self.attention([q_proj, v_proj, k_proj])\n",
|
||||||
|
" shape_attended = K.shape(multi_attended)\n",
|
||||||
|
" reshaped_attended = K.reshape(multi_attended, [shape[0], self.n_heads, shape_attended[1], shape_attended[2]])\n",
|
||||||
|
" perm = K.permute_dimensions(reshaped_attended, [0, 2, 1, 3])\n",
|
||||||
|
" concat = K.reshape(perm, [shape[0], shape_attended[1], -1])\n",
|
||||||
|
" return self.out_linear(concat)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 133,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"Q = np.random.rand(2, 50, 512)\n",
|
||||||
|
"V = np.random.rand(2, 80, 512)\n",
|
||||||
|
"multi_attn = MultiHeadAttention(8)\n",
|
||||||
|
"multi_attn([Q, V]).shape"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -2167,7 +2237,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 132,
|
"execution_count": 134,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2212,7 +2282,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 133,
|
"execution_count": 135,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2229,7 +2299,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 134,
|
"execution_count": 136,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2246,7 +2316,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 135,
|
"execution_count": 137,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2267,7 +2337,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 136,
|
"execution_count": 138,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
|
Loading…
Reference in New Issue