[PYTHON][KERNEL] Enforcing shapes to be known at compile-time for
TensorFlow Graph Execution
This commit is contained in:
@@ -5,30 +5,28 @@ def run_tf():
|
||||
M, N, K = 2048, 2048, 2048
|
||||
a = tf.placeholder(tf.float32, shape=[M, K])
|
||||
b = tf.placeholder(tf.float32, shape=[N, K])
|
||||
tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True, bench=1)
|
||||
tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False, bench=1)
|
||||
tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True)
|
||||
tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False)
|
||||
triton_c = triton.ops.dot(a, b, False, True, 1)
|
||||
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
|
||||
triton_y = tf.math.reduce_mean(triton_d)
|
||||
fw_c = tf.matmul(a, b, False, True)
|
||||
fw_d = tf.matmul(fw_c, b, True, False)
|
||||
fw_y = tf.math.reduce_mean(fw_d)
|
||||
# Gradient
|
||||
tr_da = tf.gradients(tr_d, [a])
|
||||
tf_da = tf.gradients(tf_d, [a])
|
||||
triton_da, triton_db = tf.gradients(triton_y, [a, b])
|
||||
fw_da, fw_db = tf.gradients(fw_y, [a, b])
|
||||
# Reference
|
||||
ha = np.random.rand(M, K).astype(np.float32)
|
||||
hb = np.random.rand(K, N).astype(np.float32)
|
||||
# Run
|
||||
feed_dict = {a: np.random.rand(M, K).astype(np.float32),
|
||||
b: np.random.rand(K, N).astype(np.float32)}
|
||||
sess = tf.InteractiveSession()
|
||||
sess.run(tf.global_variables_initializer())
|
||||
result = sess.run([tr_da, tf_da], feed_dict = {a: ha,
|
||||
b: hb})
|
||||
result = sess.run([triton_da, fw_da, triton_db, fw_db, fw_y, triton_y], feed_dict = feed_dict)
|
||||
triton_da, fw_da = result[0][0], result[1][0]
|
||||
triton_db, fw_db = result[2][0], result[3][0]
|
||||
# Benchmark
|
||||
nanosec = triton.bench_registry[tr_d]
|
||||
print('NANOSEC: ', nanosec)
|
||||
#print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||
# Test
|
||||
print(result[0][0])
|
||||
print(result[1][0])
|
||||
dif = np.abs(result[0][0] - result[1][0])
|
||||
print("dif: %f" % np.max(dif))
|
||||
nanosec = triton.bench_registry[triton_d]
|
||||
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||
print('Diff DA:', (triton_da - fw_da).max())
|
||||
print('Diff DB:', (triton_db - fw_db).max())
|
||||
|
||||
|
||||
def run_torch():
|
||||
@@ -41,7 +39,7 @@ def run_torch():
|
||||
torch_c = torch.matmul(a, torch.t(b))
|
||||
torch_d = torch.matmul(torch.t(torch_c), b)
|
||||
torch_y = torch.mean(torch_d)
|
||||
triton_c = triton.ops.dot(a, b, False, True)
|
||||
triton_c = triton.ops.dot(a, b, False, True, 1)
|
||||
triton_d = triton.ops.dot(triton_c, b, True, False, 1)
|
||||
triton_y = torch.mean(triton_d)
|
||||
# torch gradient
|
||||
@@ -56,7 +54,6 @@ def run_torch():
|
||||
triton_db = b.grad.clone()
|
||||
|
||||
nanosec = triton.bench_registry[triton_d]
|
||||
print(nanosec)
|
||||
print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3)
|
||||
print('Diff DA:', (torch_da - triton_da).max())
|
||||
print('Diff DB:', (torch_db - triton_db).max())
|
||||
|
@@ -53,11 +53,11 @@ class ProdKeyTest(tf.test.TestCase):
|
||||
B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32)
|
||||
E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32)
|
||||
|
||||
a = tf.placeholder(tf.float16, a_shape, name="a")
|
||||
b = tf.placeholder(tf.float16, b_shape, name="b")
|
||||
e = tf.placeholder(tf.float16, c_shape, name="e")
|
||||
feed_dict = { a: A.astype(np.float16),
|
||||
b: B.astype(np.float16),
|
||||
a = tf.placeholder(tf.float32, a_shape, name="a")
|
||||
b = tf.placeholder(tf.float32, b_shape, name="b")
|
||||
e = tf.placeholder(tf.float32, c_shape, name="e")
|
||||
feed_dict = { a: A.astype(np.float32),
|
||||
b: B.astype(np.float32),
|
||||
e: E }
|
||||
|
||||
c = triton.ops.einsum(einsum, a, b, bench=bench)
|
||||
|
@@ -156,7 +156,7 @@ void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vect
|
||||
os << " };\n ";
|
||||
os << " run();";
|
||||
os << " if(bench_ > 0)\n ";
|
||||
os << " i64scalar_map[id_] = triton::tools::bench(run, stream);\n ";
|
||||
os << " i64scalar_map[bench_id_] = triton::tools::bench(run, stream);\n ";
|
||||
}
|
||||
|
||||
void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name,
|
||||
@@ -186,6 +186,7 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
|
||||
os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl;
|
||||
os << " .Input(\"" << name << ": T" << i << "\")\n";
|
||||
}
|
||||
std::vector<int> out_idx;
|
||||
for(size_t i = 0; i < outputs.size(); i++){
|
||||
std::string name = outputs[i];
|
||||
size_t idx;
|
||||
@@ -194,11 +195,19 @@ void gen_tf_register_op(std::ostream &os, const std::string &name,
|
||||
break;
|
||||
if(idx == args.size())
|
||||
throw std::runtime_error("unknown output");
|
||||
os << " .Output(\"out" << i << ": T" << idx << "\")\n";
|
||||
out_idx.push_back(idx);
|
||||
}
|
||||
for(size_t i = 0; i < out_idx.size(); i++)
|
||||
os << " .Output(\"out" << i << ": T" << out_idx[i] << "\")\n";
|
||||
os << " .Attr(\"id: int\")\n";
|
||||
os << " .Attr(\"bench: int\")\n";
|
||||
os << " .Attr(\"bench_id: int\")\n";
|
||||
os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {\n";
|
||||
for(size_t i = 0; i < out_idx.size(); i++)
|
||||
os << " c->set_output(" << i << ", c->input(" << out_idx[i] << "));\n";
|
||||
os << " return Status::OK();\n";
|
||||
os << " })\n";
|
||||
|
||||
os << ";\n";
|
||||
}
|
||||
|
||||
@@ -313,7 +322,7 @@ oss << R"(
|
||||
private:
|
||||
int id_;
|
||||
int bench_;
|
||||
int bench_id_;
|
||||
int64 bench_id_;
|
||||
};
|
||||
|
||||
// register kernel builder
|
||||
@@ -397,6 +406,7 @@ void gen_torch_signature(std::ostringstream& oss,
|
||||
oss << ret_ty << " " << name << "(";
|
||||
oss << "int64_t id, ";
|
||||
oss << "int64_t bench, ";
|
||||
oss << "int64_t bench_id, ";
|
||||
for(size_t i = 0; i < args.size(); i++) {
|
||||
ir::argument* arg = args[i];
|
||||
if(i > 0)
|
||||
@@ -453,7 +463,7 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector<ir::argu
|
||||
os << " };\n ";
|
||||
os << " run();";
|
||||
os << " if(bench > 0)\n ";
|
||||
os << " i64scalar_map[id] = triton::tools::bench(run, stream);\n ";
|
||||
os << " i64scalar_map[bench_id] = triton::tools::bench(run, &stream);\n ";
|
||||
}
|
||||
|
||||
void gen_torch_ret(std::ostream &os, const std::vector<std::string>& outputs) {
|
||||
|
@@ -1,4 +1,5 @@
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
|
||||
using namespace tensorflow;
|
||||
|
||||
@@ -28,4 +29,10 @@ REGISTER_OP("AllocEmpty")
|
||||
.Input("x: int32")
|
||||
.Attr("T : {bool, int8, int16, int32, int64, float16, float32, float64}")
|
||||
.Output("y: T")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
shape_inference::ShapeHandle handle;
|
||||
c->MakeShapeFromShapeTensor(0, &handle);
|
||||
c->set_output(0, handle);
|
||||
return Status::OK();
|
||||
});
|
||||
;
|
||||
|
@@ -5,6 +5,7 @@ import triton._C.libtriton as libtriton
|
||||
torch = None
|
||||
tensorflow = None
|
||||
tf_extra_ops = None
|
||||
gen_resource_variable_ops = None
|
||||
|
||||
def _import_torch():
|
||||
global torch
|
||||
@@ -13,8 +14,10 @@ def _import_torch():
|
||||
|
||||
def _import_tensorflow():
|
||||
global tensorflow
|
||||
global gen_resource_variable_ops
|
||||
if tensorflow is None:
|
||||
import tensorflow
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
|
||||
def _import_tf_extra_ops():
|
||||
global tf_extra_ops
|
||||
|
@@ -13,7 +13,6 @@ class OpContext(object):
|
||||
class function_meta(type):
|
||||
|
||||
def __init__(cls, name, bases, attrs):
|
||||
cls.contexts = dict()
|
||||
cls.registered = False
|
||||
return super(function_meta, cls).__init__(name, bases, attrs)
|
||||
|
||||
@@ -45,17 +44,20 @@ class function(metaclass = function_meta):
|
||||
@classmethod
|
||||
def apply_tensorflow(cls, *args, **kwargs):
|
||||
ctx = OpContext()
|
||||
# Acquire a mutex here to ensure that calls to alloc_empty()
|
||||
# are handled properly
|
||||
mutex = fw.gen_resource_variable_ops.mutex_v2()
|
||||
lock = fw.gen_resource_variable_ops.mutex_lock(mutex)
|
||||
with fw.tensorflow.python.ops.control_dependencies([lock]):
|
||||
result = cls.forward(ctx, *args, **kwargs)
|
||||
id = result.op.get_attr('id')
|
||||
cls.contexts[id] = ctx
|
||||
ctx_registry[result] = ctx
|
||||
# register backward
|
||||
name = result.op.op_def.name
|
||||
if not cls.registered:
|
||||
@fw.tensorflow.RegisterGradient(name)
|
||||
def gradient(op, dy):
|
||||
id = op.get_attr('id')
|
||||
return cls.backward(cls.contexts[id], dy)
|
||||
with fw.tensorflow.control_dependencies([op]):
|
||||
return cls.backward(ctx_registry[op.outputs[0]], dy)
|
||||
cls.registered = True
|
||||
# return result tensor
|
||||
return result
|
||||
|
@@ -220,18 +220,18 @@ class kernel:
|
||||
op_id = self.fw_id[key]
|
||||
# register grid
|
||||
libtriton.register_grid(op_id, _make_grid(args))
|
||||
# id for the benchmark result
|
||||
bench_id = libtriton.make_scalar_id() if bench > 0 else -1
|
||||
# create operands
|
||||
op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]]
|
||||
# call framework function
|
||||
if fw.has_tensorflow():
|
||||
bench_id = libtriton.make_scalar_id() if bench > 0 else 0
|
||||
ret = self.fw_op(*op_args, id=op_id, bench=bench, bench_id=bench_id)
|
||||
args = [x for x in args[:-1]]
|
||||
ret = self.fw_op(*args, id=op_id, bench=bench, bench_id=bench_id)
|
||||
if bench > 0:
|
||||
bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id)
|
||||
|
||||
elif fw.has_torch():
|
||||
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in op_args]
|
||||
ret = self.fw_op(op_id, bench, *args)
|
||||
args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]]
|
||||
ret = self.fw_op(op_id, bench, bench_id, *args)
|
||||
if bench > 0:
|
||||
bench_registry[ret] = libtriton.retrieve_scalar(op_id)
|
||||
else:
|
||||
|
@@ -40,7 +40,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
kernel = triton.kernel(src, ['C'])
|
||||
|
||||
@staticmethod
|
||||
def _call(a, b, transpose_a, transpose_b, bench = 0):
|
||||
def _call(a, b, transpose_a, transpose_b, bench):
|
||||
# extract shapes
|
||||
shape_a = triton.shape(a)
|
||||
shape_b = triton.shape(b)
|
||||
@@ -86,24 +86,26 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
ctx.save_for_backward(a, b)
|
||||
ctx.t_a = transpose_a
|
||||
ctx.t_b = transpose_b
|
||||
ctx.bench = bench
|
||||
return _dot._call(a, b, transpose_a, transpose_b, bench)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy):
|
||||
a, b = ctx.saved_tensors
|
||||
t_a, t_b = ctx.t_a, ctx.t_b
|
||||
bench = ctx.bench
|
||||
if not t_a and not t_b:
|
||||
da = _dot._call(dy, b, False, True)
|
||||
db = _dot._call(a, dy, True, False)
|
||||
da = _dot._call(dy, b, False, True, bench)
|
||||
db = _dot._call(a, dy, True, False, bench)
|
||||
elif not t_a and t_b:
|
||||
da = _dot._call(dy, b, False, False)
|
||||
db = _dot._call(dy, a, True, False)
|
||||
da = _dot._call(dy, b, False, False, bench)
|
||||
db = _dot._call(dy, a, True, False, bench)
|
||||
elif t_a and not t_b:
|
||||
da = _dot._call(b, dy, False, True)
|
||||
db = _dot._call(a, dy, False, False)
|
||||
da = _dot._call(b, dy, False, True, bench)
|
||||
db = _dot._call(a, dy, False, False, bench)
|
||||
elif t_a and t_b:
|
||||
da = _dot._call(b, dy, True, True)
|
||||
db = _dot._call(dy, a, True, True)
|
||||
da = _dot._call(b, dy, True, True, bench)
|
||||
db = _dot._call(dy, a, True, True, bench)
|
||||
else:
|
||||
assert False
|
||||
return da, db, None, None, None, None, None, None, None
|
||||
|
@@ -1,13 +1,22 @@
|
||||
import triton.frameworks as fw
|
||||
import triton._C.libtriton as libtriton
|
||||
import numpy as np
|
||||
|
||||
def cdiv(a, b):
|
||||
return -(-a // b)
|
||||
|
||||
class tf_empty_proxy:
|
||||
|
||||
def __init__(self, args, dtype):
|
||||
self.args = args
|
||||
self.dtype = dtype
|
||||
|
||||
def empty(shapes, dtype):
|
||||
if fw.has_tensorflow():
|
||||
args = [x.handle if isinstance(x, scalar) else x for x in shapes]
|
||||
#return fw.tensorflow.Variable(np.empty(shapes),shape=shapes, dtype=dtype)
|
||||
args = [x.handle if isinstance(x, scalar) else fw.tensorflow.constant(x) for x in shapes]
|
||||
args = fw.tensorflow.stack(args)
|
||||
#return tf_empty_proxy(args, dtype)
|
||||
return fw.tf_extra_ops.alloc_empty(args, T = dtype)
|
||||
elif fw.has_torch():
|
||||
return fw.torch.empty(*shapes).cuda()
|
||||
|
Reference in New Issue
Block a user