diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index bdd695298..3cc5ea87b 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -256,7 +256,6 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c return std::unique_ptr(); barriers.run(module); isel.visit(module, *llvm); - // return binary std::unique_ptr res(driver::module::create(context, std::move(llvm))); // done diff --git a/python/examples/attention/bench.py b/python/examples/attention/bench.py new file mode 100644 index 000000000..99c1722eb --- /dev/null +++ b/python/examples/attention/bench.py @@ -0,0 +1,47 @@ +import torch +import numpy as np +import reference +import optimized +from time import time + +use_half = False +def cast(x): + if use_half: + return x.half() + else: + return x + +# GPU device +device = torch.device("cuda:0") +# shapes +batch, nhead = 16, 8 +dm, dk, dv = 512, 512, 512 +lq, lk, lv = 256, 256, 256 +# initialize tensors +torch.manual_seed(0) +np.random.seed(0) +query = cast(torch.randn(batch, lq, dm)).cuda() +key = cast(torch.randn(batch, lk, dm)).cuda() +value = cast(torch.randn(batch, lv, dm)).cuda() +# initialize layers +torch.manual_seed(0) +np.random.seed(0) +rattn = cast(reference.MultiHeadAttention(nhead, dm, dk, dv).to(device)) +torch.manual_seed(0) +np.random.seed(0) +tattn = cast(optimized.MultiHeadAttention(nhead, dm, dk, dv).to(device)) +# test +routput, _ = rattn(query, key, value) +toutput, _ = tattn(query, key, value) +diff = torch.max(torch.abs(routput - toutput)) +assert diff < 1e-2 +# benchmark +start = time() +routput, _ = rattn(query, key, value) +end = time() +rtime = end - start +start = time() +toutput, _ = tattn(query, key, value) +end = time() +ttime = end - start +print(rtime, ttime) \ No newline at end of file diff --git a/python/examples/attention/optimized.py b/python/examples/attention/optimized.py new file mode 100644 index 000000000..96cc14262 --- /dev/null +++ b/python/examples/attention/optimized.py @@ -0,0 +1,50 @@ +import numpy as np +import torch +import torch.nn as nn +import triton + +class MultiHeadAttention(nn.Module): + ''' Multi-Head Attention module ''' + + def __init__(self, n_head, d_model, d_k, d_v): + super().__init__() + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + # linear layers + self.w_qs = nn.Linear(d_model, n_head * d_k) + self.w_ks = nn.Linear(d_model, n_head * d_k) + self.w_vs = nn.Linear(d_model, n_head * d_v) + self.fc = nn.Linear(n_head * d_v, d_model) + # initialize weights + nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) + nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) + nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) + nn.init.xavier_normal_(self.fc.weight) + # layer normalization + self.layer_norm = nn.LayerNorm(d_model) + + + def forward(self, q, k, v, mask=None): + # dimensions + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + sz_b, len_q, _ = q.size() + sz_b, len_k, _ = k.size() + sz_b, len_v, _ = v.size() + # linear transformations + residual = q + q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) + k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) + v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) + # scaled dot-product attention + attn = triton.ops.einsum('blhk,bthk->hblt', q, k, [n_head, sz_b, len_q, len_k]) + attn = attn / np.sqrt(d_k) + if mask is not None: + attn = attn.masked_fill(mask[None], -np.inf) + attn = torch.softmax(attn, dim=3) + output = triton.ops.einsum('hblt,bthv->blhv', attn, v, [sz_b, len_q, n_head, d_v]) + output = output.view(sz_b, len_q, -1) + output = self.fc(output) + # epilogue + output = self.layer_norm(output + residual) + return output, attn \ No newline at end of file diff --git a/python/examples/attention/reference.py b/python/examples/attention/reference.py new file mode 100644 index 000000000..e60f474f6 --- /dev/null +++ b/python/examples/attention/reference.py @@ -0,0 +1,72 @@ +import numpy as np +import torch +import torch.nn as nn + +class ScaledDotProductAttention(nn.Module): + ''' Scaled Dot-Product Attention ''' + + def __init__(self, temperature, attn_dropout=0.1): + super().__init__() + self.temperature = temperature + self.softmax = nn.Softmax(dim=2) + + def forward(self, q, k, v, mask=None): + attn = torch.bmm(q, k.transpose(1, 2)) + attn = attn / self.temperature + if mask is not None: + attn = attn.masked_fill(mask, -np.inf) + attn = self.softmax(attn) + output = torch.bmm(attn, v) + return output, attn + + + +class MultiHeadAttention(nn.Module): + ''' Multi-Head Attention module ''' + + def __init__(self, n_head, d_model, d_k, d_v): + super().__init__() + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + # linear layers + self.w_qs = nn.Linear(d_model, n_head * d_k) + self.w_ks = nn.Linear(d_model, n_head * d_k) + self.w_vs = nn.Linear(d_model, n_head * d_v) + self.fc = nn.Linear(n_head * d_v, d_model) + # initialize weights + nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) + nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) + nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) + nn.init.xavier_normal_(self.fc.weight) + # normalization + self.layer_norm = nn.LayerNorm(d_model) + # scaled dot-product + self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) + + + def forward(self, q, k, v, mask=None): + # dimensions + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + sz_b, len_q, _ = q.size() + sz_b, len_k, _ = k.size() + sz_b, len_v, _ = v.size() + # linear transformations + residual = q + q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) + k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) + v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) + # scaled dot-product attention + q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk + k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk + v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv + if mask: + mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. + output, attn = self.attention(q, k, v, mask=mask) + # linear transformation + output = output.view(n_head, sz_b, len_q, d_v) + output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) + output = self.fc(output) + # normalization + output = self.layer_norm(output + residual) + return output, attn \ No newline at end of file diff --git a/python/examples/einsum.py b/python/examples/einsum.py index f61347d0c..e571ec955 100644 --- a/python/examples/einsum.py +++ b/python/examples/einsum.py @@ -11,25 +11,25 @@ configs = [] # Matrix multiplication MNK = [ - (512, 512 ,512), - (2048, 2048, 2048), - (8192, 8192, 8192), + # (512, 512 ,512), + # (2048, 2048, 2048), + # (8192, 8192, 8192), - (64, 64, 64000), - (64, 64, 128000), - (256, 256, 64000), - (256, 256, 128000), + # (64, 64, 64000), + # (64, 64, 128000), + # (256, 256, 64000), + # (256, 256, 128000), - (1536, 16, 1536), - (1536, 32, 1536), - (1536, 64, 1536), - (1536, 128, 1536), - (4096, 16, 4096), - (4096, 32, 4096), - (4096, 64, 4096), - (4096, 128, 4096), + # (1536, 16, 1536), + # (1536, 32, 1536), + # (1536, 64, 1536), + # (1536, 128, 1536), + # (4096, 16, 4096), + # (4096, 32, 4096), + # (4096, 64, 4096), + # (4096, 128, 4096), - #(127008, 768, 576) + # (127008, 768, 576) ] for M, N, K in MNK: matmul = lambda a, b: torch.matmul(a, b) @@ -43,32 +43,32 @@ for M, N, K in MNK: # Relative attention NTHSE = [ - #(16, 512, 1, 64, 64), + (16, 512, 1, 64, 64), # (16, 512, 1, 128, 128), # (16, 512, 1, 256, 256), # (16, 512, 1, 256, 512), - #(16, 512, 8, 64, 64), + (16, 512, 8, 64, 64), # (16, 512, 8, 128, 128), # (16, 512, 8, 256, 256), # (16, 512, 8, 256, 512), # (64, 1024, 1, 64, 64), - #(64, 1024, 1, 128, 128), + (64, 1024, 1, 128, 128), # (64, 1024, 1, 256, 256), # (64, 1024, 1, 256, 512), # (64, 1024, 8, 64, 64), - #(64, 1024, 8, 128, 128), + (64, 1024, 8, 128, 128), # (64, 1024, 8, 256, 256), # (64, 1024, 8, 256, 512), # (128, 1024, 1, 64, 64), # (128, 1024, 1, 128, 128), # (128, 1024, 1, 256, 256), - #(128, 1024, 1, 256, 512), + (128, 1024, 1, 256, 512), # (128, 1024, 8, 64, 64), # (128, 1024, 8, 128, 128), # (128, 1024, 8, 256, 256), - #(128, 1024, 8, 256, 512) + (128, 1024, 8, 256, 512) ] for N, T, H, S, E in NTHSE: configs += [([N, T, H, S], [H, E, S], [N, H, T, E], None, 'nths,hes->nhte', dict())] @@ -168,12 +168,11 @@ for N, C, H, W, K, R, S in NCHWKRS: # Benchmark torch.set_num_threads(1) for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs: - dtype = torch.cuda.HalfTensor + dtype = torch.cuda.FloatTensor # initialize input tensors a = torch.rand(*a_shape).type(dtype).cuda() b = torch.rand(*b_shape).type(dtype).cuda() # triton output - #ta = triton.ops._einsum.pad(a, [4,4,4,4]) tc = triton.ops.einsum(expr, a, b, c_shape, arrays = arrays, bench = True) # reference output if torch_fn: @@ -183,12 +182,16 @@ for a_shape, b_shape, c_shape, torch_fn, expr, arrays in configs: # performance relative to equivalent matrix multiplication ctx = triton.ctx_registry[tc] B, M, N, K = ctx.matmul_B, ctx.matmul_M, ctx.matmul_N, ctx.matmul_K - # a = torch.rand(B, M, K).type(dtype).cuda() - # b = torch.rand(B, K, N).type(dtype).cuda() - # tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True) - # ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc] - ratio = 0 + cmp_eqbmm = True + if cmp_eqbmm: + a = torch.rand(B, M, K).type(dtype).cuda() + b = torch.rand(B, K, N).type(dtype).cuda() + tmmc = triton.ops.einsum('bmk,bkn->bmn', a, b, [B, M, N], bench = True) + ratio = triton.bench_registry[tmmc] / triton.bench_registry[tc] + cmp_str = f'({ratio:4.2f})' + else: + cmp_str = '' # test and benchmark bench = 2. * B * M * N * K / triton.bench_registry[tc] * 1e-3 diff = (tc - rc).abs().max() / rc.abs().max() - print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} ({ratio:4.2f}); {diff:4.2f}') + print(f'{expr:>15}; {str(a_shape):>20}; {str(b_shape):>20}; {bench:4.2f} {cmp_str}; {diff:4.2f}') diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 77177d740..85195790d 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -206,6 +206,9 @@ class kernel: self.cst[name] = value def __call__(self, *args, **kwargs): + ######################### + # cache + ######################## # create a new framework op when defines are different key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in kwargs.items()]) if key not in self.fw_id.keys(): @@ -230,17 +233,18 @@ class kernel: libtriton.register_cst(op_id, name, value) if self.fw_op is None: self.fw_op = _make_framework_op(self.src, self.outputs, self.tmp, opt) - # benchmarking info - bench = 0 - if 'bench' in kwargs: - bench = kwargs['bench'] - # retrieve framework op + + ######################## + # initialize + ######################## op_id = self.fw_id[key] - # register grid libtriton.register_grid(op_id, args[-1]) - # id for the benchmark result + bench = kwargs['bench'] if 'bench' in kwargs else 0 bench_id = libtriton.make_scalar_id() if bench > 0 else -1 + + ######################### # call framework function + ######################### if fw.has_tensorflow(): empty = [x for x in args[:-1] if isinstance(x, triton.utils.tf_empty_proxy)] if len(empty) != len(self.outputs): @@ -268,10 +272,15 @@ class kernel: if bench > 0: for y in ret: bench_registry[y] = triton.utils.id_dict.lazy_entry(bench_id) + + ############################ + # call torch function + ############################ elif fw.has_torch(): args = [x if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]] ret = self.fw_op(op_id, bench, bench_id, *args) if bench > 0: bench_registry[ret] = libtriton.retrieve_scalar(bench_id) + else: assert False \ No newline at end of file diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index ff29432e5..dbb236b18 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -536,8 +536,8 @@ __global__ void {name}( self.pos_b = 1 self.pos_c = 2 # pre-processor macros - TM = [x for x in [16, 32, 64, 128] if x <= M] - TN = [x for x in [16, 32, 64, 128] if x <= N] + TM = [16] + [x for x in [32, 64, 128] if x <= M] + TN = [16] + [x for x in [32, 64, 128] if x <= N] TB = [x for x in [1, 2, 4] if x <= B] MAX_GZ = K // 2048 MIN_GM = M // max(TM) @@ -546,8 +546,8 @@ __global__ void {name}( TZ = [x for x in [1, 2, 4, 8, 16, 32] \ if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256] TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2] - #TB, TZ = [1], [1] - #TM, TN, TB, TZ = [128], [128], [1], [1] + TM, TN, TB = [128], [64], [1] + #print(TM, TN, TB) self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype } self.dtype = dtype self.flops = 2 * B * M * N * K @@ -582,13 +582,13 @@ __global__ void {name}( # allocate output dtype = a.dtype c = triton.empty(shape_c, dtype=dtype) + # compile einsum instance + cache = _einsum.instance_cache key = (einsum, dtype, a.stride(), b.stride(), c.stride(), a.shape, b.shape, c.shape) - # compile einsum instance - cache = _einsum.instance_cache - #if key not in cache: - cache[key] = _einsum.instance(einsum, dtype, + if key not in cache: + cache[key] = _einsum.instance(einsum, dtype, a.stride(), b.stride(), c.stride(), a.shape, b.shape, c.shape, arrays) instance = cache[key] diff --git a/python/triton/utils.py b/python/triton/utils.py index 117f69136..da8a1e8f9 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -24,7 +24,7 @@ def empty(shape, dtype): return tf_empty_proxy(shape, dtype) #return fw.tf_extra_ops.alloc_empty(args, T = dtype) elif fw.has_torch(): - return fw.torch.empty(shape, dtype=dtype).cuda() + return fw.torch.empty(shape, dtype=dtype, device='cuda:0') def shape(A) : if fw.has_tensorflow():