[PYTHON] Added missing files for nn submodule
This commit is contained in:
committed by
Philippe Tillet
parent
3d769b57e2
commit
ecb0d81b2d
2
python/triton/nn/__init__.py
Normal file
2
python/triton/nn/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .conv import replace_conv2d
|
||||
from .attention import replace_mah
|
312
python/triton/nn/attention.py
Normal file
312
python/triton/nn/attention.py
Normal file
@@ -0,0 +1,312 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
|
||||
def bmm(x, w, mask = None):
|
||||
b, m, k = x.size()
|
||||
b, k, n = w.size()
|
||||
out = torch.empty([b, m, n], device=x.device)
|
||||
triton.ops.einsum('bmk,bkn->bmn', x, w, out, mask=mask, bench=False)
|
||||
return out
|
||||
|
||||
def multi_head_attention_forward(query, # type: Tensor
|
||||
key, # type: Tensor
|
||||
value, # type: Tensor
|
||||
embed_dim_to_check, # type: int
|
||||
num_heads, # type: int
|
||||
in_proj_weight, # type: Tensor
|
||||
in_proj_bias, # type: Tensor
|
||||
bias_k, # type: Optional[Tensor]
|
||||
bias_v, # type: Optional[Tensor]
|
||||
add_zero_attn, # type: bool
|
||||
dropout_p, # type: float
|
||||
out_proj_weight, # type: Tensor
|
||||
out_proj_bias, # type: Tensor
|
||||
training=True, # type: bool
|
||||
key_padding_mask=None, # type: Optional[Tensor]
|
||||
need_weights=True, # type: bool
|
||||
attn_mask=None, # type: Optional[Tensor]
|
||||
use_separate_proj_weight=False, # type: bool
|
||||
q_proj_weight=None, # type: Optional[Tensor]
|
||||
k_proj_weight=None, # type: Optional[Tensor]
|
||||
v_proj_weight=None, # type: Optional[Tensor]
|
||||
static_k=None, # type: Optional[Tensor]
|
||||
static_v=None, # type: Optional[Tensor]
|
||||
acc_bitmask=None
|
||||
):
|
||||
# type: (...) -> Tuple[Tensor, Optional[Tensor]]
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
See "Attention Is All You Need" for more details.
|
||||
embed_dim_to_check: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
in_proj_weight, in_proj_bias: input projection weight and bias.
|
||||
bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
|
||||
add_zero_attn: add a new batch of zeros to the key and
|
||||
value sequences at dim=1.
|
||||
dropout_p: probability of an element to be zeroed.
|
||||
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
||||
training: apply dropout if is ``True``.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. This is an binary mask. When the value is True,
|
||||
the corresponding value on the attention layer will be filled with -inf.
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: mask that prevents attention to certain positions. This is an additive mask
|
||||
(i.e. the values will be added to the attention layer).
|
||||
use_separate_proj_weight: the function accept the proj. weights for query, key,
|
||||
and value in differnt forms. If false, in_proj_weight will be used, which is
|
||||
a combination of q_proj_weight, k_proj_weight, v_proj_weight.
|
||||
q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
|
||||
static_k, static_v: static key and value used for attention operators.
|
||||
|
||||
|
||||
Shape:
|
||||
Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)`, ByteTensor, where N is the batch size, S is the source sequence length.
|
||||
- attn_mask: :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
- static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
||||
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
||||
- static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
|
||||
N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
|
||||
|
||||
Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == embed_dim_to_check
|
||||
assert key.size() == value.size()
|
||||
|
||||
head_dim = embed_dim // num_heads
|
||||
assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
||||
scaling = float(head_dim) ** -0.5
|
||||
|
||||
|
||||
if not use_separate_proj_weight:
|
||||
if torch.equal(query, key) and torch.equal(key, value):
|
||||
# self-attention
|
||||
q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
|
||||
|
||||
elif torch.equal(key, value):
|
||||
# encoder-decoder attention
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = F.linear(query, _w, _b)
|
||||
|
||||
if key is None:
|
||||
assert value is None
|
||||
k = None
|
||||
v = None
|
||||
else:
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = None
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
|
||||
|
||||
else:
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = F.linear(query, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = embed_dim * 2
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
k = F.linear(key, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim * 2
|
||||
_end = None
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
v = F.linear(value, _w, _b)
|
||||
else:
|
||||
q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
|
||||
len1, len2 = q_proj_weight_non_opt.size()
|
||||
assert len1 == embed_dim and len2 == query.size(-1)
|
||||
|
||||
k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
|
||||
len1, len2 = k_proj_weight_non_opt.size()
|
||||
assert len1 == embed_dim and len2 == key.size(-1)
|
||||
|
||||
v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
|
||||
len1, len2 = v_proj_weight_non_opt.size()
|
||||
assert len1 == embed_dim and len2 == value.size(-1)
|
||||
|
||||
if in_proj_bias is not None:
|
||||
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
|
||||
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
|
||||
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
|
||||
else:
|
||||
q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
|
||||
k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
|
||||
v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
|
||||
q = q * scaling
|
||||
|
||||
if bias_k is not None and bias_v is not None:
|
||||
if static_k is None and static_v is None:
|
||||
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat([attn_mask,
|
||||
torch.zeros((attn_mask.size(0), 1),
|
||||
dtype=attn_mask.dtype,
|
||||
device=attn_mask.device)], dim=1)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
|
||||
dtype=key_padding_mask.dtype,
|
||||
device=key_padding_mask.device)], dim=1)
|
||||
else:
|
||||
assert static_k is None, "bias cannot be added to static key."
|
||||
assert static_v is None, "bias cannot be added to static value."
|
||||
else:
|
||||
assert bias_k is None
|
||||
assert bias_v is None
|
||||
|
||||
q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
if k is not None:
|
||||
k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
if v is not None:
|
||||
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
if static_k is not None:
|
||||
assert static_k.size(0) == bsz * num_heads
|
||||
assert static_k.size(2) == head_dim
|
||||
k = static_k
|
||||
|
||||
if static_v is not None:
|
||||
assert static_v.size(0) == bsz * num_heads
|
||||
assert static_v.size(2) == head_dim
|
||||
v = static_v
|
||||
|
||||
src_len = k.size(1)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
if add_zero_attn:
|
||||
src_len += 1
|
||||
k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
|
||||
v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = torch.cat([attn_mask, torch.zeros((attn_mask.size(0), 1),
|
||||
dtype=attn_mask.dtype,
|
||||
device=attn_mask.device)], dim=1)
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = torch.cat(
|
||||
[key_padding_mask, torch.zeros((key_padding_mask.size(0), 1),
|
||||
dtype=key_padding_mask.dtype,
|
||||
device=key_padding_mask.device)], dim=1)
|
||||
|
||||
|
||||
attn_output_weights = bmm(q, k.transpose(1, 2), mask=acc_bitmask)
|
||||
assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
attn_output_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float('-inf'),
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
|
||||
|
||||
attn_output_weights = F.softmax(
|
||||
attn_output_weights, dim=-1)
|
||||
attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
|
||||
|
||||
attn_output = bmm(attn_output_weights, v, mask=acc_bitmask)
|
||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
||||
attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
|
||||
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
||||
else:
|
||||
return attn_output, None
|
||||
|
||||
class MultiheadAttention(nn.modules.activation.MultiheadAttention):
|
||||
|
||||
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, acc_bitmask=None):
|
||||
super(MultiheadAttention, self).__init__(embed_dim, num_heads, dropout, bias, add_bias_kv, add_zero_attn, kdim, vdim)
|
||||
self.acc_bitmask = acc_bitmask
|
||||
|
||||
|
||||
def forward(self, query, key, value, key_padding_mask=None,
|
||||
need_weights=True, attn_mask=None):
|
||||
if not self._qkv_same_embed_dim:
|
||||
return multi_head_attention_forward(
|
||||
query, key, value, self.embed_dim, self.num_heads,
|
||||
self.in_proj_weight, self.in_proj_bias,
|
||||
self.bias_k, self.bias_v, self.add_zero_attn,
|
||||
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
||||
attn_mask=attn_mask, use_separate_proj_weight=True,
|
||||
q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
|
||||
v_proj_weight=self.v_proj_weight,
|
||||
acc_bitmask=self.acc_bitmask)
|
||||
else:
|
||||
return multi_head_attention_forward(
|
||||
query, key, value, self.embed_dim, self.num_heads,
|
||||
self.in_proj_weight, self.in_proj_bias,
|
||||
self.bias_k, self.bias_v, self.add_zero_attn,
|
||||
self.dropout, self.out_proj.weight, self.out_proj.bias,
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask, need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
acc_bitmask=self.acc_bitmask)
|
||||
|
||||
|
||||
def replace_mah(model, mask = None):
|
||||
for child_name, child in model.named_children():
|
||||
if isinstance(child, nn.modules.activation.MultiheadAttention):
|
||||
add_bias_kv = child.bias_k is not None
|
||||
device = child.in_proj_weight.device
|
||||
mah = MultiheadAttention(child.embed_dim, child.num_heads,
|
||||
dropout=child.dropout, add_bias_kv=add_bias_kv,
|
||||
add_zero_attn=child.add_zero_attn, kdim=child.kdim,
|
||||
vdim=child.vdim, acc_bitmask=mask).to(device)
|
||||
for yparam, xparam in zip(mah.parameters(), child.parameters()):
|
||||
yparam.data.copy_(xparam.data)
|
||||
setattr(model, child_name, mah)
|
||||
else:
|
||||
replace_mah(child, mask)
|
116
python/triton/nn/conv.py
Normal file
116
python/triton/nn/conv.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import triton
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class _conv2d(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias,
|
||||
stride, padding, dilation, groups,
|
||||
acc_bitmask):
|
||||
assert dilation == (1, 1)
|
||||
assert groups == 1
|
||||
assert bias == None
|
||||
pad_h, pad_w = padding
|
||||
stride_h, stride_w = stride
|
||||
n, c, h, w = input.size()
|
||||
k, c, r, s = weight.size()
|
||||
# allocate output
|
||||
p = (h + 2*padding[0] - r)//stride[0] + 1
|
||||
q = (w + 2*padding[1] - s)//stride[1] + 1
|
||||
output = torch.empty((n, k, p, q), dtype=input.dtype, device=input.device)
|
||||
# padding
|
||||
if pad_h or pad_w:
|
||||
input = triton.ops._einsum.pad(input, [pad_w, pad_w, pad_h, pad_h])
|
||||
# convolution
|
||||
triton.ops.einsum(f'nc(h*stride_h + r - pad_h)(w*stride_w + s - pad_w),kcrs->nkhw',
|
||||
input, weight, mask=acc_bitmask,
|
||||
output=output,
|
||||
values = {'pad_h': pad_h,
|
||||
'stride_h': stride_h,
|
||||
'pad_w': pad_w,
|
||||
'stride_w': stride_w})
|
||||
# prepare backprop
|
||||
ctx.save_for_backward(input, weight)
|
||||
ctx.stride = stride
|
||||
ctx.padding = padding
|
||||
ctx.acc_bitmask = acc_bitmask
|
||||
# return
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
# retrieve contextual information
|
||||
input, weight = ctx.saved_tensors
|
||||
stride = ctx.stride
|
||||
padding = ctx.padding
|
||||
acc_bitmask = ctx.acc_bitmask
|
||||
# gradient of the input
|
||||
dx = None
|
||||
if ctx.needs_input_grad[0]:
|
||||
# dy must be padded
|
||||
n, k, p, q = dy.size()
|
||||
n, c, h, w = input.size()
|
||||
k, c, r, s = weight.size()
|
||||
dypad = triton.ops._einsum.pad(dy, [4, 4, 4, 4])
|
||||
# have to be careful here
|
||||
# the gradient of strided conv is a conv over a sparse image
|
||||
# which can be decomposed as a set of smaller convs
|
||||
dx = torch.empty_like(input)
|
||||
for offh in range(stride[0]):
|
||||
for offw in range(stride[1]):
|
||||
poffh = (offh + padding[0]) % stride[0]
|
||||
poffw = (offw + padding[1]) % stride[1]
|
||||
pad_h = int((padding[0] + (stride[0] - 1)*offh) / stride[0])
|
||||
pad_w = int((padding[1] + (stride[1] - 1)*offw) / stride[1])
|
||||
if offh >= r or offw >= s:
|
||||
dx[:, :, poffh::stride[0], poffw::stride[1]] = 0
|
||||
else:
|
||||
triton.ops.einsum(f'nk(h - r + pad_h)(w - s + pad_w),kcrs->nchw',
|
||||
dypad[:, :, :, :],
|
||||
weight[:, :, offh::stride[0], offw::stride[1]],
|
||||
output = dx[:, :, poffh::stride[0], poffw::stride[1]],
|
||||
mask = acc_bitmask,
|
||||
values = {'pad_h': pad_h,
|
||||
'pad_w': pad_w})
|
||||
# gradient for the weight
|
||||
dw = None
|
||||
if ctx.needs_input_grad[1]:
|
||||
dw = torch.empty_like(weight)
|
||||
triton.ops.einsum(f'nc(p*{stride[0]}+r-{padding[0]})(q*{stride[1]}+s-{padding[1]}),nkpq->kcrs',
|
||||
input, dy, output = dw, mask = acc_bitmask)
|
||||
return dx, dw, None, None, None, None, None, None
|
||||
conv2d = _conv2d.apply
|
||||
|
||||
class Conv2d(nn.Conv2d):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
|
||||
padding=0, dilation=1, groups=1,
|
||||
bias=True, padding_mode='zeros',
|
||||
acc_bitmask = None):
|
||||
super(Conv2d, self).__init__(
|
||||
in_channels, out_channels, kernel_size, stride, padding, dilation,
|
||||
groups, bias, padding_mode)
|
||||
self.acc_bitmask = acc_bitmask
|
||||
|
||||
def forward(self, input):
|
||||
#if self.kernel_size[0] == 3:
|
||||
# return F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
return conv2d(input, self.weight, self.bias, self.stride,
|
||||
self.padding, self.dilation, self.groups,
|
||||
self.acc_bitmask)
|
||||
|
||||
|
||||
def replace_conv2d(model, acc_bitmask = None):
|
||||
for child_name, child in model.named_children():
|
||||
if isinstance(child, nn.Conv2d):
|
||||
conv2d = Conv2d(child.in_channels, child.out_channels, child.kernel_size,
|
||||
child.stride, child.padding, child.dilation, child.groups,
|
||||
child.bias, child.padding_mode,
|
||||
acc_bitmask=acc_bitmask)
|
||||
for yparam, xparam in zip(conv2d.parameters(), child.parameters()):
|
||||
yparam.data.copy_(xparam.data)
|
||||
setattr(model, child_name, conv2d)
|
||||
else:
|
||||
replace_conv2d(child, acc_bitmask)
|
13
python/triton/nn/linear.py
Normal file
13
python/triton/nn/linear.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import torch
|
||||
import triton
|
||||
|
||||
|
||||
def linear(x, w, bias = None):
|
||||
print(x.size(), w.size())
|
||||
m, k = x.size()
|
||||
k, n = w.size()
|
||||
out = torch.empty([m, n], device=x.device)
|
||||
triton.ops.einsum('mk,nk->mn', x, w, bias)
|
||||
if bias is not None:
|
||||
out += bias
|
||||
return out
|
Reference in New Issue
Block a user