From ecb0d81b2d6c7b3484a02e678bc8e83fb0c4d8c1 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 24 Feb 2020 17:58:24 -0500 Subject: [PATCH] [PYTHON] Added missing files for nn submodule --- python/triton/nn/__init__.py | 2 + python/triton/nn/attention.py | 312 ++++++++++++++++++++++++++++++++++ python/triton/nn/conv.py | 116 +++++++++++++ python/triton/nn/linear.py | 13 ++ 4 files changed, 443 insertions(+) create mode 100644 python/triton/nn/__init__.py create mode 100644 python/triton/nn/attention.py create mode 100644 python/triton/nn/conv.py create mode 100644 python/triton/nn/linear.py diff --git a/python/triton/nn/__init__.py b/python/triton/nn/__init__.py new file mode 100644 index 000000000..84c1a0a78 --- /dev/null +++ b/python/triton/nn/__init__.py @@ -0,0 +1,2 @@ +from .conv import replace_conv2d +from .attention import replace_mah \ No newline at end of file diff --git a/python/triton/nn/attention.py b/python/triton/nn/attention.py new file mode 100644 index 000000000..84d198b2e --- /dev/null +++ b/python/triton/nn/attention.py @@ -0,0 +1,312 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton + +def bmm(x, w, mask = None): + b, m, k = x.size() + b, k, n = w.size() + out = torch.empty([b, m, n], device=x.device) + triton.ops.einsum('bmk,bkn->bmn', x, w, out, mask=mask, bench=False) + return out + +def multi_head_attention_forward(query, # type: Tensor + key, # type: Tensor + value, # type: Tensor + embed_dim_to_check, # type: int + num_heads, # type: int + in_proj_weight, # type: Tensor + in_proj_bias, # type: Tensor + bias_k, # type: Optional[Tensor] + bias_v, # type: Optional[Tensor] + add_zero_attn, # type: bool + dropout_p, # type: float + out_proj_weight, # type: Tensor + out_proj_bias, # type: Tensor + training=True, # type: bool + key_padding_mask=None, # type: Optional[Tensor] + need_weights=True, # type: bool + attn_mask=None, # type: Optional[Tensor] + use_separate_proj_weight=False, # type: bool + q_proj_weight=None, # type: Optional[Tensor] + k_proj_weight=None, # type: Optional[Tensor] + v_proj_weight=None, # type: Optional[Tensor] + static_k=None, # type: Optional[Tensor] + static_v=None, # type: Optional[Tensor] + acc_bitmask=None + ): + # type: (...) -> Tuple[Tensor, Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: mask that prevents attention to certain positions. This is an additive mask + (i.e. the values will be added to the attention layer). + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in differnt forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length. + - attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size() == value.size() + + head_dim = embed_dim // num_heads + assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + + if not use_separate_proj_weight: + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = F.linear(query, _w, _b) + + if key is None: + assert value is None + k = None + v = None + else: + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = F.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = F.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = F.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = F.linear(value, _w, _b) + else: + q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) + len1, len2 = q_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == query.size(-1) + + k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) + len1, len2 = k_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == key.size(-1) + + v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) + len1, len2 = v_proj_weight_non_opt.size() + assert len1 == embed_dim and len2 == value.size(-1) + + if in_proj_bias is not None: + q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) + k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) + v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) + else: + q = F.linear(query, q_proj_weight_non_opt, in_proj_bias) + k = F.linear(key, k_proj_weight_non_opt, in_proj_bias) + v = F.linear(value, v_proj_weight_non_opt, in_proj_bias) + q = q * scaling + + if bias_k is not None and bias_v is not None: + if static_k is None and static_v is None: + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, + torch.zeros((attn_mask.size(0), 1), + dtype=attn_mask.dtype, + device=attn_mask.device)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), + dtype=key_padding_mask.dtype, + device=key_padding_mask.device)], dim=1) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + else: + assert bias_k is None + assert bias_v is None + + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * num_heads + assert static_v.size(2) == head_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1), + dtype=attn_mask.dtype, + device=attn_mask.device)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, torch.zeros((key_padding_mask.size(0), 1), + dtype=key_padding_mask.dtype, + device=key_padding_mask.device)], dim=1) + + + attn_output_weights = bmm(q, k.transpose(1, 2), mask=acc_bitmask) + assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) + + attn_output_weights = F.softmax( + attn_output_weights, dim=-1) + attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training) + + attn_output = bmm(attn_output_weights, v, mask=acc_bitmask) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + +class MultiheadAttention(nn.modules.activation.MultiheadAttention): + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, acc_bitmask=None): + super(MultiheadAttention, self).__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim) + self.acc_bitmask = acc_bitmask + + + def forward(self, query, key, value, key_padding_mask=None, + need_weights=True, attn_mask=None): + if not self._qkv_same_embed_dim: + return multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask, use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + acc_bitmask=self.acc_bitmask) + else: + return multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask, + acc_bitmask=self.acc_bitmask) + + +def replace_mah(model, mask = None): + for child_name, child in model.named_children(): + if isinstance(child, nn.modules.activation.MultiheadAttention): + add_bias_kv = child.bias_k is not None + device = child.in_proj_weight.device + mah = MultiheadAttention(child.embed_dim, child.num_heads, + dropout=child.dropout, add_bias_kv=add_bias_kv, + add_zero_attn=child.add_zero_attn, kdim=child.kdim, + vdim=child.vdim, acc_bitmask=mask).to(device) + for yparam, xparam in zip(mah.parameters(), child.parameters()): + yparam.data.copy_(xparam.data) + setattr(model, child_name, mah) + else: + replace_mah(child, mask) \ No newline at end of file diff --git a/python/triton/nn/conv.py b/python/triton/nn/conv.py new file mode 100644 index 000000000..f0bcf589c --- /dev/null +++ b/python/triton/nn/conv.py @@ -0,0 +1,116 @@ +import triton +import torch.nn as nn +import torch +import torch.nn.functional as F + +class _conv2d(torch.autograd.Function): + + @staticmethod + def forward(ctx, input, weight, bias, + stride, padding, dilation, groups, + acc_bitmask): + assert dilation == (1, 1) + assert groups == 1 + assert bias == None + pad_h, pad_w = padding + stride_h, stride_w = stride + n, c, h, w = input.size() + k, c, r, s = weight.size() + # allocate output + p = (h + 2*padding[0] - r)//stride[0] + 1 + q = (w + 2*padding[1] - s)//stride[1] + 1 + output = torch.empty((n, k, p, q), dtype=input.dtype, device=input.device) + # padding + if pad_h or pad_w: + input = triton.ops._einsum.pad(input, [pad_w, pad_w, pad_h, pad_h]) + # convolution + triton.ops.einsum(f'nc(h*stride_h + r - pad_h)(w*stride_w + s - pad_w),kcrs->nkhw', + input, weight, mask=acc_bitmask, + output=output, + values = {'pad_h': pad_h, + 'stride_h': stride_h, + 'pad_w': pad_w, + 'stride_w': stride_w}) + # prepare backprop + ctx.save_for_backward(input, weight) + ctx.stride = stride + ctx.padding = padding + ctx.acc_bitmask = acc_bitmask + # return + return output + + @staticmethod + def backward(ctx, dy): + # retrieve contextual information + input, weight = ctx.saved_tensors + stride = ctx.stride + padding = ctx.padding + acc_bitmask = ctx.acc_bitmask + # gradient of the input + dx = None + if ctx.needs_input_grad[0]: + # dy must be padded + n, k, p, q = dy.size() + n, c, h, w = input.size() + k, c, r, s = weight.size() + dypad = triton.ops._einsum.pad(dy, [4, 4, 4, 4]) + # have to be careful here + # the gradient of strided conv is a conv over a sparse image + # which can be decomposed as a set of smaller convs + dx = torch.empty_like(input) + for offh in range(stride[0]): + for offw in range(stride[1]): + poffh = (offh + padding[0]) % stride[0] + poffw = (offw + padding[1]) % stride[1] + pad_h = int((padding[0] + (stride[0] - 1)*offh) / stride[0]) + pad_w = int((padding[1] + (stride[1] - 1)*offw) / stride[1]) + if offh >= r or offw >= s: + dx[:, :, poffh::stride[0], poffw::stride[1]] = 0 + else: + triton.ops.einsum(f'nk(h - r + pad_h)(w - s + pad_w),kcrs->nchw', + dypad[:, :, :, :], + weight[:, :, offh::stride[0], offw::stride[1]], + output = dx[:, :, poffh::stride[0], poffw::stride[1]], + mask = acc_bitmask, + values = {'pad_h': pad_h, + 'pad_w': pad_w}) + # gradient for the weight + dw = None + if ctx.needs_input_grad[1]: + dw = torch.empty_like(weight) + triton.ops.einsum(f'nc(p*{stride[0]}+r-{padding[0]})(q*{stride[1]}+s-{padding[1]}),nkpq->kcrs', + input, dy, output = dw, mask = acc_bitmask) + return dx, dw, None, None, None, None, None, None +conv2d = _conv2d.apply + +class Conv2d(nn.Conv2d): + + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, + bias=True, padding_mode='zeros', + acc_bitmask = None): + super(Conv2d, self).__init__( + in_channels, out_channels, kernel_size, stride, padding, dilation, + groups, bias, padding_mode) + self.acc_bitmask = acc_bitmask + + def forward(self, input): + #if self.kernel_size[0] == 3: + # return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + return conv2d(input, self.weight, self.bias, self.stride, + self.padding, self.dilation, self.groups, + self.acc_bitmask) + + +def replace_conv2d(model, acc_bitmask = None): + for child_name, child in model.named_children(): + if isinstance(child, nn.Conv2d): + conv2d = Conv2d(child.in_channels, child.out_channels, child.kernel_size, + child.stride, child.padding, child.dilation, child.groups, + child.bias, child.padding_mode, + acc_bitmask=acc_bitmask) + for yparam, xparam in zip(conv2d.parameters(), child.parameters()): + yparam.data.copy_(xparam.data) + setattr(model, child_name, conv2d) + else: + replace_conv2d(child, acc_bitmask) \ No newline at end of file diff --git a/python/triton/nn/linear.py b/python/triton/nn/linear.py new file mode 100644 index 000000000..a62fee600 --- /dev/null +++ b/python/triton/nn/linear.py @@ -0,0 +1,13 @@ +import torch +import triton + + +def linear(x, w, bias = None): + print(x.size(), w.size()) + m, k = x.size() + k, n = w.size() + out = torch.empty([m, n], device=x.device) + triton.ops.einsum('mk,nk->mn', x, w, bias) + if bias is not None: + out += bias + return out \ No newline at end of file