[PYTHON][KERNEL] Enforcing shapes to be known at compile-time for

TensorFlow Graph Execution
This commit is contained in:
Philippe Tillet
2019-10-28 17:12:37 -04:00
parent e9c787ef05
commit 448f4433d9
9 changed files with 82 additions and 52 deletions

View File

@@ -5,30 +5,28 @@ def run_tf():
M, N, K = 2048, 2048, 2048
a = tf.placeholder(tf.float32, shape=[M, K])
b = tf.placeholder(tf.float32, shape=[N, K])
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True, bench=1)
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False, bench=1)
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False)
triton_c = triton.ops.dot(a, b, False, True, 1)
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
triton_y = tf.math.reduce_mean(triton_d)
fw_c = tf.matmul(a, b, False, True)
fw_d = tf.matmul(fw_c, b, True, False)
fw_y = tf.math.reduce_mean(fw_d)
# Gradient
tr_da = tf.gradients(tr_d, [a])
tf_da = tf.gradients(tf_d, [a])
triton_da, triton_db = tf.gradients(triton_y, [a, b])
fw_da, fw_db = tf.gradients(fw_y, [a, b])
# Reference
ha = np.random.rand(M, K).astype(np.float32)
hb = np.random.rand(K, N).astype(np.float32)
# Run
feed_dict = {a: np.random.rand(M, K).astype(np.float32),
b: np.random.rand(K, N).astype(np.float32)}
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
result = sess.run([tr_da, tf_da], feed_dict = {a: ha,
b: hb})
result = sess.run([triton_da, fw_da, triton_db, fw_db, fw_y, triton_y], feed_dict = feed_dict)
triton_da, fw_da = result[0][0], result[1][0]
triton_db, fw_db = result[2][0], result[3][0]
# Benchmark
nanosec = triton.bench_registry[tr_d]
print('NANOSEC: ', nanosec)
#print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
# Test
print(result[0][0])
print(result[1][0])
dif = np.abs(result[0][0] - result[1][0])
print("dif: %f" % np.max(dif))
nanosec = triton.bench_registry[triton_d]
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
print('Diff DA:', (triton_da - fw_da).max())
print('Diff DB:', (triton_db - fw_db).max())
def run_torch():
@@ -41,7 +39,7 @@ def run_torch():
torch_c = torch.matmul(a, torch.t(b))
torch_d = torch.matmul(torch.t(torch_c), b)
torch_y = torch.mean(torch_d)
triton_c = triton.ops.dot(a, b, False, True)
triton_c = triton.ops.dot(a, b, False, True, 1)
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
triton_y = torch.mean(triton_d)
# torch gradient
@@ -56,7 +54,6 @@ def run_torch():
triton_db = b.grad.clone()
nanosec = triton.bench_registry[triton_d]
print(nanosec)
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
print('Diff DA:', (torch_da - triton_da).max())
print('Diff DB:', (torch_db - triton_db).max())

View File

@@ -53,11 +53,11 @@ class ProdKeyTest(tf.test.TestCase):
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
a = tf.placeholder(tf.float16, a_shape, name="a")
b = tf.placeholder(tf.float16, b_shape, name="b")
e = tf.placeholder(tf.float16, c_shape, name="e")
feed_dict = { a: A.astype(np.float16),
b: B.astype(np.float16),
a = tf.placeholder(tf.float32, a_shape, name="a")
b = tf.placeholder(tf.float32, b_shape, name="b")
e = tf.placeholder(tf.float32, c_shape, name="e")
feed_dict = { a: A.astype(np.float32),
b: B.astype(np.float32),
e: E }
c = triton.ops.einsum(einsum, a, b, bench=bench)

View File

@@ -156,7 +156,7 @@ void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vect
os << " };\n ";
os << " run();";
os << " if(bench_ > 0)\n ";
os << " i64scalar_map[id_] = triton::tools::bench(run, stream);\n ";
os << " i64scalar_map[bench_id_] = triton::tools::bench(run, stream);\n ";
}
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
@@ -186,6 +186,7 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl;
os << " .Input(\"" << name << ": T" << i << "\")\n";
}
std::vector<int> out_idx;
for(size_t i = 0; i < outputs.size(); i++){
std::string name = outputs[i];
size_t idx;
@@ -194,11 +195,19 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
break;
if(idx == args.size())
throw std::runtime_error("unknown output");
os << " .Output(\"out" << i << ": T" << idx << "\")\n";
out_idx.push_back(idx);
}
for(size_t i = 0; i < out_idx.size(); i++)
os << " .Output(\"out" << i << ": T" << out_idx[i] << "\")\n";
os << " .Attr(\"id: int\")\n";
os << " .Attr(\"bench: int\")\n";
os << " .Attr(\"bench_id: int\")\n";
os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {\n";
for(size_t i = 0; i < out_idx.size(); i++)
os << " c->set_output(" << i << ", c->input(" << out_idx[i] << "));\n";
os << " return Status::OK();\n";
os << " })\n";
os << ";\n";
}
@@ -313,7 +322,7 @@ oss << R"(
private:
int id_;
int bench_;
int bench_id_;
int64 bench_id_;
};
// register kernel builder
@@ -397,6 +406,7 @@ void gen_torch_signature(std::ostringstream& oss,
oss << ret_ty << " " << name << "(";
oss << "int64_t id, ";
oss << "int64_t bench, ";
oss << "int64_t bench_id, ";
for(size_t i = 0; i < args.size(); i++) {
ir::argument* arg = args[i];
if(i > 0)
@@ -453,7 +463,7 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argu
os << " };\n ";
os << " run();";
os << " if(bench > 0)\n ";
os << " i64scalar_map[id] = triton::tools::bench(run, stream);\n ";
os << " i64scalar_map[bench_id] = triton::tools::bench(run, &stream);\n ";
}
void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {

View File

@@ -1,4 +1,5 @@
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow;
@@ -28,4 +29,10 @@ REGISTER_OP("AllocEmpty")
.Input("x: int32")
.Attr("T : {bool, int8, int16, int32, int64, float16, float32, float64}")
.Output("y: T")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle handle;
c->MakeShapeFromShapeTensor(0, &handle);
c->set_output(0, handle);
return Status::OK();
});
;

View File

@@ -5,6 +5,7 @@ import triton._C.libtriton as libtriton
torch = None
tensorflow = None
tf_extra_ops = None
gen_resource_variable_ops = None
def _import_torch():
global torch
@@ -13,8 +14,10 @@ def _import_torch():
def _import_tensorflow():
global tensorflow
global gen_resource_variable_ops
if tensorflow is None:
import tensorflow
from tensorflow.python.ops import gen_resource_variable_ops
def _import_tf_extra_ops():
global tf_extra_ops

View File

@@ -13,7 +13,6 @@ class OpContext(object):
class function_meta(type):
def __init__(cls, name, bases, attrs):
cls.contexts = dict()
cls.registered = False
return super(function_meta, cls).__init__(name, bases, attrs)
@@ -45,17 +44,20 @@ class function(metaclass = function_meta):
@classmethod
def apply_tensorflow(cls, *args, **kwargs):
ctx = OpContext()
result = cls.forward(ctx, *args, **kwargs)
id = result.op.get_attr('id')
cls.contexts[id] = ctx
# Acquire a mutex here to ensure that calls to alloc_empty()
# are handled properly
mutex = fw.gen_resource_variable_ops.mutex_v2()
lock = fw.gen_resource_variable_ops.mutex_lock(mutex)
with fw.tensorflow.python.ops.control_dependencies([lock]):
result = cls.forward(ctx, *args, **kwargs)
ctx_registry[result] = ctx
# register backward
name = result.op.op_def.name
if not cls.registered:
@fw.tensorflow.RegisterGradient(name)
def gradient(op, dy):
id = op.get_attr('id')
return cls.backward(cls.contexts[id], dy)
with fw.tensorflow.control_dependencies([op]):
return cls.backward(ctx_registry[op.outputs[0]], dy)
cls.registered = True
# return result tensor
return result

View File

@@ -220,18 +220,18 @@ class kernel:
op_id = self.fw_id[key]
# register grid
libtriton.register_grid(op_id, _make_grid(args))
# id for the benchmark result
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
# create operands
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
# call framework function
if fw.has_tensorflow():
bench_id = libtriton.make_scalar_id() if bench > 0 else 0
ret = self.fw_op(*op_args, id=op_id, bench=bench, bench_id=bench_id)
args = [x for x in args[:-1]]
ret = self.fw_op(*args, id=op_id, bench=bench, bench_id=bench_id)
if bench > 0:
bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id)
elif fw.has_torch():
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in op_args]
ret = self.fw_op(op_id, bench, *args)
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]]
ret = self.fw_op(op_id, bench, bench_id, *args)
if bench > 0:
bench_registry[ret] = libtriton.retrieve_scalar(op_id)
else:

View File

@@ -40,7 +40,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
kernel = triton.kernel(src, ['C'])
@staticmethod
def _call(a, b, transpose_a, transpose_b, bench = 0):
def _call(a, b, transpose_a, transpose_b, bench):
# extract shapes
shape_a = triton.shape(a)
shape_b = triton.shape(b)
@@ -86,24 +86,26 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
ctx.save_for_backward(a, b)
ctx.t_a = transpose_a
ctx.t_b = transpose_b
ctx.bench = bench
return _dot._call(a, b, transpose_a, transpose_b, bench)
@staticmethod
def backward(ctx, dy):
a, b = ctx.saved_tensors
t_a, t_b = ctx.t_a, ctx.t_b
bench = ctx.bench
if not t_a and not t_b:
da = _dot._call(dy, b, False, True)
db = _dot._call(a, dy, True, False)
da = _dot._call(dy, b, False, True, bench)
db = _dot._call(a, dy, True, False, bench)
elif not t_a and t_b:
da = _dot._call(dy, b, False, False)
db = _dot._call(dy, a, True, False)
da = _dot._call(dy, b, False, False, bench)
db = _dot._call(dy, a, True, False, bench)
elif t_a and not t_b:
da = _dot._call(b, dy, False, True)
db = _dot._call(a, dy, False, False)
da = _dot._call(b, dy, False, True, bench)
db = _dot._call(a, dy, False, False, bench)
elif t_a and t_b:
da = _dot._call(b, dy, True, True)
db = _dot._call(dy, a, True, True)
da = _dot._call(b, dy, True, True, bench)
db = _dot._call(dy, a, True, True, bench)
else:
assert False
return da, db, None, None, None, None, None, None, None

View File

@@ -1,13 +1,22 @@
import triton.frameworks as fw
import triton._C.libtriton as libtriton
import numpy as np
def cdiv(a, b):
return -(-a // b)
class tf_empty_proxy:
def __init__(self, args, dtype):
self.args = args
self.dtype = dtype
def empty(shapes, dtype):
if fw.has_tensorflow():
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
#return fw.tensorflow.Variable(np.empty(shapes),shape=shapes, dtype=dtype)
args = [x.handle if isinstance(x, scalar) else fw.tensorflow.constant(x) for x in shapes]
args = fw.tensorflow.stack(args)
#return tf_empty_proxy(args, dtype)
return fw.tf_extra_ops.alloc_empty(args, T = dtype)
elif fw.has_torch():
return fw.torch.empty(*shapes).cuda()