diff --git a/CMakeLists.txt b/CMakeLists.txt index 9e05aca5d..d857a96ea 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,6 +42,8 @@ if(BUILD_PYTHON_MODULE) file(GLOB_RECURSE EXTRA_TF_OPS_SRC python/src/tensorflow/*.cc) add_library(extra_tf_ops SHARED ${EXTRA_TF_OPS_SRC}) target_link_libraries(extra_tf_ops triton ${TF_LIBS}) + target_compile_definitions(extra_tf_ops PRIVATE "-D_GLIBCXX_USE_CXX11_ABI=${TF_ABI}") + endif() diff --git a/lib/driver/module.cc b/lib/driver/module.cc index 34462e8ab..497fc332c 100755 --- a/lib/driver/module.cc +++ b/lib/driver/module.cc @@ -250,10 +250,10 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo try{ dispatch::cuModuleLoadDataEx(&*cu_, source_.data(), 2, opt, optval); }catch(exception::cuda::base const &){ -//#ifdef TRITON_LOG_PTX_ERROR +#ifdef TRITON_LOG_PTX_ERROR std::cerr << "Compilation Failed! Log: " << std::endl; std::cerr << errbuf << std::endl; -//#endif +#endif throw; } } diff --git a/python/examples/blocksparse.py b/python/examples/blocksparse.py new file mode 100644 index 000000000..27b7d1e9b --- /dev/null +++ b/python/examples/blocksparse.py @@ -0,0 +1,158 @@ +import tensorflow as tf +import triton +import numpy as np + +src = ''' + #if AT == 1 + #define USE_A ^a + #define STRIDE_AK lda + #define STRIDE_AM 1 + #define BROADCAST_AK :, newaxis + #define BROADCAST_AM newaxis, : + #define SHAPE_A TK, TM + #else + #define USE_A a + #define STRIDE_AK 1 + #define STRIDE_AM lda + #define BROADCAST_AK newaxis, : + #define BROADCAST_AM :, newaxis + #define SHAPE_A TM, TK + #endif + + #if BT == 1 + #define USE_B ^b + #define STRIDE_BK 1 + #define STRIDE_BM ldb + #define BROADCAST_BN newaxis, : + #define BROADCAST_BK :, newaxis + #define SHAPE_B TN, TK + #else + #define USE_B b + #define STRIDE_BK ldb + #define STRIDE_BM 1 + #define BROADCAST_BN :, newaxis + #define BROADCAST_BK newaxis, : + #define SHAPE_B TK, TN + #endif + + void dot (TYPE* A __readonly __noalias __align(16), + TYPE* B __readonly __noalias __align(16), + TYPE* C __writeonly __noalias __align(16), + int lda, int ldb, int ldc, + int N, int* lut, + int* locks, int nlocks) { + int ridx = get_program_id(0); + float c[TM, TN] = 0; + int rka[TK] = 0 ... TK; + int rkb[TK] = 0 ... TK; + // load LUT header + int *header = lut + get_program_id(1) * 4; + int offset = *(header + 0); + int K = *(header + 1); + int column = *(header + 2); + int lockid = *(header + 3); + int *plut = lut + offset * 2; + int offx = ridx; + int offy = 0; + // compute x, y offsets + int rxa[TM] = offx * TM + (0 ... TM); + int ryb[TN] = offy * TN + (0 ... TN); + // bounds checking + bool checka[SHAPE_A] = (rxa < N)[:, newaxis]; + bool checkb[SHAPE_B] = 1; + // base offset + int offa[SHAPE_A] = rxa[BROADCAST_AM] * STRIDE_AM + rka[BROADCAST_AK] * STRIDE_AK; + int offb[SHAPE_B] = ryb[BROADCAST_BN] * STRIDE_BN + rkb[BROADCAST_BK] * STRIDE_BK; + for(int k = K; k > 0; k -= 1) { + // fetch block indices + int ak = *(plut + 0); + int bk = *(plut + 1); + lut += 2; + // compute pointers to blocks + TYPE* pa[SHAPE_A] = A + offa + ak * TK * lda; + TYPE* pb[SHAPE_B] = B + offb + bk * TK * TN; + // load blocks + TYPE a[SHAPE_A] = checka ? *pa : 0; + TYPE b[SHAPE_B] = *pb; + // multiply blocks + c += USE_A @ USE_B; + } + int rxc[TM] = ridx * TM + (0 ... TM); + int ryc[TN] = column * TN + (0 ... TN); + TYPE* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc; + bool checkc[TM, TN] = (rxc < N)[:, newaxis]; + if(lockid == 0) { + *?(checkc) pc = c; + } + else { + int *plock = locks + ridx*nlocks + lockid - 1; + int *pcount = plock + get_num_program(0)*nlocks; + while(__atomic_cas(plock, 0, 1)); + int count = *pcount; + if(count == 0) + *?(checkc) pc = c; + else + *?(checkc) pc = c + *pc; + __atomic_exch(pcount, 1); + __atomic_exch(plock, 0); + } + } +''' + + +# std::string dot::triton_c_src_dw() const { +# bool AT = (op_ == WGRAD); +# bool BT = (op_ == FPROP); +# std::string usea = AT ? "trans(a)" : "a"; +# std::string useb = BT ? "trans(b)" : "b"; +# std::string sizea = AT ? "TK, TM" : "TM, TK"; +# std::string sizeb = BT ? "TN, TK" : "TK, TN"; +# std::string bca0 = AT ? "newaxis, :" : ":, newaxis"; +# std::string bca1 = AT ? ":, newaxis" : "newaxis, :"; +# std::string bcb0 = BT ? ":, newaxis" : "newaxis, :"; +# std::string bcb1 = BT ? "newaxis, :" : ":, newaxis"; +# std::string lda0 = AT ? "*lda" : ""; +# std::string lda1 = AT ? "" : "*lda"; +# std::string ldb0 = BT ? "" : "*ldb"; +# std::string ldb1 = BT ? "*ldb" : "" ; +# std::string result = +# R"( +# const tunable int TM = {)" + std::to_string(BS_) + R"(}; +# const tunable int TN = {)" + std::to_string(BS_) + R"(}; +# const tunable int TK = {32}; +# void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A, +# restrict read_only align(16) )" + ab_ty_ + R"( *B, +# )" + c_ty_ + R"(* C, +# int lda, int ldb, int ldc, +# int N, int* lut, +# int* locks, int nlocks) { +# int ridx = get_range_id(0); +# float acc[TM, TN] = 0; +# int rka[TK] = 0 ... TK; +# int rkb[TK] = 0 ... TK; +# int *header = lut + ridx * 2; +# int offx = *(header + 0); +# int offy = *(header + 1); +# int rxa[TM] = offx*TM + (0 ... TM); +# int ryb[TN] = offy*TN + (0 ... TN); +# bool checka[TK, TM] = (rka < N)[:, newaxis]; +# bool checkb[TK, TN] = (rkb < N)[:, newaxis]; +# int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(; +# int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(; +# )" + ab_ty_ + " * pa[" + sizea + R"(] = A + offa; +# )" + ab_ty_ + " * pb[" + sizeb + R"(] = B + offb; +# )" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0; +# )" + ab_ty_ + " b[" + sizeb + R"(] = checkb ? *pb : 0; +# for(int k = N; k > 0; k = k - TK) { +# acc = dot()" + usea + ", " + useb + R"(, acc); +# pa = pa + TK)" + lda1 + R"(; +# pb = pb + TK)" + ldb1 + R"(; +# a = checka ? *pa : 0; +# b = checkb ? *pb : 0; +# } +# int rxc[TM] = (0 ... TM); +# int ryc[TN] = (0 ... TN); +# )" + c_ty_ + R"( c[TM, TN] = acc; +# )" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis]*TM + ryc[newaxis, :] + ridx*TM*TN; +# *pc = c; +# })"; \ No newline at end of file diff --git a/python/examples/dot.py b/python/examples/dot.py index ffb93fd33..4ea6fcf04 100644 --- a/python/examples/dot.py +++ b/python/examples/dot.py @@ -3,15 +3,16 @@ import triton import numpy as np src = """ +// Templates for accessing A #if AT == 1 -#define USEA ^a +#define USE_A ^a #define STRIDE_AK lda #define STRIDE_AM 1 #define BROADCAST_AK :, newaxis #define BROADCAST_AM newaxis, : #define SHAPE_A TK, TM #else -#define USEA a +#define USE_A a #define STRIDE_AK 1 #define STRIDE_AM lda #define BROADCAST_AK newaxis, : @@ -19,15 +20,16 @@ src = """ #define SHAPE_A TM, TK #endif +// Templates for accessing B #if BT == 1 -#define USEB ^b +#define USE_B ^b #define STRIDE_BK 1 #define STRIDE_BN ldb #define BROADCAST_BK newaxis, : #define BROADCAST_BN :, newaxis #define SHAPE_B TN, TK #else -#define USEB b +#define USE_B b #define STRIDE_BK ldb #define STRIDE_BN 1 #define BROADCAST_BK :, newaxis @@ -56,7 +58,7 @@ void dot(TYPE * A, TYPE * B, TYPE * C, TYPE b[SHAPE_B] = *pb; // reduction loop for(int k = K; k > 0; k-= TK){ - c += USEA @ USEB; + c += USE_A @ USE_B; pa = pa + TK * STRIDE_AK; pb = pb + TK * STRIDE_BK; a = *pa; @@ -71,57 +73,54 @@ void dot(TYPE * A, TYPE * B, TYPE * C, } """ -def cdiv(a, b): - return -(-a // b) - class dot_op: - def __init__(self, trans_a = False, trans_b = False): + def __init__(self, transpose_a = False, transpose_b = False): self.dot = triton.op(src, ['C']) - self.trans_a = trans_a - self.trans_b = trans_b + self.transpose_a = transpose_a + self.transpose_b = transpose_b def __call__(self, a, b): + # extract shapes shape_a = triton.shape(a) shape_b = triton.shape(b) - M = shape_a[0] - Ka = shape_a[1] - Kb = shape_b[0] - N = shape_b[1] + M, Ka = shape_a[0], shape_a[1] + Kb, N = shape_b[0], shape_b[1] # transpose shapes - if self.trans_a: + if self.transpose_a: M, Ka = Ka, M - if self.trans_b: + if self.transpose_b: Kb, N = N, Kb - K = Ka # contiguous dimensions - lda = Ka - ldb = N + lda = M if self.transpose_a else Ka + ldb = Kb if self.transpose_b else N ldc = N + # allocate output c = triton.empty([M, N]) - return self.dot(a, b, c, M, N, K, lda, ldb, ldc, - lambda opt: [cdiv(M, opt.d('TM')), cdiv(N, opt.d('TN'))], - AT = self.trans_a, BT = self.trans_b, TYPE = tf.float16, - TM = [128], TN = [ 128], TK = [32]) + # compute + return self.dot(a, b, c, M, N, Ka, lda, ldb, ldc, + lambda opt: [triton.cdiv(M, opt.d('TM')), triton.cdiv(N, opt.d('TN'))], + AT = self.transpose_a, BT = self.transpose_b, TYPE = tf.float16, + TM = [128], TN = [128], TK = [32]) -def dot(a, b, trans_a = False, trans_b = False): - if (trans_a, trans_b) not in dot.ops: - dot.ops[trans_a, trans_b] = dot_op(trans_a, trans_b) - return dot.ops[trans_a, trans_b](a, b) +def dot(a, b, transpose_a = False, transpose_b = False): + if (transpose_a, transpose_b) not in dot.ops: + dot.ops[transpose_a, transpose_b] = dot_op(transpose_a, transpose_b) + return dot.ops[transpose_a, transpose_b](a, b) dot.ops = dict() -# @triton.register_gradient(dot_op) -# def _dot_grad(op, dy): -# a = op.inputs[0] -# b = op.inputs[1] -# return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None] +@tf.RegisterGradient("Dot") +def _dot_grad(op, dy): + a = op.inputs[0] + b = op.inputs[1] + return [dot_tn(dy, b), dot_nt(a, dy), None, None, None, None, None, None, None] def run_dot(): M, N, K = 128, 128, 128 a = tf.placeholder(tf.float16, shape=[M, K]) b = tf.placeholder(tf.float16, shape=[N, K]) - c = dot(a, b, trans_a = False, trans_b = True) + c = dot(a, b, transpose_a = False, transpose_b = False) # Reference ha = np.random.rand(M, K).astype(np.float16) hb = np.random.rand(K, N).astype(np.float16) @@ -131,7 +130,8 @@ def run_dot(): result = sess.run([c], feed_dict = {a: ha, b: hb})[0] # Test - hresult = np.dot(ha, hb.T) + print(result) + hresult = np.dot(ha, hb) dif = np.abs(result - hresult) np.savetxt('dif.dat', dif, '%2.4f') print("dif: %f" % np.max(dif)) diff --git a/python/setup.py b/python/setup.py index b9285f84f..1cfe0a881 100644 --- a/python/setup.py +++ b/python/setup.py @@ -44,6 +44,7 @@ class CMakeBuild(build_ext): import tensorflow as tf tf_include_dirs = tf.sysconfig.get_include() tf_lib_dirs = tf.sysconfig.get_lib() + tf_abi = tf.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tf.__dict__ else 0 tf_libs = 'tensorflow_framework' cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir, @@ -52,7 +53,8 @@ class CMakeBuild(build_ext): '-DPYTHON_INCLUDE_DIRS=' + python_include_dirs, '-DTF_INCLUDE_DIRS=' + tf_include_dirs, '-DTF_LIB_DIRS=' + tf_lib_dirs, - '-DTF_LIBS=' + tf_libs] + '-DTF_LIBS=' + tf_libs, + '-DTF_ABI=' + str(tf_abi)] cfg = 'Debug' if self.debug else 'Release' build_args = ['--config', cfg] diff --git a/python/src/tensorflow.cc b/python/src/tensorflow.cc index 1932402e0..b01d5231c 100644 --- a/python/src/tensorflow.cc +++ b/python/src/tensorflow.cc @@ -4,7 +4,7 @@ #include #include #include -#include "triton/codegen/selection/selection.h" +#include "triton/codegen/selection.h" #include "triton/runtime/function.h" #include "triton/lang/code_gen.h" #include "triton/lang/parser.h" diff --git a/python/triton/ops.py b/python/triton/ops.py index f0b1ed86b..b4c4a7a54 100644 --- a/python/triton/ops.py +++ b/python/triton/ops.py @@ -102,12 +102,15 @@ def _build(src, path, framework): # libraries libraries = ['triton'] # add framework + extra_compile_args = [] if framework == tensorflow_id: _import_tensorflow() library_dirs += [tensorflow.sysconfig.get_lib()] include_dirs += [tensorflow.sysconfig.get_include()] include_dirs += ['/usr/local/cuda/include/'] libraries += ['tensorflow_framework'] + ABI = tensorflow.__cxx11_abi_flag__ if "__cxx11_abi_flag__" in tensorflow.__dict__ else 0 + extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI={ABI}'.format(ABI=ABI)] elif framework == torch_id: _import_torch() prefix = os.path.dirname(torch.__file__) @@ -120,7 +123,6 @@ def _build(src, path, framework): else: assert False # extra arguments - extra_compile_args = [] extra_link_args = [] # dependences depends = [os.path.realpath(libtriton.__file__)] @@ -254,14 +256,14 @@ class op: return op(*op_args, id=op_id) -# class register_gradient: +class register_gradient: -# def __init__(self, op): -# self.op = op + def __init__(self, op): + self.op = op -# def __call__(self, f): -# name = 'Dot' -# ops.RegisterGradient(name)(f) + def __call__(self, f): + name = 'Dot' + ops.RegisterGradient(name)(f) def empty(shapes, framework = None): @@ -276,6 +278,9 @@ def empty(shapes, framework = None): _import_torch() return torch.empty(*shapes) +def cdiv(a, b): + return -(-a // b) + class scalar: def __init__(self, x): diff --git a/tests/bench/copy1d.cc b/tests/bench/copy1d.cc index 2e2fe20d2..51afbacd6 100644 --- a/tests/bench/copy1d.cc +++ b/tests/bench/copy1d.cc @@ -22,8 +22,8 @@ std::vector do_bench(drv::stream* stream, int32_t N){ // create options rt::function::options_space_t opt; opt.defines.push_back({"TYPE", {ty}}); - opt.defines.push_back({"TN", {"512"}}); - opt.num_warps = {4}; + opt.defines.push_back({"TN", {"128"}}); + opt.num_warps = {1, 2, 4, 8}; // create function rt::function function(src::copy1d, opt); // benchmark available libraries @@ -42,7 +42,7 @@ int main() { triton::driver::stream* stream = triton::driver::stream::create(context); // shapes to benchmark typedef std::tuple config_t; - std::vector configs = { 1024*1024*16 }; + std::vector configs = { 1024*1024*32 }; int N; for(const auto& c: configs){ std::tie(N) = c; diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index 3fecb8e58..fc2243bfc 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -29,6 +29,7 @@ inline rt::function::grid_fn_ty grid2d(size_t M, size_t N) { std::vector do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ typedef float NumericT; std::string ty = "float"; + cublasDataType_t cuty = CUDA_R_32F; size_t dt_nbytes = sizeof(NumericT); drv::context* context = stream->context(); // leading dimensions @@ -44,10 +45,10 @@ std::vector do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i opt.defines.push_back({"TYPE", {ty}}); opt.defines.push_back({"AT", {AT?"1":"0"}}); opt.defines.push_back({"BT", {BT?"1":"0"}}); - opt.defines.push_back({"TM", {"128"}}); - opt.defines.push_back({"TN", {"64"}}); + opt.defines.push_back({"TM", {"64", "128"}}); + opt.defines.push_back({"TN", {"64", "128"}}); opt.defines.push_back({"TK", {"8"}}); - opt.num_warps = {4}; + opt.num_warps = {2, 4, 8}; // create function rt::function function(src::dot, opt); // benchmark available libraries @@ -57,10 +58,11 @@ std::vector do_bench(drv::stream* stream, bool AT, bool BT, int32_t M, i if(cublas::cublasinit()){ NumericT alpha(static_cast(1)); NumericT beta(static_cast(0)); - cublasGemmAlgo_t fastest = CUBLAS_GEMM_ALGO5; -// cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); - double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, - &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, nullptr, fastest); }, stream); + cublasGemmAlgo_t fastest; + cublasGemm(cuty, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); + double cublas_ms = triton::tools::bench([&]() { cublasGemm(cuty, stream, AT, BT, M, N, K, + &alpha, &*da, lda, &*db, ldb, &beta, &*dc, + ldc, nullptr, fastest); }, stream); result.push_back(tflops(cublas_ms)); } // triton