[PYTHON][EXAMPLES] Added self-attention example using triton.ops.einsum

This commit is contained in:
Philippe Tillet
2020-01-21 16:45:04 -05:00
parent 78b98fb7cf
commit ce7a00674a
8 changed files with 227 additions and 47 deletions

View File

@@ -0,0 +1,47 @@
import torch
import numpy as np
import reference
import optimized
from time import time
use_half = False
def cast(x):
if use_half:
return x.half()
else:
return x
# GPU device
device = torch.device("cuda:0")
# shapes
batch, nhead = 16, 8
dm, dk, dv = 512, 512, 512
lq, lk, lv = 256, 256, 256
# initialize tensors
torch.manual_seed(0)
np.random.seed(0)
query = cast(torch.randn(batch, lq, dm)).cuda()
key = cast(torch.randn(batch, lk, dm)).cuda()
value = cast(torch.randn(batch, lv, dm)).cuda()
# initialize layers
torch.manual_seed(0)
np.random.seed(0)
rattn = cast(reference.MultiHeadAttention(nhead, dm, dk, dv).to(device))
torch.manual_seed(0)
np.random.seed(0)
tattn = cast(optimized.MultiHeadAttention(nhead, dm, dk, dv).to(device))
# test
routput, _ = rattn(query, key, value)
toutput, _ = tattn(query, key, value)
diff = torch.max(torch.abs(routput - toutput))
assert diff < 1e-2
# benchmark
start = time()
routput, _ = rattn(query, key, value)
end = time()
rtime = end - start
start = time()
toutput, _ = tattn(query, key, value)
end = time()
ttime = end - start
print(rtime, ttime)

View File

@@ -0,0 +1,50 @@
import numpy as np
import torch
import torch.nn as nn
import triton
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, n_head, d_model, d_k, d_v):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
# linear layers
self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
self.fc = nn.Linear(n_head * d_v, d_model)
# initialize weights
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
nn.init.xavier_normal_(self.fc.weight)
# layer normalization
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, q, k, v, mask=None):
# dimensions
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
# linear transformations
residual = q
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
# scaled dot-product attention
attn = triton.ops.einsum('blhk,bthk->hblt', q, k, [n_head, sz_b, len_q, len_k])
attn = attn / np.sqrt(d_k)
if mask is not None:
attn = attn.masked_fill(mask[None], -np.inf)
attn = torch.softmax(attn, dim=3)
output = triton.ops.einsum('hblt,bthv->blhv', attn, v, [sz_b, len_q, n_head, d_v])
output = output.view(sz_b, len_q, -1)
output = self.fc(output)
# epilogue
output = self.layer_norm(output + residual)
return output, attn

View File

@@ -0,0 +1,72 @@
import numpy as np
import torch
import torch.nn as nn
class ScaledDotProductAttention(nn.Module):
''' Scaled Dot-Product Attention '''
def __init__(self, temperature, attn_dropout=0.1):
super().__init__()
self.temperature = temperature
self.softmax = nn.Softmax(dim=2)
def forward(self, q, k, v, mask=None):
attn = torch.bmm(q, k.transpose(1, 2))
attn = attn / self.temperature
if mask is not None:
attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
output = torch.bmm(attn, v)
return output, attn
class MultiHeadAttention(nn.Module):
''' Multi-Head Attention module '''
def __init__(self, n_head, d_model, d_k, d_v):
super().__init__()
self.n_head = n_head
self.d_k = d_k
self.d_v = d_v
# linear layers
self.w_qs = nn.Linear(d_model, n_head * d_k)
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
self.fc = nn.Linear(n_head * d_v, d_model)
# initialize weights
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
nn.init.xavier_normal_(self.fc.weight)
# normalization
self.layer_norm = nn.LayerNorm(d_model)
# scaled dot-product
self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
def forward(self, q, k, v, mask=None):
# dimensions
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
sz_b, len_q, _ = q.size()
sz_b, len_k, _ = k.size()
sz_b, len_v, _ = v.size()
# linear transformations
residual = q
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
# scaled dot-product attention
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
if mask:
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
output, attn = self.attention(q, k, v, mask=mask)
# linear transformation
output = output.view(n_head, sz_b, len_q, d_v)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
output = self.fc(output)
# normalization
output = self.layer_norm(output + residual)
return output, attn

View File

@@ -11,25 +11,25 @@ configs = []
# Matrix multiplication
MNK = [
(512, 512 ,512),
(2048, 2048, 2048),
(8192, 8192, 8192),
# (512, 512 ,512),
# (2048, 2048, 2048),
# (8192, 8192, 8192),
(64, 64, 64000),
(64, 64, 128000),
(256, 256, 64000),
(256, 256, 128000),
# (64, 64, 64000),
# (64, 64, 128000),
# (256, 256, 64000),
# (256, 256, 128000),
(1536, 16, 1536),
(1536, 32, 1536),
(1536, 64, 1536),
(1536, 128, 1536),
(4096, 16, 4096),
(4096, 32, 4096),
(4096, 64, 4096),
(4096, 128, 4096),
# (1536, 16, 1536),
# (1536, 32, 1536),
# (1536, 64, 1536),
# (1536, 128, 1536),
# (4096, 16, 4096),
# (4096, 32, 4096),
# (4096, 64, 4096),
# (4096, 128, 4096),
#(127008, 768, 576)
# (127008, 768, 576)
]
for M, N, K in MNK:
matmul = lambda a, b: torch.matmul(a, b)
@@ -43,32 +43,32 @@ for M, N, K in MNK:
# Relative attention
NTHSE = [
#(16, 512, 1, 64, 64),
(16, 512, 1, 64, 64),
# (16, 512, 1, 128, 128),
# (16, 512, 1, 256, 256),
# (16, 512, 1, 256, 512),
#(16, 512, 8, 64, 64),
(16, 512, 8, 64, 64),
# (16, 512, 8, 128, 128),
# (16, 512, 8, 256, 256),
# (16, 512, 8, 256, 512),
# (64, 1024, 1, 64, 64),
#(64, 1024, 1, 128, 128),
(64, 1024, 1, 128, 128),
# (64, 1024, 1, 256, 256),
# (64, 1024, 1, 256, 512),
# (64, 1024, 8, 64, 64),
#(64, 1024, 8, 128, 128),
(64, 1024, 8, 128, 128),
# (64, 1024, 8, 256, 256),
# (64, 1024, 8, 256, 512),
# (128, 1024, 1, 64, 64),
# (128, 1024, 1, 128, 128),
# (128, 1024, 1, 256, 256),
#(128, 1024, 1, 256, 512),
(128, 1024, 1, 256, 512),
# (128, 1024, 8, 64, 64),
# (128, 1024, 8, 128, 128),
# (128, 1024, 8, 256, 256),
#(128, 1024, 8, 256, 512)
(128, 1024, 8, 256, 512)
]
for N, T, H, S, E in NTHSE:
configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
@@ -168,12 +168,11 @@ for N, C, H, W, K, R, S in NCHWKRS:
# Benchmark
torch.set_num_threads(1)
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
dtype = torch.cuda.HalfTensor
dtype = torch.cuda.FloatTensor
# initialize input tensors
a = torch.rand(*a_shape).type(dtype).cuda()
b = torch.rand(*b_shape).type(dtype).cuda()
# triton output
#ta = triton.ops._einsum.pad(a, [4,4,4,4])
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
# reference output
if torch_fn:
@@ -183,12 +182,16 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
# performance relative to equivalent matrix multiplication
ctx = triton.ctx_registry[tc]
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
# a = torch.rand(B, M, K).type(dtype).cuda()
# b = torch.rand(B, K, N).type(dtype).cuda()
# tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
# ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc]
ratio = 0
cmp_eqbmm = True
if cmp_eqbmm:
a = torch.rand(B, M, K).type(dtype).cuda()
b = torch.rand(B, K, N).type(dtype).cuda()
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc]
cmp_str = f'({ratio:4.2f})'
else:
cmp_str = ''
# test and benchmark
bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3
diff = (tc - rc).abs().max() / rc.abs().max()
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} ({ratio:4.2f}); {diff:4.2f}')
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}')