[PYTHON][EXAMPLES] Cleaned self-attention benchmarks

This commit is contained in:
Philippe Tillet
2020-01-22 18:09:00 -05:00
parent ce7a00674a
commit db941161ed
2 changed files with 7 additions and 6 deletions

View File

@@ -4,7 +4,7 @@ import reference
import optimized import optimized
from time import time from time import time
use_half = False use_half = True
def cast(x): def cast(x):
if use_half: if use_half:
return x.half() return x.half()
@@ -14,9 +14,9 @@ def cast(x):
# GPU device # GPU device
device = torch.device("cuda:0") device = torch.device("cuda:0")
# shapes # shapes
batch, nhead = 16, 8 batch, nhead = 8, 28
dm, dk, dv = 512, 512, 512 dm, dk, dv = 1024, 1024, 1024
lq, lk, lv = 256, 256, 256 lq, lk, lv = 1024, 1024, 1024
# initialize tensors # initialize tensors
torch.manual_seed(0) torch.manual_seed(0)
np.random.seed(0) np.random.seed(0)
@@ -44,4 +44,5 @@ start = time()
toutput, _ = tattn(query, key, value) toutput, _ = tattn(query, key, value)
end = time() end = time()
ttime = end - start ttime = end - start
print(rtime, ttime) print(f'Torch: {rtime} s')
print(f'Triton: {ttime} s')

View File

@@ -546,7 +546,7 @@ __global__ void {name}(
TZ = [x for x in [1, 2, 4, 8, 16, 32] \ TZ = [x for x in [1, 2, 4, 8, 16, 32] \
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256] if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2] TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
TM, TN, TB = [128], [64], [1] #TM, TN, TB = [128], [64], [1]
#print(TM, TN, TB) #print(TM, TN, TB)
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype } self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype }
self.dtype = dtype self.dtype = dtype