[PYTHON][EXAMPLES] Cleaned self-attention benchmarks
This commit is contained in:
@@ -4,7 +4,7 @@ import reference
|
||||
import optimized
|
||||
from time import time
|
||||
|
||||
use_half = False
|
||||
use_half = True
|
||||
def cast(x):
|
||||
if use_half:
|
||||
return x.half()
|
||||
@@ -14,9 +14,9 @@ def cast(x):
|
||||
# GPU device
|
||||
device = torch.device("cuda:0")
|
||||
# shapes
|
||||
batch, nhead = 16, 8
|
||||
dm, dk, dv = 512, 512, 512
|
||||
lq, lk, lv = 256, 256, 256
|
||||
batch, nhead = 8, 28
|
||||
dm, dk, dv = 1024, 1024, 1024
|
||||
lq, lk, lv = 1024, 1024, 1024
|
||||
# initialize tensors
|
||||
torch.manual_seed(0)
|
||||
np.random.seed(0)
|
||||
@@ -44,4 +44,5 @@ start = time()
|
||||
toutput, _ = tattn(query, key, value)
|
||||
end = time()
|
||||
ttime = end - start
|
||||
print(rtime, ttime)
|
||||
print(f'Torch: {rtime} s')
|
||||
print(f'Triton: {ttime} s')
|
@@ -546,7 +546,7 @@ __global__ void {name}(
|
||||
TZ = [x for x in [1, 2, 4, 8, 16, 32] \
|
||||
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
|
||||
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
|
||||
TM, TN, TB = [128], [64], [1]
|
||||
#TM, TN, TB = [128], [64], [1]
|
||||
#print(TM, TN, TB)
|
||||
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype }
|
||||
self.dtype = dtype
|
||||
|
Reference in New Issue
Block a user