[PYTHON][KERNEL] Enforcing shapes to be known at compile-time for

TensorFlow Graph Execution
This commit is contained in:
Philippe Tillet
2019-10-28 17:12:37 -04:00
parent e9c787ef05
commit 448f4433d9
9 changed files with 82 additions and 52 deletions

View File

@@ -5,30 +5,28 @@ def run_tf():
M, N, K = 2048, 2048, 2048 M, N, K = 2048, 2048, 2048
a = tf.placeholder(tf.float32, shape=[M, K]) a = tf.placeholder(tf.float32, shape=[M, K])
b = tf.placeholder(tf.float32, shape=[N, K]) b = tf.placeholder(tf.float32, shape=[N, K])
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True, bench=1) triton_c = triton.ops.dot(a, b, False, True, 1)
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False, bench=1) triton_d = triton.ops.dot(triton_c, b, True, False, 1)
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True) triton_y = tf.math.reduce_mean(triton_d)
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False) fw_c = tf.matmul(a, b, False, True)
fw_d = tf.matmul(fw_c, b, True, False)
fw_y = tf.math.reduce_mean(fw_d)
# Gradient # Gradient
tr_da = tf.gradients(tr_d, [a]) triton_da, triton_db = tf.gradients(triton_y, [a, b])
tf_da = tf.gradients(tf_d, [a]) fw_da, fw_db = tf.gradients(fw_y, [a, b])
# Reference # Reference
ha = np.random.rand(M, K).astype(np.float32) feed_dict = {a: np.random.rand(M, K).astype(np.float32),
hb = np.random.rand(K, N).astype(np.float32) b: np.random.rand(K, N).astype(np.float32)}
# Run
sess = tf.InteractiveSession() sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
result = sess.run([tr_da, tf_da], feed_dict = {a: ha, result = sess.run([triton_da, fw_da, triton_db, fw_db, fw_y, triton_y], feed_dict = feed_dict)
b: hb}) triton_da, fw_da = result[0][0], result[1][0]
triton_db, fw_db = result[2][0], result[3][0]
# Benchmark # Benchmark
nanosec = triton.bench_registry[tr_d] nanosec = triton.bench_registry[triton_d]
print('NANOSEC: ', nanosec) print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
#print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3) print('Diff DA:', (triton_da - fw_da).max())
# Test print('Diff DB:', (triton_db - fw_db).max())
print(result[0][0])
print(result[1][0])
dif = np.abs(result[0][0] - result[1][0])
print("dif: %f" % np.max(dif))
def run_torch(): def run_torch():
@@ -41,7 +39,7 @@ def run_torch():
torch_c = torch.matmul(a, torch.t(b)) torch_c = torch.matmul(a, torch.t(b))
torch_d = torch.matmul(torch.t(torch_c), b) torch_d = torch.matmul(torch.t(torch_c), b)
torch_y = torch.mean(torch_d) torch_y = torch.mean(torch_d)
triton_c = triton.ops.dot(a, b, False, True) triton_c = triton.ops.dot(a, b, False, True, 1)
triton_d = triton.ops.dot(triton_c, b, True, False, 1) triton_d = triton.ops.dot(triton_c, b, True, False, 1)
triton_y = torch.mean(triton_d) triton_y = torch.mean(triton_d)
# torch gradient # torch gradient
@@ -56,7 +54,6 @@ def run_torch():
triton_db = b.grad.clone() triton_db = b.grad.clone()
nanosec = triton.bench_registry[triton_d] nanosec = triton.bench_registry[triton_d]
print(nanosec)
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3) print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
print('Diff DA:', (torch_da - triton_da).max()) print('Diff DA:', (torch_da - triton_da).max())
print('Diff DB:', (torch_db - triton_db).max()) print('Diff DB:', (torch_db - triton_db).max())

View File

@@ -53,11 +53,11 @@ class ProdKeyTest(tf.test.TestCase):
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32) B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32) E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
a = tf.placeholder(tf.float16, a_shape, name="a") a = tf.placeholder(tf.float32, a_shape, name="a")
b = tf.placeholder(tf.float16, b_shape, name="b") b = tf.placeholder(tf.float32, b_shape, name="b")
e = tf.placeholder(tf.float16, c_shape, name="e") e = tf.placeholder(tf.float32, c_shape, name="e")
feed_dict = { a: A.astype(np.float16), feed_dict = { a: A.astype(np.float32),
b: B.astype(np.float16), b: B.astype(np.float32),
e: E } e: E }
c = triton.ops.einsum(einsum, a, b, bench=bench) c = triton.ops.einsum(einsum, a, b, bench=bench)

View File

@@ -156,7 +156,7 @@ void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vect
os << " };\n "; os << " };\n ";
os << " run();"; os << " run();";
os << " if(bench_ > 0)\n "; os << " if(bench_ > 0)\n ";
os << " i64scalar_map[id_] = triton::tools::bench(run, stream);\n "; os << " i64scalar_map[bench_id_] = triton::tools::bench(run, stream);\n ";
} }
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name, void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
@@ -186,6 +186,7 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl; os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl;
os << " .Input(\"" << name << ": T" << i << "\")\n"; os << " .Input(\"" << name << ": T" << i << "\")\n";
} }
std::vector<int> out_idx;
for(size_t i = 0; i < outputs.size(); i++){ for(size_t i = 0; i < outputs.size(); i++){
std::string name = outputs[i]; std::string name = outputs[i];
size_t idx; size_t idx;
@@ -194,11 +195,19 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
break; break;
if(idx == args.size()) if(idx == args.size())
throw std::runtime_error("unknown output"); throw std::runtime_error("unknown output");
os << " .Output(\"out" << i << ": T" << idx << "\")\n"; out_idx.push_back(idx);
} }
for(size_t i = 0; i < out_idx.size(); i++)
os << " .Output(\"out" << i << ": T" << out_idx[i] << "\")\n";
os << " .Attr(\"id: int\")\n"; os << " .Attr(\"id: int\")\n";
os << " .Attr(\"bench: int\")\n"; os << " .Attr(\"bench: int\")\n";
os << " .Attr(\"bench_id: int\")\n"; os << " .Attr(\"bench_id: int\")\n";
os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {\n";
for(size_t i = 0; i < out_idx.size(); i++)
os << " c->set_output(" << i << ", c->input(" << out_idx[i] << "));\n";
os << " return Status::OK();\n";
os << " })\n";
os << ";\n"; os << ";\n";
} }
@@ -313,7 +322,7 @@ oss << R"(
private: private:
int id_; int id_;
int bench_; int bench_;
int bench_id_; int64 bench_id_;
}; };
// register kernel builder // register kernel builder
@@ -397,6 +406,7 @@ void gen_torch_signature(std::ostringstream& oss,
oss << ret_ty << " " << name << "("; oss << ret_ty << " " << name << "(";
oss << "int64_t id, "; oss << "int64_t id, ";
oss << "int64_t bench, "; oss << "int64_t bench, ";
oss << "int64_t bench_id, ";
for(size_t i = 0; i < args.size(); i++) { for(size_t i = 0; i < args.size(); i++) {
ir::argument* arg = args[i]; ir::argument* arg = args[i];
if(i > 0) if(i > 0)
@@ -453,7 +463,7 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argu
os << " };\n "; os << " };\n ";
os << " run();"; os << " run();";
os << " if(bench > 0)\n "; os << " if(bench > 0)\n ";
os << " i64scalar_map[id] = triton::tools::bench(run, stream);\n "; os << " i64scalar_map[bench_id] = triton::tools::bench(run, &stream);\n ";
} }
void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) { void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {

View File

@@ -1,4 +1,5 @@
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow; using namespace tensorflow;
@@ -28,4 +29,10 @@ REGISTER_OP("AllocEmpty")
.Input("x: int32") .Input("x: int32")
.Attr("T : {bool, int8, int16, int32, int64, float16, float32, float64}") .Attr("T : {bool, int8, int16, int32, int64, float16, float32, float64}")
.Output("y: T") .Output("y: T")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle handle;
c->MakeShapeFromShapeTensor(0, &handle);
c->set_output(0, handle);
return Status::OK();
});
; ;

View File

@@ -5,6 +5,7 @@ import triton._C.libtriton as libtriton
torch = None torch = None
tensorflow = None tensorflow = None
tf_extra_ops = None tf_extra_ops = None
gen_resource_variable_ops = None
def _import_torch(): def _import_torch():
global torch global torch
@@ -13,8 +14,10 @@ def _import_torch():
def _import_tensorflow(): def _import_tensorflow():
global tensorflow global tensorflow
global gen_resource_variable_ops
if tensorflow is None: if tensorflow is None:
import tensorflow import tensorflow
from tensorflow.python.ops import gen_resource_variable_ops
def _import_tf_extra_ops(): def _import_tf_extra_ops():
global tf_extra_ops global tf_extra_ops

View File

@@ -13,7 +13,6 @@ class OpContext(object):
class function_meta(type): class function_meta(type):
def __init__(cls, name, bases, attrs): def __init__(cls, name, bases, attrs):
cls.contexts = dict()
cls.registered = False cls.registered = False
return super(function_meta, cls).__init__(name, bases, attrs) return super(function_meta, cls).__init__(name, bases, attrs)
@@ -45,17 +44,20 @@ class function(metaclass = function_meta):
@classmethod @classmethod
def apply_tensorflow(cls, *args, **kwargs): def apply_tensorflow(cls, *args, **kwargs):
ctx = OpContext() ctx = OpContext()
result = cls.forward(ctx, *args, **kwargs) # Acquire a mutex here to ensure that calls to alloc_empty()
id = result.op.get_attr('id') # are handled properly
cls.contexts[id] = ctx mutex = fw.gen_resource_variable_ops.mutex_v2()
lock = fw.gen_resource_variable_ops.mutex_lock(mutex)
with fw.tensorflow.python.ops.control_dependencies([lock]):
result = cls.forward(ctx, *args, **kwargs)
ctx_registry[result] = ctx ctx_registry[result] = ctx
# register backward # register backward
name = result.op.op_def.name name = result.op.op_def.name
if not cls.registered: if not cls.registered:
@fw.tensorflow.RegisterGradient(name) @fw.tensorflow.RegisterGradient(name)
def gradient(op, dy): def gradient(op, dy):
id = op.get_attr('id') with fw.tensorflow.control_dependencies([op]):
return cls.backward(cls.contexts[id], dy) return cls.backward(ctx_registry[op.outputs[0]], dy)
cls.registered = True cls.registered = True
# return result tensor # return result tensor
return result return result

View File

@@ -220,18 +220,18 @@ class kernel:
op_id = self.fw_id[key] op_id = self.fw_id[key]
# register grid # register grid
libtriton.register_grid(op_id, _make_grid(args)) libtriton.register_grid(op_id, _make_grid(args))
# id for the benchmark result
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
# create operands # create operands
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
# call framework function # call framework function
if fw.has_tensorflow(): if fw.has_tensorflow():
bench_id = libtriton.make_scalar_id() if bench > 0 else 0 args = [x for x in args[:-1]]
ret = self.fw_op(*op_args, id=op_id, bench=bench, bench_id=bench_id) ret = self.fw_op(*args, id=op_id, bench=bench, bench_id=bench_id)
if bench > 0: if bench > 0:
bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id) bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id)
elif fw.has_torch(): elif fw.has_torch():
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in op_args] args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]]
ret = self.fw_op(op_id, bench, *args) ret = self.fw_op(op_id, bench, bench_id, *args)
if bench > 0: if bench > 0:
bench_registry[ret] = libtriton.retrieve_scalar(op_id) bench_registry[ret] = libtriton.retrieve_scalar(op_id)
else: else:

View File

@@ -40,7 +40,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
kernel = triton.kernel(src, ['C']) kernel = triton.kernel(src, ['C'])
@staticmethod @staticmethod
def _call(a, b, transpose_a, transpose_b, bench = 0): def _call(a, b, transpose_a, transpose_b, bench):
# extract shapes # extract shapes
shape_a = triton.shape(a) shape_a = triton.shape(a)
shape_b = triton.shape(b) shape_b = triton.shape(b)
@@ -86,24 +86,26 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
ctx.save_for_backward(a, b) ctx.save_for_backward(a, b)
ctx.t_a = transpose_a ctx.t_a = transpose_a
ctx.t_b = transpose_b ctx.t_b = transpose_b
ctx.bench = bench
return _dot._call(a, b, transpose_a, transpose_b, bench) return _dot._call(a, b, transpose_a, transpose_b, bench)
@staticmethod @staticmethod
def backward(ctx, dy): def backward(ctx, dy):
a, b = ctx.saved_tensors a, b = ctx.saved_tensors
t_a, t_b = ctx.t_a, ctx.t_b t_a, t_b = ctx.t_a, ctx.t_b
bench = ctx.bench
if not t_a and not t_b: if not t_a and not t_b:
da = _dot._call(dy, b, False, True) da = _dot._call(dy, b, False, True, bench)
db = _dot._call(a, dy, True, False) db = _dot._call(a, dy, True, False, bench)
elif not t_a and t_b: elif not t_a and t_b:
da = _dot._call(dy, b, False, False) da = _dot._call(dy, b, False, False, bench)
db = _dot._call(dy, a, True, False) db = _dot._call(dy, a, True, False, bench)
elif t_a and not t_b: elif t_a and not t_b:
da = _dot._call(b, dy, False, True) da = _dot._call(b, dy, False, True, bench)
db = _dot._call(a, dy, False, False) db = _dot._call(a, dy, False, False, bench)
elif t_a and t_b: elif t_a and t_b:
da = _dot._call(b, dy, True, True) da = _dot._call(b, dy, True, True, bench)
db = _dot._call(dy, a, True, True) db = _dot._call(dy, a, True, True, bench)
else: else:
assert False assert False
return da, db, None, None, None, None, None, None, None return da, db, None, None, None, None, None, None, None

View File

@@ -1,13 +1,22 @@
import triton.frameworks as fw import triton.frameworks as fw
import triton._C.libtriton as libtriton import triton._C.libtriton as libtriton
import numpy as np
def cdiv(a, b): def cdiv(a, b):
return -(-a // b) return -(-a // b)
class tf_empty_proxy:
def __init__(self, args, dtype):
self.args = args
self.dtype = dtype
def empty(shapes, dtype): def empty(shapes, dtype):
if fw.has_tensorflow(): if fw.has_tensorflow():
args = [x.handle if isinstance(x, scalar) else x for x in shapes] #return fw.tensorflow.Variable(np.empty(shapes),shape=shapes, dtype=dtype)
args = [x.handle if isinstance(x, scalar) else fw.tensorflow.constant(x) for x in shapes]
args = fw.tensorflow.stack(args) args = fw.tensorflow.stack(args)
#return tf_empty_proxy(args, dtype)
return fw.tf_extra_ops.alloc_empty(args, T = dtype) return fw.tf_extra_ops.alloc_empty(args, T = dtype)
elif fw.has_torch(): elif fw.has_torch():
return fw.torch.empty(*shapes).cuda() return fw.torch.empty(*shapes).cuda()