[PYTHON] [OPS] Added einsum implementation

This commit is contained in:
Philippe Tillet
2019-10-26 22:14:50 -04:00
parent 655f43fb5b
commit e11557855f
5 changed files with 183 additions and 44 deletions

View File

@@ -187,12 +187,12 @@ generator::generator(analysis::axes *a_axes,
void generator::visit_value(ir::value* v) {
std::cout << "visiting " << typeid(*v).name() << std::endl;
if(!seen_.insert(v).second)
return;
// create machine tile
if(v->get_type()->is_tile_ty())
if(v->get_type()->is_tile_ty()){
tmap_[v] = machine_layouts_.at(layouts_->get(v))->create(v);
}
// visit operands
BasicBlock *current = builder_->GetInsertBlock();
auto *inst = dynamic_cast<ir::instruction*>(v);

View File

@@ -10,67 +10,62 @@ namespace codegen{
namespace transform{
void extract_retile_chain(ir::user *root,
const std::vector<ir::user*>& current,
std::vector<std::vector<ir::user*>>& result,
std::map<int, std::set<ir::user*>>& result,
int depth,
std::set<ir::value*>& seen) {
if(!seen.insert(root).second)
return;
if(dynamic_cast<ir::make_range*>(root) || dynamic_cast<ir::splat_inst*>(root)){
std::vector<ir::user*> next = current;
next.push_back(root);
result.push_back(next);
result[depth].insert(root);
if(dynamic_cast<ir::make_range*>(root) ||
dynamic_cast<ir::splat_inst*>(root)){
return;
}
for(ir::value *op: root->ops()){
ir::user *u = dynamic_cast<ir::user*>(op);
if(!u)
continue;
std::vector<ir::user*> next = current;
next.push_back(u);
extract_retile_chain(u, next, result, seen);
extract_retile_chain(u, result, depth + 1, seen);
}
}
void disassociate::run(ir::module &mod) {
ir::builder &bld = mod.get_builder();
std::map<ir::user*, std::vector<std::vector<ir::user*>>> clone_info;
std::map<ir::user*, std::map<int, std::set<ir::user*>>> clone_info;
ir::for_each_instruction(mod, [&](ir::instruction *i){
if(dynamic_cast<ir::reshape_inst*>(i)){
std::vector<std::vector<ir::user*>> chains;
std::map<int, std::set<ir::user*>> chains;
std::set<ir::value*> seen;
if(!dynamic_cast<ir::user*>(i->get_operand(0)))
return;
extract_retile_chain(i, {}, chains, seen);
extract_retile_chain(i, chains, 0, seen);
if(chains.size())
clone_info[i] = chains;
}
});
for(auto x: clone_info){
for(auto chain: x.second){
for(int i = 0; i < chain.size(); i++) {
ir::instruction *y = (ir::instruction*)chain[i];
for(const auto& x: clone_info){
int depth = 1;
std::map<ir::instruction*, ir::instruction*> clone_map;
while(x.second.find(depth) != x.second.end()){
// clone all users
const auto& remat = x.second.at(depth);
for(ir::user* u: remat){
ir::instruction *y = (ir::instruction*)u;
ir::instruction *cloned = y->clone();
bld.set_insert_point(y);
bld.insert(cloned);
if(i > 0)
chain[i-1]->replace_uses_of_with(y, cloned);
else
clone_map[y] = cloned;
// replace in above level
if(depth > 1){
for(ir::user* ux: x.second.at(depth - 1))
clone_map.at((ir::instruction*)ux)->replace_uses_of_with(y, cloned);
}
else{
x.first->replace_uses_of_with(y, cloned);
}
}
// ir::instruction *y = (ir::instruction*)parent;
// for(ir::user *u: chain){
// ir::instruction *cloned = y->clone();
// bld.set_insert_point(y);
// bld.insert(cloned);
// std::cout << typeid(*u).name() << std::endl;
// u->replace_uses_of_with(y, cloned);
// y = (ir::instruction*)u;
// }
depth += 1;
}
}

View File

@@ -221,9 +221,17 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
codegen::transform::cts cts;
codegen::generator isel(&axes, &layouts, &align, &allocation, target.get(), opt.num_warps);
// run passes
std::cout << "begin" << std::endl;
disassociate.run(module);
// ir::print(module, std::cout);
dce.run(module);
// ir::print(module, std::cout);
disassociate.run(module);
// ir::print(module, std::cout);
dce.run(module);
// ir::print(module, std::cout);
peephole.run(module);
dce.run(module);
align.run(module);
@@ -245,10 +253,10 @@ std::unique_ptr<driver::module> function::make_bin(ir::module &module, driver::c
if(allocation.allocated_size() > context->device()->max_shared_memory())
return std::unique_ptr<driver::module>();
barriers.run(module);
std::cout << "isel" << std::endl;
// std::cout << "isel" << std::endl;
// ir::print(module, std::cout);
isel.visit(module, *llvm);
std::cout << "done" << std::endl;
// std::cout << "done" << std::endl;
// return binary
std::unique_ptr<driver::module> res(driver::module::create(context, std::move(llvm)));
// done

View File

@@ -0,0 +1,129 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
import triton
import blocksparse as bs
from tensorflow.python.ops import gradient_checker
one = 0
out = 0
bench = 0
class ProdKeyTest(tf.test.TestCase):
def testEinsum(self):
# multi-threading screws up benchmarking
conf = tf.ConfigProto(
intra_op_parallelism_threads=1,
inter_op_parallelism_threads=1)
with self.test_session(config=conf) as sess, tf.device("/gpu:0"):
batch_dim = 4
ctx_dim = 256
head_dim = 8
n_keys = 512
key_dim = 128
# batch_dim = 2
# ctx_dim = 8
# head_dim = 2
# n_keys = 16
# key_dim = 16
for a_shape, b_shape, c_shape, einsum in [
[ [ 4, 8, 8 ], [ 8, 8 ], [ 4, 8, 8 ], "btc,ck->btk" ],
[ [ 4, 1024, 1024 ], [ 1024, 512 ], [ 4, 1024, 512 ], "btc,ck->btk" ],
[ (batch_dim, ctx_dim, head_dim, 2, key_dim//2),(head_dim, 2, n_keys, key_dim//2), (batch_dim, ctx_dim, head_dim, 2, n_keys), "bchak,hank->bchan" ],
]:
if one:
A = np.ones(a_shape, dtype=np.float32)
B = np.ones(b_shape, dtype=np.float32)
E = np.ones(c_shape, dtype=np.float32)
else:
# QK = np.random.normal(loc=0.0, scale=1.0, size=qk_shape).astype(np.float16).astype(np.float32)
# V = np.random.normal(loc=0.0, scale=1.0, size=vw_shape).astype(np.float16).astype(np.float32)
A = np.random.uniform(-1.0, 1.0, a_shape).astype(np.float16).astype(np.float32)
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
a = tf.placeholder(tf.float32, a_shape, name="a")
b = tf.placeholder(tf.float32, b_shape, name="b")
e = tf.placeholder(tf.float32, c_shape, name="e")
feed_dict = { a:A, b:B, e:E }
cc = triton.ops.einsum(einsum, a, b)
# error = gradient_checker.compute_gradient_error(a, a_shape, c, c_shape, delta=1e-1, extra_feed_dict={ b:B }) #
# print(error)
# error = gradient_checker.compute_gradient_error(b, b_shape, c, c_shape, delta=1e-1, extra_feed_dict={ a:A }) #
# print(error)
# return
with tf.control_dependencies([cc.op]):
da, db = tf.gradients(cc, [a, b], e)
# c, = sess.run( [ c, ], feed_dict )
c, da, db = sess.run( [ cc, da, db ], feed_dict )
if bench == 0:
C = np.einsum(einsum, A, B)
id = cc.op.get_attr('id')
ctx = triton.ops._einsum.contexts[id]
t_a = ctx.trans_a
t_b = ctx.trans_b
e_a = ctx.einsum_a
e_b = ctx.einsum_b
e_c = ctx.einsum_c
if not t_a and not t_b: # NN
DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B)
DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E)
elif not t_a and t_b: # NT
DA = np.einsum(f"{e_c},{e_b}->{e_a}", E, B)
DB = np.einsum(f"{e_c},{e_a}->{e_b}", E, A)
elif t_a and not t_b: # TN
DA = np.einsum(f"{e_b},{e_c}->{e_a}", B, E)
DB = np.einsum(f"{e_a},{e_c}->{e_b}", A, E)
print("testProdKey", einsum)
if not bench:
for op, dev, cpu in [
[ "C", c, C ],
[ "DA", da, DA ],
[ "DB", db, DB ],
]:
self.compare_results(op, dev, cpu)
def compare_results(self, op, dev, cpu):
dev = dev.astype(np.float64)
cpu = cpu.astype(np.float64)
# print(dev.reshape(-1)[0:4])
# print(cpu.reshape(-1)[0:4])
dif = np.abs(cpu - dev)
maxval = np.max(abs(cpu))
avgval = np.average(abs(cpu))
maxdif = dif.max()
max_err = maxdif if avgval == 0 else maxdif / avgval
l2_err = 0.0 if avgval == 0 else np.sqrt(np.square(dif).sum()) / np.sqrt(np.square(cpu).sum())
print("op:%3s, max:%18.12f, avg:%18.12f, dif:%18.12f, err:%18.12f, l2_err:%18.12f shape:%15s" % (op, maxval, avgval, maxdif, max_err, l2_err, str(cpu.shape)))
if out:
dim = cpu.shape[-1]
np.savetxt("%s_dif.txt" % op, dif.reshape((-1,dim)), fmt='%4.1f') #7.5 5.3
np.savetxt("%s_cpu.txt" % op, cpu.reshape((-1,dim)), fmt='%4.1f') #7.5 5.3
np.savetxt("%s_dev.txt" % op, dev.reshape((-1,dim)), fmt='%4.1f') #7.5 5.3
exit()
if __name__ == "__main__":
tf.test.main()

View File

@@ -1,3 +1,6 @@
# Special thanks to Scott Gray from OpenAI for writing the einsum parsing function
import triton
class _einsum(triton.function):
@@ -31,14 +34,18 @@ class _einsum(triton.function):
TYPE a[SHAPE_A] = *pa;
TYPE b[SHAPE_B] = *pb;
c += USE_A @ USE_B;
pa += TK;
pb += TK;
pa += TK * STRIDE_AK;
pb += TK * STRIDE_BK;
}
// write-back
TYPE *pc[TM, TN, TB] = C + rm[:, newaxis, newaxis] * std_C1
+ rn[newaxis, :, newaxis] * 1
+ rb[newaxis, newaxis, :] * std_C0;
*pc = c;
bool checkm[TM] = rm < dim_M;
bool checkn[TN] = rn < dim_N;
bool checkc[TM, TN, TB] = checkm[:, newaxis, newaxis] &&
checkn[newaxis, :, newaxis];
*?(checkc)pc = c;
}
"""
@@ -141,12 +148,12 @@ class _einsum(triton.function):
'BROADCAST_AM': 'newaxis, :, newaxis' if trans_a else ':, newaxis, newaxis',
'SHAPE_A' : 'TK, TM, TB' if trans_a else 'TM, TK, TB',
# handle B transposition
'USE_B' : 'b[^1, ^0, ^2]' if not trans_b else 'b',
'USE_B' : 'b' if not trans_b else 'b[^1, ^0, ^2]',
'STRIDE_BK' : 'std_B1' if not trans_b else '1',
'STRIDE_BN' : '1' if not trans_b else 'std_B1',
'BROADCAST_BK': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis',
'BROADCAST_BN': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis',
'SHAPE_B' : 'TN, TK, TB' if not trans_b else 'TK, TN, TB'}
'BROADCAST_BK': ':, newaxis, newaxis' if not trans_b else 'newaxis, :, newaxis',
'BROADCAST_BN': 'newaxis, :, newaxis' if not trans_b else ':, newaxis, newaxis',
'SHAPE_B' : 'TK, TN, TB' if not trans_b else 'TN, TK, TB'}
return _einsum.kernel(a, b, c,
bmnk[1], bmnk[2], bmnk[3],
std0[0], std0[1], std0[2],