diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b252c520..7c7a1c0ab 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -48,9 +48,16 @@ endif() # Python module if(BUILD_PYTHON_MODULE) message(STATUS "Adding Python module") - file(GLOB_RECURSE PYTHON_SRC python/src/*.cpp) - include_directories(python/src/ ${PYTHON_INCLUDE_DIRS}) - set(PYTHON_LIBS ) + # PyBind11 wrapper source file + file(GLOB_RECURSE PYTHON_SRC python/src/tensorflow.cpp) + # update include directory + include_directories(python/src/ ${PYTHON_INCLUDE_DIRS} ${TF_INCLUDE_DIRS}) + # update link directories + link_directories(${TF_LIB_DIRS}) + # extra tensorflow ops (e.g., alloc_empty) + file(GLOB_RECURSE EXTRA_TF_OPS_SRC python/src/tensorflow/*.cpp) + add_library(extra_tf_ops SHARED ${EXTRA_TF_OPS_SRC}) + target_link_libraries(extra_tf_ops ${TF_LIBS}) endif() diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index f97cc2021..90287f719 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -164,16 +164,17 @@ perf_t do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int res.cublas = 0; // test -// stream->synchronize(); -// stream->read(dc, true, 0, hc); -// std::vector rc(hc.size()); -// cpu_ref(AT, BT, M, N, K, rc, ha, hb); -// for(size_t i = 0; i < M*N; i++) -// if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){ -// std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; -// exit(EXIT_FAILURE); -// } -// std::cout << "Pass!" << std::endl; + stream->synchronize(); + stream->read(dc, true, 0, hc); + std::vector rc(hc.size()); + cpu_ref(AT, BT, M, N, K, rc, ha, hb); + for(size_t i = 0; i < M*N; i++) + if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){ + std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; + exit(EXIT_FAILURE); + } + std::cout << hc[0] << " " << std::endl; + std::cout << "Pass!" << std::endl; // clean-up delete dc; diff --git a/include/triton/ir/enums.h b/include/triton/ir/enums.h new file mode 100644 index 000000000..600c83ade --- /dev/null +++ b/include/triton/ir/enums.h @@ -0,0 +1,84 @@ +#ifndef TRITON_IR_ENUMS_H +#define TRITON_IR_ENUMS_H + +namespace triton{ +namespace ir{ + + +enum binary_op_t { + Add, + FAdd, + Sub, + FSub, + Mul, + FMul, + UDiv, + SDiv, + FDiv, + URem, + SRem, + FRem, + Shl, + LShr, + AShr, + And, + Or, + Xor +}; + +enum cast_op_t { + Trunc, + ZExt, + SExt, + FPTrunc, + FPExt, + UIToFP, + SIToFP, + FPToUI, + FPToSI, + PtrToInt, + IntToPtr, + BitCast, + AddrSpaceCast +}; + +enum cmp_pred_t { + FIRST_FCMP_PREDICATE, + FCMP_FALSE, + FCMP_OEQ, + FCMP_OGT, + FCMP_OGE, + FCMP_OLT, + FCMP_OLE, + FCMP_ONE, + FCMP_ORD, + FCMP_UNO, + FCMP_UEQ, + FCMP_UGT, + FCMP_UGE, + FCMP_ULT, + FCMP_ULE, + FCMP_UNE, + FCMP_TRUE, + LAST_FCMP_PREDICATE, + FIRST_ICMP_PREDICATE, + ICMP_EQ, + ICMP_NE, + ICMP_UGT, + ICMP_UGE, + ICMP_ULT, + ICMP_ULE, + ICMP_SGT, + ICMP_SGE, + ICMP_SLT, + ICMP_SLE, + LAST_ICMP_PREDICATE +}; + + + + +} +} + +#endif diff --git a/python/examples/dot.py b/python/examples/dot.py index 52c7e0a2e..29d6f9470 100644 --- a/python/examples/dot.py +++ b/python/examples/dot.py @@ -4,6 +4,7 @@ import distutils import distutils.log import setuptools.command.build_ext import setuptools +import numpy as np import os src = """ @@ -45,23 +46,25 @@ void matmul(restrict read_only align(16) half *A, } """ +extra_ops = tf.load_op_library('/home/philippe/development/triton/python/build/lib.linux-x86_64-3.6/libextra_tf_ops.so') + + with open('test.cpp', 'w+') as test: src = libtriton.make_tensorflow_src(src, [2], '(M + #TM - 1)/#TM, (N + #TN - 1)/#TN, 1') test.writelines(src) triton_include_dirs = ['/home/philippe/development/triton/include'] tensorflow_include_dirs = [tf.sysconfig.get_include()] -llvm_include_dirs = ['/usr/include/llvm-8/', '/usr/include/llvm-c-8/'] cuda_include_dirs = ['/usr/local/cuda-10.1/targets/x86_64-linux/include/'] -triton_library_dirs = [os.path.realpath(libtriton.__file__)] +triton_library_dirs = [os.path.realpath(os.path.join(libtriton.__file__, os.path.pardir))] tensorflow_library_dirs = [tf.sysconfig.get_lib()] include_dirs = triton_include_dirs + tensorflow_include_dirs + cuda_include_dirs extra_compile_args = [] extra_link_args = [] -library_dirs = tensorflow_library_dirs -libraries = ['tensorflow_framework'] +library_dirs = triton_library_dirs + tensorflow_library_dirs +libraries = ['tensorflow_framework', 'triton'] ext = setuptools.Extension( name = 'test', @@ -92,4 +95,46 @@ args = dict( setuptools.setup(**args) library_dir = os.path.dirname(os.path.realpath(__file__)) module = tf.load_op_library(os.path.join(library_dir, 'build/lib.linux-x86_64-3.6/test.cpython-36m-x86_64-linux-gnu.so')) -print(module.matmul) \ No newline at end of file + +class dot: + + def __init__(self): + trans_a = True + trans_b = False + + def __call__(self, a, b): + shape_a = tf.shape(a) + shape_b = tf.shape(b) + M = shape_a[0] + K = shape_a[1] + N = shape_b[0] + lda = M + ldb = K + ldc = M + c = extra_ops.alloc_empty(tf.stack([M, N])) + return module.matmul(a, b, c, M, N, K, lda, ldb, ldc) + +dot_nt = dot() +def run_dot(): + M, N, K = 128, 128, 128 + a = tf.placeholder(tf.float16, shape=[M, K]) + b = tf.placeholder(tf.float16, shape=[N, K]) + # c = tf.matmul(a, b, transpose_a=True) + c = dot_nt(a, b) + # Reference + ha = np.random.rand(M, K).astype(np.float16) + hb = np.random.rand(N, K).astype(np.float16) + # Run + sess = tf.InteractiveSession() + sess.run(tf.global_variables_initializer()) + result = sess.run([c], feed_dict = {a: ha, + b: hb})[0] + # Test + hresult = np.dot(ha.T, hb).T + dif = np.abs(result - hresult) + np.savetxt('dif.dat', dif, '%2.4f') + print(hresult) + print(result) + print("dif: %f" % np.max(dif)) + +run_dot() \ No newline at end of file diff --git a/python/setup.py b/python/setup.py index 057362b0f..3d98218ac 100644 --- a/python/setup.py +++ b/python/setup.py @@ -35,12 +35,22 @@ class CMakeBuild(build_ext): def build_extension(self, ext): extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + # python directors python_include_dirs = distutils.sysconfig.get_python_inc() python_lib_dirs = distutils.sysconfig.get_config_var('LIBDIR') + # tensorflow directories + import tensorflow as tf + tf_include_dirs = tf.sysconfig.get_include() + tf_lib_dirs = tf.sysconfig.get_lib() + tf_libs = 'tensorflow_framework' + cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir, '-DBUILD_EXAMPLES=OFF', '-DBUILD_PYTHON_MODULE=ON', - '-DPYTHON_INCLUDE_DIRS=' + python_include_dirs] + '-DPYTHON_INCLUDE_DIRS=' + python_include_dirs, + '-DTF_INCLUDE_DIRS=' + tf_include_dirs, + '-DTF_LIB_DIRS=' + tf_lib_dirs, + '-DTF_LIBS=' + tf_libs] cfg = 'Debug' if self.debug else 'Release' build_args = ['--config', cfg] diff --git a/python/src/tensorflow.cpp b/python/src/tensorflow.cpp index 12e64fa4f..c1c224916 100644 --- a/python/src/tensorflow.cpp +++ b/python/src/tensorflow.cpp @@ -161,7 +161,7 @@ result += R"( // extract outputs)"; for(unsigned i = 0; i < n_outputs; i++) result += R"( - context->set_output()" + str_i[i] + ", " + arg_names[outputs[i]] + ");"; + context->set_output()" + str_i[i] + ", " + arg_names[outputs[i]] + ");"; result += R"( @@ -201,15 +201,26 @@ private: rt::function fn_; }; -REGISTER_KERNEL_BUILDER(Name(")" + name + "\").Device(DEVICE_GPU), " + classname + R"(); +REGISTER_KERNEL_BUILDER(Name(")" + name + "\").Device(DEVICE_GPU)"; +for(size_t i = 0; i < tf_scalar_tys.size(); i++){ + std::string arg_name = arg_names[i]; + std::transform(arg_name.begin(), arg_name.end(), arg_name.begin(), [](char c) { return std::tolower(c);}); + if(!fn_ty->get_param_ty(i)->is_pointer_ty()) + result += ".HostMemory(\"" + arg_name + "\")"; +} +result += ", " + classname + R"(); + REGISTER_OP(")" + name + "\")\n"; for(size_t i = 0; i < tf_scalar_tys.size(); i++){ bool is_output = std::find(outputs.begin(), outputs.end(), i) != outputs.end(); - std::string mode = is_output ? "Output" : "Input" ; + std::string mode = is_output ? "Input" : "Input" ; std::string arg_name = arg_names[i]; std::transform(arg_name.begin(), arg_name.end(), arg_name.begin(), [](char c) { return std::tolower(c);}); - result += " ." + mode + "(\"" + arg_name + ": " + tf_scalar_tys[i] + "\")\n"; + result += " .Input(\"" + arg_name + ": " + tf_scalar_tys[i] + "\")\n"; +} +for(size_t i = 0; i < outputs.size(); i++){ + result += " .Output(\"out: " + tf_scalar_tys[outputs[i]] + "\")\n"; } result += ";\n"; diff --git a/python/src/tensorflow/alloc_empty.cpp b/python/src/tensorflow/alloc_empty.cpp new file mode 100644 index 000000000..e60e8436c --- /dev/null +++ b/python/src/tensorflow/alloc_empty.cpp @@ -0,0 +1,30 @@ +#include "tensorflow/core/framework/op_kernel.h" + +using namespace tensorflow; + +class AllocEmptyOp : public OpKernel { + public: + explicit AllocEmptyOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + // fetch input + const Tensor& x = context->input(0); + const int32* x_data = (const int32*)x.tensor_data().data(); + // allocate output + Tensor* y = NULL; + int32 x_rank = x.dims(); + OP_REQUIRES(context, x_rank == 1, errors::InvalidArgument("Input tensor must be 1D")); + int32 y_rank = x.dim_size(0); + TensorShape y_shapes; + for(size_t i = 0; i < y_rank; i++) + y_shapes.AddDim(x_data[i]); + OP_REQUIRES_OK(context, context->allocate_output(0, y_shapes, &y)); + } +}; + + +REGISTER_KERNEL_BUILDER(Name("AllocEmpty").HostMemory("x").Device(DEVICE_CPU).Device(DEVICE_GPU), AllocEmptyOp); +REGISTER_OP("AllocEmpty") + .Input("x: int32") + .Output("y: float16") +; diff --git a/python/triton/tools/build.py b/python/triton/tools/build.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/triton/tools/checksum.py b/python/triton/tools/checksum.py new file mode 100644 index 000000000..e69de29bb