From 0ec213547c2593ec868c402193ed49da90a735d5 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 27 Oct 2019 15:32:34 -0400 Subject: [PATCH] [PYTHON][KERNEL] Added benchmarking functionalities for kernels --- lib/codegen/transform/disassociate.cc | 8 +- python/examples/dot.py | 19 ++-- python/examples/einsum_test.py | 17 ++-- python/src/bindings.cc | 49 +++++++--- python/triton/kernel.py | 52 ++++++++++- python/triton/ops/dot.py | 41 ++++----- python/triton/ops/einsum.py | 128 +++++++++++++++----------- python/triton/utils.py | 3 +- tests/bench/dot.cc | 2 +- 9 files changed, 207 insertions(+), 112 deletions(-) diff --git a/lib/codegen/transform/disassociate.cc b/lib/codegen/transform/disassociate.cc index 1134463ec..2244ebccd 100644 --- a/lib/codegen/transform/disassociate.cc +++ b/lib/codegen/transform/disassociate.cc @@ -56,14 +56,12 @@ void disassociate::run(ir::module &mod) { bld.set_insert_point(y); bld.insert(cloned); clone_map[y] = cloned; - // replace in above level - if(depth > 1){ + // replace operands of parents + if(depth > 1) for(ir::user* ux: x.second.at(depth - 1)) clone_map.at((ir::instruction*)ux)->replace_uses_of_with(y, cloned); - } - else{ + else x.first->replace_uses_of_with(y, cloned); - } } depth += 1; } diff --git a/python/examples/dot.py b/python/examples/dot.py index eaa9c2d68..8fd0b35d9 100644 --- a/python/examples/dot.py +++ b/python/examples/dot.py @@ -2,11 +2,11 @@ import numpy as np import triton def run_tf(): - M, N, K = 128, 128, 128 + M, N, K = 2048, 2048, 2048 a = tf.placeholder(tf.float32, shape=[M, K]) b = tf.placeholder(tf.float32, shape=[N, K]) - tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True) - tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False) + tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True, bench=10) + tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False, bench=10) tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True) tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False) # Gradient @@ -20,15 +20,20 @@ def run_tf(): sess.run(tf.global_variables_initializer()) result = sess.run([tr_da, tf_da], feed_dict = {a: ha, b: hb}) + # Benchmark + nanosec = triton.bench_registry[tr_d] + print('NANOSEC: ', nanosec) + print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3) # Test print(result[0][0]) print(result[1][0]) dif = np.abs(result[0][0] - result[1][0]) print("dif: %f" % np.max(dif)) + def run_torch(): torch.manual_seed(0) - M, N, K = 128, 128, 128 + M, N, K = 2048, 2048, 2048 a = torch.randn(M, K).cuda() b = torch.randn(K, N).cuda() a.requires_grad_(True) @@ -37,9 +42,8 @@ def run_torch(): torch_d = torch.matmul(torch.t(torch_c), b) torch_y = torch.mean(torch_d) triton_c = triton.ops.dot(a, b, False, True) - triton_d = triton.ops.dot(triton_c, b, True, False) + triton_d = triton.ops.dot(triton_c, b, True, False, 1) triton_y = torch.mean(triton_d) - # torch gradient torch_y.backward() torch_da = a.grad.clone() @@ -51,6 +55,9 @@ def run_torch(): triton_da = a.grad.clone() triton_db = b.grad.clone() + nanosec = triton.bench_registry[triton_d] + print(nanosec) + print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3) print('Diff DA:', (torch_da - triton_da).max()) print('Diff DB:', (torch_db - triton_db).max()) diff --git a/python/examples/einsum_test.py b/python/examples/einsum_test.py index 799efbf70..4a7c2f2c7 100644 --- a/python/examples/einsum_test.py +++ b/python/examples/einsum_test.py @@ -12,7 +12,8 @@ from tensorflow.python.ops import gradient_checker one = 0 out = 0 -bench = 0 +bench = 10 + class ProdKeyTest(tf.test.TestCase): def testEinsum(self): @@ -36,9 +37,9 @@ class ProdKeyTest(tf.test.TestCase): # key_dim = 16 for a_shape, b_shape, c_shape, einsum in [ - [ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ], - [ [ 4, 1024, 1024 ], [ 1024, 512 ], [ 4, 1024, 512 ], "btc,ck->btk" ], - [ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ], + #[ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ], + [ [4, 2048, 2048 ], [ 2048, 2048 ], [4, 2048, 2048 ], "btc,ck->btk" ], + #[ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ], ]: if one: @@ -57,7 +58,7 @@ class ProdKeyTest(tf.test.TestCase): e = tf.placeholder(tf.float32, c_shape, name="e") feed_dict = { a:A, b:B, e:E } - cc = triton.ops.einsum(einsum, a, b) + cc = triton.ops.einsum(einsum, a, b, bench=bench) # error = gradient_checker.compute_gradient_error(a, a_shape, c, c_shape, delta=1e-1, extra_feed_dict={ b:B }) # # print(error) @@ -71,8 +72,12 @@ class ProdKeyTest(tf.test.TestCase): # c, = sess.run( [ c, ], feed_dict ) c, da, db = sess.run( [ cc, da, db ], feed_dict ) - if bench == 0: + if bench > 0: + nanosec = triton.bench_registry[cc] + print(A.shape, B.shape) + print(nanosec) + else: C = np.einsum(einsum, A, B) id = cc.op.get_attr('id') ctx = triton.ops._einsum.contexts[id] diff --git a/python/src/bindings.cc b/python/src/bindings.cc index 969f74df4..59b5c54d6 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -20,13 +20,13 @@ using namespace triton; namespace rt = triton::runtime; - -/* TF triton op properties */ - std::map> id_grid_map; std::map> id_fn_map; +std::map fp64scalar_map; std::map i64scalar_map; +/* Grid map */ + void register_grid(size_t id, const rt::function::grid_fn_ty& grid_fn) { id_grid_map[id].reset(new rt::function::grid_fn_ty(grid_fn)); @@ -36,6 +36,8 @@ void delete_grid(size_t id) { id_grid_map.erase(id); } +/* Function map */ + void register_fn(size_t id, const std::string& src, const rt::function::options_space_t& opt) { @@ -56,8 +58,11 @@ size_t make_op_id() { return id_fn_map.size(); } +/* TF scalar wrapper */ size_t make_scalar_id() { - return i64scalar_map.size(); + size_t ret = i64scalar_map.size(); + i64scalar_map[ret] = int64_t(); + return ret; } bool has_scalar(size_t id) { @@ -135,8 +140,9 @@ void gen_make_handles(std::ostream &os, const std::vector& args) } } -void gen_make_launch_function(std::ostream &os, const std::vector& args) { - os << " (*id_fn_map.at(id_))({"; +void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vector& args) { + os << " std::function run = [&](){\n "; + os << " (*id_fn_map.at(id_))({"; for(unsigned i = 0; i < args.size() ; i++){ ir::argument *arg = args[i]; std::string name = arg->get_name(); @@ -146,7 +152,11 @@ void gen_make_launch_function(std::ostream &os, const std::vector os << ", "; os << name; } - os << "}, *id_grid_map.at(id_), stream); \n"; + os << "}, *id_grid_map.at(id_), stream);\n"; + os << " };\n "; + os << " run();"; + os << " if(bench_ > 0)\n "; + os << " i64scalar_map[id_] = triton::tools::bench(run, stream);\n "; } void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name, @@ -186,7 +196,9 @@ void gen_tf_register_op(std::ostream &os, const std::string &name, throw std::runtime_error("unknown output"); os << " .Output(\"out" << i << ": T" << idx << "\")\n"; } - os << " .Attr(\"id: int\")" << std::endl; + os << " .Attr(\"id: int\")\n"; + os << " .Attr(\"bench: int\")\n"; + os << " .Attr(\"bench_id: int\")\n"; os << ";\n"; } @@ -247,6 +259,7 @@ std::tuple> id_grid_map; extern std::map> id_fn_map; - +extern std::map i64scalar_map; class )" << opname << R"(: public OpKernel { public: explicit )" << opname << R"((OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("id", &id_)); + OP_REQUIRES_OK(context, context->GetAttr("bench", &bench_)); + OP_REQUIRES_OK(context, context->GetAttr("bench_id", &bench_id_)); } void Compute(OpKernelContext* context){ @@ -291,12 +306,14 @@ oss << R"( oss << R"( // launch function )"; -gen_make_launch_function(oss, fn->args()); +gen_make_launch_function(oss, outputs.size(), fn->args()); oss << R"( } private: int id_; + int bench_; + int bench_id_; }; // register kernel builder @@ -379,6 +396,7 @@ void gen_torch_signature(std::ostringstream& oss, oss << ret_ty << " " << name << "("; oss << "int64_t id, "; + oss << "int64_t bench, "; for(size_t i = 0; i < args.size(); i++) { ir::argument* arg = args[i]; if(i > 0) @@ -420,7 +438,8 @@ void gen_torch_make_handles(std::ostream &os, } void gen_torch_make_launch_function(std::ostream &os, const std::vector& args) { - os << " (*id_fn_map.at(id))({"; + os << " std::function run = [&](){\n "; + os << " (*id_fn_map.at(id))({"; for(unsigned i = 0; i < args.size() ; i++){ ir::argument *arg = args[i]; std::string name = "arg_" + arg->get_name(); @@ -431,7 +450,11 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector 0)\n "; + os << " i64scalar_map[id] = triton::tools::bench(run, stream);\n "; + } void gen_torch_ret(std::ostream &os, const std::vector& outputs) { if(outputs.size() == 1){ @@ -465,6 +488,7 @@ std::tuple> id_grid_map; extern std::map> id_fn_map; +extern std::map i64scalar_map; )"; diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 50ade154e..3a71d0ecd 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -5,6 +5,7 @@ import shutil import hashlib import sysconfig import sys +import weakref # import for just-in-time compilation import distutils import setuptools.command.build_ext @@ -176,6 +177,38 @@ def _make_grid(args) : return grid +class bench_dict: + + # Lazy entry for e.g., tensorflow, when value of benchmark is + # not known at graph compile time + class lazy_entry: + def __init__(self, id): + self.id = id + + def get(self): + return libtriton.retrieve_scalar(self.id) + + + def __init__(self): + self.data = dict() + + def __delitem__(self, key): + del self.data[id(key)] + + def __getitem__(self, key): + ret = self.data[id(key)] + if isinstance(ret, bench_dict.lazy_entry): + return ret.get() + return ret + + def __len__(self): + return len(self.data) + + def __setitem__(self, key, value): + self.data[id(key)] = value + +bench_registry = bench_dict() + class kernel: def __init__(self, src, outputs): @@ -200,7 +233,7 @@ class kernel: defines.append((k, values)) opt = libtriton.options_space() opt.defines = defines - opt.num_warps = [4] + opt.num_warps = [2, 4, 8] # create unique id for this op op_id = libtriton.make_op_id() self.fw_id[key] = op_id @@ -209,6 +242,10 @@ class kernel: if self.fw_op is None: self.fw_op = _make_framework_op(self.src, self.outputs, opt) + # benchmarking info + bench = 0 + if 'bench' in kwargs: + bench = kwargs['bench'] # retrieve framework op op_id = self.fw_id[key] # register grid @@ -217,9 +254,16 @@ class kernel: op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]] # call framework function if fw.has_tensorflow(): - return self.fw_op(*op_args, id=op_id) + bench_id = libtriton.make_scalar_id() if bench > 0 else 0 + ret = self.fw_op(*op_args, id=op_id, bench=bench, bench_id=bench_id) + if bench > 0: + bench_registry[ret] = bench_dict.lazy_entry(bench_id) + elif fw.has_torch(): args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in op_args] - return self.fw_op(op_id, *args) + ret = self.fw_op(op_id, bench, *args) + if bench > 0: + bench_registry[ret] = libtriton.retrieve_scalar(op_id) else: - assert False \ No newline at end of file + assert False + return ret \ No newline at end of file diff --git a/python/triton/ops/dot.py b/python/triton/ops/dot.py index b37f2e32a..7a5069701 100644 --- a/python/triton/ops/dot.py +++ b/python/triton/ops/dot.py @@ -11,38 +11,36 @@ void dot(TYPE * A, TYPE * B, TYPE * C, // prologue int ridx = get_program_id(0); int ridy = get_program_id(1); - int rxa[TM] = ridx * TM + 0 ... TM; - int ryb[TN] = ridy * TN + 0 ... TN; - int rka[TK] = 0 ... TK; - int rkb[TK] = 0 ... TK; + int rm[TM] = ridx * TM + 0 ... TM; + int rn[TN] = ridy * TN + 0 ... TN; + int rk[TK] = 0 ... TK; float c[TM, TN] = 0; // pointers to operands - TYPE* pa[SHAPE_A] = A + rka[BROADCAST_AK] * STRIDE_AK + rxa[BROADCAST_AM] * STRIDE_AM; - TYPE* pb[SHAPE_B] = B + rkb[BROADCAST_BK] * STRIDE_BK + ryb[BROADCAST_BN] * STRIDE_BN; + TYPE* pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM; + TYPE* pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + rn[BROADCAST_BN] * STRIDE_BN; // prefetches operands - TYPE a[SHAPE_A] = (*pa); - TYPE b[SHAPE_B] = (*pb); + TYPE a[SHAPE_A] = *pa; + TYPE b[SHAPE_B] = *pb; // reduction loop for(int k = K; k > 0; k-= TK){ c += USE_A @ USE_B; pa = pa + TK * STRIDE_AK; pb = pb + TK * STRIDE_BK; - a = *pa; - b = *pb; + bool checka[SHAPE_A] = k > TK; + bool checkb[SHAPE_B] = k > TK; + a = checka ? *pa : 0; + b = checkb ? *pb : 0; } // epilogue - int rxc[TM] = ridx * TM + 0 ... TM; - int ryc[TN] = ridy * TN + 0 ... TN; - TYPE* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis] * ldc; - bool checkc[TM, TN] = (rxc < M)[:, newaxis] && (ryc < N)[newaxis, :]; - *?(checkc) pc = c; + TYPE* pc[TM, TN] = C + rm[:, newaxis] * ldc + rn[newaxis, :]; + *pc = c; } """ kernel = triton.kernel(src, ['C']) @staticmethod - def _call(a, b, transpose_a, transpose_b): + def _call(a, b, transpose_a, transpose_b, bench = 0): # extract shapes shape_a = triton.shape(a) shape_b = triton.shape(b) @@ -78,16 +76,17 @@ void dot(TYPE * A, TYPE * B, TYPE * C, 'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis', 'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :', 'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'} - return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc, grid, + return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc, + grid, bench=bench, AT = transpose_a, BT = transpose_b, TYPE = dtype, - TM = [128], TN = [128], TK = [8], **macros) + TM = [64, 128], TN = [64, 128], TK = [8], **macros) @staticmethod - def forward(ctx, a, b, transpose_a = False, transpose_b = False): + def forward(ctx, a, b, transpose_a = False, transpose_b = False, bench = 0): ctx.save_for_backward(a, b) ctx.t_a = transpose_a ctx.t_b = transpose_b - return _dot._call(a, b, transpose_a, transpose_b) + return _dot._call(a, b, transpose_a, transpose_b, bench) @staticmethod def backward(ctx, dy): @@ -108,5 +107,5 @@ void dot(TYPE * A, TYPE * B, TYPE * C, else: assert False return da, db, None, None, None, None, None, None, None - + dot = _dot.apply \ No newline at end of file diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 7f4457d99..d6207c194 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -2,52 +2,58 @@ import triton +import math class _einsum(triton.function): src = """ - void einsum_(TYPE * A, TYPE * B, TYPE * C, - int dim_M, int dim_N, int dim_K, - int std_A0, int std_B0, int std_C0, - int std_A1, int std_B1, int std_C1) { - // program id - int pgm = get_program_id(0); - int pgn = get_program_id(1); - int pgb = get_program_id(2); - // range - int rm[TM] = pgm * TM + 0 ... TM; - int rn[TN] = pgn * TN + 0 ... TN; - int rb[TB] = pgb * TB + 0 ... TB; - int rk[TK] = 0 ... TK; - // accumulator - TYPE c[TM, TN, TB] = 0; - // pointers to a - TYPE *pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK - + rm[BROADCAST_AM] * STRIDE_AM - + rb[newaxis, newaxis, :] * std_A0; - // pointers to b - TYPE *pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK - + rn[BROADCAST_BN] * STRIDE_BN - + rb[newaxis, newaxis, :] * std_B0; - // accumulation - for(int k = dim_K; k > 0; k -= TK) { - TYPE a[SHAPE_A] = *pa; - TYPE b[SHAPE_B] = *pb; - c += USE_A @ USE_B; - pa += TK * STRIDE_AK; - pb += TK * STRIDE_BK; - } - // write-back - TYPE *pc[TM, TN, TB] = C + rm[:, newaxis, newaxis] * std_C1 - + rn[newaxis, :, newaxis] * 1 - + rb[newaxis, newaxis, :] * std_C0; - bool checkm[TM] = rm < dim_M; - bool checkn[TN] = rn < dim_N; - bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] && - checkn[newaxis, :, newaxis]; - *?(checkc)pc = c; +void einsum_(TYPE * A, TYPE * B, TYPE * C, + int dim_M, int dim_N, int dim_K, + int std_A0, int std_B0, int std_C0, + int std_A1, int std_B1, int std_C1) { + // program id + int pgm = get_program_id(0); + int pgn = get_program_id(1); + int pgb = get_program_id(2); + // range + int rm[TM] = pgm * TM + 0 ... TM; + int rn[TN] = pgn * TN + 0 ... TN; + int rb[TB] = pgb * TB + 0 ... TB; + int rk[TK] = 0 ... TK; + // accumulator + TYPE c[TM, TN, TB] = 0; + // pointers to a + TYPE *pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + + rm[BROADCAST_AM] * STRIDE_AM + + rb[newaxis, newaxis, :] * std_A0; + // pointers to b + TYPE *pb[SHAPE_B] = B + rk[BROADCAST_BK] * STRIDE_BK + + rn[BROADCAST_BN] * STRIDE_BN + + rb[newaxis, newaxis, :] * std_B0; + // prefetch + TYPE a[SHAPE_A] = *pa; + TYPE b[SHAPE_B] = *pb; + // accumulation + for(int k = dim_K; k > 0; k -= TK) { + c += USE_A @ USE_B; + pa += TK * STRIDE_AK; + pb += TK * STRIDE_BK; + bool checka[SHAPE_A] = k > TK; + bool checkb[SHAPE_B] = k > TK; + a = checka ? *pa : 0; + b = checkb ? *pb : 0; } - """ + // write-back + TYPE *pc[TM, TN, TB] = C + rm[:, newaxis, newaxis] * std_C1 + + rn[newaxis, :, newaxis] * 1 + + rb[newaxis, newaxis, :] * std_C0; + bool checkm[TM] = rm < dim_M; + bool checkn[TN] = rn < dim_N; + bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] && + checkn[newaxis, :, newaxis]; + *?(checkc)pc = c; +} +""" kernel = triton.kernel(src, ['C']) @@ -134,7 +140,8 @@ class _einsum(triton.function): @staticmethod def call(a, b, trans_a, trans_b, shape_c, bmnk, - std0, std1, einsum_a, einsum_b, einsum_c): + std0, std1, einsum_a, einsum_b, einsum_c, + bench): dtype = a.dtype c = triton.empty(shape_c, dtype) grid = lambda opt: [triton.cdiv(bmnk[1], opt.d('TM')), @@ -154,16 +161,22 @@ class _einsum(triton.function): 'BROADCAST_BK': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis', 'BROADCAST_BN': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis', 'SHAPE_B' : 'TK, TN, TB' if not trans_b else 'TN, TK, TB'} - return _einsum.kernel(a, b, c, + TM = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[1]) + 1 ))))] + TN = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[2]) + 1 ))))] + TB = [2**i for i in range(0, max(1, min(3, int(math.log2(bmnk[0]) + 1 ))))] + print(TM) + print(TN) + return _einsum.kernel(a, b, c, bmnk[1], bmnk[2], bmnk[3], std0[0], std0[1], std0[2], std1[0], std1[1], std1[2], - grid, **macros, - TYPE='float', TM=32, TN=32, TK=8, TB=1) + grid, bench=bench, + **macros, + TYPE='float', TM=TM, TN=TN, TK=8, TB=TB) @staticmethod - def forward(ctx, subscripts, a, b): + def forward(ctx, subscripts, a, b, **kwargs): ctx.save_for_backward(a, b) if type(subscripts) is str: einsum_a, einsum_bc = subscripts.split(",") @@ -173,14 +186,16 @@ class _einsum(triton.function): shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum( einsum_a, einsum_b, einsum_c, - a.shape.as_list(), b.shape.as_list() + triton.shape(a), triton.shape(b) ) + bench = kwargs['bench'] if 'bench' in kwargs else 0 ctx.trans_a = ta ctx.trans_b = tb ctx.einsum_a = einsum_a ctx.einsum_b = einsum_b ctx.einsum_c = einsum_c - return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c) + ctx.bench = bench + return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c, bench) @staticmethod @@ -191,22 +206,23 @@ class _einsum(triton.function): einsum_a = ctx.einsum_a einsum_b = ctx.einsum_b einsum_c = ctx.einsum_c + bench = ctx.bench if not trans_a and not trans_b: # NN - da = einsum((einsum_c, einsum_b, einsum_a), dc, b) - db = einsum((einsum_a, einsum_c, einsum_b), a, dc) + da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench=bench) + db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench=bench) elif not trans_a and trans_b: # NT - da = einsum((einsum_c, einsum_b, einsum_a), dc, b) - db = einsum((einsum_c, einsum_a, einsum_b), dc, a) + da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench=bench) + db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench=bench) elif trans_a and not trans_b: # TN - da = einsum((einsum_b, einsum_c, einsum_a), b, dc) - db = einsum((einsum_a, einsum_c, einsum_b), a, dc) + da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench=bench) + db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench=bench) elif trans_a and trans_b: # TT (not used) - da = einsum((einsum_b, einsum_c, einsum_a), b, dc) - db = einsum((einsum_c, einsum_a, einsum_b), dc, a) + da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench=bench) + db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench=bench) return da, db, None, None, None, None, None, None, None, None, None, None diff --git a/python/triton/utils.py b/python/triton/utils.py index 127d67364..d5f5f4129 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -22,7 +22,8 @@ class lazy_shape: def shape(A) : if fw.has_tensorflow(): - return lazy_shape(fw.tensorflow.shape(A)) + return A.shape.as_list() + #return lazy_shape(fw.tensorflow.shape(A)) elif fw.has_torch(): return A.shape else: diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 927f0044b..c87e1c938 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -34,7 +34,7 @@ int main() { for(const auto& c: configs){ std::tie(ord, AT, BT, M, N, K) = c; std::cout << "// " << c << std::flush; - for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord)) + for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord)) std::cout << ", " << perf << std::flush; std::cout << std::endl; }