[PYTHON][KERNEL] Enforcing shapes to be known at compile-time for
TensorFlow Graph Execution
This commit is contained in:
@@ -5,30 +5,28 @@ def run_tf():
|
|||||||
M, N, K = 2048, 2048, 2048
|
M, N, K = 2048, 2048, 2048
|
||||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||||
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True, bench=1)
|
triton_c = triton.ops.dot(a, b, False, True, 1)
|
||||||
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False, bench=1)
|
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
|
||||||
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)
|
triton_y = tf.math.reduce_mean(triton_d)
|
||||||
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False)
|
fw_c = tf.matmul(a, b, False, True)
|
||||||
|
fw_d = tf.matmul(fw_c, b, True, False)
|
||||||
|
fw_y = tf.math.reduce_mean(fw_d)
|
||||||
# Gradient
|
# Gradient
|
||||||
tr_da = tf.gradients(tr_d, [a])
|
triton_da, triton_db = tf.gradients(triton_y, [a, b])
|
||||||
tf_da = tf.gradients(tf_d, [a])
|
fw_da, fw_db = tf.gradients(fw_y, [a, b])
|
||||||
# Reference
|
# Reference
|
||||||
ha = np.random.rand(M, K).astype(np.float32)
|
feed_dict = {a: np.random.rand(M, K).astype(np.float32),
|
||||||
hb = np.random.rand(K, N).astype(np.float32)
|
b: np.random.rand(K, N).astype(np.float32)}
|
||||||
# Run
|
|
||||||
sess = tf.InteractiveSession()
|
sess = tf.InteractiveSession()
|
||||||
sess.run(tf.global_variables_initializer())
|
sess.run(tf.global_variables_initializer())
|
||||||
result = sess.run([tr_da, tf_da], feed_dict = {a: ha,
|
result = sess.run([triton_da, fw_da, triton_db, fw_db, fw_y, triton_y], feed_dict = feed_dict)
|
||||||
b: hb})
|
triton_da, fw_da = result[0][0], result[1][0]
|
||||||
|
triton_db, fw_db = result[2][0], result[3][0]
|
||||||
# Benchmark
|
# Benchmark
|
||||||
nanosec = triton.bench_registry[tr_d]
|
nanosec = triton.bench_registry[triton_d]
|
||||||
print('NANOSEC: ', nanosec)
|
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||||
#print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
print('Diff DA:', (triton_da - fw_da).max())
|
||||||
# Test
|
print('Diff DB:', (triton_db - fw_db).max())
|
||||||
print(result[0][0])
|
|
||||||
print(result[1][0])
|
|
||||||
dif = np.abs(result[0][0] - result[1][0])
|
|
||||||
print("dif: %f" % np.max(dif))
|
|
||||||
|
|
||||||
|
|
||||||
def run_torch():
|
def run_torch():
|
||||||
@@ -41,7 +39,7 @@ def run_torch():
|
|||||||
torch_c = torch.matmul(a, torch.t(b))
|
torch_c = torch.matmul(a, torch.t(b))
|
||||||
torch_d = torch.matmul(torch.t(torch_c), b)
|
torch_d = torch.matmul(torch.t(torch_c), b)
|
||||||
torch_y = torch.mean(torch_d)
|
torch_y = torch.mean(torch_d)
|
||||||
triton_c = triton.ops.dot(a, b, False, True)
|
triton_c = triton.ops.dot(a, b, False, True, 1)
|
||||||
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
|
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
|
||||||
triton_y = torch.mean(triton_d)
|
triton_y = torch.mean(triton_d)
|
||||||
# torch gradient
|
# torch gradient
|
||||||
@@ -56,7 +54,6 @@ def run_torch():
|
|||||||
triton_db = b.grad.clone()
|
triton_db = b.grad.clone()
|
||||||
|
|
||||||
nanosec = triton.bench_registry[triton_d]
|
nanosec = triton.bench_registry[triton_d]
|
||||||
print(nanosec)
|
|
||||||
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||||
print('Diff DA:', (torch_da - triton_da).max())
|
print('Diff DA:', (torch_da - triton_da).max())
|
||||||
print('Diff DB:', (torch_db - triton_db).max())
|
print('Diff DB:', (torch_db - triton_db).max())
|
||||||
|
@@ -53,11 +53,11 @@ class ProdKeyTest(tf.test.TestCase):
|
|||||||
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
|
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
|
||||||
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
|
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
|
||||||
|
|
||||||
a = tf.placeholder(tf.float16, a_shape, name="a")
|
a = tf.placeholder(tf.float32, a_shape, name="a")
|
||||||
b = tf.placeholder(tf.float16, b_shape, name="b")
|
b = tf.placeholder(tf.float32, b_shape, name="b")
|
||||||
e = tf.placeholder(tf.float16, c_shape, name="e")
|
e = tf.placeholder(tf.float32, c_shape, name="e")
|
||||||
feed_dict = { a: A.astype(np.float16),
|
feed_dict = { a: A.astype(np.float32),
|
||||||
b: B.astype(np.float16),
|
b: B.astype(np.float32),
|
||||||
e: E }
|
e: E }
|
||||||
|
|
||||||
c = triton.ops.einsum(einsum, a, b, bench=bench)
|
c = triton.ops.einsum(einsum, a, b, bench=bench)
|
||||||
|
@@ -156,7 +156,7 @@ void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vect
|
|||||||
os << " };\n ";
|
os << " };\n ";
|
||||||
os << " run();";
|
os << " run();";
|
||||||
os << " if(bench_ > 0)\n ";
|
os << " if(bench_ > 0)\n ";
|
||||||
os << " i64scalar_map[id_] = triton::tools::bench(run, stream);\n ";
|
os << " i64scalar_map[bench_id_] = triton::tools::bench(run, stream);\n ";
|
||||||
}
|
}
|
||||||
|
|
||||||
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
|
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||||
@@ -186,6 +186,7 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
|
|||||||
os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl;
|
os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl;
|
||||||
os << " .Input(\"" << name << ": T" << i << "\")\n";
|
os << " .Input(\"" << name << ": T" << i << "\")\n";
|
||||||
}
|
}
|
||||||
|
std::vector<int> out_idx;
|
||||||
for(size_t i = 0; i < outputs.size(); i++){
|
for(size_t i = 0; i < outputs.size(); i++){
|
||||||
std::string name = outputs[i];
|
std::string name = outputs[i];
|
||||||
size_t idx;
|
size_t idx;
|
||||||
@@ -194,11 +195,19 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
|
|||||||
break;
|
break;
|
||||||
if(idx == args.size())
|
if(idx == args.size())
|
||||||
throw std::runtime_error("unknown output");
|
throw std::runtime_error("unknown output");
|
||||||
os << " .Output(\"out" << i << ": T" << idx << "\")\n";
|
out_idx.push_back(idx);
|
||||||
}
|
}
|
||||||
|
for(size_t i = 0; i < out_idx.size(); i++)
|
||||||
|
os << " .Output(\"out" << i << ": T" << out_idx[i] << "\")\n";
|
||||||
os << " .Attr(\"id: int\")\n";
|
os << " .Attr(\"id: int\")\n";
|
||||||
os << " .Attr(\"bench: int\")\n";
|
os << " .Attr(\"bench: int\")\n";
|
||||||
os << " .Attr(\"bench_id: int\")\n";
|
os << " .Attr(\"bench_id: int\")\n";
|
||||||
|
os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {\n";
|
||||||
|
for(size_t i = 0; i < out_idx.size(); i++)
|
||||||
|
os << " c->set_output(" << i << ", c->input(" << out_idx[i] << "));\n";
|
||||||
|
os << " return Status::OK();\n";
|
||||||
|
os << " })\n";
|
||||||
|
|
||||||
os << ";\n";
|
os << ";\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -313,7 +322,7 @@ oss << R"(
|
|||||||
private:
|
private:
|
||||||
int id_;
|
int id_;
|
||||||
int bench_;
|
int bench_;
|
||||||
int bench_id_;
|
int64 bench_id_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// register kernel builder
|
// register kernel builder
|
||||||
@@ -397,6 +406,7 @@ void gen_torch_signature(std::ostringstream& oss,
|
|||||||
oss << ret_ty << " " << name << "(";
|
oss << ret_ty << " " << name << "(";
|
||||||
oss << "int64_t id, ";
|
oss << "int64_t id, ";
|
||||||
oss << "int64_t bench, ";
|
oss << "int64_t bench, ";
|
||||||
|
oss << "int64_t bench_id, ";
|
||||||
for(size_t i = 0; i < args.size(); i++) {
|
for(size_t i = 0; i < args.size(); i++) {
|
||||||
ir::argument* arg = args[i];
|
ir::argument* arg = args[i];
|
||||||
if(i > 0)
|
if(i > 0)
|
||||||
@@ -453,7 +463,7 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argu
|
|||||||
os << " };\n ";
|
os << " };\n ";
|
||||||
os << " run();";
|
os << " run();";
|
||||||
os << " if(bench > 0)\n ";
|
os << " if(bench > 0)\n ";
|
||||||
os << " i64scalar_map[id] = triton::tools::bench(run, stream);\n ";
|
os << " i64scalar_map[bench_id] = triton::tools::bench(run, &stream);\n ";
|
||||||
}
|
}
|
||||||
|
|
||||||
void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
|
void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
|
||||||
using namespace tensorflow;
|
using namespace tensorflow;
|
||||||
|
|
||||||
@@ -28,4 +29,10 @@ REGISTER_OP("AllocEmpty")
|
|||||||
.Input("x: int32")
|
.Input("x: int32")
|
||||||
.Attr("T : {bool, int8, int16, int32, int64, float16, float32, float64}")
|
.Attr("T : {bool, int8, int16, int32, int64, float16, float32, float64}")
|
||||||
.Output("y: T")
|
.Output("y: T")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
shape_inference::ShapeHandle handle;
|
||||||
|
c->MakeShapeFromShapeTensor(0, &handle);
|
||||||
|
c->set_output(0, handle);
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
;
|
;
|
||||||
|
@@ -5,6 +5,7 @@ import triton._C.libtriton as libtriton
|
|||||||
torch = None
|
torch = None
|
||||||
tensorflow = None
|
tensorflow = None
|
||||||
tf_extra_ops = None
|
tf_extra_ops = None
|
||||||
|
gen_resource_variable_ops = None
|
||||||
|
|
||||||
def _import_torch():
|
def _import_torch():
|
||||||
global torch
|
global torch
|
||||||
@@ -13,8 +14,10 @@ def _import_torch():
|
|||||||
|
|
||||||
def _import_tensorflow():
|
def _import_tensorflow():
|
||||||
global tensorflow
|
global tensorflow
|
||||||
|
global gen_resource_variable_ops
|
||||||
if tensorflow is None:
|
if tensorflow is None:
|
||||||
import tensorflow
|
import tensorflow
|
||||||
|
from tensorflow.python.ops import gen_resource_variable_ops
|
||||||
|
|
||||||
def _import_tf_extra_ops():
|
def _import_tf_extra_ops():
|
||||||
global tf_extra_ops
|
global tf_extra_ops
|
||||||
|
@@ -13,7 +13,6 @@ class OpContext(object):
|
|||||||
class function_meta(type):
|
class function_meta(type):
|
||||||
|
|
||||||
def __init__(cls, name, bases, attrs):
|
def __init__(cls, name, bases, attrs):
|
||||||
cls.contexts = dict()
|
|
||||||
cls.registered = False
|
cls.registered = False
|
||||||
return super(function_meta, cls).__init__(name, bases, attrs)
|
return super(function_meta, cls).__init__(name, bases, attrs)
|
||||||
|
|
||||||
@@ -45,17 +44,20 @@ class function(metaclass = function_meta):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def apply_tensorflow(cls, *args, **kwargs):
|
def apply_tensorflow(cls, *args, **kwargs):
|
||||||
ctx = OpContext()
|
ctx = OpContext()
|
||||||
result = cls.forward(ctx, *args, **kwargs)
|
# Acquire a mutex here to ensure that calls to alloc_empty()
|
||||||
id = result.op.get_attr('id')
|
# are handled properly
|
||||||
cls.contexts[id] = ctx
|
mutex = fw.gen_resource_variable_ops.mutex_v2()
|
||||||
|
lock = fw.gen_resource_variable_ops.mutex_lock(mutex)
|
||||||
|
with fw.tensorflow.python.ops.control_dependencies([lock]):
|
||||||
|
result = cls.forward(ctx, *args, **kwargs)
|
||||||
ctx_registry[result] = ctx
|
ctx_registry[result] = ctx
|
||||||
# register backward
|
# register backward
|
||||||
name = result.op.op_def.name
|
name = result.op.op_def.name
|
||||||
if not cls.registered:
|
if not cls.registered:
|
||||||
@fw.tensorflow.RegisterGradient(name)
|
@fw.tensorflow.RegisterGradient(name)
|
||||||
def gradient(op, dy):
|
def gradient(op, dy):
|
||||||
id = op.get_attr('id')
|
with fw.tensorflow.control_dependencies([op]):
|
||||||
return cls.backward(cls.contexts[id], dy)
|
return cls.backward(ctx_registry[op.outputs[0]], dy)
|
||||||
cls.registered = True
|
cls.registered = True
|
||||||
# return result tensor
|
# return result tensor
|
||||||
return result
|
return result
|
||||||
|
@@ -220,18 +220,18 @@ class kernel:
|
|||||||
op_id = self.fw_id[key]
|
op_id = self.fw_id[key]
|
||||||
# register grid
|
# register grid
|
||||||
libtriton.register_grid(op_id, _make_grid(args))
|
libtriton.register_grid(op_id, _make_grid(args))
|
||||||
|
# id for the benchmark result
|
||||||
|
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
|
||||||
# create operands
|
# create operands
|
||||||
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
|
|
||||||
# call framework function
|
# call framework function
|
||||||
if fw.has_tensorflow():
|
if fw.has_tensorflow():
|
||||||
bench_id = libtriton.make_scalar_id() if bench > 0 else 0
|
args = [x for x in args[:-1]]
|
||||||
ret = self.fw_op(*op_args, id=op_id, bench=bench, bench_id=bench_id)
|
ret = self.fw_op(*args, id=op_id, bench=bench, bench_id=bench_id)
|
||||||
if bench > 0:
|
if bench > 0:
|
||||||
bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id)
|
bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id)
|
||||||
|
|
||||||
elif fw.has_torch():
|
elif fw.has_torch():
|
||||||
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in op_args]
|
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]]
|
||||||
ret = self.fw_op(op_id, bench, *args)
|
ret = self.fw_op(op_id, bench, bench_id, *args)
|
||||||
if bench > 0:
|
if bench > 0:
|
||||||
bench_registry[ret] = libtriton.retrieve_scalar(op_id)
|
bench_registry[ret] = libtriton.retrieve_scalar(op_id)
|
||||||
else:
|
else:
|
||||||
|
@@ -40,7 +40,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
kernel = triton.kernel(src, ['C'])
|
kernel = triton.kernel(src, ['C'])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _call(a, b, transpose_a, transpose_b, bench = 0):
|
def _call(a, b, transpose_a, transpose_b, bench):
|
||||||
# extract shapes
|
# extract shapes
|
||||||
shape_a = triton.shape(a)
|
shape_a = triton.shape(a)
|
||||||
shape_b = triton.shape(b)
|
shape_b = triton.shape(b)
|
||||||
@@ -86,24 +86,26 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
ctx.save_for_backward(a, b)
|
ctx.save_for_backward(a, b)
|
||||||
ctx.t_a = transpose_a
|
ctx.t_a = transpose_a
|
||||||
ctx.t_b = transpose_b
|
ctx.t_b = transpose_b
|
||||||
|
ctx.bench = bench
|
||||||
return _dot._call(a, b, transpose_a, transpose_b, bench)
|
return _dot._call(a, b, transpose_a, transpose_b, bench)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, dy):
|
def backward(ctx, dy):
|
||||||
a, b = ctx.saved_tensors
|
a, b = ctx.saved_tensors
|
||||||
t_a, t_b = ctx.t_a, ctx.t_b
|
t_a, t_b = ctx.t_a, ctx.t_b
|
||||||
|
bench = ctx.bench
|
||||||
if not t_a and not t_b:
|
if not t_a and not t_b:
|
||||||
da = _dot._call(dy, b, False, True)
|
da = _dot._call(dy, b, False, True, bench)
|
||||||
db = _dot._call(a, dy, True, False)
|
db = _dot._call(a, dy, True, False, bench)
|
||||||
elif not t_a and t_b:
|
elif not t_a and t_b:
|
||||||
da = _dot._call(dy, b, False, False)
|
da = _dot._call(dy, b, False, False, bench)
|
||||||
db = _dot._call(dy, a, True, False)
|
db = _dot._call(dy, a, True, False, bench)
|
||||||
elif t_a and not t_b:
|
elif t_a and not t_b:
|
||||||
da = _dot._call(b, dy, False, True)
|
da = _dot._call(b, dy, False, True, bench)
|
||||||
db = _dot._call(a, dy, False, False)
|
db = _dot._call(a, dy, False, False, bench)
|
||||||
elif t_a and t_b:
|
elif t_a and t_b:
|
||||||
da = _dot._call(b, dy, True, True)
|
da = _dot._call(b, dy, True, True, bench)
|
||||||
db = _dot._call(dy, a, True, True)
|
db = _dot._call(dy, a, True, True, bench)
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
return da, db, None, None, None, None, None, None, None
|
return da, db, None, None, None, None, None, None, None
|
||||||
|
@@ -1,13 +1,22 @@
|
|||||||
import triton.frameworks as fw
|
import triton.frameworks as fw
|
||||||
import triton._C.libtriton as libtriton
|
import triton._C.libtriton as libtriton
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
def cdiv(a, b):
|
def cdiv(a, b):
|
||||||
return -(-a // b)
|
return -(-a // b)
|
||||||
|
|
||||||
|
class tf_empty_proxy:
|
||||||
|
|
||||||
|
def __init__(self, args, dtype):
|
||||||
|
self.args = args
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
def empty(shapes, dtype):
|
def empty(shapes, dtype):
|
||||||
if fw.has_tensorflow():
|
if fw.has_tensorflow():
|
||||||
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
|
#return fw.tensorflow.Variable(np.empty(shapes),shape=shapes, dtype=dtype)
|
||||||
|
args = [x.handle if isinstance(x, scalar) else fw.tensorflow.constant(x) for x in shapes]
|
||||||
args = fw.tensorflow.stack(args)
|
args = fw.tensorflow.stack(args)
|
||||||
|
#return tf_empty_proxy(args, dtype)
|
||||||
return fw.tf_extra_ops.alloc_empty(args, T = dtype)
|
return fw.tf_extra_ops.alloc_empty(args, T = dtype)
|
||||||
elif fw.has_torch():
|
elif fw.has_torch():
|
||||||
return fw.torch.empty(*shapes).cuda()
|
return fw.torch.empty(*shapes).cuda()
|
||||||
|
Reference in New Issue
Block a user