From ed0f7060052c03513cc2677b2e6c7cfb1fc0305d Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 5 Sep 2019 00:19:42 -0400 Subject: [PATCH] [python] fixed various issues in pytorch supoport --- CMakeLists.txt | 22 ++++---- python/examples/dot.py | 20 +++++-- python/setup.py | 22 +++++--- python/src/tensorflow.cc | 110 ++++++++++++++++++++++++++---------- python/triton/frameworks.py | 7 +++ python/triton/kernel.py | 49 ++++++++++------ python/triton/ops/dot.py | 26 +++++---- python/triton/utils.py | 18 ++++-- 8 files changed, 182 insertions(+), 92 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 201f14c5a..20add646f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,16 +34,18 @@ if(BUILD_PYTHON_MODULE) message(STATUS "Adding Python module") # PyBind11 wrapper source file file(GLOB_RECURSE PYTHON_SRC python/src/tensorflow.cc) - # update include directory - include_directories(python/src/ ${PYTHON_INCLUDE_DIRS} ${TF_INCLUDE_DIRS}) - # update link directories - link_directories(${TF_LIB_DIRS}) - # extra tensorflow ops (e.g., alloc_empty) - file(GLOB_RECURSE EXTRA_TF_OPS_SRC python/src/tensorflow/*.cc) - add_library(extra_tf_ops SHARED ${EXTRA_TF_OPS_SRC}) - target_link_libraries(extra_tf_ops triton ${TF_LIBS}) - target_compile_definitions(extra_tf_ops PRIVATE "-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI}") - + if(TF_LIBS) + # extra tensorflow ops (e.g., alloc_empty) + # update directories + link_directories(${TF_LIB_DIRS}) + include_directories(python/src/ ${PYTHON_INCLUDE_DIRS} ${TF_INCLUDE_DIRS}) + # get sources + file(GLOB_RECURSE EXTRA_TF_OPS_SRC python/src/tensorflow/*.cc) + add_library(extra_tf_ops SHARED ${EXTRA_TF_OPS_SRC}) + # create target + target_link_libraries(extra_tf_ops triton ${TF_LIBS}) + target_compile_definitions(extra_tf_ops PRIVATE "-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI}") + endif() endif() diff --git a/python/examples/dot.py b/python/examples/dot.py index 84ae9b6f3..ce8e45c34 100644 --- a/python/examples/dot.py +++ b/python/examples/dot.py @@ -1,14 +1,13 @@ import numpy as np -import tensorflow as tf import triton -def run_dot(): +def run_tf(): + import tensorflow as tf M, N, K = 128, 128, 128 a = tf.placeholder(tf.float32, shape=[M, K]) b = tf.placeholder(tf.float32, shape=[N, K]) - _dot = triton.ops.dot.apply - tr_c = _dot(a, b, transpose_a = False, transpose_b = True) - tr_d = _dot(tr_c, b, transpose_a = True, transpose_b = False) + tr_c = triton.ops.dot(a, b, transpose_a = False, transpose_b = True) + tr_d = triton.ops.dot(tr_c, b, transpose_a = True, transpose_b = False) tf_c = tf.matmul(a, b, transpose_a = False, transpose_b = True) tf_d = tf.matmul(tf_c, b, transpose_a = True, transpose_b = False) # Gradient @@ -28,4 +27,13 @@ def run_dot(): dif = np.abs(result[0][0] - result[1][0]) print("dif: %f" % np.max(dif)) -run_dot() \ No newline at end of file +def run_torch(): + import torch as th + M, N, K = 128, 128, 128 + a = th.randn(M, K).cuda() + b = th.randn(K, N).cuda() + th_c = th.matmul(a, b) + tr_c = triton.ops.dot(a, b) + print(c) + +run_torch() \ No newline at end of file diff --git a/python/setup.py b/python/setup.py index a70aa6c51..49317af9f 100644 --- a/python/setup.py +++ b/python/setup.py @@ -41,18 +41,22 @@ class CMakeBuild(build_ext): python_include_dirs = distutils.sysconfig.get_python_inc() python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR') # tensorflow directories - import tensorflow as tf - tf_abi = tf.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tf.__dict__ else 0 - tf_include_dirs = tf.sysconfig.get_include() - tf_libs = tf.sysconfig.get_link_flags()[1].replace('-l', '') cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir, '-DBUILD_TESTS=OFF', '-DBUILD_PYTHON_MODULE=ON', - '-DPYTHON_INCLUDE_DIRS=' + python_include_dirs, - '-DTF_INCLUDE_DIRS=' + tf_include_dirs, - '-DTF_LIB_DIRS=' + tf.sysconfig.get_lib(), - '-DTF_LIBS=' + tf_libs, - '-DTF_ABI=' + str(tf_abi)] + '-DPYTHON_INCLUDE_DIRS=' + python_include_dirs] + # tensorflow compatibility + try: + import tensorflow as tf + tf_abi = tf.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tf.__dict__ else 0 + tf_include_dirs = tf.sysconfig.get_include() + tf_libs = tf.sysconfig.get_link_flags()[1].replace('-l', '') + cmake_args += ['-DTF_INCLUDE_DIRS=' + tf_include_dirs, + '-DTF_LIB_DIRS=' + tf.sysconfig.get_lib(), + '-DTF_LIBS=' + tf_libs, + '-DTF_ABI=' + str(tf_abi)] + except ModuleNotFoundError: + pass cfg = 'Debug' if self.debug else 'Release' build_args = ['--config', cfg] diff --git a/python/src/tensorflow.cc b/python/src/tensorflow.cc index 95ac51620..2450f35ef 100644 --- a/python/src/tensorflow.cc +++ b/python/src/tensorflow.cc @@ -315,16 +315,8 @@ gen_tf_register_op(oss, cc_name, fn->args(), outputs); inline std::string to_torch_ty(ir::type *ty) { - if(ty->is_integer_ty(1)) - return "bool"; - if(ty->is_integer_ty(8)) - return "int8"; - if(ty->is_integer_ty(16)) - return "int16"; - if(ty->is_integer_ty(32)) - return "int32"; - if(ty->is_integer_ty(64)) - return "int64"; + if(ty->is_integer_ty()) + return "int64_t"; if(ty->is_half_ty()) return "float16"; if(ty->is_float_ty()) @@ -332,7 +324,29 @@ inline std::string to_torch_ty(ir::type *ty) { if(ty->is_double_ty()) return "float64"; if(ty->is_pointer_ty()) - return "Tensor"; + return "torch::Tensor"; + throw std::runtime_error("unknown type"); +} + +inline std::string to_c_ty(ir::type *ty) { + if(ty->is_integer_ty(1)) + return "bool"; + if(ty->is_integer_ty(8)) + return "int8_t"; + if(ty->is_integer_ty(16)) + return "int16_t"; + if(ty->is_integer_ty(32)) + return "int32_t"; + if(ty->is_integer_ty(64)) + return "int64_t"; + if(ty->is_half_ty()) + return "float16"; + if(ty->is_float_ty()) + return "float32"; + if(ty->is_double_ty()) + return "float64"; + if(ty->is_pointer_ty()) + return "drv::cu_buffer"; throw std::runtime_error("unknown type"); } @@ -352,15 +366,22 @@ void gen_torch_signature(std::ostringstream& oss, out_types.push_back((*it)->get_type()); } - oss << "std::tuple<"; - for(size_t i = 0; i < out_types.size(); i++){ - if(i > 0) - oss << ", "; - oss << to_torch_ty(out_types[i]); + std::string ret_ty; + if(out_types.empty()) + ret_ty = "void"; + else{ + ir::type* ty = out_types[0]; + ret_ty = to_torch_ty(ty); + if(out_types.size() > 1){ + for(size_t i = 1; i < out_types.size(); i++) + if(out_types[i] != ty) + throw std::runtime_error("outputs of different types not supported by pytorch"); + ret_ty = "std::vector<" + ret_ty + ">"; + } } - oss << "> "; - oss << name << "("; - oss << "int64 id" << std::endl; + + oss << ret_ty << " " << name << "("; + oss << "int64_t id, "; for(size_t i = 0; i < args.size(); i++) { ir::argument* arg = args[i]; if(i > 0) @@ -370,9 +391,16 @@ void gen_torch_signature(std::ostringstream& oss, oss << ")"; } -void gen_torch_init_driver(std::ostringstream &oss) { +void gen_torch_init_driver(std::ostringstream &oss, + const std::vector&args) { + ir::argument* tensor = nullptr; + for(ir::argument* arg: args) + if(arg->get_type()->is_pointer_ty()){ + tensor = arg; + break; + } oss << " // Wrap CUDA handles" << std::endl; - oss << " c10::DeviceIndex device = torcha.storage().device().index();" << std::endl; + oss << " c10::DeviceIndex device = " << tensor->get_name() << ".storage().device().index();" << std::endl; oss << " // Get stream" << std::endl; oss << " CUstream custream = (CUstream)at::cuda::getCurrentCUDAStream(device).stream();" << std::endl; oss << " triton::driver::cu_stream stream(custream, false);" << std::endl; @@ -383,10 +411,12 @@ void gen_torch_make_handles(std::ostream &os, const std::vector& args) { for(unsigned i = 0; i < args.size(); i++){ ir::argument *arg = args[i]; - if(!arg->get_type()->is_pointer_ty()) - continue; const std::string& name = arg->get_name(); - os << " drv::cu_buffer cu_" + name + "(ctx, " + name + ".storage().size(), (CUdeviceptr)" + name + ".storage.data(), false);\n "; + ir::type* ty = arg->get_type(); + if(!ty->is_pointer_ty()) + os << " " << to_c_ty(ty) << " arg_" << name << " = " << name << ";" << std::endl; + else + os << " drv::cu_buffer arg_" + name + "(ctx, " + name + ".storage().size(), (CUdeviceptr)" + name + ".storage().data(), false);" << std::endl; } } @@ -394,19 +424,28 @@ void gen_torch_make_launch_function(std::ostream &os, const std::vectorget_name(); + std::string name = "arg_" + arg->get_name(); if(arg->get_type()->is_pointer_ty()) - name = "&cu_" + name; + name = "&" + name; if(i > 0) os << ", "; os << name; } - os << "}, *id_grid_map.at(id), stream); \n"; + os << "}, *id_grid_map.at(id), &stream);\n"; } +void gen_torch_ret(std::ostream &os, const std::vector& outputs) { + os << " return {"; + for(size_t i = 0; i < outputs.size(); i++){ + if(i > 0) + os << ", "; + os << outputs[i]; + } + os << "};" << std::endl; +} std::tuple make_pytorch_src(const std::string& src, + std::string> make_torch_src(const std::string& src, const std::vector& outputs, const runtime::function::options_space_t& opt) { // triton-ir code-gen @@ -423,6 +462,10 @@ std::tuple> id_fn_map; gen_torch_signature(oss, fn, outputs, name); oss << " {" << std::endl; - gen_torch_init_driver(oss); + gen_torch_init_driver(oss, fn->args()); gen_torch_make_handles(oss, fn->args()); gen_torch_make_launch_function(oss, fn->args()); - oss << std::endl << "}"; + gen_torch_ret(oss, outputs); + oss << "}" << std::endl; + oss << std::endl; + oss << std::endl; oss << "static auto registry = torch::jit::RegisterOperators(\"triton::" << name << "\", &" << name << ");" << std::endl; + + return {oss.str(), name}; } @@ -453,7 +501,7 @@ PYBIND11_MODULE(libtriton, m) { m.def("make_tensorflow_src", &make_tensorflow_src, "Creates C++ source code for a custom Tensorflow op " "corresponding to the specified Triton kernel"); - m.def("make_pytorch_src", &make_pytorch_src, + m.def("make_torch_src", &make_torch_src, "Creates C++ source code for a custom PyTorch op "); // bindings for triton classes diff --git a/python/triton/frameworks.py b/python/triton/frameworks.py index 60c0728f1..4d10697ad 100644 --- a/python/triton/frameworks.py +++ b/python/triton/frameworks.py @@ -9,6 +9,13 @@ torch = None tensorflow = None tf_extra_ops = None +def to_str(framework): + if framework == tensorflow_id: + return 'tensorflow' + elif framework == torch_id: + return 'torch' + else: + assert False def _import_torch(): global torch diff --git a/python/triton/kernel.py b/python/triton/kernel.py index b3d2be50a..2a7f2c929 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -66,16 +66,19 @@ def _build(src, path, framework): include_dirs += [fw.tensorflow.sysconfig.get_include()] include_dirs += ['/usr/local/cuda/include/'] libraries += [fw.tensorflow.sysconfig.get_link_flags()[1].replace('-l', '')] - ABI = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0 - extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={ABI}'.format(ABI=ABI)] + abi = fw.tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in fw.tensorflow.__dict__ else 0 + extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)] elif framework == fw.torch_id: - prefix = os.path.dirname(torch.__file__) + prefix = os.path.dirname(fw.torch.__file__) library_dirs += [os.path.join(prefix, 'lib')] - include_dirs += [os.path.join(prefix, 'lib', 'include'), + include_dirs += ['/usr/local/cuda/include/', + os.path.join(prefix, 'lib', 'include'), os.path.join(prefix, 'lib', 'include', 'torch', 'csrc', 'api', 'include'), os.path.join(prefix, 'include'), os.path.join(prefix, 'include', 'torch', 'csrc', 'api', 'include')] libraries += ['torch'] + abi = fw.torch._C._GLIBCXX_USE_CXX11_ABI + extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={abi}'.format(abi=abi)] else: assert False # extra arguments @@ -84,7 +87,7 @@ def _build(src, path, framework): depends = [os.path.realpath(libtriton.__file__)] # create extension module ext = setuptools.Extension( - name = 'tensorflow', + name = fw.to_str(framework), language = 'c++', sources = [src], include_dirs = include_dirs, @@ -124,14 +127,14 @@ def _cvt_to_def_str(obj, framework): fw.tensorflow.float64: 'double'}[obj] # torch type elif framework == fw.torch_id: - if isinstance(obj, torch.dtype): - return {torch.int8: 'char', - torch.int16: 'short', - torch.int32: 'int', - torch.int64: 'long', - torch.float16: 'half', - torch.float32: 'float', - torch.float64: 'double'}[obj] + if isinstance(obj, fw.torch.dtype): + return {fw.torch.int8: 'char', + fw.torch.int16: 'short', + fw.torch.int32: 'int', + fw.torch.int64: 'long', + fw.torch.float16: 'half', + fw.torch.float32: 'float', + fw.torch.float64: 'double'}[obj] else: assert False # default @@ -146,8 +149,8 @@ def _make_framework_op(src, outputs, options, framework): if framework == fw.tensorflow_id: return fw.tensorflow.load_op_library(so).__dict__[name] elif framework == fw.torch_id: - torch.ops.load_library(so) - return torch.ops.triton.__dict__[name] + fw.torch.ops.load_library(so) + return getattr(fw.torch.ops.triton, name) else: assert False @@ -171,7 +174,12 @@ class kernel: self.fw_op = None self.src = src self.outputs = outputs - self.framework = fw._find_framework(framework) + self.framework = framework + + def _init_framework(self): + if self.framework is not None: + return + self.framework = fw._find_framework(self.framework) if self.framework == fw.tensorflow_id: fw._import_tensorflow() fw._import_tf_extra_ops() @@ -180,8 +188,8 @@ class kernel: else: assert False - def __call__(self, *args, **kwargs): + self._init_framework() # create a new framework op when defines are different key = '-'.join(['{key}-{val}'.format(key=key, val=val) for key, val in kwargs.items()]) if key not in self.fw_id.keys(): @@ -212,4 +220,9 @@ class kernel: # create operands op_args = [x.handle if isinstance(x, triton.utils.scalar) else x for x in args[:-1]] # call framework function - return self.fw_op(*op_args, id=op_id) \ No newline at end of file + if self.framework == fw.tensorflow_id: + return self.fw_op(*op_args, id=op_id) + elif self.framework == fw.torch_id: + return self.fw_op(op_id, *op_args) + else: + assert False \ No newline at end of file diff --git a/python/triton/ops/dot.py b/python/triton/ops/dot.py index f799be983..36bde11fe 100644 --- a/python/triton/ops/dot.py +++ b/python/triton/ops/dot.py @@ -1,6 +1,6 @@ import triton -class dot(triton.function): +class _dot(triton.function): src = """ void dot(TYPE * A, TYPE * B, TYPE * C, @@ -78,30 +78,32 @@ void dot(TYPE * A, TYPE * B, TYPE * C, 'BROADCAST_BK': 'newaxis, :' if transpose_b else ':, newaxis', 'BROADCAST_BN': ':, newaxis' if transpose_b else 'newaxis, :', 'SHAPE_B' : 'TN, TK' if transpose_b else 'TK, TN'} - return dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc, grid, + return _dot.kernel(a, b, c, M, N, Ka, lda, ldb, ldc, grid, AT = transpose_a, BT = transpose_b, TYPE = dtype, TM = [64, 128], TN = [64, 128], TK = [8], **macros) @staticmethod def forward(ctx, a, b, transpose_a = False, transpose_b = False): ctx.save_for_backward(a, b, transpose_a, transpose_b) - return dot._call(a, b, transpose_a, transpose_b) + return _dot._call(a, b, transpose_a, transpose_b) @staticmethod def backward(ctx, dy): a, b, t_a, t_b = ctx.saved_tensors if not t_a and not t_b: - da = dot._call(dy, b, False, True) - db = dot._call(a, dy, True, False) + da = _dot._call(dy, b, False, True) + db = _dot._call(a, dy, True, False) elif not t_a and t_b: - da = dot._call(dy, b, False, False) - db = dot._call(dy, a, True, False) + da = _dot._call(dy, b, False, False) + db = _dot._call(dy, a, True, False) elif t_a and not t_b: - da = dot._call(b, dy, False, True) - db = dot._call(a, dy, False, False) + da = _dot._call(b, dy, False, True) + db = _dot._call(a, dy, False, False) elif t_a and t_b: - da = dot._call(b, dy, True, True) - db = dot._call(dy, a, True, True) + da = _dot._call(b, dy, True, True) + db = _dot._call(dy, a, True, True) else: assert False - return [da, db, None, None, None, None, None, None, None] \ No newline at end of file + return [da, db, None, None, None, None, None, None, None] + +dot = _dot.apply \ No newline at end of file diff --git a/python/triton/utils.py b/python/triton/utils.py index 98380bf37..422f1117b 100644 --- a/python/triton/utils.py +++ b/python/triton/utils.py @@ -7,12 +7,13 @@ def cdiv(a, b): def empty(shapes, dtype, framework = None): framework = fw._find_framework(framework) if framework == fw.tensorflow_id: + fw._import_tensorflow() args = [x.handle if isinstance(x, scalar) else x for x in shapes] args = fw.tensorflow.stack(args) return fw.tf_extra_ops.alloc_empty(args, T = dtype) elif framework == fw.torch_id: - _import_torch() - return fw.torch.empty(*shapes) + fw._import_torch() + return fw.torch.empty(*shapes).cuda() class lazy_shape: @@ -22,15 +23,20 @@ class lazy_shape: def __getitem__(self, key): return scalar(self.shape[key]) -def shape(A) : - fw._import_tensorflow() - return lazy_shape(fw.tensorflow.shape(A)) +def shape(A, framework = None) : + framework = fw._find_framework(framework) + if framework == fw.tensorflow_id: + fw._import_tensorflow() + return lazy_shape(fw.tensorflow.shape(A)) + else: + return A.shape class scalar: - def __init__(self, x): + def __init__(self, x, framework = None): self.id = libtriton.make_scalar_id() + fw._import_tf_extra_ops() self.handle = fw.tf_extra_ops.register_scalar(x, id=self.id) self.assume_initialized = False