From e9c787ef054c96b5038f5023b51463e93a4b6ef7 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Mon, 28 Oct 2019 11:33:18 -0400 Subject: [PATCH] [PYTHON][EINSUM] Added support for FP16 --- include/triton/tools/bench.hpp | 2 +- lib/codegen/analysis/layout.cc | 4 +-- lib/codegen/selection/generator.cc | 5 ++- python/examples/dot.py | 6 ++-- python/examples/einsum_test.py | 51 ++++++++++++++++-------------- python/setup.py | 1 + python/triton/function.py | 8 ++++- python/triton/kernel.py | 36 ++------------------- python/triton/ops/einsum.py | 20 +++++++----- python/triton/utils.py | 29 +++++++++++++++++ tests/bench/dot.cc | 6 ++-- tests/common/dot.h | 8 ++--- tests/common/src/dot.h | 4 +-- 13 files changed, 97 insertions(+), 83 deletions(-) diff --git a/include/triton/tools/bench.hpp b/include/triton/tools/bench.hpp index 554b3bcc3..48a4ab972 100644 --- a/include/triton/tools/bench.hpp +++ b/include/triton/tools/bench.hpp @@ -38,7 +38,7 @@ inline double bench(std::function const & op, driver::stream * stream) double total_time = 0; op(); stream->synchronize(); - while(total_time*1e-9 < 1e-3){ + while(total_time*1e-9 < 1e-2){ float norm = 1; // normalize clock if possible to reduce noise in auto-tuning if(auto cu_device = dynamic_cast(stream->context()->device())) diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 6f717d77c..70ca9e3b2 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -314,11 +314,11 @@ layout_shared_t::layout_shared_t(const layout_t *arg, // padding pad = 0; if(hmma_dot_a){ - bool row = is_trans(hmma_dot_a) ^ order[0] == 1; + bool row = is_trans(hmma_dot_a) ^ order[0] != 0; pad = 24 - shapes[row ? order[0] : order[1]] % 32; } else if(hmma_dot_b){ - bool row = is_trans(hmma_dot_b) ^ order[0] == 1; + bool row = is_trans(hmma_dot_b) ^ order[0] != 0; pad = 24 - shapes[row ? order[1] : order[0]] % 32; } else if(order != arg->order) { diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 1ff4287eb..2efa834a8 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -560,9 +560,8 @@ void generator::visit_hmma_dot(ir::dot_inst* dot, shared_tile *TA, shared_tile * bool is_a_trans = is_trans(dot->get_operand(0)); bool is_b_trans = is_trans(dot->get_operand(1)); - bool is_a_row = is_a_trans ^ (ord_a[0] == 1); - bool is_b_row = is_b_trans ^ (ord_b[0] == 1); - + bool is_a_row = is_a_trans ^ (ord_a[0] != 0); + bool is_b_row = is_b_trans ^ (ord_b[0] != 0); Value *offset_a_i = hmma->offset_a_i_; Value *offset_a_k = hmma->offset_a_k_; diff --git a/python/examples/dot.py b/python/examples/dot.py index 8fd0b35d9..425fed986 100644 --- a/python/examples/dot.py +++ b/python/examples/dot.py @@ -5,8 +5,8 @@ def run_tf(): M, N, K = 2048, 2048, 2048 a = tf.placeholder(tf.float32, shape=[M, K]) b = tf.placeholder(tf.float32, shape=[N, K]) - tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True, bench=10) - tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False, bench=10) + tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True, bench=1) + tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False, bench=1) tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True) tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False) # Gradient @@ -23,7 +23,7 @@ def run_tf(): # Benchmark nanosec = triton.bench_registry[tr_d] print('NANOSEC: ', nanosec) - print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3) + #print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3) # Test print(result[0][0]) print(result[1][0]) diff --git a/python/examples/einsum_test.py b/python/examples/einsum_test.py index 4a7c2f2c7..3363a88ea 100644 --- a/python/examples/einsum_test.py +++ b/python/examples/einsum_test.py @@ -12,7 +12,7 @@ from tensorflow.python.ops import gradient_checker one = 0 out = 0 -bench = 10 +bench = 0 class ProdKeyTest(tf.test.TestCase): @@ -37,14 +37,14 @@ class ProdKeyTest(tf.test.TestCase): # key_dim = 16 for a_shape, b_shape, c_shape, einsum in [ - #[ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ], - [ [4, 2048, 2048 ], [ 2048, 2048 ], [4, 2048, 2048 ], "btc,ck->btk" ], - #[ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ], + [ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ], + [ [4, 1024, 1024], [ 1024, 1024 ], [4, 1024, 1024 ], "btc,ck->btk" ], + [ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ], ]: if one: - A = np.ones(a_shape, dtype=np.float32) - B = np.ones(b_shape, dtype=np.float32) + A = np.ones(a_shape, dtype=np.float16).astype(np.float32) + B = np.ones(b_shape, dtype=np.float16).astype(np.float32) E = np.ones(c_shape, dtype=np.float32) else: # QK = np.random.normal(loc=0.0, scale=1.0, size=qk_shape).astype(np.float16).astype(np.float32) @@ -53,12 +53,14 @@ class ProdKeyTest(tf.test.TestCase): B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32) E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32) - a = tf.placeholder(tf.float32, a_shape, name="a") - b = tf.placeholder(tf.float32, b_shape, name="b") - e = tf.placeholder(tf.float32, c_shape, name="e") - feed_dict = { a:A, b:B, e:E } + a = tf.placeholder(tf.float16, a_shape, name="a") + b = tf.placeholder(tf.float16, b_shape, name="b") + e = tf.placeholder(tf.float16, c_shape, name="e") + feed_dict = { a: A.astype(np.float16), + b: B.astype(np.float16), + e: E } - cc = triton.ops.einsum(einsum, a, b, bench=bench) + c = triton.ops.einsum(einsum, a, b, bench=bench) # error = gradient_checker.compute_gradient_error(a, a_shape, c, c_shape, delta=1e-1, extra_feed_dict={ b:B }) # # print(error) @@ -66,21 +68,24 @@ class ProdKeyTest(tf.test.TestCase): # print(error) # return - with tf.control_dependencies([cc.op]): - da, db = tf.gradients(cc, [a, b], e) + with tf.control_dependencies([c.op]): + da, db = tf.gradients(c, [a, b], e) # c, = sess.run( [ c, ], feed_dict ) - c, da, db = sess.run( [ cc, da, db ], feed_dict ) - + rc, rda, rdb = sess.run( [ c, da, db ], feed_dict ) + if bench > 0: - nanosec = triton.bench_registry[cc] - print(A.shape, B.shape) - print(nanosec) + nanosec = triton.bench_registry[c] + ctx = triton.ctx_registry[c] + b, m, n, k = tuple((ctx.bmnk[i] for i in range(0, 4))) + ops = 2. * b * m * n * k + print('C TFLOPS:', ops / triton.bench_registry[c] * 1e-3) + print('DA TFLOPS:', ops / triton.bench_registry[da] * 1e-3) + print('DB TFLOPS:', ops / triton.bench_registry[db] * 1e-3) else: C = np.einsum(einsum, A, B) - id = cc.op.get_attr('id') - ctx = triton.ops._einsum.contexts[id] + ctx = triton.ctx_registry[c] t_a = ctx.trans_a t_b = ctx.trans_b e_a = ctx.einsum_a @@ -100,9 +105,9 @@ class ProdKeyTest(tf.test.TestCase): print("testProdKey", einsum) if not bench: for op, dev, cpu in [ - [ "C", c, C ], - [ "DA", da, DA ], - [ "DB", db, DB ], + [ "C", rc, C ], + [ "DA", rda, DA ], + [ "DB", rdb, DB ], ]: self.compare_results(op, dev, cpu) diff --git a/python/setup.py b/python/setup.py index ea1568b2f..060a1c450 100644 --- a/python/setup.py +++ b/python/setup.py @@ -77,6 +77,7 @@ class CMakeBuild(build_ext): pass cfg = 'Debug' if self.debug else 'Release' + cfg = 'Release' build_args = ['--config', cfg] if platform.system() == "Windows": diff --git a/python/triton/function.py b/python/triton/function.py index 125cad668..79a0e5ec8 100644 --- a/python/triton/function.py +++ b/python/triton/function.py @@ -1,4 +1,5 @@ import triton.frameworks as fw +import triton.utils class OpContext(object): @@ -16,6 +17,8 @@ class function_meta(type): cls.registered = False return super(function_meta, cls).__init__(name, bases, attrs) +ctx_registry = triton.utils.id_dict() + class function(metaclass = function_meta): @staticmethod @@ -31,7 +34,9 @@ class function(metaclass = function_meta): class TorchFunction(fw.torch.autograd.Function): @staticmethod def forward(ctx, *targs, **tkwargs): - return cls.forward(ctx, *targs, **tkwargs) + y = cls.forward(ctx, *targs, **tkwargs) + ctx_registry[y] = ctx + return y @staticmethod def backward(ctx, grad_output): return cls.backward(ctx, grad_output) @@ -43,6 +48,7 @@ class function(metaclass = function_meta): result = cls.forward(ctx, *args, **kwargs) id = result.op.get_attr('id') cls.contexts[id] = ctx + ctx_registry[result] = ctx # register backward name = result.op.op_def.name if not cls.registered: diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 3a71d0ecd..57e0afc13 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -177,37 +177,7 @@ def _make_grid(args) : return grid -class bench_dict: - - # Lazy entry for e.g., tensorflow, when value of benchmark is - # not known at graph compile time - class lazy_entry: - def __init__(self, id): - self.id = id - - def get(self): - return libtriton.retrieve_scalar(self.id) - - - def __init__(self): - self.data = dict() - - def __delitem__(self, key): - del self.data[id(key)] - - def __getitem__(self, key): - ret = self.data[id(key)] - if isinstance(ret, bench_dict.lazy_entry): - return ret.get() - return ret - - def __len__(self): - return len(self.data) - - def __setitem__(self, key, value): - self.data[id(key)] = value - -bench_registry = bench_dict() +bench_registry = triton.utils.id_dict() class kernel: @@ -233,7 +203,7 @@ class kernel: defines.append((k, values)) opt = libtriton.options_space() opt.defines = defines - opt.num_warps = [2, 4, 8] + opt.num_warps = [4] # create unique id for this op op_id = libtriton.make_op_id() self.fw_id[key] = op_id @@ -257,7 +227,7 @@ class kernel: bench_id = libtriton.make_scalar_id() if bench > 0 else 0 ret = self.fw_op(*op_args, id=op_id, bench=bench, bench_id=bench_id) if bench > 0: - bench_registry[ret] = bench_dict.lazy_entry(bench_id) + bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id) elif fw.has_torch(): args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in op_args] diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index d6207c194..7bbc18f5a 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -7,10 +7,14 @@ import math class _einsum(triton.function): src = """ -void einsum_(TYPE * A, TYPE * B, TYPE * C, +void einsumk(TYPE * A, TYPE * B, TYPE * C, int dim_M, int dim_N, int dim_K, - int std_A0, int std_B0, int std_C0, - int std_A1, int std_B1, int std_C1) { + int std_A0 __multipleof(8), + int std_B0 __multipleof(8), + int std_C0 __multipleof(8), + int std_A1 __multipleof(8), + int std_B1 __multipleof(8), + int std_C1 __multipleof(8)) { // program id int pgm = get_program_id(0); int pgn = get_program_id(1); @@ -21,7 +25,7 @@ void einsum_(TYPE * A, TYPE * B, TYPE * C, int rb[TB] = pgb * TB + 0 ... TB; int rk[TK] = 0 ... TK; // accumulator - TYPE c[TM, TN, TB] = 0; + float c[TM, TN, TB] = 0; // pointers to a TYPE *pa[SHAPE_A] = A + rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM @@ -51,7 +55,7 @@ void einsum_(TYPE * A, TYPE * B, TYPE * C, bool checkn[TN] = rn < dim_N; bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] && checkn[newaxis, :, newaxis]; - *?(checkc)pc = c; + *?(checkc)pc = (TYPE[TM, TN, TB])c; } """ @@ -164,15 +168,14 @@ void einsum_(TYPE * A, TYPE * B, TYPE * C, TM = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[1]) + 1 ))))] TN = [2**i for i in range(5, max(6, min(8, int(math.log2(bmnk[2]) + 1 ))))] TB = [2**i for i in range(0, max(1, min(3, int(math.log2(bmnk[0]) + 1 ))))] - print(TM) - print(TN) + TK = [bmnk[2]] if bmnk[2] < 16 else [8, 16] return _einsum.kernel(a, b, c, bmnk[1], bmnk[2], bmnk[3], std0[0], std0[1], std0[2], std1[0], std1[1], std1[2], grid, bench=bench, **macros, - TYPE='float', TM=TM, TN=TN, TK=8, TB=TB) + TYPE=dtype, TM=TM, TN=TN, TK=TK, TB=TB) @staticmethod @@ -195,6 +198,7 @@ void einsum_(TYPE * A, TYPE * B, TYPE * C, ctx.einsum_b = einsum_b ctx.einsum_c = einsum_c ctx.bench = bench + ctx.bmnk = bmnk return _einsum.call(a, b, ta, tb, shape_c, bmnk, std0, std1, einsum_a, einsum_b, einsum_c, bench) diff --git a/python/triton/utils.py b/python/triton/utils.py index d5f5f4129..5b832668f 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -89,3 +89,32 @@ class scalar: return -self.get_value() +class id_dict: + + # Lazy entry for e.g., tensorflow, when value of benchmark is + # not known at graph compile time + class lazy_entry: + def __init__(self, id): + self.id = id + + def get(self): + return libtriton.retrieve_scalar(self.id) + + + def __init__(self): + self.data = dict() + + def __delitem__(self, key): + del self.data[id(key)] + + def __getitem__(self, key): + ret = self.data[id(key)] + if isinstance(ret, id_dict.lazy_entry): + return ret.get() + return ret + + def __len__(self): + return len(self.data) + + def __setitem__(self, key, value): + self.data[id(key)] = value \ No newline at end of file diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index c87e1c938..876ce0962 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -9,9 +9,9 @@ int main() { // shapes to benchmark typedef std::tuple, bool, bool, int, int, int> config_t; std::vector configs; - for(auto ord: std::vector>{{0, 1}, {1, 0}}) + for(auto ord: std::vector>{{1, 0}}) for(auto x: std::vector>{{false, false}, {false, true}, - {true, false}, {true, true}}){ + {true, false}}){ std::vector tmp = { config_t{ord, x[0], x[1], 2048, 2048, 2048}, // config_t{ord, x[0], x[1], 16, 2048, 2048}, @@ -34,7 +34,7 @@ int main() { for(const auto& c: configs){ std::tie(ord, AT, BT, M, N, K) = c; std::cout << "// " << c << std::flush; - for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord)) + for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord)) std::cout << ", " << perf << std::flush; std::cout << std::endl; } diff --git a/tests/common/dot.h b/tests/common/dot.h index 23bb46c72..a157d7994 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -109,10 +109,10 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT, opt.num_warps = {nwarp}; } if(mode == BENCH) { - opt.defines.push_back({"TM", {"64", "128"}}); - opt.defines.push_back({"TN", {"64", "128"}}); - opt.defines.push_back({"TK", {"8"}}); - opt.num_warps = {2, 4, 8}; + opt.defines.push_back({"TM", {"128"}}); + opt.defines.push_back({"TN", {"128"}}); + opt.defines.push_back({"TK", {"16"}}); + opt.num_warps = {4}; } // kernels diff --git a/tests/common/src/dot.h b/tests/common/src/dot.h index 05ed68a7b..7c368e593 100644 --- a/tests/common/src/dot.h +++ b/tests/common/src/dot.h @@ -23,8 +23,8 @@ void dot(TYPE * A, TYPE * B, TYPE * C, // reduction loop for(int k = K; k > 0; k-= TK){ c += USEA @ USEB; - pa = pa + TK * STRIDE_AK; - pb = pb + TK * STRIDE_BK; + pa += TK * STRIDE_AK; + pb += TK * STRIDE_BK; bool checka[SHAPE_A] = k > TK; bool checkb[SHAPE_B] = k > TK; a = checka ? *pa : 0;