diff --git a/lib/codegen/selection/generator.cc b/lib/codegen/selection/generator.cc index 03f393069..1ff4287eb 100644 --- a/lib/codegen/selection/generator.cc +++ b/lib/codegen/selection/generator.cc @@ -187,12 +187,12 @@ generator::generator(analysis::axes *a_axes, void generator::visit_value(ir::value* v) { - std::cout << "visiting " << typeid(*v).name() << std::endl; if(!seen_.insert(v).second) return; // create machine tile - if(v->get_type()->is_tile_ty()) + if(v->get_type()->is_tile_ty()){ tmap_[v] = machine_layouts_.at(layouts_->get(v))->create(v); + } // visit operands BasicBlock *current = builder_->GetInsertBlock(); auto *inst = dynamic_cast(v); diff --git a/lib/codegen/transform/disassociate.cc b/lib/codegen/transform/disassociate.cc index 70134b186..1134463ec 100644 --- a/lib/codegen/transform/disassociate.cc +++ b/lib/codegen/transform/disassociate.cc @@ -10,67 +10,62 @@ namespace codegen{ namespace transform{ void extract_retile_chain(ir::user *root, - const std::vector& current, - std::vector>& result, + std::map>& result, + int depth, std::set& seen) { if(!seen.insert(root).second) return; - if(dynamic_cast(root) || dynamic_cast(root)){ - std::vector next = current; - next.push_back(root); - result.push_back(next); + result[depth].insert(root); + if(dynamic_cast(root) || + dynamic_cast(root)){ return; } for(ir::value *op: root->ops()){ ir::user *u = dynamic_cast(op); if(!u) continue; - std::vector next = current; - next.push_back(u); - extract_retile_chain(u, next, result, seen); + extract_retile_chain(u, result, depth + 1, seen); } } void disassociate::run(ir::module &mod) { ir::builder &bld = mod.get_builder(); - std::map>> clone_info; + std::map>> clone_info; ir::for_each_instruction(mod, [&](ir::instruction *i){ if(dynamic_cast(i)){ - std::vector> chains; + std::map> chains; std::set seen; if(!dynamic_cast(i->get_operand(0))) return; - extract_retile_chain(i, {}, chains, seen); + extract_retile_chain(i, chains, 0, seen); if(chains.size()) clone_info[i] = chains; } }); - - for(auto x: clone_info){ - for(auto chain: x.second){ - for(int i = 0; i < chain.size(); i++) { - ir::instruction *y = (ir::instruction*)chain[i]; + for(const auto& x: clone_info){ + int depth = 1; + std::map clone_map; + while(x.second.find(depth) != x.second.end()){ + // clone all users + const auto& remat = x.second.at(depth); + for(ir::user* u: remat){ + ir::instruction *y = (ir::instruction*)u; ir::instruction *cloned = y->clone(); bld.set_insert_point(y); bld.insert(cloned); - if(i > 0) - chain[i-1]->replace_uses_of_with(y, cloned); - else + clone_map[y] = cloned; + // replace in above level + if(depth > 1){ + for(ir::user* ux: x.second.at(depth - 1)) + clone_map.at((ir::instruction*)ux)->replace_uses_of_with(y, cloned); + } + else{ x.first->replace_uses_of_with(y, cloned); + } } - - -// ir::instruction *y = (ir::instruction*)parent; -// for(ir::user *u: chain){ -// ir::instruction *cloned = y->clone(); -// bld.set_insert_point(y); -// bld.insert(cloned); -// std::cout << typeid(*u).name() << std::endl; -// u->replace_uses_of_with(y, cloned); -// y = (ir::instruction*)u; -// } + depth += 1; } } diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index b37dbf332..bc55d65eb 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -221,9 +221,17 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c codegen::transform::cts cts; codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps); // run passes - std::cout << "begin" << std::endl; - disassociate.run(module); +// ir::print(module, std::cout); dce.run(module); +// ir::print(module, std::cout); + + disassociate.run(module); + +// ir::print(module, std::cout); + + dce.run(module); +// ir::print(module, std::cout); + peephole.run(module); dce.run(module); align.run(module); @@ -245,10 +253,10 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c if(allocation.allocated_size() > context->device()->max_shared_memory()) return std::unique_ptr(); barriers.run(module); - std::cout << "isel" << std::endl; +// std::cout << "isel" << std::endl; // ir::print(module, std::cout); isel.visit(module, *llvm); - std::cout << "done" << std::endl; +// std::cout << "done" << std::endl; // return binary std::unique_ptr res(driver::module::create(context, std::move(llvm))); // done diff --git a/python/examples/einsum_test.py b/python/examples/einsum_test.py new file mode 100644 index 000000000..799efbf70 --- /dev/null +++ b/python/examples/einsum_test.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +import triton +import blocksparse as bs +from tensorflow.python.ops import gradient_checker + +one = 0 +out = 0 +bench = 0 +class ProdKeyTest(tf.test.TestCase): + + def testEinsum(self): + # multi-threading screws up benchmarking + conf = tf.ConfigProto( + intra_op_parallelism_threads=1, + inter_op_parallelism_threads=1) + + with self.test_session(config=conf) as sess, tf.device("/gpu:0"): + + batch_dim = 4 + ctx_dim = 256 + head_dim = 8 + n_keys = 512 + key_dim = 128 + + # batch_dim = 2 + # ctx_dim = 8 + # head_dim = 2 + # n_keys = 16 + # key_dim = 16 + + for a_shape, b_shape, c_shape, einsum in [ + [ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ], + [ [ 4, 1024, 1024 ], [ 1024, 512 ], [ 4, 1024, 512 ], "btc,ck->btk" ], + [ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ], + ]: + + if one: + A = np.ones(a_shape, dtype=np.float32) + B = np.ones(b_shape, dtype=np.float32) + E = np.ones(c_shape, dtype=np.float32) + else: + # QK = np.random.normal(loc=0.0, scale=1.0, size=qk_shape).astype(np.float16).astype(np.float32) + # V = np.random.normal(loc=0.0, scale=1.0, size=vw_shape).astype(np.float16).astype(np.float32) + A = np.random.uniform(-1.0, 1.0, a_shape).astype(np.float16).astype(np.float32) + B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32) + E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32) + + a = tf.placeholder(tf.float32, a_shape, name="a") + b = tf.placeholder(tf.float32, b_shape, name="b") + e = tf.placeholder(tf.float32, c_shape, name="e") + feed_dict = { a:A, b:B, e:E } + + cc = triton.ops.einsum(einsum, a, b) + + # error = gradient_checker.compute_gradient_error(a, a_shape, c, c_shape, delta=1e-1, extra_feed_dict={ b:B }) # + # print(error) + # error = gradient_checker.compute_gradient_error(b, b_shape, c, c_shape, delta=1e-1, extra_feed_dict={ a:A }) # + # print(error) + # return + + with tf.control_dependencies([cc.op]): + da, db = tf.gradients(cc, [a, b], e) + + # c, = sess.run( [ c, ], feed_dict ) + c, da, db = sess.run( [ cc, da, db ], feed_dict ) + + if bench == 0: + + C = np.einsum(einsum, A, B) + id = cc.op.get_attr('id') + ctx = triton.ops._einsum.contexts[id] + t_a = ctx.trans_a + t_b = ctx.trans_b + e_a = ctx.einsum_a + e_b = ctx.einsum_b + e_c = ctx.einsum_c + + if not t_a and not t_b: # NN + DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B) + DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E) + elif not t_a and t_b: # NT + DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B) + DB = np.einsum(f"{e_c},{e_a}->{e_b}", E, A) + elif t_a and not t_b: # TN + DA = np.einsum(f"{e_b},{e_c}->{e_a}", B, E) + DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E) + + print("testProdKey", einsum) + if not bench: + for op, dev, cpu in [ + [ "C", c, C ], + [ "DA", da, DA ], + [ "DB", db, DB ], + ]: + self.compare_results(op, dev, cpu) + + def compare_results(self, op, dev, cpu): + dev = dev.astype(np.float64) + cpu = cpu.astype(np.float64) + + # print(dev.reshape(-1)[0:4]) + # print(cpu.reshape(-1)[0:4]) + + dif = np.abs(cpu - dev) + maxval = np.max(abs(cpu)) + avgval = np.average(abs(cpu)) + maxdif = dif.max() + max_err = maxdif if avgval == 0 else maxdif / avgval + l2_err = 0.0 if avgval == 0 else np.sqrt(np.square(dif).sum()) / np.sqrt(np.square(cpu).sum()) + + print("op:%3s, max:%18.12f, avg:%18.12f, dif:%18.12f, err:%18.12f, l2_err:%18.12f shape:%15s" % (op, maxval, avgval, maxdif, max_err, l2_err, str(cpu.shape))) + + if out: + dim = cpu.shape[-1] + np.savetxt("%s_dif.txt" % op, dif.reshape((-1,dim)), fmt='%4.1f') #7.5 5.3 + np.savetxt("%s_cpu.txt" % op, cpu.reshape((-1,dim)), fmt='%4.1f') #7.5 5.3 + np.savetxt("%s_dev.txt" % op, dev.reshape((-1,dim)), fmt='%4.1f') #7.5 5.3 + exit() + +if __name__ == "__main__": + tf.test.main() + diff --git a/python/triton/ops/einsum.py b/python/triton/ops/einsum.py index d78d9f1a2..7f4457d99 100644 --- a/python/triton/ops/einsum.py +++ b/python/triton/ops/einsum.py @@ -1,3 +1,6 @@ +# Special thanks to Scott Gray from OpenAI for writing the einsum parsing function + + import triton class _einsum(triton.function): @@ -31,14 +34,18 @@ class _einsum(triton.function): TYPE a[SHAPE_A] = *pa; TYPE b[SHAPE_B] = *pb; c += USE_A @ USE_B; - pa += TK; - pb += TK; + pa += TK * STRIDE_AK; + pb += TK * STRIDE_BK; } // write-back TYPE *pc[TM, TN, TB] = C + rm[:, newaxis, newaxis] * std_C1 + rn[newaxis, :, newaxis] * 1 + rb[newaxis, newaxis, :] * std_C0; - *pc = c; + bool checkm[TM] = rm < dim_M; + bool checkn[TN] = rn < dim_N; + bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] && + checkn[newaxis, :, newaxis]; + *?(checkc)pc = c; } """ @@ -141,12 +148,12 @@ class _einsum(triton.function): 'BROADCAST_AM': 'newaxis, :, newaxis' if trans_a else ':, newaxis, newaxis', 'SHAPE_A' : 'TK, TM, TB' if trans_a else 'TM, TK, TB', # handle B transposition - 'USE_B' : 'b[^1, ^0, ^2]' if not trans_b else 'b', + 'USE_B' : 'b' if not trans_b else 'b[^1, ^0, ^2]', 'STRIDE_BK' : 'std_B1' if not trans_b else '1', 'STRIDE_BN' : '1' if not trans_b else 'std_B1', - 'BROADCAST_BK': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis', - 'BROADCAST_BN': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis', - 'SHAPE_B' : 'TN, TK, TB' if not trans_b else 'TK, TN, TB'} + 'BROADCAST_BK': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis', + 'BROADCAST_BN': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis', + 'SHAPE_B' : 'TK, TN, TB' if not trans_b else 'TN, TK, TB'} return _einsum.kernel(a, b, c, bmnk[1], bmnk[2], bmnk[3], std0[0], std0[1], std0[2],