From 321d268a4a105dbf9070c42f1a7fbc15c8be3cf7 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 25 Aug 2019 21:26:09 -0700 Subject: [PATCH] more progress --- CMakeLists.txt | 6 +- include/triton/runtime/function.h | 13 ++- lib/codegen/analysis/tune.cc | 4 +- lib/runtime/function.cc | 38 ++++--- python/examples/dot.py | 88 ++++++++++++---- python/setup.py | 1 + python/src/tensorflow.cc | 162 +++++++++++++++++++----------- python/triton/ops.py | 66 +++++++++++- 8 files changed, 268 insertions(+), 110 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 84e16ddf9..15985cc87 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,20 +33,20 @@ endif() if(BUILD_PYTHON_MODULE) message(STATUS "Adding Python module") # PyBind11 wrapper source file - file(GLOB_RECURSE PYTHON_SRC python/src/tensorflow.cpp) + file(GLOB_RECURSE PYTHON_SRC python/src/tensorflow.cc) # update include directory include_directories(python/src/ ${PYTHON_INCLUDE_DIRS} ${TF_INCLUDE_DIRS}) # update link directories link_directories(${TF_LIB_DIRS}) # extra tensorflow ops (e.g., alloc_empty) - file(GLOB_RECURSE EXTRA_TF_OPS_SRC python/src/tensorflow/*.cpp) + file(GLOB_RECURSE EXTRA_TF_OPS_SRC python/src/tensorflow/*.cc) add_library(extra_tf_ops SHARED ${EXTRA_TF_OPS_SRC}) target_link_libraries(extra_tf_ops ${TF_LIBS}) endif() # Triton -file(GLOB_RECURSE LIBTRITON_SRC lib/*.cpp lib/*.cc) +file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc) add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC}) target_link_libraries(triton LLVM) diff --git a/include/triton/runtime/function.h b/include/triton/runtime/function.h index c3f4d53ff..b0054c647 100644 --- a/include/triton/runtime/function.h +++ b/include/triton/runtime/function.h @@ -59,12 +59,11 @@ class metaparameter; namespace runtime{ -typedef std::array grid_t; +typedef std::vector grid_t; typedef std::map params_t; - -template T convert(const std::string& name); -template<> long convert(const std::string& name) { return std::stol(name); } -template<> int convert(const std::string& name) { return std::stoi(name); } +template inline T convert(const std::string& name); +template<> inline long convert(const std::string& name) { return std::stol(name); } +template<> inline int convert(const std::string& name) { return std::stoi(name); } class function { public: @@ -91,7 +90,7 @@ private: class caller { public: caller(ir::function *ir, std::shared_ptr program, const options_t& opt_); - void operator()(driver::stream *stream, const std::array& grid, const std::vector& args) const; + void operator()(driver::stream *stream, const grid_t& grid, const std::vector& args) const; const options_t opt() const { return opt_; } private: @@ -113,7 +112,7 @@ private: public: function(const std::string& src, const options_space_t& opt = options_space_t()); - void operator()(const std::vector& args, const std::array& grid, driver::stream* stream); + void operator()(const std::vector& args, const grid_t& grid, driver::stream* stream); void operator()(const std::vector& args, const grid_fn_ty& grid, driver::stream *stream); std::string make_tensorflow_src(const std::vector &outputs, const std::string ¯o); diff --git a/lib/codegen/analysis/tune.cc b/lib/codegen/analysis/tune.cc index 5f150ee2c..5ff536849 100644 --- a/lib/codegen/analysis/tune.cc +++ b/lib/codegen/analysis/tune.cc @@ -15,8 +15,8 @@ namespace triton{ namespace codegen{ namespace analysis{ -grids::grids(size_t num_warps): num_warps_(num_warps){ -} +grids::grids(size_t num_warps): num_warps_(num_warps) +{ } bool is_hmma(ir::value *v){ bool result = false; diff --git a/lib/runtime/function.cc b/lib/runtime/function.cc index 8a47c35d4..750952bbb 100644 --- a/lib/runtime/function.cc +++ b/lib/runtime/function.cc @@ -93,7 +93,7 @@ function::caller::caller(ir::function *ir, std::shared_ptr paren } -void function::caller::operator ()(driver::stream *stream, const std::array& grid, const std::vector& args) const { +void function::caller::operator ()(driver::stream *stream, const grid_t& _grid, const std::vector& args) const { if(args.size() != param_tys_.size()) throw std::runtime_error("invalid number of arguments"); for(size_t i = 0; i < args.size(); i++){ @@ -106,6 +106,12 @@ void function::caller::operator ()(driver::stream *stream, const std::arraysetArg(i, size_of(ty), arg_i.data()); } + // sanity check + if(_grid.size() > 3) + throw std::runtime_error("grid size must be no greater than 3"); + std::array grid; + for(size_t i = 0; i < 3; i++) + grid[i] = (i < _grid.size()) ? _grid[i] : 1; stream->enqueue(&*bin_, grid, {opt_.num_warps * 32, 1, 1}); } @@ -207,20 +213,21 @@ std::unique_ptr function::make_bin(ir::module &module, driver::c } std::string preheader() { -return R"( - #define bool _Bool - #define true 1 - #define false 0 - #define __bool_true_false_are_defined 1 +return +R"( +#define bool _Bool +#define true 1 +#define false 0 +#define __bool_true_false_are_defined 1 - #define __readonly __attribute__((readonly)) - #define __writeonly __attribute__((writeonly)) - #define __noalias __attribute__((noalias)) - #define __aligned(A) __attribute__((aligned(A))) - #define __multipleof(A) __attribute__((multipleof(A))) +#define __readonly __attribute__((readonly)) +#define __writeonly __attribute__((writeonly)) +#define __noalias __attribute__((noalias)) +#define __aligned(A) __attribute__((aligned(A))) +#define __multipleof(A) __attribute__((multipleof(A))) - extern int get_program_id(int); - )"; +extern int get_program_id(int); +)"; } function::function(const std::string &src, const options_space_t& opt): src_(src), opt_space_(opt) { @@ -228,9 +235,10 @@ function::function(const std::string &src, const options_space_t& opt): src_(sr } void function::operator()(const std::vector& args, const grid_fn_ty& grid_fn, driver::stream *stream) { - /* determine if should re-tune or not */ cache_key_t key; - // re-tune if device is difference + + /* figure out if the kernel should be re-tuned */ + // re-tune if device is different key.first = stream->context()->device(); // re-tune if any int argument is different for(size_t i = 0; i < args.size(); i++){ diff --git a/python/examples/dot.py b/python/examples/dot.py index 638d49c20..e807305e6 100644 --- a/python/examples/dot.py +++ b/python/examples/dot.py @@ -3,49 +3,89 @@ import tensorflow as tf import numpy as np src = """ -const tunable int TM = {128}; -const tunable int TN = {128}; -const tunable int TK = {32}; +#if AT == 1 +#define USEA ^a +#else +#define USEA a +#endif -void matmul(restrict read_only align(16) half *A, - restrict read_only align(16) half *B, - restrict read_only align(16) half *C, - int M, int N, int K, - multiple_of(8) int lda, multiple_of(8) int ldb, int ldc) -{ +#if BT == 1 +#define USEB ^b +#else +#define USEB b +#endif + +void dot(TYPE * A __noalias __readonly __aligned(16), + TYPE * B __noalias __readonly __aligned(16), + TYPE * C __noalias __readonly __aligned(16), + int M, int N, int K, + int lda __multipleof(8), + int ldb __multipleof(8), + int ldc) { int ridx = get_program_id(0); int ridy = get_program_id(1); - int rxa[TM] = ridx * TM + (0 ... TM); - int ryb[TN] = ridy * TN + (0 ... TN); + int rxa[TM] = ridx * TM + 0 ... TM; + int ryb[TN] = ridy * TN + 0 ... TN; int rka[TK] = 0 ... TK; int rkb[TK] = 0 ... TK; float xc[TM, TN] = 0; - half* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis]; - half* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis]; - half a[TM, TK] = *pa; - half b[TN, TK] = *pb; + + /* pointers for A */ +#if AT == 1 + TYPE* pa[TK, TM] = A + rka[:, newaxis] + rxa[newaxis, :]*lda; + TYPE a[TK, TM] = *pa; +#else + TYPE* pa[TM, TK] = A + rka[newaxis, :]*lda + rxa[:, newaxis]; + TYPE a[TM, TK] = *pa; +#endif + + /* pointers for B */ +#if BT == 1 + TYPE* pb[TN, TK] = B + rkb[newaxis, :]*ldb + ryb[:, newaxis]; + TYPE b[TN, TK] = *pb; +#else + TYPE* pb[TK, TN] = B + rkb[:, newaxis] + ryb[newaxis, :]*ldb; + TYPE b[TK, TN] = *pb; +#endif + + /* reduction loop */ for(int k = K; k > 0; k = k - TK){ - xc = dot(a, trans(b), xc); + xc = USEA @ USEB + xc; +#if AT == 1 + pa = pa + TK; +#else pa = pa + TK*lda; +#endif +#if BT == 1 pb = pb + TK*ldb; +#else + pb = pb + TK; +#endif a = *pa; b = *pb; } + + /* epilogue */ int rxc[TM] = ridx * TM + (0 ... TM); int ryc[TN] = ridy * TN + (0 ... TN); - half* pc[TM, TN] = C + ryc[newaxis, :] + rxc[:, newaxis]*ldc; - half c[TM, TN] = xc; + TYPE* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; + TYPE c[TM, TN] = xc; bool checkc0[TM] = rxc < M; bool checkc1[TN] = ryc < N; bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; - @checkc *pc = c; + *pc = c; } """ +def cdiv(a, b): + return -(-a // b) + class dot: - def __init__(self): - self.matmul = triton.make_tensorflow_op(src, ['C'], ['(M + #TM - 1)/#TM', '(N + #TN - 1)/#TN']) + def __init__(self, trans_a = False, trans_b = True): + self.dot = triton.op(src, ['C']) + self.trans_a = trans_a + self.trans_b = trans_b def __call__(self, a, b): shape_a = tf.shape(a) @@ -57,9 +97,13 @@ class dot: ldb = K ldc = N c = triton.empty([M, N]) - return self.matmul.matmul(a, b, c, M, N, K, lda, ldb, ldc) + return self.dot(a, b, c, M, N, K, lda, ldb, ldc, + lambda opt: [cdiv(M, opt.D('TM')), cdiv(N, opt.D('TN')), 1], + AT = self.trans_a, BT = self.trans_b, TYPE = tf.float16, + TM = [128], TN = [128], TK = [32]) dot_tn = dot() + def run_dot(): M, N, K = 128, 128, 128 a = tf.placeholder(tf.float16, shape=[M, K]) diff --git a/python/setup.py b/python/setup.py index aeba8b5a6..ef5fa9865 100644 --- a/python/setup.py +++ b/python/setup.py @@ -35,6 +35,7 @@ class CMakeBuild(build_ext): self.build_extension(ext) def build_extension(self, ext): + self.debug = True extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) # python directors python_include_dirs = distutils.sysconfig.get_python_inc() diff --git a/python/src/tensorflow.cc b/python/src/tensorflow.cc index 098d338ad..489c545ac 100644 --- a/python/src/tensorflow.cc +++ b/python/src/tensorflow.cc @@ -1,11 +1,14 @@ -#include +#include #include +#include #include #include #include #include "triton/codegen/selection/selection.h" #include "triton/runtime/function.h" -#include "triton/lang/lang.h" +#include "triton/lang/code_gen.h" +#include "triton/lang/parser.h" +#include "triton/lang/cpp.h" #include "triton/driver/device.h" #include "triton/driver/stream.h" #include "triton/driver/kernel.h" @@ -14,14 +17,33 @@ #include "triton/ir/function.h" #include "triton/tools/bench.hpp" -typedef struct yy_buffer_state * YY_BUFFER_STATE; -extern int yyparse(); -extern YY_BUFFER_STATE yy_scan_string(const char * str); -extern void yy_delete_buffer(YY_BUFFER_STATE buffer); -extern triton::lang::translation_unit *ast_root; - using namespace triton; +namespace rt = triton::runtime; + + +/* TF triton op properties */ + +std::map id_grid_map; +std::map id_fn_map; + +void register_grid(size_t id, + const rt::function::grid_fn_ty& grid_fn) { + id_grid_map[id] = grid_fn; +} + +size_t register_fn(const std::string& src, + const rt::function::options_space_t& opt) { + size_t id = id_grid_map.size(); + bool is_inserted = id_fn_map.insert({id, new rt::function(src, opt)}).second; + if(!is_inserted) + assert(false); + return id; +} + + +/* TF source-code generation */ + inline std::string to_tf_ty(ir::type *ty) { if(ty->is_integer_ty(1)) return "bool"; @@ -59,21 +81,6 @@ inline std::string ref_to_tf_ty(ir::type *ty) { return res; } -inline triton::lang::translation_unit *make_ast(const char *src) { - YY_BUFFER_STATE buffer = yy_scan_string(src); - yyparse(); - yy_delete_buffer(buffer); - triton::lang::translation_unit *program = ast_root; - return program; -} - -inline std::unique_ptr make_ir(ir::context& ctx, triton::lang::translation_unit *program) { - // create Triton-IR from AST - ir::module* module = new ir::module("", ctx); - program->codegen(module); - return std::unique_ptr(module); -} - void gen_extract_inputs(std::ostream &os, const std::vector& args) { for(unsigned i = 0; i < args.size(); i++){ @@ -102,24 +109,8 @@ void gen_make_handles(std::ostream &os, const std::vector& args) } } -void gen_make_spmd_grid(std::ostream &os, const std::vector& macros) { - std::regex regex("#([a-zA-Z]([a-zA-Z]|[0-9])*)"); - std::vector grids = macros; - for(size_t i = grids.size(); i < 3; i++) - grids.push_back("1"); - std::string grid = "rt::grid_t{"; - for(size_t i = 0; i < grids.size(); i++){ - if(i > 0) - grid += ", "; - grid += std::regex_replace(grids[i], regex, "x.at(\"$1\")"); - } - grid += "}"; - - os << " auto grid = [&](const rt::params_t& x) { return " << grid << "; };\n "; -} - void gen_make_launch_function(std::ostream &os, const std::vector& args) { - os << " fn_({"; + os << " (*id_fn_map.at(id_))({"; for(unsigned i = 0; i < args.size() ; i++){ ir::argument *arg = args[i]; std::string name = arg->get_name(); @@ -129,7 +120,7 @@ void gen_make_launch_function(std::ostream &os, const std::vector os << ", "; os << name; } - os << "}, grid, stream); \n"; + os << "}, id_grid_map.at(id_), stream); \n"; } void gen_register_kernel_builder(std::ostream &os, const std::string &name, @@ -168,20 +159,55 @@ void gen_register_op(std::ostream &os, const std::string &name, throw std::runtime_error("unknown output"); os << " .Output(\"out" << i << ": " << to_tf_scalar_ty(args[idx]->get_type()) << "\")\n"; } + os << " .Attr(\"id: int\")" << std::endl; os << ";\n"; } -std::string make_tensorflow_src(const std::string src, +inline std::string preheader() { +return +R"( +#define bool _Bool +#define true 1 +#define false 0 +#define __bool_true_false_are_defined 1 + +#define __readonly __attribute__((readonly)) +#define __writeonly __attribute__((writeonly)) +#define __noalias __attribute__((noalias)) +#define __aligned(A) __attribute__((aligned(A))) +#define __multipleof(A) __attribute__((multipleof(A))) + +extern int get_program_id(int); +)"; +} + +std::tuple make_tensorflow_src(std::string src, const std::vector& outputs, - const std::vector& macros) { - triton::lang::translation_unit *ast = make_ast(src.c_str()); - triton::ir::context context; - std::unique_ptr ir = make_ir(context, ast); + const runtime::function::options_space_t& opt) +{ + src = preheader() + src; + // pre-process + TokenSequence tokens; + Preprocessor cpp(&src, true); + for(auto it: opt.defines){ + cpp.AddMacro(it.first, &it.second[0]); + } + cpp.Process(tokens); + // parse + Parser parser(tokens); + parser.Parse(); + // triton-ir code-gen + ir::context ctx; + auto ir = std::unique_ptr(new ir::module("", ctx)); + Generator gen(&parser); + gen.Gen(&*ir); // function ir::function* fn = ir->get_function_list().front(); std::string name = fn->get_name(); - name[0] = static_cast(std::toupper(name[0])); - std::string opname = name + "Op"; + std::string cc_name = name; + cc_name[0] = static_cast(std::toupper(cc_name[0])); + std::string opname = cc_name + "Op"; std::ostringstream oss; oss << R"( @@ -204,12 +230,16 @@ using GPUDevice = Eigen::GpuDevice; namespace rt = triton::runtime; namespace drv = triton::driver; -std::string src = R"TTKERNSRC( )" + src + ")TTKERNSRC\";" + R"( +extern std::map id_grid_map; +extern std::map id_fn_map; + class )" << opname << R"(: public OpKernel { public: explicit )" << opname << R"((OpKernelConstruction* context) - : OpKernel(context), fn_(src) { } + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("id", &id_)); + } void Compute(OpKernelContext* context){ // get device/stream @@ -229,9 +259,7 @@ oss << R"( )"; gen_make_handles(oss, fn->args()); oss << R"( - // create spmd grid )"; -gen_make_spmd_grid(oss, macros); oss << R"( // launch function )"; @@ -240,22 +268,42 @@ oss << R"( } private: - rt::function fn_; + int id_; }; // register kernel builder )"; -gen_register_kernel_builder(oss, name, opname, fn->args()); +gen_register_kernel_builder(oss, cc_name, opname, fn->args()); oss << R"( // register op )"; -gen_register_op(oss, name, fn->args(), outputs); +gen_register_op(oss, cc_name, fn->args(), outputs); - return oss.str(); + return {oss.str(), name}; } +typedef triton::runtime::function::options_t options_t; +typedef triton::runtime::function::options_space_t options_space_t; PYBIND11_MODULE(libtriton, m) { m.doc() = "Python bindings to the C++ Triton API"; - m.def("make_tensorflow_src", &make_tensorflow_src, "Creates C++ source code for a custom Tensorflow op corresponding to the specified Triton kernel"); + + // framework binding source code generation + m.def("make_tensorflow_src", &make_tensorflow_src, + "Creates C++ source code for a custom Tensorflow op " + "corresponding to the specified Triton kernel"); + + // bindings for triton classes + pybind11::class_(m, "options") + .def(pybind11::init<>()) + .def("D", &options_t::D); + + pybind11::class_(m, "options_space") + .def(pybind11::init<>()) + .def_readwrite("defines", &options_space_t::defines) + .def_readwrite("num_warps", &options_space_t::num_warps); + + // hooks into triton constructs since frameworks may not use pybind11 + m.def("register_grid", ®ister_grid); + m.def("register_fn", ®ister_fn); } diff --git a/python/triton/ops.py b/python/triton/ops.py index a10739903..0099e1289 100644 --- a/python/triton/ops.py +++ b/python/triton/ops.py @@ -91,13 +91,71 @@ def build(src, path): setuptools.setup(**args) shutil.rmtree(tmp) +def _cvt_to_def_str(obj): + if isinstance(obj, bool): + return str(int(obj)) + if isinstance(obj, tf.DType): + return {tf.int8: 'char', + tf.int16: 'short', + tf.int32: 'int', + tf.int64: 'long', + tf.float16: 'half', + tf.float32: 'float', + tf.float64: 'double'}[obj] + return str(obj) + +class op: + + def _make_tensorflow_op(self, src, outputs, options): + src, name = make_bindings(src, outputs, options) + cache_path = make_cache_path(src) + cpp, so = write_bindings(src, cache_path) + build(cpp, cache_path) + result = tf.load_op_library(so) + return result.__dict__[name] + + def __init__(self, src, outputs): + self.fw_ops = dict() + self.src = src + self.outputs = outputs + pass + + def D(self, name): + pass + + def __call__(self, *args, **kwargs): + # recompilation key + key = zip(kwargs.keys(), kwargs.values()) + # create a new op when non-iterable defines are different + if key not in self.fw_ops: + # code generation options + defines = [] + for k, v in kwargs.items(): + try: + values = list(map(_cvt_to_def_str, v)) + except TypeError: + values = [_cvt_to_def_str(v)] + defines.append((k, values)) + opt = libtriton.options_space() + opt.defines = defines + opt.num_warps = [1, 2, 4, 8] + # register framework op + id = libtriton.register_fn(self.src, opt) + self.fw_ops[key] = (self._make_tensorflow_op(self.src, self.outputs, opt), id) + # retrieve framework op + op, id = self.fw_ops[key] + libtriton.register_grid(id, args[-1]) + op_args = args[:-1] + return op(*op_args, id=id) + + def make_tensorflow_op(src, outputs, grids): - bindings = make_bindings(src, outputs, grids) - cache_path = make_cache_path(bindings) - cpp, so = write_bindings(bindings, cache_path) + src, name = make_bindings(src, outputs, grids) + cache_path = make_cache_path(src) + cpp, so = write_bindings(src, cache_path) build(cpp, cache_path) result = tf.load_op_library(so) - return result + return result.__dict__[name] def empty(shapes): return extra_ops.alloc_empty(tf.stack(shapes))