diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index 23ca7d1e0..539de8684 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -8,6 +8,7 @@ #include #include #include +#include // codegen #include "triton/ir/context.h" #include "triton/codegen/target.h" @@ -110,6 +111,7 @@ private: std::string src_; options_space_t opt_space_; std::map cache_; + std::mutex src_mutex_; }; } diff --git a/lib/lang/cpp.cc b/lib/lang/cpp.cc index 308eba1e6..2cdfb453a 100644 --- a/lib/lang/cpp.cc +++ b/lib/lang/cpp.cc @@ -9,8 +9,6 @@ #include -extern std::string filename_in; -extern std::string filename_out; using DirectiveMap = std::unordered_map; diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index bc55d65eb..a7072b757 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -29,9 +29,9 @@ #include "triton/ir/print.h" #include "triton/tools/bench.hpp" #include "llvm/IR/Module.h" +#include - - +std::mutex mut; namespace triton{ namespace runtime { @@ -168,7 +168,6 @@ function::caller function::autotune(driver::stream* stream, const grid_fn_ty& gr for(auto it: opt_space_.defines) cpp.AddMacro(it.first, &opt.defines.at(it.first)); cpp.Process(tokens); -// tokens.Print(stdout); // parse Parser parser(tokens); parser.Parse(); @@ -309,7 +308,10 @@ void function::operator()(const std::vector& args, const grid_fn_ty& grid_f } /* re-tune and re-compile */ - cache_.insert({key, autotune(stream, grid_fn, args)}); + { + std::lock_guard lock(mut); + cache_.insert({key, autotune(stream, grid_fn, args)}); + } } void function::operator()(const std::vector& args, const grid_t& grid, driver::stream *stream) { diff --git a/python/examples/einsum.py b/python/examples/einsum.py index 5585cc9b6..a8ec95435 100644 --- a/python/examples/einsum.py +++ b/python/examples/einsum.py @@ -1,38 +1,92 @@ -import numpy as np -import torch +#!/usr/bin/env python + +import numpy as np +from enum import Enum import triton -batch_dim = 16 -ctx_dim = 32 -head_dim = 8 -state_dim = 32 -key_dim = 32 -n_keys = 32 -bs = batch_dim * ctx_dim +class MODE(Enum): + TF = 1 + TORCH = 2 -# shapes -x_shape = (bs, state_dim) -qw_shape = (state_dim, head_dim * key_dim) -kw_shape = (head_dim, 2, n_keys, key_dim // 2) +try: + import tensorflow as tf + mode = MODE.TF +except ModuleNotFoundError: + pass -np.random.seed(0) -x = np.random.uniform(-1.0, 1.0, x_shape).astype(np.float32) # layer input -qw = np.random.uniform(-1.0, 1.0, qw_shape).astype(np.float32) # query weights -kw = np.random.uniform(-1.0, 1.0, kw_shape).astype(np.float32) # key weights -# (bs, head_dim * key_dim) = (bs, state_dim) * (state_dim, head_dim * key_dim) -# (bs, head_dim, 2, key_dim//2) <== (bs, head_dim * key_dim) -q = np.dot(x, qw).reshape(bs, head_dim, 2, key_dim//2) # normal matmul +try: + import torch + mode = MODE.TORCH +except ModuleNotFoundError: + pass -# (bs, head_dim, 2, n_keys) = (bs, head_dim, 2, key_dim//2) * (head_dim, 2, n_keys, key_dim//2) -# outer: bs, n_keys -# inner: key_dim//2 -# batch: head_dim, 2 (key_axis) -qk = np.einsum("bhak,hank->bhan", q, kw) +cases = [] +# Matmul +cases += [[[4, 1024, 1024], [1024, 1024], [4, 1024, 1024], "btc,ck->btk"]] +# Attention +cases += [[[4, 256, 8, 2, 64], [8, 2, 512, 64], [4, 256, 8, 2, 512], "bchak,hank->bchan"]] -tq = torch.from_numpy(q).contiguous().cuda() -tkw = torch.from_numpy(kw).contiguous().cuda() -tqk = triton.ops.einsum("bhak,hank->bhan", tq, tkw) -diff = np.abs(qk - tqk.cpu().numpy()) -print(np.max(diff)) -print(np.min(diff)) +if mode == MODE.TF: + sess = tf.InteractiveSession() +for a_shape, b_shape, c_shape, einsum in cases: + + A = np.random.uniform(-1.0, 1.0, a_shape).astype(np.float16).astype(np.float32) + B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32) + E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32) + + # Execute (tensorflow) + if mode == MODE.TF: + a = tf.placeholder(tf.float32, a_shape, name="a") + b = tf.placeholder(tf.float32, b_shape, name="b") + e = tf.placeholder(tf.float32, c_shape, name="e") + c = triton.ops.einsum(einsum, a, b, 1) + da, db = tf.gradients(c, [a, b], e) + feed_dict = { a: A.astype(np.float32), + b: B.astype(np.float32), + e: E } + sess.run(tf.global_variables_initializer()) + result = sess.run([c, da, db], feed_dict = feed_dict) + # Execute (torch) + if mode == MODE.TORCH: + a = torch.from_numpy(A).cuda() + b = torch.from_numpy(B).cuda() + e = torch.from_numpy(E).cuda() + a.requires_grad_(True) + b.requires_grad_(True) + c = triton.ops.einsum(einsum, a, b, 1) + torch.autograd.backward(c, e) + da = a.grad + db = b.grad + result = [c.cpu().detach().numpy(), da.cpu().detach().numpy(), db.cpu().detach().numpy()] + + # benchmark + nanosec = triton.bench_registry[c] + ctx = triton.ctx_registry[c] + b, m, n, k = tuple((ctx.bmnk[i] for i in range(0, 4))) + ops = 2.*b*m*n*k + print('C TFLOPS:', ops / triton.bench_registry[c] * 1e-3) + #print('DA TFLOPS:', ops / triton.bench_registry[da] * 1e-3) + #print('DB TFLOPS:', ops / triton.bench_registry[db] * 1e-3) + + # test + ctx = triton.ctx_registry[c] + t_a = ctx.trans_a + t_b = ctx.trans_b + e_a = ctx.einsum_a + e_b = ctx.einsum_b + e_c = ctx.einsum_c + C = np.einsum(einsum, A, B) + if not t_a and not t_b: # NN + DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B) + DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E) + elif not t_a and t_b: # NT + DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B) + DB = np.einsum(f"{e_c},{e_a}->{e_b}", E, A) + elif t_a and not t_b: # TN + DA = np.einsum(f"{e_b},{e_c}->{e_a}", B, E) + DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E) + c, da, db = result[0], result[1], result[2] + print('C diff:', np.abs((C - c)).max()) + print('DA diff:', np.abs((DA - da)).max()) + print('DB diff:', np.abs((DB - db)).max()) \ No newline at end of file diff --git a/python/examples/einsum_test.py b/python/examples/einsum_test.py deleted file mode 100644 index b09f46cab..000000000 --- a/python/examples/einsum_test.py +++ /dev/null @@ -1,139 +0,0 @@ -#!/usr/bin/env python - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf -import triton -import blocksparse as bs -from tensorflow.python.ops import gradient_checker - -one = 0 -out = 0 -bench = 0 - -class ProdKeyTest(tf.test.TestCase): - - def testEinsum(self): - # multi-threading screws up benchmarking - conf = tf.ConfigProto( - intra_op_parallelism_threads=1, - inter_op_parallelism_threads=1) - - with self.test_session(config=conf) as sess, tf.device("/gpu:0"): - - batch_dim = 4 - ctx_dim = 256 - head_dim = 8 - n_keys = 512 - key_dim = 128 - - # batch_dim = 2 - # ctx_dim = 8 - # head_dim = 2 - # n_keys = 16 - # key_dim = 16 - - for a_shape, b_shape, c_shape, einsum in [ - [ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ], - [ [4, 1024, 1024], [ 1024, 1024 ], [4, 1024, 1024 ], "btc,ck->btk" ], - [ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ], - ]: - - if one: - A = np.ones(a_shape, dtype=np.float16).astype(np.float32) - B = np.ones(b_shape, dtype=np.float16).astype(np.float32) - E = np.ones(c_shape, dtype=np.float32) - else: - # QK = np.random.normal(loc=0.0, scale=1.0, size=qk_shape).astype(np.float16).astype(np.float32) - # V = np.random.normal(loc=0.0, scale=1.0, size=vw_shape).astype(np.float16).astype(np.float32) - A = np.random.uniform(-1.0, 1.0, a_shape).astype(np.float16).astype(np.float32) - B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32) - E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32) - - a = tf.placeholder(tf.float32, a_shape, name="a") - b = tf.placeholder(tf.float32, b_shape, name="b") - e = tf.placeholder(tf.float32, c_shape, name="e") - feed_dict = { a: A.astype(np.float32), - b: B.astype(np.float32), - e: E } - - c = triton.ops.einsum(einsum, a, b, bench=bench) - - # error = gradient_checker.compute_gradient_error(a, a_shape, c, c_shape, delta=1e-1, extra_feed_dict={ b:B }) # - # print(error) - # error = gradient_checker.compute_gradient_error(b, b_shape, c, c_shape, delta=1e-1, extra_feed_dict={ a:A }) # - # print(error) - # return - - with tf.control_dependencies([c.op]): - da, db = tf.gradients(c, [a, b], e) - - # c, = sess.run( [ c, ], feed_dict ) - rc, rda, rdb = sess.run( [ c, da, db ], feed_dict ) - - if bench > 0: - nanosec = triton.bench_registry[c] - ctx = triton.ctx_registry[c] - b, m, n, k = tuple((ctx.bmnk[i] for i in range(0, 4))) - ops = 2. * b * m * n * k - print('C TFLOPS:', ops / triton.bench_registry[c] * 1e-3) - print('DA TFLOPS:', ops / triton.bench_registry[da] * 1e-3) - print('DB TFLOPS:', ops / triton.bench_registry[db] * 1e-3) - - else: - C = np.einsum(einsum, A, B) - ctx = triton.ctx_registry[c] - t_a = ctx.trans_a - t_b = ctx.trans_b - e_a = ctx.einsum_a - e_b = ctx.einsum_b - e_c = ctx.einsum_c - - if not t_a and not t_b: # NN - DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B) - DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E) - elif not t_a and t_b: # NT - DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B) - DB = np.einsum(f"{e_c},{e_a}->{e_b}", E, A) - elif t_a and not t_b: # TN - DA = np.einsum(f"{e_b},{e_c}->{e_a}", B, E) - DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E) - - print("testProdKey", einsum) - if not bench: - for op, dev, cpu in [ - [ "C", rc, C ], - [ "DA", rda, DA ], - [ "DB", rdb, DB ], - ]: - self.compare_results(op, dev, cpu) - - def compare_results(self, op, dev, cpu): - dev = dev.astype(np.float64) - cpu = cpu.astype(np.float64) - - # print(dev.reshape(-1)[0:4]) - # print(cpu.reshape(-1)[0:4]) - - dif = np.abs(cpu - dev) - maxval = np.max(abs(cpu)) - avgval = np.average(abs(cpu)) - maxdif = dif.max() - max_err = maxdif if avgval == 0 else maxdif / avgval - l2_err = 0.0 if avgval == 0 else np.sqrt(np.square(dif).sum()) / np.sqrt(np.square(cpu).sum()) - - print("op:%3s, max:%18.12f, avg:%18.12f, dif:%18.12f, err:%18.12f, l2_err:%18.12f shape:%15s" % (op, maxval, avgval, maxdif, max_err, l2_err, str(cpu.shape))) - - if out: - dim = cpu.shape[-1] - np.savetxt("%s_dif.txt" % op, dif.reshape((-1,dim)), fmt='%4.1f') #7.5 5.3 - np.savetxt("%s_cpu.txt" % op, cpu.reshape((-1,dim)), fmt='%4.1f') #7.5 5.3 - np.savetxt("%s_dev.txt" % op, dev.reshape((-1,dim)), fmt='%4.1f') #7.5 5.3 - exit() - -if __name__ == "__main__": - tf.test.main() - diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index 7bbc18f5a..f91a178ad 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -179,7 +179,7 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C, @staticmethod - def forward(ctx, subscripts, a, b, **kwargs): + def forward(ctx, subscripts, a, b, bench = 0): ctx.save_for_backward(a, b) if type(subscripts) is str: einsum_a, einsum_bc = subscripts.split(",") @@ -189,9 +189,7 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C, shape_c, bmnk, std0, std1, ta, tb = _einsum._parse_einsum( einsum_a, einsum_b, einsum_c, - triton.shape(a), triton.shape(b) - ) - bench = kwargs['bench'] if 'bench' in kwargs else 0 + triton.shape(a), triton.shape(b)) ctx.trans_a = ta ctx.trans_b = tb ctx.einsum_a = einsum_a @@ -213,20 +211,20 @@ void einsumk(TYPE * A, TYPE * B, TYPE * C, bench = ctx.bench if not trans_a and not trans_b: # NN - da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench=bench) - db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench=bench) + da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench) + db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench) elif not trans_a and trans_b: # NT - da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench=bench) - db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench=bench) + da = einsum((einsum_c, einsum_b, einsum_a), dc, b, bench) + db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench) elif trans_a and not trans_b: # TN - da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench=bench) - db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench=bench) + da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench) + db = einsum((einsum_a, einsum_c, einsum_b), a, dc, bench) elif trans_a and trans_b: # TT (not used) - da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench=bench) - db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench=bench) + da = einsum((einsum_b, einsum_c, einsum_a), b, dc, bench) + db = einsum((einsum_c, einsum_a, einsum_b), dc, a, bench) return da, db, None, None, None, None, None, None, None, None, None, None