diff --git a/python/examples/dot.py b/python/examples/dot.py index 425fed986..3e061c112 100644 --- a/python/examples/dot.py +++ b/python/examples/dot.py @@ -5,30 +5,28 @@ def run_tf(): M, N, K = 2048, 2048, 2048 a = tf.placeholder(tf.float32, shape=[M, K]) b = tf.placeholder(tf.float32, shape=[N, K]) - tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True, bench=1) - tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False, bench=1) - tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True) - tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False) + triton_c = triton.ops.dot(a, b, False, True, 1) + triton_d = triton.ops.dot(triton_c, b, True, False, 1) + triton_y = tf.math.reduce_mean(triton_d) + fw_c = tf.matmul(a, b, False, True) + fw_d = tf.matmul(fw_c, b, True, False) + fw_y = tf.math.reduce_mean(fw_d) # Gradient - tr_da = tf.gradients(tr_d, [a]) - tf_da = tf.gradients(tf_d, [a]) + triton_da, triton_db = tf.gradients(triton_y, [a, b]) + fw_da, fw_db = tf.gradients(fw_y, [a, b]) # Reference - ha = np.random.rand(M, K).astype(np.float32) - hb = np.random.rand(K, N).astype(np.float32) - # Run + feed_dict = {a: np.random.rand(M, K).astype(np.float32), + b: np.random.rand(K, N).astype(np.float32)} sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) - result = sess.run([tr_da, tf_da], feed_dict = {a: ha, - b: hb}) + result = sess.run([triton_da, fw_da, triton_db, fw_db, fw_y, triton_y], feed_dict = feed_dict) + triton_da, fw_da = result[0][0], result[1][0] + triton_db, fw_db = result[2][0], result[3][0] # Benchmark - nanosec = triton.bench_registry[tr_d] - print('NANOSEC: ', nanosec) - #print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3) - # Test - print(result[0][0]) - print(result[1][0]) - dif = np.abs(result[0][0] - result[1][0]) - print("dif: %f" % np.max(dif)) + nanosec = triton.bench_registry[triton_d] + print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3) + print('Diff DA:', (triton_da - fw_da).max()) + print('Diff DB:', (triton_db - fw_db).max()) def run_torch(): @@ -41,7 +39,7 @@ def run_torch(): torch_c = torch.matmul(a, torch.t(b)) torch_d = torch.matmul(torch.t(torch_c), b) torch_y = torch.mean(torch_d) - triton_c = triton.ops.dot(a, b, False, True) + triton_c = triton.ops.dot(a, b, False, True, 1) triton_d = triton.ops.dot(triton_c, b, True, False, 1) triton_y = torch.mean(triton_d) # torch gradient @@ -56,7 +54,6 @@ def run_torch(): triton_db = b.grad.clone() nanosec = triton.bench_registry[triton_d] - print(nanosec) print('TFLOPS:', 2. * M * N * K / nanosec * 1e-3) print('Diff DA:', (torch_da - triton_da).max()) print('Diff DB:', (torch_db - triton_db).max()) diff --git a/python/examples/einsum_test.py b/python/examples/einsum_test.py index 3363a88ea..b09f46cab 100644 --- a/python/examples/einsum_test.py +++ b/python/examples/einsum_test.py @@ -53,11 +53,11 @@ class ProdKeyTest(tf.test.TestCase): B = np.random.uniform(-1.0, 1.0, b_shape).astype(np.float16).astype(np.float32) E = np.random.uniform(-1.0, 1.0, c_shape).astype(np.float16).astype(np.float32) - a = tf.placeholder(tf.float16, a_shape, name="a") - b = tf.placeholder(tf.float16, b_shape, name="b") - e = tf.placeholder(tf.float16, c_shape, name="e") - feed_dict = { a: A.astype(np.float16), - b: B.astype(np.float16), + a = tf.placeholder(tf.float32, a_shape, name="a") + b = tf.placeholder(tf.float32, b_shape, name="b") + e = tf.placeholder(tf.float32, c_shape, name="e") + feed_dict = { a: A.astype(np.float32), + b: B.astype(np.float32), e: E } c = triton.ops.einsum(einsum, a, b, bench=bench) diff --git a/python/src/bindings.cc b/python/src/bindings.cc index 59b5c54d6..80e2f8ddc 100644 --- a/python/src/bindings.cc +++ b/python/src/bindings.cc @@ -156,7 +156,7 @@ void gen_make_launch_function(std::ostream &os, int num_outputs, const std::vect os << " };\n "; os << " run();"; os << " if(bench_ > 0)\n "; - os << " i64scalar_map[id_] = triton::tools::bench(run, stream);\n "; + os << " i64scalar_map[bench_id_] = triton::tools::bench(run, stream);\n "; } void gen_tf_register_kernel_builder(std::ostream &os, const std::string &name, @@ -186,6 +186,7 @@ void gen_tf_register_op(std::ostream &os, const std::string &name, os << " .Attr(\"T" << i << " : {bool, int8, int16, int32, int64, float16, float32, float64}\")" << std::endl; os << " .Input(\"" << name << ": T" << i << "\")\n"; } + std::vector out_idx; for(size_t i = 0; i < outputs.size(); i++){ std::string name = outputs[i]; size_t idx; @@ -194,11 +195,19 @@ void gen_tf_register_op(std::ostream &os, const std::string &name, break; if(idx == args.size()) throw std::runtime_error("unknown output"); - os << " .Output(\"out" << i << ": T" << idx << "\")\n"; + out_idx.push_back(idx); } + for(size_t i = 0; i < out_idx.size(); i++) + os << " .Output(\"out" << i << ": T" << out_idx[i] << "\")\n"; os << " .Attr(\"id: int\")\n"; os << " .Attr(\"bench: int\")\n"; os << " .Attr(\"bench_id: int\")\n"; + os << " .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {\n"; + for(size_t i = 0; i < out_idx.size(); i++) + os << " c->set_output(" << i << ", c->input(" << out_idx[i] << "));\n"; + os << " return Status::OK();\n"; + os << " })\n"; + os << ";\n"; } @@ -313,7 +322,7 @@ oss << R"( private: int id_; int bench_; - int bench_id_; + int64 bench_id_; }; // register kernel builder @@ -397,6 +406,7 @@ void gen_torch_signature(std::ostringstream& oss, oss << ret_ty << " " << name << "("; oss << "int64_t id, "; oss << "int64_t bench, "; + oss << "int64_t bench_id, "; for(size_t i = 0; i < args.size(); i++) { ir::argument* arg = args[i]; if(i > 0) @@ -453,7 +463,7 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vector 0)\n "; - os << " i64scalar_map[id] = triton::tools::bench(run, stream);\n "; + os << " i64scalar_map[bench_id] = triton::tools::bench(run, &stream);\n "; } void gen_torch_ret(std::ostream &os, const std::vector& outputs) { diff --git a/python/src/tensorflow/alloc_empty.cc b/python/src/tensorflow/alloc_empty.cc index 75ab1201d..43f82cbfa 100644 --- a/python/src/tensorflow/alloc_empty.cc +++ b/python/src/tensorflow/alloc_empty.cc @@ -1,4 +1,5 @@ #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" using namespace tensorflow; @@ -28,4 +29,10 @@ REGISTER_OP("AllocEmpty") .Input("x: int32") .Attr("T : {bool, int8, int16, int32, int64, float16, float32, float64}") .Output("y: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle handle; + c->MakeShapeFromShapeTensor(0, &handle); + c->set_output(0, handle); + return Status::OK(); + }); ; diff --git a/python/triton/frameworks.py b/python/triton/frameworks.py index 993389a82..f495680f0 100644 --- a/python/triton/frameworks.py +++ b/python/triton/frameworks.py @@ -5,6 +5,7 @@ import triton._C.libtriton as libtriton torch = None tensorflow = None tf_extra_ops = None +gen_resource_variable_ops = None def _import_torch(): global torch @@ -13,8 +14,10 @@ def _import_torch(): def _import_tensorflow(): global tensorflow + global gen_resource_variable_ops if tensorflow is None: import tensorflow + from tensorflow.python.ops import gen_resource_variable_ops def _import_tf_extra_ops(): global tf_extra_ops diff --git a/python/triton/function.py b/python/triton/function.py index 79a0e5ec8..eb52d145e 100644 --- a/python/triton/function.py +++ b/python/triton/function.py @@ -13,7 +13,6 @@ class OpContext(object): class function_meta(type): def __init__(cls, name, bases, attrs): - cls.contexts = dict() cls.registered = False return super(function_meta, cls).__init__(name, bases, attrs) @@ -45,17 +44,20 @@ class function(metaclass = function_meta): @classmethod def apply_tensorflow(cls, *args, **kwargs): ctx = OpContext() - result = cls.forward(ctx, *args, **kwargs) - id = result.op.get_attr('id') - cls.contexts[id] = ctx + # Acquire a mutex here to ensure that calls to alloc_empty() + # are handled properly + mutex = fw.gen_resource_variable_ops.mutex_v2() + lock = fw.gen_resource_variable_ops.mutex_lock(mutex) + with fw.tensorflow.python.ops.control_dependencies([lock]): + result = cls.forward(ctx, *args, **kwargs) ctx_registry[result] = ctx # register backward name = result.op.op_def.name if not cls.registered: @fw.tensorflow.RegisterGradient(name) def gradient(op, dy): - id = op.get_attr('id') - return cls.backward(cls.contexts[id], dy) + with fw.tensorflow.control_dependencies([op]): + return cls.backward(ctx_registry[op.outputs[0]], dy) cls.registered = True # return result tensor return result diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 57e0afc13..769e47a29 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -220,18 +220,18 @@ class kernel: op_id = self.fw_id[key] # register grid libtriton.register_grid(op_id, _make_grid(args)) + # id for the benchmark result + bench_id = libtriton.make_scalar_id() if bench > 0 else -1 # create operands - op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]] # call framework function if fw.has_tensorflow(): - bench_id = libtriton.make_scalar_id() if bench > 0 else 0 - ret = self.fw_op(*op_args, id=op_id, bench=bench, bench_id=bench_id) + args = [x for x in args[:-1]] + ret = self.fw_op(*args, id=op_id, bench=bench, bench_id=bench_id) if bench > 0: bench_registry[ret] = triton.utils.id_dict.lazy_entry(bench_id) - elif fw.has_torch(): - args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in op_args] - ret = self.fw_op(op_id, bench, *args) + args = [x.contiguous() if isinstance(x, fw.torch.Tensor) else x for x in args[:-1]] + ret = self.fw_op(op_id, bench, bench_id, *args) if bench > 0: bench_registry[ret] = libtriton.retrieve_scalar(op_id) else: diff --git a/python/triton/ops/dot.py b/python/triton/ops/dot.py index 7a5069701..140cd82cd 100644 --- a/python/triton/ops/dot.py +++ b/python/triton/ops/dot.py @@ -40,7 +40,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C, kernel = triton.kernel(src, ['C']) @staticmethod - def _call(a, b, transpose_a, transpose_b, bench = 0): + def _call(a, b, transpose_a, transpose_b, bench): # extract shapes shape_a = triton.shape(a) shape_b = triton.shape(b) @@ -86,24 +86,26 @@ void dot(TYPE * A, TYPE * B, TYPE * C, ctx.save_for_backward(a, b) ctx.t_a = transpose_a ctx.t_b = transpose_b + ctx.bench = bench return _dot._call(a, b, transpose_a, transpose_b, bench) @staticmethod def backward(ctx, dy): a, b = ctx.saved_tensors t_a, t_b = ctx.t_a, ctx.t_b + bench = ctx.bench if not t_a and not t_b: - da = _dot._call(dy, b, False, True) - db = _dot._call(a, dy, True, False) + da = _dot._call(dy, b, False, True, bench) + db = _dot._call(a, dy, True, False, bench) elif not t_a and t_b: - da = _dot._call(dy, b, False, False) - db = _dot._call(dy, a, True, False) + da = _dot._call(dy, b, False, False, bench) + db = _dot._call(dy, a, True, False, bench) elif t_a and not t_b: - da = _dot._call(b, dy, False, True) - db = _dot._call(a, dy, False, False) + da = _dot._call(b, dy, False, True, bench) + db = _dot._call(a, dy, False, False, bench) elif t_a and t_b: - da = _dot._call(b, dy, True, True) - db = _dot._call(dy, a, True, True) + da = _dot._call(b, dy, True, True, bench) + db = _dot._call(dy, a, True, True, bench) else: assert False return da, db, None, None, None, None, None, None, None diff --git a/python/triton/utils.py b/python/triton/utils.py index 5b832668f..eca9f665e 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -1,13 +1,22 @@ import triton.frameworks as fw import triton._C.libtriton as libtriton +import numpy as np def cdiv(a, b): return -(-a // b) +class tf_empty_proxy: + + def __init__(self, args, dtype): + self.args = args + self.dtype = dtype + def empty(shapes, dtype): if fw.has_tensorflow(): - args = [x.handle if isinstance(x, scalar) else x for x in shapes] + #return fw.tensorflow.Variable(np.empty(shapes),shape=shapes, dtype=dtype) + args = [x.handle if isinstance(x, scalar) else fw.tensorflow.constant(x) for x in shapes] args = fw.tensorflow.stack(args) + #return tf_empty_proxy(args, dtype) return fw.tf_extra_ops.alloc_empty(args, T = dtype) elif fw.has_torch(): return fw.torch.empty(*shapes).cuda()