[PYTHON][EXAMPLES] Cleaned self-attention benchmarks
This commit is contained in:
@@ -4,7 +4,7 @@ import reference
|
|||||||
import optimized
|
import optimized
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
use_half = False
|
use_half = True
|
||||||
def cast(x):
|
def cast(x):
|
||||||
if use_half:
|
if use_half:
|
||||||
return x.half()
|
return x.half()
|
||||||
@@ -14,9 +14,9 @@ def cast(x):
|
|||||||
# GPU device
|
# GPU device
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
# shapes
|
# shapes
|
||||||
batch, nhead = 16, 8
|
batch, nhead = 8, 28
|
||||||
dm, dk, dv = 512, 512, 512
|
dm, dk, dv = 1024, 1024, 1024
|
||||||
lq, lk, lv = 256, 256, 256
|
lq, lk, lv = 1024, 1024, 1024
|
||||||
# initialize tensors
|
# initialize tensors
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
@@ -44,4 +44,5 @@ start = time()
|
|||||||
toutput, _ = tattn(query, key, value)
|
toutput, _ = tattn(query, key, value)
|
||||||
end = time()
|
end = time()
|
||||||
ttime = end - start
|
ttime = end - start
|
||||||
print(rtime, ttime)
|
print(f'Torch: {rtime} s')
|
||||||
|
print(f'Triton: {ttime} s')
|
@@ -546,7 +546,7 @@ __global__ void {name}(
|
|||||||
TZ = [x for x in [1, 2, 4, 8, 16, 32] \
|
TZ = [x for x in [1, 2, 4, 8, 16, 32] \
|
||||||
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
|
if x < MAX_GZ and x*MIN_GM*MIN_GN*MIN_GB < 256]
|
||||||
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
|
TZ = [1] if not TZ else [TZ[-1], TZ[-1]*2]
|
||||||
TM, TN, TB = [128], [64], [1]
|
#TM, TN, TB = [128], [64], [1]
|
||||||
#print(TM, TN, TB)
|
#print(TM, TN, TB)
|
||||||
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype }
|
self.macros = { 'TM': TM, 'TN': TN, 'TB': TB, 'TK': TK, 'TZ': TZ, 'TYPE': dtype }
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
Reference in New Issue
Block a user