[PYTHON][EXAMPLES] Cleaned self-attention benchmarks

This commit is contained in:
Philippe Tillet
2020-01-22 18:09:00 -05:00
parent ce7a00674a
commit db941161ed
2 changed files with 7 additions and 6 deletions

View File

@@ -4,7 +4,7 @@ import reference
import optimized
from time import time
use_half = False
use_half = True
def cast(x):
if use_half:
return x.half()
@@ -14,9 +14,9 @@ def cast(x):
# GPU device
device = torch.device("cuda:0")
# shapes
batch, nhead = 16, 8
dm, dk, dv = 512, 512, 512
lq, lk, lv = 256, 256, 256
batch, nhead = 8, 28
dm, dk, dv = 1024, 1024, 1024
lq, lk, lv = 1024, 1024, 1024
# initialize tensors
torch.manual_seed(0)
np.random.seed(0)
@@ -44,4 +44,5 @@ start = time()
toutput, _ = tattn(query, key, value)
end = time()
ttime = end - start
print(rtime, ttime)
print(f'Torch: {rtime} s')
print(f'Triton: {ttime} s')

View File

@@ -546,7 +546,7 @@ __global__ void {name}(
TZ = [x for x in [1, 2, 4, 8, 16, 32] \
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
TM, TN, TB = [128], [64], [1]
#TM, TN, TB = [128], [64], [1]
#print(TM, TN, TB)
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype }
self.dtype = dtype