diff --git a/python/examples/attention/bench.py b/python/examples/attention/bench.py index 99c1722eb..abd4ed24c 100644 --- a/python/examples/attention/bench.py +++ b/python/examples/attention/bench.py @@ -4,7 +4,7 @@ import reference import optimized from time import time -use_half = False +use_half = True def cast(x): if use_half: return x.half() @@ -14,9 +14,9 @@ def cast(x): # GPU device device = torch.device("cuda:0") # shapes -batch, nhead = 16, 8 -dm, dk, dv = 512, 512, 512 -lq, lk, lv = 256, 256, 256 +batch, nhead = 8, 28 +dm, dk, dv = 1024, 1024, 1024 +lq, lk, lv = 1024, 1024, 1024 # initialize tensors torch.manual_seed(0) np.random.seed(0) @@ -44,4 +44,5 @@ start = time() toutput, _ = tattn(query, key, value) end = time() ttime = end - start -print(rtime, ttime) \ No newline at end of file +print(f'Torch: {rtime} s') +print(f'Triton: {ttime} s') \ No newline at end of file diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index dbb236b18..7af856be3 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -546,7 +546,7 @@ __global__ void {name}( TZ = [x for x in [1, 2, 4, 8, 16, 32] \ if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256] TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2] - TM, TN, TB = [128], [64], [1] + #TM, TN, TB = [128], [64], [1] #print(TM, TN, TB) self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype } self.dtype = dtype