import math
from mxnet import nd
from mxnet.gluon import nn
Masked softmax.
# X: 3-D tensor, valid_length: 1-D or 2-D tensor
def masked_softmax(X, valid_length):
if valid_length is None:
return X.softmax()
else:
shape = X.shape
if valid_length.ndim == 1:
valid_length = valid_length.repeat(shape[1], axis=0)
else:
valid_length = valid_length.reshape((-1,))
# fill masked elements with a large negative, whose exp is 0
X = nd.SequenceMask(X.reshape((-1, shape[-1])), valid_length, True,
axis=1, value=-1e6)
return X.softmax().reshape(shape)
Example
masked_softmax(nd.random.uniform(shape=(2,2,4)), nd.array([2,3]))
[[[0.488994 0.511006 0. 0. ] [0.43654838 0.56345165 0. 0. ]] [[0.28817102 0.3519408 0.3598882 0. ] [0.29034293 0.25239873 0.45725834 0. ]]] <NDArray 2x2x4 @cpu(0)>
class DotProductAttention(nn.Block):
def __init__(self, dropout, **kwargs):
super(DotProductAttention, self).__init__(**kwargs)
self.dropout = nn.Dropout(dropout)
# query: (batch_size, #queries, d)
# key: (batch_size, #kv_pairs, d)
# value: (batch_size, #kv_pairs, dim_v)
# valid_length: either (batch_size, ) or (batch_size, seq_len)
def forward(self, query, key, value, valid_length=None):
d = query.shape[-1]
# set transpose_b=True to swap the last two dimensions of key
scores = nd.batch_dot(query, key, transpose_b=True) / math.sqrt(d)
attention_weights = self.dropout(masked_softmax(scores, valid_length))
return nd.batch_dot(attention_weights, value)
Example:
atten = DotProductAttention(dropout=0.5)
atten.initialize()
keys = nd.ones((2,10,2))
values = nd.arange(40).reshape((1,10,4)).repeat(2,axis=0)
atten(nd.ones((2,1,2)), keys, values, nd.array([2, 6]))
[[[ 2. 3. 4. 5. ]] [[10. 11. 12.000001 13. ]]] <NDArray 2x1x4 @cpu(0)>
$\mathbf W_k\in\mathbb R^{h\times d_k}$, $\mathbf W_q\in\mathbb R^{h\times d_q}$, and $\mathbf v\in\mathbb R^{p}$:
$$\alpha(\mathbf k, \mathbf q) = \mathbf v^T \text{tanh}(\mathbf W_k \mathbf k + \mathbf W_q\mathbf q). $$class MLPAttention(nn.Block): # This class is saved in d2l.
def __init__(self, units, dropout, **kwargs):
super(MLPAttention, self).__init__(**kwargs)
# Use flatten=True to keep query's and key's 3-D shapes.
self.W_k = nn.Dense(units, activation='tanh',
use_bias=False, flatten=False)
self.W_q = nn.Dense(units, activation='tanh',
use_bias=False, flatten=False)
self.v = nn.Dense(1, use_bias=False, flatten=False)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, valid_length):
query, key = self.W_k(query), self.W_q(key)
# expand query to (batch_size, #querys, 1, units), and key to
# (batch_size, 1, #kv_pairs, units). Then plus them with broadcast.
features = query.expand_dims(axis=2) + key.expand_dims(axis=1)
scores = self.v(features).squeeze(axis=-1)
attention_weights = self.dropout(masked_softmax(scores, valid_length))
return nd.batch_dot(attention_weights, value)
Example
atten = MLPAttention(units=8, dropout=0.1)
atten.initialize()
atten(nd.ones((2,1,2)), keys, values, nd.array([2, 6]))
[[[ 2. 3. 4. 5. ]] [[10. 11. 12.000001 13. ]]] <NDArray 2x1x4 @cpu(0)>