[PYTHON][EINSUM] Added support for FP16
This commit is contained in:
@@ -38,7 +38,7 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
|
|||||||
double total_time = 0;
|
double total_time = 0;
|
||||||
op();
|
op();
|
||||||
stream->synchronize();
|
stream->synchronize();
|
||||||
while(total_time*1e-9 < 1e-3){
|
while(total_time*1e-9 < 1e-2){
|
||||||
float norm = 1;
|
float norm = 1;
|
||||||
// normalize clock if possible to reduce noise in auto-tuning
|
// normalize clock if possible to reduce noise in auto-tuning
|
||||||
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(stream->context()->device()))
|
||||||
|
@@ -314,11 +314,11 @@ layout_shared_t::layout_shared_t(const layout_t *arg,
|
|||||||
// padding
|
// padding
|
||||||
pad = 0;
|
pad = 0;
|
||||||
if(hmma_dot_a){
|
if(hmma_dot_a){
|
||||||
bool row = is_trans(hmma_dot_a) ^ order[0] == 1;
|
bool row = is_trans(hmma_dot_a) ^ order[0] != 0;
|
||||||
pad = 24 - shapes[row ? order[0] : order[1]] % 32;
|
pad = 24 - shapes[row ? order[0] : order[1]] % 32;
|
||||||
}
|
}
|
||||||
else if(hmma_dot_b){
|
else if(hmma_dot_b){
|
||||||
bool row = is_trans(hmma_dot_b) ^ order[0] == 1;
|
bool row = is_trans(hmma_dot_b) ^ order[0] != 0;
|
||||||
pad = 24 - shapes[row ? order[1] : order[0]] % 32;
|
pad = 24 - shapes[row ? order[1] : order[0]] % 32;
|
||||||
}
|
}
|
||||||
else if(order != arg->order) {
|
else if(order != arg->order) {
|
||||||
|
@@ -560,9 +560,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile *
|
|||||||
|
|
||||||
bool is_a_trans = is_trans(dot->get_operand(0));
|
bool is_a_trans = is_trans(dot->get_operand(0));
|
||||||
bool is_b_trans = is_trans(dot->get_operand(1));
|
bool is_b_trans = is_trans(dot->get_operand(1));
|
||||||
bool is_a_row = is_a_trans ^ (ord_a[0] == 1);
|
bool is_a_row = is_a_trans ^ (ord_a[0] != 0);
|
||||||
bool is_b_row = is_b_trans ^ (ord_b[0] == 1);
|
bool is_b_row = is_b_trans ^ (ord_b[0] != 0);
|
||||||
|
|
||||||
|
|
||||||
Value *offset_a_i = hmma->offset_a_i_;
|
Value *offset_a_i = hmma->offset_a_i_;
|
||||||
Value *offset_a_k = hmma->offset_a_k_;
|
Value *offset_a_k = hmma->offset_a_k_;
|
||||||
|
@@ -5,8 +5,8 @@ def run_tf():
|
|||||||
M, N, K = 2048, 2048, 2048
|
M, N, K = 2048, 2048, 2048
|
||||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||||
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True, bench=10)
|
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True, bench=1)
|
||||||
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False, bench=10)
|
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False, bench=1)
|
||||||
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)
|
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)
|
||||||
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False)
|
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False)
|
||||||
# Gradient
|
# Gradient
|
||||||
@@ -23,7 +23,7 @@ def run_tf():
|
|||||||
# Benchmark
|
# Benchmark
|
||||||
nanosec = triton.bench_registry[tr_d]
|
nanosec = triton.bench_registry[tr_d]
|
||||||
print('NANOSEC: ', nanosec)
|
print('NANOSEC: ', nanosec)
|
||||||
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
#print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||||
# Test
|
# Test
|
||||||
print(result[0][0])
|
print(result[0][0])
|
||||||
print(result[1][0])
|
print(result[1][0])
|
||||||
|
@@ -12,7 +12,7 @@ from tensorflow.python.ops import gradient_checker
|
|||||||
|
|
||||||
one = 0
|
one = 0
|
||||||
out = 0
|
out = 0
|
||||||
bench = 10
|
bench = 0
|
||||||
|
|
||||||
class ProdKeyTest(tf.test.TestCase):
|
class ProdKeyTest(tf.test.TestCase):
|
||||||
|
|
||||||
@@ -37,14 +37,14 @@ class ProdKeyTest(tf.test.TestCase):
|
|||||||
# key_dim = 16
|
# key_dim = 16
|
||||||
|
|
||||||
for a_shape, b_shape, c_shape, einsum in [
|
for a_shape, b_shape, c_shape, einsum in [
|
||||||
#[ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ],
|
[ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ],
|
||||||
[ [4, 2048, 2048 ], [ 2048, 2048 ], [4, 2048, 2048 ], "btc,ck->btk" ],
|
[ [4, 1024, 1024], [ 1024, 1024 ], [4, 1024, 1024 ], "btc,ck->btk" ],
|
||||||
#[ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ],
|
[ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ],
|
||||||
]:
|
]:
|
||||||
|
|
||||||
if one:
|
if one:
|
||||||
A = np.ones(a_shape, dtype=np.float32)
|
A = np.ones(a_shape, dtype=np.float16).astype(np.float32)
|
||||||
B = np.ones(b_shape, dtype=np.float32)
|
B = np.ones(b_shape, dtype=np.float16).astype(np.float32)
|
||||||
E = np.ones(c_shape, dtype=np.float32)
|
E = np.ones(c_shape, dtype=np.float32)
|
||||||
else:
|
else:
|
||||||
# QK = np.random.normal(loc=0.0, scale=1.0, size=qk_shape).astype(np.float16).astype(np.float32)
|
# QK = np.random.normal(loc=0.0, scale=1.0, size=qk_shape).astype(np.float16).astype(np.float32)
|
||||||
@@ -53,12 +53,14 @@ class ProdKeyTest(tf.test.TestCase):
|
|||||||
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
|
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
|
||||||
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
|
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
|
||||||
|
|
||||||
a = tf.placeholder(tf.float32, a_shape, name="a")
|
a = tf.placeholder(tf.float16, a_shape, name="a")
|
||||||
b = tf.placeholder(tf.float32, b_shape, name="b")
|
b = tf.placeholder(tf.float16, b_shape, name="b")
|
||||||
e = tf.placeholder(tf.float32, c_shape, name="e")
|
e = tf.placeholder(tf.float16, c_shape, name="e")
|
||||||
feed_dict = { a:A, b:B, e:E }
|
feed_dict = { a: A.astype(np.float16),
|
||||||
|
b: B.astype(np.float16),
|
||||||
|
e: E }
|
||||||
|
|
||||||
cc = triton.ops.einsum(einsum, a, b, bench=bench)
|
c = triton.ops.einsum(einsum, a, b, bench=bench)
|
||||||
|
|
||||||
# error = gradient_checker.compute_gradient_error(a, a_shape, c, c_shape, delta=1e-1, extra_feed_dict={ b:B }) #
|
# error = gradient_checker.compute_gradient_error(a, a_shape, c, c_shape, delta=1e-1, extra_feed_dict={ b:B }) #
|
||||||
# print(error)
|
# print(error)
|
||||||
@@ -66,21 +68,24 @@ class ProdKeyTest(tf.test.TestCase):
|
|||||||
# print(error)
|
# print(error)
|
||||||
# return
|
# return
|
||||||
|
|
||||||
with tf.control_dependencies([cc.op]):
|
with tf.control_dependencies([c.op]):
|
||||||
da, db = tf.gradients(cc, [a, b], e)
|
da, db = tf.gradients(c, [a, b], e)
|
||||||
|
|
||||||
# c, = sess.run( [ c, ], feed_dict )
|
# c, = sess.run( [ c, ], feed_dict )
|
||||||
c, da, db = sess.run( [ cc, da, db ], feed_dict )
|
rc, rda, rdb = sess.run( [ c, da, db ], feed_dict )
|
||||||
|
|
||||||
if bench > 0:
|
if bench > 0:
|
||||||
nanosec = triton.bench_registry[cc]
|
nanosec = triton.bench_registry[c]
|
||||||
print(A.shape, B.shape)
|
ctx = triton.ctx_registry[c]
|
||||||
print(nanosec)
|
b, m, n, k = tuple((ctx.bmnk[i] for i in range(0, 4)))
|
||||||
|
ops = 2. * b * m * n * k
|
||||||
|
print('C TFLOPS:', ops / triton.bench_registry[c] * 1e-3)
|
||||||
|
print('DA TFLOPS:', ops / triton.bench_registry[da] * 1e-3)
|
||||||
|
print('DB TFLOPS:', ops / triton.bench_registry[db] * 1e-3)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
C = np.einsum(einsum, A, B)
|
C = np.einsum(einsum, A, B)
|
||||||
id = cc.op.get_attr('id')
|
ctx = triton.ctx_registry[c]
|
||||||
ctx = triton.ops._einsum.contexts[id]
|
|
||||||
t_a = ctx.trans_a
|
t_a = ctx.trans_a
|
||||||
t_b = ctx.trans_b
|
t_b = ctx.trans_b
|
||||||
e_a = ctx.einsum_a
|
e_a = ctx.einsum_a
|
||||||
@@ -100,9 +105,9 @@ class ProdKeyTest(tf.test.TestCase):
|
|||||||
print("testProdKey", einsum)
|
print("testProdKey", einsum)
|
||||||
if not bench:
|
if not bench:
|
||||||
for op, dev, cpu in [
|
for op, dev, cpu in [
|
||||||
[ "C", c, C ],
|
[ "C", rc, C ],
|
||||||
[ "DA", da, DA ],
|
[ "DA", rda, DA ],
|
||||||
[ "DB", db, DB ],
|
[ "DB", rdb, DB ],
|
||||||
]:
|
]:
|
||||||
self.compare_results(op, dev, cpu)
|
self.compare_results(op, dev, cpu)
|
||||||
|
|
||||||
|
@@ -77,6 +77,7 @@ class CMakeBuild(build_ext):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
cfg = 'Debug' if self.debug else 'Release'
|
cfg = 'Debug' if self.debug else 'Release'
|
||||||
|
cfg = 'Release'
|
||||||
build_args = ['--config', cfg]
|
build_args = ['--config', cfg]
|
||||||
|
|
||||||
if platform.system() == "Windows":
|
if platform.system() == "Windows":
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import triton.frameworks as fw
|
import triton.frameworks as fw
|
||||||
|
import triton.utils
|
||||||
|
|
||||||
class OpContext(object):
|
class OpContext(object):
|
||||||
|
|
||||||
@@ -16,6 +17,8 @@ class function_meta(type):
|
|||||||
cls.registered = False
|
cls.registered = False
|
||||||
return super(function_meta, cls).__init__(name, bases, attrs)
|
return super(function_meta, cls).__init__(name, bases, attrs)
|
||||||
|
|
||||||
|
ctx_registry = triton.utils.id_dict()
|
||||||
|
|
||||||
class function(metaclass = function_meta):
|
class function(metaclass = function_meta):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -31,7 +34,9 @@ class function(metaclass = function_meta):
|
|||||||
class TorchFunction(fw.torch.autograd.Function):
|
class TorchFunction(fw.torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, *targs, **tkwargs):
|
def forward(ctx, *targs, **tkwargs):
|
||||||
return cls.forward(ctx, *targs, **tkwargs)
|
y = cls.forward(ctx, *targs, **tkwargs)
|
||||||
|
ctx_registry[y] = ctx
|
||||||
|
return y
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return cls.backward(ctx, grad_output)
|
return cls.backward(ctx, grad_output)
|
||||||
@@ -43,6 +48,7 @@ class function(metaclass = function_meta):
|
|||||||
result = cls.forward(ctx, *args, **kwargs)
|
result = cls.forward(ctx, *args, **kwargs)
|
||||||
id = result.op.get_attr('id')
|
id = result.op.get_attr('id')
|
||||||
cls.contexts[id] = ctx
|
cls.contexts[id] = ctx
|
||||||
|
ctx_registry[result] = ctx
|
||||||
# register backward
|
# register backward
|
||||||
name = result.op.op_def.name
|
name = result.op.op_def.name
|
||||||
if not cls.registered:
|
if not cls.registered:
|
||||||
|
@@ -177,37 +177,7 @@ def _make_grid(args) :
|
|||||||
return grid
|
return grid
|
||||||
|
|
||||||
|
|
||||||
class bench_dict:
|
bench_registry = triton.utils.id_dict()
|
||||||
|
|
||||||
# Lazy entry for e.g., tensorflow, when value of benchmark is
|
|
||||||
# not known at graph compile time
|
|
||||||
class lazy_entry:
|
|
||||||
def __init__(self, id):
|
|
||||||
self.id = id
|
|
||||||
|
|
||||||
def get(self):
|
|
||||||
return libtriton.retrieve_scalar(self.id)
|
|
||||||
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.data = dict()
|
|
||||||
|
|
||||||
def __delitem__(self, key):
|
|
||||||
del self.data[id(key)]
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
ret = self.data[id(key)]
|
|
||||||
if isinstance(ret, bench_dict.lazy_entry):
|
|
||||||
return ret.get()
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def __setitem__(self, key, value):
|
|
||||||
self.data[id(key)] = value
|
|
||||||
|
|
||||||
bench_registry = bench_dict()
|
|
||||||
|
|
||||||
class kernel:
|
class kernel:
|
||||||
|
|
||||||
@@ -233,7 +203,7 @@ class kernel:
|
|||||||
defines.append((k, values))
|
defines.append((k, values))
|
||||||
opt = libtriton.options_space()
|
opt = libtriton.options_space()
|
||||||
opt.defines = defines
|
opt.defines = defines
|
||||||
opt.num_warps = [2, 4, 8]
|
opt.num_warps = [4]
|
||||||
# create unique id for this op
|
# create unique id for this op
|
||||||
op_id = libtriton.make_op_id()
|
op_id = libtriton.make_op_id()
|
||||||
self.fw_id[key] = op_id
|
self.fw_id[key] = op_id
|
||||||
@@ -257,7 +227,7 @@ class kernel:
|
|||||||
bench_id = libtriton.make_scalar_id() if bench > 0 else 0
|
bench_id = libtriton.make_scalar_id() if bench > 0 else 0
|
||||||
ret = self.fw_op(*op_args, id=op_id, bench=bench, bench_id=bench_id)
|
ret = self.fw_op(*op_args, id=op_id, bench=bench, bench_id=bench_id)
|
||||||
if bench > 0:
|
if bench > 0:
|
||||||
bench_registry[ret] = bench_dict.lazy_entry(bench_id)
|
bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id)
|
||||||
|
|
||||||
elif fw.has_torch():
|
elif fw.has_torch():
|
||||||
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in op_args]
|
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in op_args]
|
||||||
|
@@ -7,10 +7,14 @@ import math
|
|||||||
class _einsum(triton.function):
|
class _einsum(triton.function):
|
||||||
|
|
||||||
src = """
|
src = """
|
||||||
void einsum_(TYPE * A, TYPE * B, TYPE * C,
|
void einsumk(TYPE * A, TYPE * B, TYPE * C,
|
||||||
int dim_M, int dim_N, int dim_K,
|
int dim_M, int dim_N, int dim_K,
|
||||||
int std_A0, int std_B0, int std_C0,
|
int std_A0 __multipleof(8),
|
||||||
int std_A1, int std_B1, int std_C1) {
|
int std_B0 __multipleof(8),
|
||||||
|
int std_C0 __multipleof(8),
|
||||||
|
int std_A1 __multipleof(8),
|
||||||
|
int std_B1 __multipleof(8),
|
||||||
|
int std_C1 __multipleof(8)) {
|
||||||
// program id
|
// program id
|
||||||
int pgm = get_program_id(0);
|
int pgm = get_program_id(0);
|
||||||
int pgn = get_program_id(1);
|
int pgn = get_program_id(1);
|
||||||
@@ -21,7 +25,7 @@ void einsum_(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
int rb[TB] = pgb * TB + 0 ... TB;
|
int rb[TB] = pgb * TB + 0 ... TB;
|
||||||
int rk[TK] = 0 ... TK;
|
int rk[TK] = 0 ... TK;
|
||||||
// accumulator
|
// accumulator
|
||||||
TYPE c[TM, TN, TB] = 0;
|
float c[TM, TN, TB] = 0;
|
||||||
// pointers to a
|
// pointers to a
|
||||||
TYPE *pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK
|
TYPE *pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK
|
||||||
+ rm[BROADCAST_AM] * STRIDE_AM
|
+ rm[BROADCAST_AM] * STRIDE_AM
|
||||||
@@ -51,7 +55,7 @@ void einsum_(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
bool checkn[TN] = rn < dim_N;
|
bool checkn[TN] = rn < dim_N;
|
||||||
bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] &&
|
bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] &&
|
||||||
checkn[newaxis, :, newaxis];
|
checkn[newaxis, :, newaxis];
|
||||||
*?(checkc)pc = c;
|
*?(checkc)pc = (TYPE[TM, TN, TB])c;
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -164,15 +168,14 @@ void einsum_(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
TM = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[1]) + 1 ))))]
|
TM = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[1]) + 1 ))))]
|
||||||
TN = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[2]) + 1 ))))]
|
TN = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[2]) + 1 ))))]
|
||||||
TB = [2**i for i in range(0, max(1, min(3, int(math.log2(bmnk[0]) + 1 ))))]
|
TB = [2**i for i in range(0, max(1, min(3, int(math.log2(bmnk[0]) + 1 ))))]
|
||||||
print(TM)
|
TK = [bmnk[2]] if bmnk[2] < 16 else [8, 16]
|
||||||
print(TN)
|
|
||||||
return _einsum.kernel(a, b, c,
|
return _einsum.kernel(a, b, c,
|
||||||
bmnk[1], bmnk[2], bmnk[3],
|
bmnk[1], bmnk[2], bmnk[3],
|
||||||
std0[0], std0[1], std0[2],
|
std0[0], std0[1], std0[2],
|
||||||
std1[0], std1[1], std1[2],
|
std1[0], std1[1], std1[2],
|
||||||
grid, bench=bench,
|
grid, bench=bench,
|
||||||
**macros,
|
**macros,
|
||||||
TYPE='float', TM=TM, TN=TN, TK=8, TB=TB)
|
TYPE=dtype, TM=TM, TN=TN, TK=TK, TB=TB)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -195,6 +198,7 @@ void einsum_(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
ctx.einsum_b = einsum_b
|
ctx.einsum_b = einsum_b
|
||||||
ctx.einsum_c = einsum_c
|
ctx.einsum_c = einsum_c
|
||||||
ctx.bench = bench
|
ctx.bench = bench
|
||||||
|
ctx.bmnk = bmnk
|
||||||
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c, bench)
|
return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c, bench)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -89,3 +89,32 @@ class scalar:
|
|||||||
return -self.get_value()
|
return -self.get_value()
|
||||||
|
|
||||||
|
|
||||||
|
class id_dict:
|
||||||
|
|
||||||
|
# Lazy entry for e.g., tensorflow, when value of benchmark is
|
||||||
|
# not known at graph compile time
|
||||||
|
class lazy_entry:
|
||||||
|
def __init__(self, id):
|
||||||
|
self.id = id
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
return libtriton.retrieve_scalar(self.id)
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.data = dict()
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
del self.data[id(key)]
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
ret = self.data[id(key)]
|
||||||
|
if isinstance(ret, id_dict.lazy_entry):
|
||||||
|
return ret.get()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
self.data[id(key)] = value
|
@@ -9,9 +9,9 @@ int main() {
|
|||||||
// shapes to benchmark
|
// shapes to benchmark
|
||||||
typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t;
|
typedef std::tuple<std::vector<int>, bool, bool, int, int, int> config_t;
|
||||||
std::vector<config_t> configs;
|
std::vector<config_t> configs;
|
||||||
for(auto ord: std::vector<std::vector<int>>{{0, 1}, {1, 0}})
|
for(auto ord: std::vector<std::vector<int>>{{1, 0}})
|
||||||
for(auto x: std::vector<std::array<bool, 2>>{{false, false}, {false, true},
|
for(auto x: std::vector<std::array<bool, 2>>{{false, false}, {false, true},
|
||||||
{true, false}, {true, true}}){
|
{true, false}}){
|
||||||
std::vector<config_t> tmp = {
|
std::vector<config_t> tmp = {
|
||||||
config_t{ord, x[0], x[1], 2048, 2048, 2048},
|
config_t{ord, x[0], x[1], 2048, 2048, 2048},
|
||||||
// config_t{ord, x[0], x[1], 16, 2048, 2048},
|
// config_t{ord, x[0], x[1], 16, 2048, 2048},
|
||||||
@@ -34,7 +34,7 @@ int main() {
|
|||||||
for(const auto& c: configs){
|
for(const auto& c: configs){
|
||||||
std::tie(ord, AT, BT, M, N, K) = c;
|
std::tie(ord, AT, BT, M, N, K) = c;
|
||||||
std::cout << "// " << c << std::flush;
|
std::cout << "// " << c << std::flush;
|
||||||
for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord))
|
for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord))
|
||||||
std::cout << ", " << perf << std::flush;
|
std::cout << ", " << perf << std::flush;
|
||||||
std::cout << std::endl;
|
std::cout << std::endl;
|
||||||
}
|
}
|
||||||
|
@@ -109,10 +109,10 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
|||||||
opt.num_warps = {nwarp};
|
opt.num_warps = {nwarp};
|
||||||
}
|
}
|
||||||
if(mode == BENCH) {
|
if(mode == BENCH) {
|
||||||
opt.defines.push_back({"TM", {"64", "128"}});
|
opt.defines.push_back({"TM", {"128"}});
|
||||||
opt.defines.push_back({"TN", {"64", "128"}});
|
opt.defines.push_back({"TN", {"128"}});
|
||||||
opt.defines.push_back({"TK", {"8"}});
|
opt.defines.push_back({"TK", {"16"}});
|
||||||
opt.num_warps = {2, 4, 8};
|
opt.num_warps = {4};
|
||||||
}
|
}
|
||||||
|
|
||||||
// kernels
|
// kernels
|
||||||
|
@@ -23,8 +23,8 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
// reduction loop
|
// reduction loop
|
||||||
for(int k = K; k > 0; k-= TK){
|
for(int k = K; k > 0; k-= TK){
|
||||||
c += USEA @ USEB;
|
c += USEA @ USEB;
|
||||||
pa = pa + TK * STRIDE_AK;
|
pa += TK * STRIDE_AK;
|
||||||
pb = pb + TK * STRIDE_BK;
|
pb += TK * STRIDE_BK;
|
||||||
bool checka[SHAPE_A] = k > TK;
|
bool checka[SHAPE_A] = k > TK;
|
||||||
bool checkb[SHAPE_B] = k > TK;
|
bool checkb[SHAPE_B] = k > TK;
|
||||||
a = checka ? *pa : 0;
|
a = checka ? *pa : 0;
|
||||||
|
Reference in New Issue
Block a user