[PYTHON][EXAMPLES] Added self-attention example using triton.ops.einsum
This commit is contained in:
@@ -256,7 +256,6 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
|
||||
return std::unique_ptr<driver::module>();
|
||||
barriers.run(module);
|
||||
isel.visit(module, *llvm);
|
||||
|
||||
// return binary
|
||||
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
|
||||
// done
|
||||
|
47
python/examples/attention/bench.py
Normal file
47
python/examples/attention/bench.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import reference
|
||||
import optimized
|
||||
from time import time
|
||||
|
||||
use_half = False
|
||||
def cast(x):
|
||||
if use_half:
|
||||
return x.half()
|
||||
else:
|
||||
return x
|
||||
|
||||
# GPU device
|
||||
device = torch.device("cuda:0")
|
||||
# shapes
|
||||
batch, nhead = 16, 8
|
||||
dm, dk, dv = 512, 512, 512
|
||||
lq, lk, lv = 256, 256, 256
|
||||
# initialize tensors
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
query = cast(torch.randn(batch, lq, dm)).cuda()
|
||||
key = cast(torch.randn(batch, lk, dm)).cuda()
|
||||
value = cast(torch.randn(batch, lv, dm)).cuda()
|
||||
# initialize layers
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
rattn = cast(reference.MultiHeadAttention(nhead, dm, dk, dv).to(device))
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
tattn = cast(optimized.MultiHeadAttention(nhead, dm, dk, dv).to(device))
|
||||
# test
|
||||
routput, _ = rattn(query, key, value)
|
||||
toutput, _ = tattn(query, key, value)
|
||||
diff = torch.max(torch.abs(routput - toutput))
|
||||
assert diff < 1e-2
|
||||
# benchmark
|
||||
start = time()
|
||||
routput, _ = rattn(query, key, value)
|
||||
end = time()
|
||||
rtime = end - start
|
||||
start = time()
|
||||
toutput, _ = tattn(query, key, value)
|
||||
end = time()
|
||||
ttime = end - start
|
||||
print(rtime, ttime)
|
50
python/examples/attention/optimized.py
Normal file
50
python/examples/attention/optimized.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
''' Multi-Head Attention module '''
|
||||
|
||||
def __init__(self, n_head, d_model, d_k, d_v):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.d_k = d_k
|
||||
self.d_v = d_v
|
||||
# linear layers
|
||||
self.w_qs = nn.Linear(d_model, n_head * d_k)
|
||||
self.w_ks = nn.Linear(d_model, n_head * d_k)
|
||||
self.w_vs = nn.Linear(d_model, n_head * d_v)
|
||||
self.fc = nn.Linear(n_head * d_v, d_model)
|
||||
# initialize weights
|
||||
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
|
||||
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
|
||||
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
|
||||
nn.init.xavier_normal_(self.fc.weight)
|
||||
# layer normalization
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
# dimensions
|
||||
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
|
||||
sz_b, len_q, _ = q.size()
|
||||
sz_b, len_k, _ = k.size()
|
||||
sz_b, len_v, _ = v.size()
|
||||
# linear transformations
|
||||
residual = q
|
||||
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
|
||||
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
|
||||
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
|
||||
# scaled dot-product attention
|
||||
attn = triton.ops.einsum('blhk,bthk->hblt', q, k, [n_head, sz_b, len_q, len_k])
|
||||
attn = attn / np.sqrt(d_k)
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill(mask[None], -np.inf)
|
||||
attn = torch.softmax(attn, dim=3)
|
||||
output = triton.ops.einsum('hblt,bthv->blhv', attn, v, [sz_b, len_q, n_head, d_v])
|
||||
output = output.view(sz_b, len_q, -1)
|
||||
output = self.fc(output)
|
||||
# epilogue
|
||||
output = self.layer_norm(output + residual)
|
||||
return output, attn
|
72
python/examples/attention/reference.py
Normal file
72
python/examples/attention/reference.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class ScaledDotProductAttention(nn.Module):
|
||||
''' Scaled Dot-Product Attention '''
|
||||
|
||||
def __init__(self, temperature, attn_dropout=0.1):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
self.softmax = nn.Softmax(dim=2)
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
attn = torch.bmm(q, k.transpose(1, 2))
|
||||
attn = attn / self.temperature
|
||||
if mask is not None:
|
||||
attn = attn.masked_fill(mask, -np.inf)
|
||||
attn = self.softmax(attn)
|
||||
output = torch.bmm(attn, v)
|
||||
return output, attn
|
||||
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
''' Multi-Head Attention module '''
|
||||
|
||||
def __init__(self, n_head, d_model, d_k, d_v):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.d_k = d_k
|
||||
self.d_v = d_v
|
||||
# linear layers
|
||||
self.w_qs = nn.Linear(d_model, n_head * d_k)
|
||||
self.w_ks = nn.Linear(d_model, n_head * d_k)
|
||||
self.w_vs = nn.Linear(d_model, n_head * d_v)
|
||||
self.fc = nn.Linear(n_head * d_v, d_model)
|
||||
# initialize weights
|
||||
nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
|
||||
nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
|
||||
nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
|
||||
nn.init.xavier_normal_(self.fc.weight)
|
||||
# normalization
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
# scaled dot-product
|
||||
self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
|
||||
|
||||
|
||||
def forward(self, q, k, v, mask=None):
|
||||
# dimensions
|
||||
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
|
||||
sz_b, len_q, _ = q.size()
|
||||
sz_b, len_k, _ = k.size()
|
||||
sz_b, len_v, _ = v.size()
|
||||
# linear transformations
|
||||
residual = q
|
||||
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
|
||||
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
|
||||
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
|
||||
# scaled dot-product attention
|
||||
q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
|
||||
k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
|
||||
v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
|
||||
if mask:
|
||||
mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
|
||||
output, attn = self.attention(q, k, v, mask=mask)
|
||||
# linear transformation
|
||||
output = output.view(n_head, sz_b, len_q, d_v)
|
||||
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
|
||||
output = self.fc(output)
|
||||
# normalization
|
||||
output = self.layer_norm(output + residual)
|
||||
return output, attn
|
@@ -11,25 +11,25 @@ configs = []
|
||||
|
||||
# Matrix multiplication
|
||||
MNK = [
|
||||
(512, 512 ,512),
|
||||
(2048, 2048, 2048),
|
||||
(8192, 8192, 8192),
|
||||
# (512, 512 ,512),
|
||||
# (2048, 2048, 2048),
|
||||
# (8192, 8192, 8192),
|
||||
|
||||
(64, 64, 64000),
|
||||
(64, 64, 128000),
|
||||
(256, 256, 64000),
|
||||
(256, 256, 128000),
|
||||
# (64, 64, 64000),
|
||||
# (64, 64, 128000),
|
||||
# (256, 256, 64000),
|
||||
# (256, 256, 128000),
|
||||
|
||||
(1536, 16, 1536),
|
||||
(1536, 32, 1536),
|
||||
(1536, 64, 1536),
|
||||
(1536, 128, 1536),
|
||||
(4096, 16, 4096),
|
||||
(4096, 32, 4096),
|
||||
(4096, 64, 4096),
|
||||
(4096, 128, 4096),
|
||||
# (1536, 16, 1536),
|
||||
# (1536, 32, 1536),
|
||||
# (1536, 64, 1536),
|
||||
# (1536, 128, 1536),
|
||||
# (4096, 16, 4096),
|
||||
# (4096, 32, 4096),
|
||||
# (4096, 64, 4096),
|
||||
# (4096, 128, 4096),
|
||||
|
||||
#(127008, 768, 576)
|
||||
# (127008, 768, 576)
|
||||
]
|
||||
for M, N, K in MNK:
|
||||
matmul = lambda a, b: torch.matmul(a, b)
|
||||
@@ -43,32 +43,32 @@ for M, N, K in MNK:
|
||||
|
||||
# Relative attention
|
||||
NTHSE = [
|
||||
#(16, 512, 1, 64, 64),
|
||||
(16, 512, 1, 64, 64),
|
||||
# (16, 512, 1, 128, 128),
|
||||
# (16, 512, 1, 256, 256),
|
||||
# (16, 512, 1, 256, 512),
|
||||
#(16, 512, 8, 64, 64),
|
||||
(16, 512, 8, 64, 64),
|
||||
# (16, 512, 8, 128, 128),
|
||||
# (16, 512, 8, 256, 256),
|
||||
# (16, 512, 8, 256, 512),
|
||||
|
||||
# (64, 1024, 1, 64, 64),
|
||||
#(64, 1024, 1, 128, 128),
|
||||
(64, 1024, 1, 128, 128),
|
||||
# (64, 1024, 1, 256, 256),
|
||||
# (64, 1024, 1, 256, 512),
|
||||
# (64, 1024, 8, 64, 64),
|
||||
#(64, 1024, 8, 128, 128),
|
||||
(64, 1024, 8, 128, 128),
|
||||
# (64, 1024, 8, 256, 256),
|
||||
# (64, 1024, 8, 256, 512),
|
||||
|
||||
# (128, 1024, 1, 64, 64),
|
||||
# (128, 1024, 1, 128, 128),
|
||||
# (128, 1024, 1, 256, 256),
|
||||
#(128, 1024, 1, 256, 512),
|
||||
(128, 1024, 1, 256, 512),
|
||||
# (128, 1024, 8, 64, 64),
|
||||
# (128, 1024, 8, 128, 128),
|
||||
# (128, 1024, 8, 256, 256),
|
||||
#(128, 1024, 8, 256, 512)
|
||||
(128, 1024, 8, 256, 512)
|
||||
]
|
||||
for N, T, H, S, E in NTHSE:
|
||||
configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())]
|
||||
@@ -168,12 +168,11 @@ for N, C, H, W, K, R, S in NCHWKRS:
|
||||
# Benchmark
|
||||
torch.set_num_threads(1)
|
||||
for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||
dtype = torch.cuda.HalfTensor
|
||||
dtype = torch.cuda.FloatTensor
|
||||
# initialize input tensors
|
||||
a = torch.rand(*a_shape).type(dtype).cuda()
|
||||
b = torch.rand(*b_shape).type(dtype).cuda()
|
||||
# triton output
|
||||
#ta = triton.ops._einsum.pad(a, [4,4,4,4])
|
||||
tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True)
|
||||
# reference output
|
||||
if torch_fn:
|
||||
@@ -183,12 +182,16 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs:
|
||||
# performance relative to equivalent matrix multiplication
|
||||
ctx = triton.ctx_registry[tc]
|
||||
B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K
|
||||
# a = torch.rand(B, M, K).type(dtype).cuda()
|
||||
# b = torch.rand(B, K, N).type(dtype).cuda()
|
||||
# tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
|
||||
# ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc]
|
||||
ratio = 0
|
||||
cmp_eqbmm = True
|
||||
if cmp_eqbmm:
|
||||
a = torch.rand(B, M, K).type(dtype).cuda()
|
||||
b = torch.rand(B, K, N).type(dtype).cuda()
|
||||
tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True)
|
||||
ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc]
|
||||
cmp_str = f'({ratio:4.2f})'
|
||||
else:
|
||||
cmp_str = ''
|
||||
# test and benchmark
|
||||
bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3
|
||||
diff = (tc - rc).abs().max() / rc.abs().max()
|
||||
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} ({ratio:4.2f}); {diff:4.2f}')
|
||||
print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}')
|
||||
|
@@ -206,6 +206,9 @@ class kernel:
|
||||
self.cst[name] = value
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
#########################
|
||||
# cache
|
||||
########################
|
||||
# create a new framework op when defines are different
|
||||
key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in kwargs.items()])
|
||||
if key not in self.fw_id.keys():
|
||||
@@ -230,17 +233,18 @@ class kernel:
|
||||
libtriton.register_cst(op_id, name, value)
|
||||
if self.fw_op is None:
|
||||
self.fw_op = _make_framework_op(self.src, self.outputs, self.tmp, opt)
|
||||
# benchmarking info
|
||||
bench = 0
|
||||
if 'bench' in kwargs:
|
||||
bench = kwargs['bench']
|
||||
# retrieve framework op
|
||||
|
||||
########################
|
||||
# initialize
|
||||
########################
|
||||
op_id = self.fw_id[key]
|
||||
# register grid
|
||||
libtriton.register_grid(op_id, args[-1])
|
||||
# id for the benchmark result
|
||||
bench = kwargs['bench'] if 'bench' in kwargs else 0
|
||||
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
|
||||
|
||||
#########################
|
||||
# call framework function
|
||||
#########################
|
||||
if fw.has_tensorflow():
|
||||
empty = [x for x in args[:-1] if isinstance(x, triton.utils.tf_empty_proxy)]
|
||||
if len(empty) != len(self.outputs):
|
||||
@@ -268,10 +272,15 @@ class kernel:
|
||||
if bench > 0:
|
||||
for y in ret:
|
||||
bench_registry[y] = triton.utils.id_dict.lazy_entry(bench_id)
|
||||
|
||||
############################
|
||||
# call torch function
|
||||
############################
|
||||
elif fw.has_torch():
|
||||
args = [x if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]]
|
||||
ret = self.fw_op(op_id, bench, bench_id, *args)
|
||||
if bench > 0:
|
||||
bench_registry[ret] = libtriton.retrieve_scalar(bench_id)
|
||||
|
||||
else:
|
||||
assert False
|
@@ -536,8 +536,8 @@ __global__ void {name}(
|
||||
self.pos_b = 1
|
||||
self.pos_c = 2
|
||||
# pre-processor macros
|
||||
TM = [x for x in [16, 32, 64, 128] if x <= M]
|
||||
TN = [x for x in [16, 32, 64, 128] if x <= N]
|
||||
TM = [16] + [x for x in [32, 64, 128] if x <= M]
|
||||
TN = [16] + [x for x in [32, 64, 128] if x <= N]
|
||||
TB = [x for x in [1, 2, 4] if x <= B]
|
||||
MAX_GZ = K // 2048
|
||||
MIN_GM = M // max(TM)
|
||||
@@ -546,8 +546,8 @@ __global__ void {name}(
|
||||
TZ = [x for x in [1, 2, 4, 8, 16, 32] \
|
||||
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
|
||||
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
|
||||
#TB, TZ = [1], [1]
|
||||
#TM, TN, TB, TZ = [128], [128], [1], [1]
|
||||
TM, TN, TB = [128], [64], [1]
|
||||
#print(TM, TN, TB)
|
||||
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype }
|
||||
self.dtype = dtype
|
||||
self.flops = 2 * B * M * N * K
|
||||
@@ -582,13 +582,13 @@ __global__ void {name}(
|
||||
# allocate output
|
||||
dtype = a.dtype
|
||||
c = triton.empty(shape_c, dtype=dtype)
|
||||
# compile einsum instance
|
||||
cache = _einsum.instance_cache
|
||||
key = (einsum, dtype,
|
||||
a.stride(), b.stride(), c.stride(),
|
||||
a.shape, b.shape, c.shape)
|
||||
# compile einsum instance
|
||||
cache = _einsum.instance_cache
|
||||
#if key not in cache:
|
||||
cache[key] = _einsum.instance(einsum, dtype,
|
||||
if key not in cache:
|
||||
cache[key] = _einsum.instance(einsum, dtype,
|
||||
a.stride(), b.stride(), c.stride(),
|
||||
a.shape, b.shape, c.shape, arrays)
|
||||
instance = cache[key]
|
||||
|
@@ -24,7 +24,7 @@ def empty(shape, dtype):
|
||||
return tf_empty_proxy(shape, dtype)
|
||||
#return fw.tf_extra_ops.alloc_empty(args, T = dtype)
|
||||
elif fw.has_torch():
|
||||
return fw.torch.empty(shape, dtype=dtype).cuda()
|
||||
return fw.torch.empty(shape, dtype=dtype, device='cuda:0')
|
||||
|
||||
def shape(A) :
|
||||
if fw.has_tensorflow():
|
||||
|
Reference in New Issue
Block a user