From 3413aad582414cc843b81bd974893ee652e3f214 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Thu, 25 Apr 2019 16:17:36 -0400 Subject: [PATCH] [general] major overhaul of triton-c/triton-ir/triton-jit: - Added alloc const - Added atomics - Pruning tuning space - Added example for dot/conv/shift - Bugfixes --- examples/CMakeLists.txt | 7 +- examples/cpp/CMakeLists.txt | 6 + examples/{matrix.cpp => cpp/blocksparse.cpp} | 87 ++---- examples/cpp/common.hpp | 286 ++++++++++++++++++ examples/cpp/conv.cpp | 236 +++++++++++++++ examples/cpp/dot.cpp | 162 ++++++++++ examples/cpp/shift.cpp | 212 +++++++++++++ examples/cpp/shift.ptx | 93 ++++++ include/triton/ast/ast.h | 68 ++++- include/triton/ast/parser.y | 9 +- include/triton/ast/scanner.l | 8 +- include/triton/codegen/layout.h | 45 --- include/triton/codegen/optimize_cse.h | 27 ++ include/triton/codegen/optimize_dot.h | 31 ++ include/triton/codegen/optimize_trans.h | 33 ++ include/triton/codegen/selection.h | 13 +- include/triton/codegen/shared_copy.h | 41 --- .../{allocation.h => shmem_allocation.h} | 12 +- .../codegen/{barriers.h => shmem_barriers.h} | 12 +- .../codegen/{buffer_info.h => shmem_info.h} | 5 +- .../codegen/{liveness.h => shmem_liveness.h} | 8 +- include/triton/codegen/target.h | 4 + include/triton/driver/module.h | 2 +- include/triton/ir/builder.h | 7 +- include/triton/ir/constant.h | 1 + include/triton/ir/instructions.h | 82 ++++- include/triton/ir/module.h | 1 + include/triton/jit.h | 61 ++-- lib/ast/lowering.cpp | 148 ++++++--- lib/codegen/buffer_info.cpp | 90 ------ lib/codegen/layout.cpp | 56 ---- lib/codegen/loop_info.cpp | 0 lib/codegen/optimize_cse.cpp | 14 + lib/codegen/optimize_dot.cpp | 50 +++ lib/codegen/optimize_trans.cpp | 71 +++++ lib/codegen/selection.cpp | 154 +++++++--- lib/codegen/shared_copy.cpp | 40 --- .../{allocation.cpp => shmem_allocation.cpp} | 13 +- .../{barriers.cpp => shmem_barriers.cpp} | 40 +-- lib/codegen/shmem_info.cpp | 135 +++++++++ .../{liveness.cpp => shmem_liveness.cpp} | 18 +- lib/codegen/target.cpp | 28 +- lib/codegen/tune.cpp | 35 ++- lib/driver/buffer.cpp | 3 - lib/driver/module.cpp | 11 +- lib/ir/builder.cpp | 24 +- lib/ir/instructions.cpp | 89 +++++- lib/ir/module.cpp | 3 + lib/ir/type.cpp | 2 +- lib/jit.cpp | 38 ++- 50 files changed, 2051 insertions(+), 570 deletions(-) create mode 100644 examples/cpp/CMakeLists.txt rename examples/{matrix.cpp => cpp/blocksparse.cpp} (69%) create mode 100644 examples/cpp/common.hpp create mode 100644 examples/cpp/conv.cpp create mode 100644 examples/cpp/dot.cpp create mode 100644 examples/cpp/shift.cpp create mode 100644 examples/cpp/shift.ptx delete mode 100644 include/triton/codegen/layout.h create mode 100644 include/triton/codegen/optimize_cse.h create mode 100644 include/triton/codegen/optimize_dot.h create mode 100644 include/triton/codegen/optimize_trans.h delete mode 100644 include/triton/codegen/shared_copy.h rename include/triton/codegen/{allocation.h => shmem_allocation.h} (79%) rename include/triton/codegen/{barriers.h => shmem_barriers.h} (82%) rename include/triton/codegen/{buffer_info.h => shmem_info.h} (84%) rename include/triton/codegen/{liveness.h => shmem_liveness.h} (90%) delete mode 100644 lib/codegen/buffer_info.cpp delete mode 100644 lib/codegen/layout.cpp delete mode 100644 lib/codegen/loop_info.cpp create mode 100644 lib/codegen/optimize_cse.cpp create mode 100644 lib/codegen/optimize_dot.cpp create mode 100644 lib/codegen/optimize_trans.cpp delete mode 100644 lib/codegen/shared_copy.cpp rename lib/codegen/{allocation.cpp => shmem_allocation.cpp} (91%) rename lib/codegen/{barriers.cpp => shmem_barriers.cpp} (75%) create mode 100644 lib/codegen/shmem_info.cpp rename lib/codegen/{liveness.cpp => shmem_liveness.cpp} (67%) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index e577a1d81..2322a85f7 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,6 +1 @@ -foreach(PROG matrix) - add_executable(${PROG} ${PROG}.cpp) - set_target_properties(${PROG} PROPERTIES OUTPUT_NAME ${PROG}) - include_directories(/usr/local/cuda/include/) - target_link_libraries(${PROG} triton) -endforeach(PROG) +add_subdirectory(cpp) diff --git a/examples/cpp/CMakeLists.txt b/examples/cpp/CMakeLists.txt new file mode 100644 index 000000000..db1e5421f --- /dev/null +++ b/examples/cpp/CMakeLists.txt @@ -0,0 +1,6 @@ +foreach(PROG dot conv shift) + add_executable(${PROG} ${PROG}.cpp) + set_target_properties(${PROG} PROPERTIES OUTPUT_NAME ${PROG}) + include_directories(/usr/local/cuda/include/) + target_link_libraries(${PROG} triton) +endforeach(PROG) diff --git a/examples/matrix.cpp b/examples/cpp/blocksparse.cpp similarity index 69% rename from examples/matrix.cpp rename to examples/cpp/blocksparse.cpp index e630e5164..5a816aff1 100644 --- a/examples/matrix.cpp +++ b/examples/cpp/blocksparse.cpp @@ -1,17 +1,18 @@ #include #include +#include "common.hpp" #include "triton/jit.h" #include "triton/driver/backend.h" #include "triton/driver/stream.h" const char* src = R"( -const tunable int32 TM = {16, 32, 64}; -const tunable int32 TN = {16, 32, 64}; +const tunable int32 TM = {16, 32, 64, 128}; +const tunable int32 TN = {8}; const tunable int32 TK = {8}; -void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c, - int32 M, int32 N, int32 K, int32 bound){ +void blocksparse(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c, + int32 M, int32 N, int32 K, int32 bound){ int32 rxa[TM] = get_global_range[TM](0); int32 ryb[TN] = get_global_range[TN](1); int32 rka[TK] = 0 ... TK; @@ -22,9 +23,9 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c, fp32 a[TM, TK] = *pa; fp32 b[TN, TK] = *pb; for(int32 k = K; k > 0;){ - C = dot(a, b, C); + C = dot(a, trans(b), C); pa = pa + TK*M; - pb = pb + TK*K; + pb = pb + TK*N; k = k - TK; int1 checka[TM, TK] = k > bound; int1 checkb[TN, TK] = k > bound; @@ -51,71 +52,24 @@ void matmul(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c, } )"; - -template -void simple_gemm(std::vector &c, const std::vector &a, const std::vector &b, size_t M, size_t N, size_t K){ - for(size_t m = 0; m < M; m++) - for(size_t n = 0; n < N; n++){ - T acc = 0; - for(size_t k = 0; k < K; k++) - acc += a[m + k*M] * b[n + k*N]; - c[m + n*M] = acc; +std::vector make_deltas(std::vector mask, int K, int N){ + std::vector>> pairs(N); + unsigned int current = 0; + for(int k = 0; k < K; k++) + for(int n = 0; n < N; n++){ + if(mask[k + n*K]) + pairs[n].push_back({current, k}); } } -class timer{ - typedef std::chrono::high_resolution_clock high_resolution_clock; - typedef std::chrono::nanoseconds nanoseconds; - -public: - explicit timer(bool run = false) - { if (run) start(); } - - void start() - { _start = high_resolution_clock::now(); } - - nanoseconds get() const - { return std::chrono::duration_cast(high_resolution_clock::now() - _start); } - -private: - high_resolution_clock::time_point _start; -}; - -template -T min(std::vector x) -{ return *std::min_element(x.begin(), x.end()); } - - -template -double bench(OP const & op, SYNC const & sync, triton::driver::device const & device) -{ - timer tmr; - std::vector times; - double total_time = 0; - op(); - sync(); - while(total_time*1e-9 < 1e-3){ - float norm = 1; - // normalize clock if possible to get roughly constant result - if(auto cu_device = dynamic_cast(&device)) - norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock(); - tmr.start(); - op(); - sync(); - times.push_back(norm*tmr.get().count()); - total_time+=times.back(); - } - return min(times); -} - - int main() { // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); triton::jit jit(context); + // matrix multiplication parameters - int32_t M = 512, N = 512, K = 512; + int32_t M = 512, N = 32, K = 2048; std::vector hc(M*N); std::vector rc(M*N); std::vector ha(M*K); @@ -183,14 +137,13 @@ int main() { 8, 8, 4 }; - - jit.autotune(src, benchmark); - jit.add_module(src, params); + jit.autotune("matmul",src, benchmark); + jit.add_module("matmul", src, params); triton::driver::kernel* kernel = jit.get_function("matmul"); triton::jit::launch_information info = jit.get_launch_info("matmul"); - std::cout << benchmark(kernel, info) << std::endl; + std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl; stream->read(dc, true, 0, hc); - simple_gemm(rc, ha, hb, M, N, K); + simple_gemm(rc, ha, hb, M, N, K); for(size_t i = 0; i < M*N; i++) if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; diff --git a/examples/cpp/common.hpp b/examples/cpp/common.hpp new file mode 100644 index 000000000..8a16b9457 --- /dev/null +++ b/examples/cpp/common.hpp @@ -0,0 +1,286 @@ +#include +#include +#include "triton/driver/device.h" +#include + +template +void simple_gemm(std::vector &c, const std::vector &a, const std::vector &b, size_t M, size_t N, size_t K){ + for(size_t m = 0; m < M; m++) + for(size_t n = 0; n < N; n++){ + T acc = 0; + for(size_t k = 0; k < K; k++) + acc += (AT?a[k + m*K]:a[m + k*M]) * (BT?b[n + k*N]:b[k + n*K]); + c[m + n*M] = acc; + } +} + + +class timer{ + typedef std::chrono::high_resolution_clock high_resolution_clock; + typedef std::chrono::nanoseconds nanoseconds; + +public: + explicit timer(bool run = false) + { if (run) start(); } + + void start() + { _start = high_resolution_clock::now(); } + + nanoseconds get() const + { return std::chrono::duration_cast(high_resolution_clock::now() - _start); } + +private: + high_resolution_clock::time_point _start; +}; + +template +T min(std::vector x) +{ return *std::min_element(x.begin(), x.end()); } + + +template +double bench(OP const & op, SYNC const & sync, triton::driver::device const & device) +{ + timer tmr; + std::vector times; + double total_time = 0; + op(); + sync(); + while(total_time*1e-9 < 1e-3){ + float norm = 1; + // normalize clock if possible to get roughly constant result + if(auto cu_device = dynamic_cast(&device)) + norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock(); + tmr.start(); + op(); + sync(); + times.push_back(norm*tmr.get().count()); + total_time+=times.back(); + } + return min(times); +} + +// + +void build_conv_lut(int TK, + int stride_d, int stride_h, int stride_w, int stride_c, + int pad_d, int pad_h, int pad_w, + int T, int R, int S, + std::vector& res, std::vector& masks) { + /* convolution parameters */ + int F = T * R * S; + int Nlut = (TK + F - 1) / F * F; + int upsample_w = 1; + int upsample_h = 1; + int upsample_d = 1; + /* unpack index wrt filters */ + auto unpack = [&](int32_t trs){ + int32_t tr = trs / S; + int32_t s = trs - tr*S; + int32_t t = tr / R; + int32_t r = tr - t*R; + return std::make_tuple(t, r, s); + }; + /* increments */ + for(size_t i = 0; i < Nlut; ++i) + res[i] = (((i + TK) % Nlut) - i); + /* deltas */ + size_t Ds0 = Nlut; + size_t Ds1 = upsample_w; + size_t Ds2 = upsample_h; + size_t Ds3 = upsample_d; + for(size_t pd = 0; pd < Ds3; ++pd) + for(size_t ph = 0; ph < Ds2; ++ph) + for(size_t pw = 0; pw < Ds1; ++pw){ + int32_t* deltas_ptr = &res[Nlut + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2]; + // cumulative increments + for(size_t i = 0; i < Ds0; ++i){ + int32_t ctrs = i; + int32_t c = ctrs / F; + int32_t t, r, s; + std::tie(t, r, s) = unpack(ctrs % F); + // next indices + int32_t nextctrs = ctrs + TK; + int32_t nextc = nextctrs / F; + int32_t nextt, nextr, nexts; + std::tie(nextt, nextr, nexts) = unpack(nextctrs % F); + // diffs + int32_t cdiff = nextc - c; + int32_t tdiff = (nextt + pd)/upsample_d - (t + pd)/upsample_d; + int32_t rdiff = (nextr + ph)/upsample_h - (r + ph)/upsample_h; + int32_t sdiff = (nexts + pw)/upsample_w - (s + pw)/upsample_w; + // delta pointers + deltas_ptr[i] = cdiff*stride_c + sdiff*stride_w + rdiff*stride_h + tdiff*stride_d; + } + } + + /* Masks */ + size_t Ms0 = Nlut; + size_t Ms1 = 2*pad_w + 1; + size_t Ms2 = 2*pad_h + 1; + size_t Ms3 = 2*pad_d + 1; + + for(size_t pd = 0; pd < Ms3; ++pd) + for(size_t ph = 0; ph < Ms2; ++ph) + for(size_t pw = 0; pw < Ms1; ++pw){ + int32_t* masks_ptr = &masks[Nlut + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2]; + for(size_t i = 0; i < Ms0; ++i){ + int32_t t, r, s; + int32_t mask = 0x0; + for(size_t j = 0; j < TK; ++j){ + std::tie(t, r, s) = unpack((i + j) % F); + bool in_bounds_d = (t + pd) >= pad_d && (t + pd) < (T + pad_d); + bool in_bounds_h = (r + ph) >= pad_h && (r + ph) < (R + pad_h); + bool in_bounds_w = (s + pw) >= pad_w && (s + pw) < (S + pad_w); + mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j; + } + masks_ptr[i] = mask; + } + } + for(size_t i = 0; i < Nlut; ++i) + masks[i] = 0x0; +} + + +// Index computation +inline int32_t idx(int32_t x, int32_t y, int32_t z, int32_t w, int32_t u, + int32_t /*s0*/, int32_t s1, int32_t s2, int32_t s3, int32_t s4) +{ return u + w*s4 + z*s4*s3 + y*s4*s3*s2 + x*s4*s3*s2*s1; } + + +// Pack + +template T clamp(T x, T lo, T hi){ + return std::max(lo, std::min(x, hi)); +} + + +template +T pack(U* tmp, U scale); + +template<> +double pack(double* tmp, double scale) +{ return tmp[0]*scale; } + +template<> +float pack(float* tmp, float scale) +{ return tmp[0]*scale; } + +template<> +int pack(float* tmp, float scale) +{ + int res = 0; + for(int i = 0; i < 4; i++){ + int8_t clamped = std::round(clamp(tmp[i]*scale, (float)-128, (float)127)); + res |= (clamped & 0xFF) << (8*i); + } + return res; +} + +template struct pack_increment +{ enum{ VALUE = 1}; }; + +template<> struct pack_increment +{ enum{ VALUE = 4}; }; + +// Dot +template +inline T dot(T x, T y, T z) +{ + return std::fma(x, y, z); +} + +inline int dot(int x, int y, int z){ + int res = 0; + for(int i = 0; i < 4; i++){ + int32_t a = ((x >> (8*i)) & 0x000000FF); + int32_t b = ((y >> (8*i)) & 0x000000FF); + res += (*(int8_t*)(&a)) * (*(int8_t*)(&b)); + } + return res + z; +} + + + +template +void cpp_conv_nchw(int32_t C, int32_t N, int32_t K, + int32_t D, int32_t H, int32_t W, + int32_t T, int32_t R, int32_t S, + int32_t pad_d, int32_t pad_h, int32_t pad_w, + int32_t stride_d, int32_t stride_h, int32_t stride_w, + int32_t M, int32_t P, int32_t Q, + std::vector& O, + const std::vector& I, + const std::vector& F) +{ + static const int PACK_IN = pack_increment::VALUE; + static const int PACK_OUT = pack_increment::VALUE; + if(C % PACK_IN != 0) throw std::runtime_error("Number of input channels must be a multiple of 4"); + if(K % PACK_OUT != 0) throw std::runtime_error("Number of output channels must be a multiple of 4"); + C /= PACK_IN; + K /= PACK_OUT; + int32_t Kout = K; + IN_DTYPE accs[PACK_OUT]; + float tmp[PACK_OUT]; + for(int32_t m = 0 ; m < M; ++m) + for(int32_t p = 0 ; p < P; ++p) + for(int32_t q = 0; q < Q; ++q) + for(int32_t n = 0; n < N; ++n) + for(int32_t k = 0; k < Kout ; ++k) + { + for(int32_t i = 0; i < PACK_OUT; ++i) + accs[i] = 0; + int32_t mm = m*stride_d - pad_d; + int32_t pp = p*stride_h - pad_h; + int32_t qq = q*stride_w - pad_w; + for(int32_t kk = 0; kk < PACK_OUT; ++kk) + for(int32_t c = 0; c < C; ++c) + for(int32_t t = 0; t < T; ++t) + for(int32_t r = 0; r < R; ++r) + for(int32_t s = 0; s < S; ++s){ + int32_t d = mm + t; + int32_t h = pp + r; + int32_t w = qq + s; + bool in_bounds = (d >= 0 && h >= 0 && w >= 0 && d < D && h < H && w < W); + IN_DTYPE i = in_bounds?I[idx(n, c, d, h, w, N, C, D, H, W)]:0; + IN_DTYPE f = F[idx(c, t, r, s, k*PACK_OUT + kk, C, T, R, S, K*PACK_OUT)]; + accs[kk] = dot(i, f, accs[kk]); + } + for(int32_t kk = 0; kk < PACK_OUT; ++kk){ + tmp[kk] = accs[kk]; + } + O[idx(n, k, m, p, q, N, K, M, P, Q)] = tmp[0]; + } +} + + +// input layout: C, H, W, BS +// filter layout: C, K +// output layout: K, H, W, BS +template +void shift_conv(int32_t C, int32_t H, int32_t W, int32_t BS, + int32_t K, + std::vector& O, + const std::vector& I, + const std::vector& F, + const std::vector shift_h, + const std::vector shift_w) +{ + OUT_DTYPE acc; + for(int32_t p = 0; p < H; ++p) + for(int32_t q = 0; q < W; ++q) + for(int32_t bs = 0; bs < BS; ++bs) + for(int32_t k = 0; k < K; ++k) + { + acc = 0; + for(int32_t c = 0; c < C; ++c){ + int32_t h = p + shift_h[c]; + int32_t w = q + shift_w[c]; + bool in_bounds = (h >= 0 && w >= 0 && h < H && w < W); + IN_DTYPE a = in_bounds?I[bs + w*BS + h*BS*W + c*BS*H*W]:0; + IN_DTYPE b = F[k + c*K]; + acc = dot(a, b, acc); + } + O[bs + q*BS + p*BS*W + k*BS*H*W] = acc; + } +} diff --git a/examples/cpp/conv.cpp b/examples/cpp/conv.cpp new file mode 100644 index 000000000..721489b9f --- /dev/null +++ b/examples/cpp/conv.cpp @@ -0,0 +1,236 @@ +#include +#include +#include "common.hpp" +#include "triton/jit.h" +#include "triton/driver/backend.h" +#include "triton/driver/stream.h" + +std::string src = +R"( +const tunable int32 TM = {16, 32, 64}; +const tunable int32 TN = {16, 32, 64}; +const tunable int32 TK = {8}; + +__constant__ int32* delta = alloc_const int32[18]; +__constant__ int32* masks = alloc_const int32[1024]; + +void conv(read_only restrict fp32 *a, + read_only restrict fp32 *b, + fp32 *c, + int32 M, int32 N, int32 K, + int32 AN, int32 AH, int32 AW, + int32 CN, int32 CK, int32 CP, int32 CQ, + int32 AC, int32 AR, int32 AS, + int32 lda_n, int32 lda_c, int32 lda_h, int32 lda_w, + int32 ldc_n, int32 ldc_k, int32 ldc_p, int32 ldc_q, + int32 pad_h, int32 pad_w, + int32 bound){ + int32 rxa[TM] = get_global_range[TM](0); + int32 rb0[TN] = get_global_range[TN](1); + int32 rka[TK] = 0 ... TK; + int32 rb1[TK] = 0 ... TK; + fp32 C[TM, TN] = 0; + int32 ranh[TM] = rxa / CQ; + int32 raw[TM] = rxa % CQ - pad_w; + int32 ran[TM] = ranh / CP; + int32 rah[TM] = ranh % CP - pad_h; + int32 ra0[TM] = ran*lda_n + rah*lda_h + raw*lda_w; + int32 racr[TK] = rka / AS; + int32 ras[TK] = rka % AS; + int32 rac[TK] = racr / AR; + int32 rar[TK] = racr % AR; + int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; + fp32* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis]; + fp32* pb[TN, TK] = b + rb1[newaxis, :]*CK + rb0[:, newaxis]; + __constant__ int32* pincd[TK] = delta + rka; + __constant__ int32* pd[TK] = delta + AR*AS + rka; + int32 d[TK] = *pd; + int32 incd[TK] = *pincd; + int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0); + int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0); + __constant__ int32* pm[TM] = masks + AR*AS + maskw*AR*AS + maskh*AR*AS*(2*pad_w + 1); + __constant__ int32* pincm[TM] = delta; + int32 incm[TM] = *pincm; + int32 checka0[TM] = *pm; + int32 checka1[TK] = 1 << rka; + int1 checka[TM, TK] = (checka0[:, newaxis] & checka1[newaxis, :]) > 0; + fp32 a[TM, TK] = checka ? *pa : 0; + fp32 b[TN, TK] = *pb; + for(int32 k = K; k > 0; k = k - TK){ + C = dot(a, trans(b), C); + pb = pb + TK*CK; + pa = pa + d[newaxis, :]; + b = *pb; + pd = pd + incd; + pincd = pincd + incd; + d = *pd; + incd = *pincd; + pm = pm + incm; + pincm = pincm + incm; + incm = *pincm; + checka0 = *pm; + checka = (checka0[:, newaxis] & checka1[newaxis, :]) > 0; + a = checka ? *pa : 0; + } + int32 rxc[TM] = get_global_range[TM](0); + int32 rc1[TN] = get_global_range[TN](1); + int32 rcn[TM] = rxc / (CP*CQ); + int32 rcpq[TM] = rxc % (CP*CQ); + int32 rc0[TM] = rcn * ldc_n + rcpq; + fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis]; + int1 checkc0[TM] = rxc < M; + int1 checkc1[TN] = rc1 < N; + int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; + @checkc *pc = C; +})"; + + + +int main() { + // initialize default compute device + auto context = triton::driver::backend::contexts::get_default(); + // initialize just-in-time compiler + triton::jit jit(context); + // initialization + int32_t AN = 4, CK = 32; + int32_t AD = 1, AH = 24, AW = 240; + int32_t BC = 64, BT = 1, BR = 3, BS = 3; + int32_t pad_d = 0, pad_h = 1, pad_w = 1; + int32_t stride_d = 1, stride_h = 1, stride_w = 1; + int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; + int32_t CM = (AD*upsample_d - BT + 1 + 2*pad_d + stride_d - 1)/stride_d; + int32_t CP = (AH*upsample_h - BR + 1 + 2*pad_h + stride_h - 1)/stride_h; + int32_t CQ = (AW*upsample_w - BS + 1 + 2*pad_w + stride_w - 1)/stride_w; + // equivalent matmul dimensions + int32_t M = AN*CM*CP*CQ; + int32_t N = CK; + int32_t K = BC*BT*BR*BS; + std::vector hc(AN*CP*CQ*CK); + std::vector rc(AN*CP*CQ*CK); + std::vector ha(AN*BC*AH*AW); + std::vector hb(BC*BR*BS*CK); + srand(0); + for(size_t i = 0; i < ha.size(); i++) + ha[i] = 1; + for(size_t i = 0; i < hb.size(); i++) + hb[i] = 1; + for(size_t i = 0; i < hc.size(); i++) + hc[i] = 0; + triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4); + triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4); + triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4); + triton::driver::stream* stream = triton::driver::stream::create(context); + stream->write(da, true, 0, ha); + stream->write(db, true, 0, hb); + stream->write(dc, true, 0, hc); + stream->synchronize(); + // memory strides for data + int32_t stride_i_w = 1; + int32_t stride_i_h = AW*stride_i_w; + int32_t stride_i_d = AH*stride_i_h; + int32_t stride_i_c = AD*stride_i_d; + int32_t stride_i_n = BC*stride_i_c; + // memory strides for filters + int32_t stride_f_k = 1; + int32_t stride_f_s = CK*stride_f_k; + int32_t stride_f_r = BS*stride_f_s; + int32_t stride_f_t = BR*stride_f_r; + int32_t stride_f_c = BT*stride_f_t; + // memory stride for activations + int32_t stride_o_q = 1; + int32_t stride_o_p = CQ*stride_o_q; + int32_t stride_o_m = CP*stride_o_p; + int32_t stride_o_k = CM*stride_o_m; + int32_t stride_o_n = CK*stride_o_k; + // look-up table + int TK = 8; + int F = BT * BR * BS; + int nlut = (TK + F - 1) / F * F; + std::vector h_delta(nlut + upsample_d*upsample_h*upsample_w*nlut); + std::vector h_masks(nlut + (2*pad_h+1)*(2*pad_w+1)*(2*pad_d+1)*nlut); + build_conv_lut(TK, stride_i_d, stride_i_h, stride_i_w, stride_i_c, pad_d, pad_h, pad_w, BT, BR, BS, h_delta, h_masks); + // benchmark a given convolution kernel + auto benchmark = [&](triton::driver::kernel* kernel, + triton::jit::launch_information info) { + // launch info + unsigned TM = info.global_range_size[0]; + unsigned TN = info.global_range_size[1]; + unsigned TK = jit.get_int("TK"); + // initialize constant memory + triton::driver::buffer* delta = jit.get_buffer("delta"); + triton::driver::buffer* masks = jit.get_buffer("masks"); + stream->write(delta, false, 0, h_delta.size()*4, h_delta.data()); + stream->write(masks, false, 0, h_masks.size()*4, h_masks.data()); + stream->synchronize(); + // launch info + unsigned nthreads = info.num_threads; + std::array grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}; + // fast bounds-checking + unsigned lasti = (grid[0]*TM - 1)*TM + TM - 1; + unsigned lastj = (grid[1]*TN - 1)*TN + TN - 1; + unsigned lastk = TK - 1; + bool AT = false; + bool BT = true; + unsigned last_safe_a = (AT==false)?(M*K - 1 - lasti)/M - lastk : M*K - 1 - lasti*K - lastk; + unsigned last_safe_b = (BT==true)?(N*K - 1 - lastj)/N - lastk : N*K - 1 - lastj*K - lastk; + int32_t bound = std::max(1, std::max(K - last_safe_a, K - last_safe_b)); + // set arguments + kernel->setArg(0, da); + kernel->setArg(1, db); + kernel->setArg(2, dc); + kernel->setArg(3, M); + kernel->setArg(4, N); + kernel->setArg(5, K); + kernel->setArg(6, AN); + kernel->setArg(7, AH); + kernel->setArg(8, AW); + kernel->setArg(9, AN); + kernel->setArg(10, CK); + kernel->setArg(11, CP); + kernel->setArg(12, CQ); + kernel->setArg(13, BC); + kernel->setArg(14, BR); + kernel->setArg(15, BS); + kernel->setArg(16, stride_i_n); + kernel->setArg(17, stride_i_c); + kernel->setArg(18, stride_i_h); + kernel->setArg(19, stride_i_w); + kernel->setArg(20, stride_o_n); + kernel->setArg(21, stride_o_k); + kernel->setArg(22, stride_o_p); + kernel->setArg(23, stride_o_q); + kernel->setArg(24, pad_h); + kernel->setArg(25, pad_w); + kernel->setArg(26, bound); + // dry run + stream->enqueue(kernel, grid, {nthreads, 1, 1}); + stream->synchronize(); + // benchmark + double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});}, + [&](){ stream->synchronize(); }, *context->device()); + ts = ts * 1e-9; + double tflops = 2.*M*N*K / ts * 1e-12; + return tflops; + }; + // run + std::vector params = { + 16, 2, 64, + 32, 2, 64, + 16, 8, 2, 2, + 8, 8, + 4 + }; +// jit.autotune("conv", src, benchmark); + jit.add_module("conv", src, params); + triton::driver::kernel* kernel = jit.get_function("conv"); + triton::jit::launch_information info = jit.get_launch_info("conv"); + std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl; + stream->read(dc, true, 0, hc); + cpp_conv_nchw(BC, AN, CK, AD, AH, AW, BT, BR, BS, pad_d, pad_h, pad_w, stride_d, stride_h, stride_w, CM, CP, CQ, rc, ha, hb); + for(size_t i = 0; i < M*N; i++) + if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ + std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; + exit(EXIT_FAILURE); + } + std::cout << "Pass!" << std::endl; +} diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp new file mode 100644 index 000000000..7bda6c775 --- /dev/null +++ b/examples/cpp/dot.cpp @@ -0,0 +1,162 @@ +#include +#include +#include "common.hpp" +#include "triton/jit.h" +#include "triton/driver/backend.h" +#include "triton/driver/stream.h" + +const char* src = +R"( +const tunable int32 TM = {16, 32, 64, 128}; +const tunable int32 TN = {16, 32, 64, 128}; +const tunable int32 TK = {8}; +const tunable int32 GZ = {1}; + +void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C, + int32 M, int32 N, int32 K, + int32 lda, int32 ldb, int32 ldc, + int32 *locks, int32 grid0, int32 grid1) { + int32 rxa[TM] = get_global_range[TM](0); + int32 ryb[TN] = get_global_range[TN](1); + int32 rz = get_global_range[1](2); + int32 rka[TK] = 0 ... TK; + int32 rkb[TK] = 0 ... TK; + fp32 c[TM, TN] = 0; + int32 div = K / GZ; + int32 rem = K % GZ; + K = select(rz < rem, div - 1, div); + int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem); + fp32* pa[TM, TK] = A + (offk + rka[newaxis, :])*lda + rxa[:, newaxis]; + fp32* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis]; + fp32 a[TM, TK] = *pa; + fp32 b[TN, TK] = *pb; + int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda; + int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb; + last_a = last_a / TK * TK; + last_b = last_b / TK * TK; + int32 bound = K - max(last_a, last_b); + for(int32 k = K; k > bound; k = k - TK){ + c = dot(a, trans(b), c); + pa = pa + TK*lda; + pb = pb + TK*ldb; + a = *pa; + b = *pb; + } + int32 rxc[TM] = get_global_range[TM](0); + int32 ryc[TN] = get_global_range[TN](1); + for(int32 k = bound; k > 0; k = k - 1){ + int1 checka[TM, 1] = rxc[:, newaxis] < M; + int1 checkb[TN, 1] = ryc[:, newaxis] < N; + fp32* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis]; + fp32* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis]; + fp32 a[TM, 1] = checka ? *pa : 0; + fp32 b[TN, 1] = checkb ? *pb : 0; + c = dot(a, trans(b), c); + } + int32 ridx = get_range_id(0); + int32 ridy = get_range_id(1); + fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; + int32 *plock = locks + ridx + ridy*grid0; + for(int32 L = __atomic_cas(plock, 0, 1); L == 1; L = __atomic_cas(plock, 0, 1)){} + int32 *pcount = plock + grid0*grid1; + int32 count = *pcount; + int32 countp1 = select(count == GZ - 1, 0, count + 1); + int1 checkc0[TM] = rxc < M; + int1 checkc1[TN] = ryc < N; + int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; + if(count == 0) { + @checkc *pc = c; + *pcount = countp1; + } + else { + @checkc *pc = c + (checkc ? *pc : 0); + *pcount = countp1; + } + __atomic_cas(plock, 1, 0); +} +)"; + +int main() { + // initialize default compute device + auto context = triton::driver::backend::contexts::get_default(); + triton::jit jit(context); + + // matrix multiplication parameters + int32_t M = 512, N = 512, K = 512; + std::vector hc(M*N); + std::vector rc(M*N); + std::vector ha(M*K); + std::vector hb(K*N); + std::vector hlocks(2048); + srand(0); + for(size_t i = 0; i < ha.size(); i++) + ha[i] = (float)rand()/RAND_MAX; + for(size_t i = 0; i < hb.size(); i++) + hb[i] = (float)rand()/RAND_MAX; + for(size_t i = 0; i < hc.size(); i++) + hc[i] = 0; + triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4); + triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4); + triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4); + triton::driver::buffer* dlocks = triton::driver::buffer::create(context, hlocks.size()*4); + triton::driver::stream* stream = triton::driver::stream::create(context); + stream->write(da, true, 0, ha); + stream->write(db, true, 0, hb); + stream->write(dc, true, 0, hc); + stream->synchronize(); + + + // benchmark a given matrix multiplication kernel + auto benchmark = [&](triton::driver::kernel* kernel, + triton::jit::launch_information info) { + // launch info + unsigned TM = info.global_range_size[0]; + unsigned TN = info.global_range_size[1]; + unsigned nthreads = info.num_threads; + unsigned GZ = jit.get_int("GZ"); + std::array grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, GZ}; + // init locks + stream->write(dlocks, true, 0, hlocks); + // set argument + kernel->setArg(0, da); + kernel->setArg(1, db); + kernel->setArg(2, dc); + kernel->setArg(3, M); + kernel->setArg(4, N); + kernel->setArg(5, K); + kernel->setArg(6, M); + kernel->setArg(7, N); + kernel->setArg(8, M); + kernel->setArg(9, dlocks); + kernel->setArg(10, grid[0]); + kernel->setArg(11, grid[1]); + // dry run + stream->enqueue(kernel, grid, {nthreads, 1, 1}); + stream->synchronize(); + // benchmark + double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});}, + [&](){ stream->synchronize(); }, *context->device()); + ts = ts * 1e-9; + double tflops = 2.*M*N*K / ts * 1e-12; + return tflops; + }; + + + // just-in-time compile source-code + std::vector params = { + 16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1 + }; +// jit.autotune("matmul",src, benchmark); + jit.add_module("matmul", src, params); + triton::driver::kernel* kernel = jit.get_function("matmul"); + triton::jit::launch_information info = jit.get_launch_info("matmul"); + std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl; + stream->read(dc, true, 0, hc); + simple_gemm(rc, ha, hb, M, N, K); + for(size_t i = 0; i < M*N; i++) + if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ + std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; + exit(EXIT_FAILURE); + } + std::cout << "Pass!" << std::endl; +} diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp new file mode 100644 index 000000000..f75046e2f --- /dev/null +++ b/examples/cpp/shift.cpp @@ -0,0 +1,212 @@ +#include +#include +#include "common.hpp" +#include "triton/jit.h" +#include "triton/driver/backend.h" +#include "triton/driver/stream.h" + +// K = channels +// M = batch * height * width +// N = number of feature maps + +const char* src = +R"( +const tunable int32 TM = {16, 32, 64, 128}; +const tunable int32 TN = {16, 32, 64, 128}; +const tunable int32 TK = {8}; + +__constant__ int32* delta = alloc_const int32[256]; +__constant__ int32* masks = alloc_const int32[8192]; + +void shift(restrict read_only fp32 *a, restrict read_only fp32 *b, fp32 *c, + int32 M, int32 N, int32 K, + int32 ABS, int32 AH, int32 AW, int32 AR, int32 AS){ + int32 rxa[TM] = get_global_range[TM](0); + int32 ryb[TN] = get_global_range[TN](1); + int32 rka[TK] = 0 ... TK; + int32 rkb[TK] = 0 ... TK; + fp32 C[TM, TN] = 0; + fp32* pxa[TM, TK] = a + rxa[:, newaxis]; + fp32* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis]; + __constant__ int32* pd[TK] = delta + rka; + int32 pad_h = AR/2; + int32 pad_w = AS/2; + int32 rawhc[TM] = rxa / ABS; + int32 raw[TM] = rawhc % AW - pad_w; + int32 rahc[TM] = rawhc / AW; + int32 rah[TM] = rahc % AH - pad_h; + int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + AR - AH, 0); + int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + AS - AW, 0); + __constant__ int32* pxm[TM] = masks + maskh*K + maskw*K*(2*pad_h + 1); + __constant__ int32* pm[TM, TK] = pxm[:, newaxis] + rka[newaxis, :]; + for(int32 k = K; k > 0; k = k - TK){ + int32 delta[TK] = *pd; + fp32 *pa[TM, TK] = pxa + delta[newaxis, :]; + int1 m[TM, TK] = *pm > 0; + fp32 a[TM, TK] = m ? *pa : 0; + fp32 b[TN, TK] = *pb; + C = dot(a, trans(b), C); + pb = pb + TK*N; + pd = pd + TK; + pm = pm + TK; + } + int32 rxc[TM] = get_global_range[TM](0); + int32 ryc[TN] = get_global_range[TN](1); + fp32* pc[TM, TN] = c + ryc[newaxis, :]*M + rxc[:, newaxis]; + int1 checkc0[TM] = rxc < M; + int1 checkc1[TN] = ryc < N; + int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; + @checkc *pc = C; +} +)"; + +std::vector shift_deltas(// strides + int32_t stride_w, int32_t stride_h, int32_t stride_c, + // shift + int32_t C, + const std::vector& shift_h, + const std::vector& shift_w) { + std::vector res(C); + for(unsigned c = 0; c < C; c++){ + res[c] = c*stride_c; + res[c] += shift_h[c]*stride_h; + res[c] += shift_w[c]*stride_w; + } + return res; +} + +std::vector shift_masks(int32_t C, + const std::vector& shift_h, + const std::vector& shift_w, + int32_t R, int32_t S) { + size_t S0 = C; + size_t S1 = R; + size_t S2 = S; + std::vector res(S0*S1*S2); + for(size_t ph = 0; ph < S1; ++ph) + for(size_t pw = 0; pw < S2; ++pw){ + int32_t* ptr = &res[ph*S0 + pw*S0*S1]; + for(size_t i = 0; i < S0; ++i){ + bool in_bounds_h = shift_h[i] + ph >= 0 && shift_h[i] + ph < R; + bool in_bounds_w = shift_w[i] + pw >= 0 && shift_w[i] + pw < S; + ptr[i] = in_bounds_h && in_bounds_w; + } + } + return res; +} + +int main() { + // initialize default compute device + auto context = triton::driver::backend::contexts::get_default(); + // initialize just-in-time compiler + triton::jit jit(context); + // initialization + int32_t R = 3, S = 3; + int32_t BS = 4, F = 128; + int32_t H = 32, W = 32; + int32_t C = 128; + // equivalent matmul dimensions + int32_t M = BS*H*W; + int32_t N = F; + int32_t K = C; + std::cout << M << " " << N << " " << K << std::endl; + std::vector hc(BS*H*W*F); + std::vector rc(BS*H*W*F); + std::vector ha(BS*C*H*W); + std::vector hb(F*C); + // strides + int32_t stride_i_bs = 1; + int32_t stride_i_w = BS*stride_i_bs; + int32_t stride_i_h = W*stride_i_w; + int32_t stride_i_c = H*stride_i_h; + // random shifts + std::vector shift_h(C); + std::vector shift_w(C); + for(int32_t c = 0; c < C; c++){ + shift_h[c] = rand() % R - R/2; + shift_w[c] = rand() % S - S/2; + } + // initialize buffers + srand(0); + for(int c = 0 ; c < C; c++) + for(int h = 0 ; h < H; h++) + for(int w = 0 ; w < W; w++) + for(int bs = 0 ; bs < BS; bs++){ + float value = (float)rand() / RAND_MAX; + size_t idx = bs + w*stride_i_w + h*stride_i_h + c*stride_i_c; + ha[idx] = value; + } + for(size_t i = 0; i < hb.size(); i++) + hb[i] = (float)rand() / RAND_MAX; + for(size_t i = 0; i < hc.size(); i++) + hc[i] = 0; + triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4); + triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*4); + triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*4); + triton::driver::stream* stream = triton::driver::stream::create(context); + stream->write(da, true, 0, ha); + stream->write(db, true, 0, hb); + stream->write(dc, true, 0, hc); + stream->synchronize(); + std::vector h_delta = shift_deltas(stride_i_w, stride_i_h, stride_i_c, C, shift_h, shift_w); + std::vector h_masks = shift_masks(C, shift_h, shift_w, R, S); + // benchmark a given matrix multiplication kernel + auto benchmark = [&](triton::driver::kernel* kernel, + triton::jit::launch_information info) { + // launch info + unsigned TM = info.global_range_size[0]; + unsigned TN = info.global_range_size[1]; + unsigned nthreads = info.num_threads; + // initialize constant memory + triton::driver::buffer* delta = jit.get_buffer("delta"); + triton::driver::buffer* masks = jit.get_buffer("masks"); + stream->write(delta, false, 0, h_delta.size()*4, h_delta.data()); + stream->write(masks, false, 0, h_masks.size()*4, h_masks.data()); + stream->synchronize(); + // set argument + kernel->setArg(0, da); + kernel->setArg(1, db); + kernel->setArg(2, dc); + kernel->setArg(3, M); + kernel->setArg(4, N); + kernel->setArg(5, K); + kernel->setArg(6, BS); + kernel->setArg(7, H); + kernel->setArg(8, W); + kernel->setArg(9, R); + kernel->setArg(10, S); + // dry run + std::array grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}; + stream->enqueue(kernel, grid, {nthreads, 1, 1}); + stream->synchronize(); + // benchmark + double ts = bench([&](){stream->enqueue(kernel, grid, {nthreads, 1, 1});}, + [&](){ stream->synchronize(); }, *context->device()); + ts = ts * 1e-9; + double tflops = 2.*M*N*K / ts * 1e-12; + return tflops; + }; + + // shift + std::vector params = { + 16, 2, 64, + 32, 2, 64, + 16, 8, 2, 2, + 8, 8, + 4 + }; +// jit.autotune("shift", src, benchmark); + jit.add_module("shift", src, params); + triton::driver::kernel* kernel = jit.get_function("shift"); + triton::jit::launch_information info = jit.get_launch_info("shift"); + std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl; + stream->read(dc, true, 0, hc); + shift_conv(C, H, W, BS, F, rc, ha, hb, shift_h, shift_w); + for(size_t i = 0; i < M*N; i++) + if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ + std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; + exit(EXIT_FAILURE); + } + std::cout << "Pass!" << std::endl; + +} diff --git a/examples/cpp/shift.ptx b/examples/cpp/shift.ptx new file mode 100644 index 000000000..62a841909 --- /dev/null +++ b/examples/cpp/shift.ptx @@ -0,0 +1,93 @@ +// +// Generated by NVIDIA NVVM Compiler +// +// Compiler Build ID: CL-24817639 +// Cuda compilation tools, release 10.0, V10.0.130 +// Based on LLVM 3.4svn +// + +.version 6.3 +.target sm_60 +.address_size 64 + + // .globl _Z25shift_cuda_forward_kernelPKfPKiPfiiii + +.visible .entry shift( + .param .u64 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_0, + .param .u64 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_1, + .param .u64 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_2, + .param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_3, + .param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_4, + .param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_5, + .param .u32 _Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_6 +) +{ + .reg .pred %p<10>; + .reg .f32 %f<2>; + .reg .b32 %r<31>; + .reg .b64 %rd<13>; + + + ld.param.u64 %rd1, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_0]; + ld.param.u64 %rd3, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_1]; + ld.param.u64 %rd2, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_2]; + ld.param.u32 %r3, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_3]; + ld.param.u32 %r4, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_4]; + ld.param.u32 %r5, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_5]; + ld.param.u32 %r6, [_Z25shift_cuda_forward_kernelPKfPKiPfiiii_param_6]; + cvta.to.global.u64 %rd4, %rd3; + mov.u32 %r7, %ntid.x; + mov.u32 %r8, %ctaid.x; + mov.u32 %r9, %tid.x; + mad.lo.s32 %r1, %r7, %r8, %r9; + mul.lo.s32 %r10, %r4, %r3; + mul.lo.s32 %r11, %r10, %r5; + mul.lo.s32 %r12, %r11, %r6; + mul.lo.s32 %r13, %r5, %r4; + mul.lo.s32 %r14, %r13, %r6; + rem.s32 %r15, %r1, %r14; + sub.s32 %r16, %r1, %r15; + mul.lo.s32 %r17, %r6, %r5; + div.s32 %r18, %r15, %r17; + mul.lo.s32 %r19, %r18, %r17; + sub.s32 %r20, %r15, %r19; + div.s32 %r21, %r20, %r5; + mul.lo.s32 %r22, %r21, %r6; + sub.s32 %r23, %r20, %r22; + shl.b32 %r24, %r18, 1; + mul.wide.s32 %rd5, %r24, 4; + add.s64 %rd6, %rd4, %rd5; + ld.global.nc.u32 %r25, [%rd6]; + add.s32 %r26, %r25, %r21; + ld.global.nc.u32 %r27, [%rd6+4]; + add.s32 %r28, %r23, %r27; + add.s32 %r29, %r16, %r19; + mad.lo.s32 %r30, %r26, %r5, %r29; + add.s32 %r2, %r30, %r28; + setp.lt.s32 %p1, %r1, %r12; + setp.gt.s32 %p2, %r26, -1; + and.pred %p3, %p1, %p2; + setp.lt.s32 %p4, %r26, %r5; + and.pred %p5, %p3, %p4; + setp.gt.s32 %p6, %r28, -1; + and.pred %p7, %p5, %p6; + setp.lt.s32 %p8, %r28, %r6; + and.pred %p9, %p7, %p8; + @!%p9 bra BB0_2; + bra.uni BB0_1; + +BB0_1: + cvta.to.global.u64 %rd7, %rd1; + mul.wide.s32 %rd8, %r1, 4; + add.s64 %rd9, %rd7, %rd8; + ld.global.nc.f32 %f1, [%rd9]; + cvta.to.global.u64 %rd10, %rd2; + mul.wide.s32 %rd11, %r2, 4; + add.s64 %rd12, %rd10, %rd11; + st.global.f32 [%rd12], %f1; + +BB0_2: + ret; +} + + diff --git a/include/triton/ast/ast.h b/include/triton/ast/ast.h index b286c5a79..8eccd6f92 100644 --- a/include/triton/ast/ast.h +++ b/include/triton/ast/ast.h @@ -74,8 +74,8 @@ class constant; class node { protected: static ir::value* explicit_cast(ir::builder &builder, ir::value *src, ir::type *dst_ty); + static void implicit_broadcast(ir::module *mod, ir::type *dst_ty, ir::value *&src); static void implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs); - static void implicit_broadcast(ir::module *mod, ir::value *&arg, ir::type *ty); static void implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs, bool &is_float, bool &is_ptr, bool &is_int, bool &is_signed); public: @@ -164,6 +164,27 @@ private: const constant* axis_; }; +class get_range_id: public builtin_expression{ +public: + get_range_id(node *axis): axis_((constant*)axis) { } + ir::value* codegen(ir::module *) const; + +private: + const constant* axis_; +}; + +class atomic_cas: public builtin_expression{ +public: + atomic_cas(node *ptr, node *cmp, node *val): ptr_(ptr), cmp_(cmp), val_(val) { } + ir::value* codegen(ir::module *) const; + +private: + const node *ptr_; + const node *cmp_; + const node *val_; +}; + + class matmul_expression: public builtin_expression{ public: matmul_expression(node* A, node *B, node *C): @@ -176,6 +197,49 @@ private: const expression *C_; }; +class max_expression: public builtin_expression{ +public: + max_expression(node* x, node* y) + : x_((expression*)x), y_((expression*)y){ } + ir::value* codegen(ir::module *) const; + +private: + const expression *x_; + const expression *y_; +}; + +class min_expression: public builtin_expression{ +public: + min_expression(node* x, node* y) + : x_((expression*)x), y_((expression*)y){ } + ir::value* codegen(ir::module *mod) const; + +private: + const expression *x_; + const expression *y_; +}; + +class select_expression: public builtin_expression{ +public: + select_expression(node* pred, node* if_value, node* else_value) + : pred_((expression*)pred), if_value_((expression*)if_value), else_value_((expression*)else_value) { } + ir::value* codegen(ir::module *mod) const; + +private: + const expression *pred_; + const expression *if_value_; + const expression *else_value_; +}; + +class trans_expression: public builtin_expression{ +public: + trans_expression(node *arg): arg_(arg) {} + ir::value* codegen(ir::module *mod) const; + +private: + node* arg_; +}; + class indexing_expression: public postfix_expression{ public: @@ -189,6 +253,8 @@ private: const list* slices_; }; + + class named_expression: public expression { public: named_expression(node *id): id_((const identifier*)id) { lvalue_ = this; } diff --git a/include/triton/ast/parser.y b/include/triton/ast/parser.y index ae4b7d4e3..5302c7d14 100644 --- a/include/triton/ast/parser.y +++ b/include/triton/ast/parser.y @@ -55,7 +55,7 @@ STORAGE_SPEC_T get_storage_spec(node *op) { return ((token*)op)->storage_spec;} %token VOID UINT1 UINT8 UINT16 UINT32 UINT64 INT1 INT8 INT16 INT32 INT64 FP32 FP64 %token IF ELSE FOR CONTINUE %token NEWAXIS ELLIPSIS AT -%token GET_GLOBAL_RANGE DOT ALLOC_CONST +%token GET_GLOBAL_RANGE GET_RANGE_ID DOT TRANS MAX MIN SELECT ATOMIC_CAS ALLOC_CONST %start translation_unit %% @@ -118,8 +118,15 @@ identifier builtin : GET_GLOBAL_RANGE '[' primary_expression ']' '(' constant ')' { $$ = new get_global_range($3, $6); } + | GET_RANGE_ID '(' constant ')' { $$ = new get_range_id($3); } | DOT '(' expression ',' expression ',' expression ')' { $$ = new matmul_expression($3, $5, $7); } | ALLOC_CONST type_specifier '[' constant ']' { $$ = new alloc_const(new typed_declaration_specifier(get_type_spec($2)), $4); } + | TRANS '(' expression ')' { $$ = new trans_expression($3); } + | MAX '(' expression ',' expression ')' { $$ = new max_expression($3, $5); } + | MIN '(' expression ',' expression ')' { $$ = new min_expression($3, $5); } + | SELECT '(' expression ',' expression ',' expression ')' { $$ = new select_expression($3, $5, $7); } + | ATOMIC_CAS '(' expression ',' expression ',' expression ')' { $$ = new atomic_cas($3, $5, $7); } + ; primary_expression : identifier { $$ = new named_expression($1); } diff --git a/include/triton/ast/scanner.l b/include/triton/ast/scanner.l index 91b700655..e4e018a14 100644 --- a/include/triton/ast/scanner.l +++ b/include/triton/ast/scanner.l @@ -41,7 +41,13 @@ using triton::ast::return_void; "fp64" { return return_impl(FP64, yytext); } "..." { return return_impl(ELLIPSIS, yytext); } "get_global_range" { return return_impl(GET_GLOBAL_RANGE, yytext); } +"get_range_id" { return return_impl(GET_RANGE_ID, yytext); } +"__atomic_cas" { return return_impl(ATOMIC_CAS, yytext); } "dot" { return return_impl(DOT, yytext); } +"max" { return return_impl(MAX, yytext); } +"min" { return return_impl(MIN, yytext); } +"select" { return return_impl(SELECT, yytext); } +"trans" { return return_impl(TRANS, yytext); } "continue" { return return_impl(CONTINUE, yytext); } "alloc_const" { return return_impl(ALLOC_CONST, yytext); } {L}({L}|{D})* { return return_impl(IDENTIFIER, yytext); } @@ -52,8 +58,6 @@ using triton::ast::return_void; L?'(\\.|[^\\'])+' { return return_impl(CONSTANT, yytext); } {D}+{E}{FS}? { return return_impl(CONSTANT, yytext); } -{D}*"."{D}+({E})?{FS}? { return return_impl(CONSTANT, yytext); } -{D}+"."{D}*({E})?{FS}? { return return_impl(CONSTANT, yytext); } L?\"(\\.|[^\\"])*\" { return return_impl(STRING_LITERAL, yytext); } diff --git a/include/triton/codegen/layout.h b/include/triton/codegen/layout.h deleted file mode 100644 index a18f6439f..000000000 --- a/include/triton/codegen/layout.h +++ /dev/null @@ -1,45 +0,0 @@ -#ifndef TDL_INCLUDE_IR_CODEGEN_LAYOUT_H -#define TDL_INCLUDE_IR_CODEGEN_LAYOUT_H - -#include -#include - -namespace triton { - -namespace ir { - class module; - class instruction; - class value; -} - -namespace codegen{ - -struct shared_view_info{ - ir::value *usr; - bool has_dedicated_storage; -}; - -class layout { -private: - typedef std::vector shared_view_val_t; - - void add_phi_nodes(ir::value *v); - void add_shared_views(ir::value *v); - -public: - // accessors - unsigned get_num_shared_views(ir::value *v); - shared_view_info get_shared_view(ir::value *v, unsigned idx); - - // run - void run(ir::module &mod); - -private: - std::map shared_views_; -}; - - -} -} - -#endif diff --git a/include/triton/codegen/optimize_cse.h b/include/triton/codegen/optimize_cse.h new file mode 100644 index 000000000..d718f318e --- /dev/null +++ b/include/triton/codegen/optimize_cse.h @@ -0,0 +1,27 @@ +#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H +#define TDL_INCLUDE_CODEGEN_OPTIMIZE_CSE_H + +#include +#include +#include + +namespace triton { + +namespace ir { + class module; +} + +namespace codegen{ +class tune; + +class optimize_cse { +public: + optimize_cse() {} + void run(ir::module &mod); +}; + + +} +} + +#endif diff --git a/include/triton/codegen/optimize_dot.h b/include/triton/codegen/optimize_dot.h new file mode 100644 index 000000000..76d8368dc --- /dev/null +++ b/include/triton/codegen/optimize_dot.h @@ -0,0 +1,31 @@ +#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_DOT_H +#define TDL_INCLUDE_CODEGEN_OPTIMIZE_DOT_H + +#include +#include +#include + +namespace triton { + +namespace ir { + class module; +} + +namespace codegen{ + +class tune; + +class optimize_dot { +public: + optimize_dot(tune* params): params_(params) {} + void run(ir::module &mod); + +private: + tune* params_; +}; + + +} +} + +#endif diff --git a/include/triton/codegen/optimize_trans.h b/include/triton/codegen/optimize_trans.h new file mode 100644 index 000000000..beaace2a5 --- /dev/null +++ b/include/triton/codegen/optimize_trans.h @@ -0,0 +1,33 @@ +#ifndef TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H +#define TDL_INCLUDE_CODEGEN_OPTIMIZE_TRANS_H + +#include +#include +#include + +namespace triton { + +namespace ir { + class module; + class value; + class instruction; + class trans_inst; + class builder; +} + +namespace codegen{ + +class optimize_trans { +private: + ir::value *replace_phi(ir::value* value, std::vector& to_delete, ir::builder &builder); + +public: + optimize_trans() {} + void run(ir::module &mod); +}; + + +} +} + +#endif diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index 11acf28e7..d9ce08c53 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -7,7 +7,7 @@ #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/type.h" -#include "triton/codegen/buffer_info.h" +#include "triton/codegen/shmem_info.h" namespace llvm{ @@ -21,9 +21,9 @@ namespace llvm{ namespace triton{ namespace codegen{ -class allocation; +class shmem_allocation; class tune; -class buffer_info_pass; +class shmem_info; class target; typedef std::vector indices_t; @@ -129,7 +129,7 @@ private: void lower_tile_instruction(ir::instruction *src, llvm::IRBuilder<> &builder); public: - selection(allocation *alloc, tune *params, buffer_info_pass *buffer_info, target *tgt) + selection(shmem_allocation *alloc, tune *params, shmem_info *buffer_info, target *tgt) : alloc_(alloc), params_(params), buffer_info_(buffer_info), tgt_(tgt){ } void run(ir::module &src, llvm::Module &dst); @@ -139,11 +139,12 @@ private: tmap_t tmap_; pmap_t pmap_; pmap_t last_block_; - allocation *alloc_; + shmem_allocation *alloc_; tune *params_; target *tgt_; - buffer_info_pass *buffer_info_; + shmem_info *buffer_info_; std::map axes_; + llvm::Value *sh_mem_ptr_; }; } diff --git a/include/triton/codegen/shared_copy.h b/include/triton/codegen/shared_copy.h deleted file mode 100644 index 3a3d7363b..000000000 --- a/include/triton/codegen/shared_copy.h +++ /dev/null @@ -1,41 +0,0 @@ -#ifndef TDL_INCLUDE_CODEGEN_SHARED_COPY_H -#define TDL_INCLUDE_CODEGEN_SHARED_COPY_H - -#include -#include - -namespace triton { - -namespace ir { - class module; - class value; - class builder; - class basic_block; -} - -namespace codegen{ - -class buffer_info_pass; - -class place_shared_copy { -private: - typedef std::pair interval_t; - typedef std::vector interval_vec_t; - -private: - bool intersect(const interval_vec_t &I, interval_t i); - void add_copy(ir::value *x, ir::builder &builder); - -public: - place_shared_copy(buffer_info_pass *info): info_(info) { } - void run(ir::module &mod); - -private: - buffer_info_pass *info_; -}; - - -} -} - -#endif diff --git a/include/triton/codegen/allocation.h b/include/triton/codegen/shmem_allocation.h similarity index 79% rename from include/triton/codegen/allocation.h rename to include/triton/codegen/shmem_allocation.h index 1f2a7656c..27a96f285 100644 --- a/include/triton/codegen/allocation.h +++ b/include/triton/codegen/shmem_allocation.h @@ -16,12 +16,12 @@ namespace codegen{ class layout; class target_tuner; -class liveness; -class buffer_info_pass; +class shmem_liveness; +class shmem_info; -class allocation { +class shmem_allocation { public: - allocation(liveness *live, buffer_info_pass *buffer_info) + shmem_allocation(shmem_liveness *live, shmem_info *buffer_info) : liveness_(live), buffer_info_(buffer_info){ } // utilities @@ -39,8 +39,8 @@ private: std::map num_bytes_; size_t allocated_size_; // dependences - liveness *liveness_; - buffer_info_pass *buffer_info_; + shmem_liveness *liveness_; + shmem_info *buffer_info_; }; } diff --git a/include/triton/codegen/barriers.h b/include/triton/codegen/shmem_barriers.h similarity index 82% rename from include/triton/codegen/barriers.h rename to include/triton/codegen/shmem_barriers.h index 336ec255a..271b745cc 100644 --- a/include/triton/codegen/barriers.h +++ b/include/triton/codegen/shmem_barriers.h @@ -17,10 +17,10 @@ namespace ir { namespace codegen{ -class allocation; -class buffer_info_pass; +class shmem_allocation; +class shmem_info; -class barriers { +class shmem_barriers { private: typedef std::pair interval_t; typedef std::vector interval_vec_t; @@ -36,12 +36,12 @@ private: std::pair transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set &insert_loc); public: - barriers(allocation *alloc, buffer_info_pass *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {} + shmem_barriers(shmem_allocation *alloc, shmem_info *buffer_info): alloc_(alloc), buffer_info_(buffer_info) {} void run(ir::module &mod); private: - allocation *alloc_; - buffer_info_pass *buffer_info_; + shmem_allocation *alloc_; + shmem_info *buffer_info_; }; diff --git a/include/triton/codegen/buffer_info.h b/include/triton/codegen/shmem_info.h similarity index 84% rename from include/triton/codegen/buffer_info.h rename to include/triton/codegen/shmem_info.h index 58f140d61..f8325d00b 100644 --- a/include/triton/codegen/buffer_info.h +++ b/include/triton/codegen/shmem_info.h @@ -10,18 +10,19 @@ namespace ir { class module; class value; class phi_node; + class instruction; } namespace codegen{ -class buffer_info_pass { +class shmem_info { public: void run(ir::module &mod); // queries bool is_double(ir::value *x); void add_shared(ir::value *v); bool is_shared(ir::value *x); - bool is_loop_latch(ir::phi_node *phi, ir::value *terminator); + bool is_loop_latch(ir::phi_node *phi, ir::instruction *terminator); ir::value *get_reference(ir::value *x); void replace(ir::value* before, ir::value *after); diff --git a/include/triton/codegen/liveness.h b/include/triton/codegen/shmem_liveness.h similarity index 90% rename from include/triton/codegen/liveness.h rename to include/triton/codegen/shmem_liveness.h index 010bb4e2a..69210d03f 100644 --- a/include/triton/codegen/liveness.h +++ b/include/triton/codegen/shmem_liveness.h @@ -15,7 +15,7 @@ namespace codegen{ typedef unsigned slot_index; -class buffer_info_pass; +class shmem_info; struct segment { slot_index start; @@ -30,7 +30,7 @@ struct segment { } }; -class liveness { +class shmem_liveness { private: typedef std::map indices_map_t; typedef std::map intervals_map_t; @@ -43,7 +43,7 @@ public: public: // constructor - liveness(buffer_info_pass *info): info_(info){ } + shmem_liveness(shmem_info *info): info_(info){ } // accessors const intervals_map_t& intervals() const { return intervals_; } @@ -53,7 +53,7 @@ public: void run(ir::module &mod); private: - buffer_info_pass *info_; + shmem_info *info_; has_storage_map_t has_dedicated_storage_; indices_map_t indices_; intervals_map_t intervals_; diff --git a/include/triton/codegen/target.h b/include/triton/codegen/target.h index e2dc4518a..9079fc869 100644 --- a/include/triton/codegen/target.h +++ b/include/triton/codegen/target.h @@ -24,6 +24,7 @@ public: virtual llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder) = 0; virtual llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax) = 0; virtual llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0; + virtual llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax) = 0; bool is_gpu() const; private: @@ -37,6 +38,7 @@ public: llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder); llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax); llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax); + llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax); }; class nvidia_cu_target: public target { @@ -46,6 +48,7 @@ public: llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder); llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax); llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax); + llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax); }; class cpu_target: public target { @@ -55,6 +58,7 @@ public: llvm::Instruction* add_barrier(llvm::Module *module, llvm::IRBuilder<>& builder); llvm::Value* get_global_offset(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned stride, unsigned ax); llvm::Value* get_local_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax); + llvm::Value* get_block_id(llvm::Module *module, llvm::IRBuilder<>& builder, unsigned ax); }; } diff --git a/include/triton/driver/module.h b/include/triton/driver/module.h index 89ff3c39d..7fe2d6f06 100755 --- a/include/triton/driver/module.h +++ b/include/triton/driver/module.h @@ -90,7 +90,7 @@ class cu_module: public module { public: cu_module(driver::context* context, llvm::Module *module); cu_module(driver::context* context, const std::string& source); - cu_buffer symbol(const char * name) const; + cu_buffer* symbol(const char * name) const; private: std::string source_; diff --git a/include/triton/ir/builder.h b/include/triton/ir/builder.h index 852f55aa9..51dd656d3 100644 --- a/include/triton/ir/builder.h +++ b/include/triton/ir/builder.h @@ -67,6 +67,7 @@ public: value* create_fp_ext(value *src, type *dst_ty, const std::string &name = ""); value* create_fp_trunc(value *src, type *dst_ty, const std::string &name = ""); value* create_int_cast(value *src, type *dst_ty, bool is_signed, const std::string &name = ""); + value *create_downcast(value *arg, const std::string &name = ""); // Phi instruction phi_node* create_phi(type *ty, unsigned num_reserved, const std::string &name = ""); // Binary instructions @@ -124,7 +125,11 @@ public: value *create_broadcast(value *arg, const type::tile_shapes_t &shapes, const std::string &name = ""); // Built-in instruction value *create_get_global_range(unsigned axis, type::tile_shapes_t::value_type size, const std::string &name = ""); - value *create_matmul(value *A, value *B, value *C, const std::string &name = ""); + value *create_get_range_id(unsigned axis, const std::string &name = ""); + value *create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name = ""); + value *create_dot(value *A, value *B, value *C, const std::string &name = ""); + value *create_trans(value *A, const std::string &name = ""); + value *create_select(value *pred, value *if_value, value *else_value, const std::string &name = ""); // Intrinsics value *create_copy_to_shared(value *arg, const std::string &name = ""); value *create_vectorize(value *arg, const std::string &name = ""); diff --git a/include/triton/ir/constant.h b/include/triton/ir/constant.h index 0c18787ea..43aa41c6d 100644 --- a/include/triton/ir/constant.h +++ b/include/triton/ir/constant.h @@ -54,6 +54,7 @@ public: void set_value(uint64_t value) { has_value_ = true; value_ = value; } bool has_value() { return has_value_; } const std::vector& get_space() { return space_; } + void set_space(const std::vector &space) { space_ = space; } private: std::vector space_; diff --git a/include/triton/ir/instructions.h b/include/triton/ir/instructions.h index 2d8e7d91d..961bb43ce 100644 --- a/include/triton/ir/instructions.h +++ b/include/triton/ir/instructions.h @@ -464,6 +464,17 @@ public: }; +// downcast + +class downcast_inst: public unary_inst { +private: + using unary_inst::unary_inst; + std::string repr_impl() const { return "downcast"; } + +public: + static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr); +}; + //===----------------------------------------------------------------------===// // builtin_inst classes //===----------------------------------------------------------------------===// @@ -488,17 +499,76 @@ private: unsigned axis_; }; -class matmul_inst: public builtin_inst { +class get_range_id_inst: public builtin_inst { private: - matmul_inst(value *A, value *B, value *C, const std::string &name, instruction *next); - std::string repr_impl() const { return "dot"; } + get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next); + std::string repr_impl() const { return "get_range_id(" + std::to_string(axis_) + ")"; } public: - static instruction* create(value *A, value *B, value *C, - const std::string &name = "", - instruction *next = nullptr); + static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr); + unsigned get_axis() const { return axis_; } + +private: + unsigned axis_; }; +class atomic_cas_inst: public builtin_inst { +private: + atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next); + std::string repr_impl() const { return "atomic_cas"; } + +public: + static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr); +}; + +class dot_inst: public builtin_inst { +public: + enum TransT { NoTrans, Trans }; + +private: + dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next); + std::string repr_impl() const { return std::string("dot.") + ((AT_==NoTrans)?"n":"t") + ((BT_==NoTrans)?"n":"t"); } + +public: + static instruction* create_nn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); + static instruction* create_nt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); + static instruction* create_tn(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); + static instruction* create_tt(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); + bool is_a_trans() { return AT_ == Trans; } + bool is_b_trans() { return BT_ == Trans; } + +private: + TransT AT_; + TransT BT_; +}; + +//class outer_inst: public builtin_inst { +//private: +// outer_inst(value *A, value *B, value *C, const std::string &name, instruction *next); +//public: +// static instruction* create(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); +//}; + +class trans_inst: public builtin_inst { +public: + ir::type* get_res_ty(ir::type* in); + +private: + trans_inst(value *arg, const std::string& name, instruction* next); + std::string repr_impl() const { return "trans"; } + +public: + static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr); +}; + +class select_inst: public builtin_inst { +private: + select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next); + std::string repr_impl() const { return "select"; } + +public: + static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr); +}; //===----------------------------------------------------------------------===// // intrinsics classes diff --git a/include/triton/ir/module.h b/include/triton/ir/module.h index 3d2d5afb9..13d99d436 100644 --- a/include/triton/ir/module.h +++ b/include/triton/ir/module.h @@ -66,6 +66,7 @@ public: // Getters value *get_value(const std::string& name, basic_block* block); value *get_value(const std::string& name); + const std::string& get_name(); std::function get_continue_fn(); // Seal block -- no more predecessors will be added void seal_block(basic_block *block); diff --git a/include/triton/jit.h b/include/triton/jit.h index c4809d254..b9c502aad 100644 --- a/include/triton/jit.h +++ b/include/triton/jit.h @@ -10,13 +10,15 @@ #include "triton/driver/kernel.h" #include "triton/codegen/selection.h" #include "triton/codegen/tune.h" -#include "triton/codegen/shared_copy.h" -#include "triton/codegen/allocation.h" -#include "triton/codegen/liveness.h" -#include "triton/codegen/vectorize.h" -#include "triton/codegen/buffer_info.h" -#include "triton/codegen/barriers.h" +#include "triton/codegen/optimize_dot.h" +#include "triton/codegen/optimize_cse.h" +#include "triton/codegen/optimize_trans.h" +#include "triton/codegen/shmem_allocation.h" +#include "triton/codegen/shmem_liveness.h" +#include "triton/codegen/shmem_info.h" +#include "triton/codegen/shmem_barriers.h" #include "triton/codegen/target.h" +#include "triton/codegen/vectorize.h" #include namespace llvm { @@ -45,48 +47,59 @@ public: struct passes_wrapper { passes_wrapper(codegen::target* target) - : shared(&buffer_info), liveness(&buffer_info), - allocation(&liveness, &buffer_info), - barriers(&allocation, &buffer_info), + : shmem_liveness(&shmem_info), + shmem_allocation(&shmem_liveness, &shmem_info), + shmem_barriers(&shmem_allocation, &shmem_info), vectorize(&tune), - selection(&allocation, &tune, &buffer_info, target), + selection(&shmem_allocation, &tune, &shmem_info, target), + optimize_dot(&tune), + optimize_cse(), + optimize_trans(), target_(target) { } - void init(ir::module &module) { + void target_independent(ir::module &module) { + optimize_dot.run(module); + optimize_trans.run(module); +// ir::print(module, std::cout); + } + + void target_dependent(ir::module &module) { if(target_->is_gpu()){ - buffer_info.run(module); - shared.run(module); - liveness.run(module); - allocation.run(); - barriers.run(module); + shmem_info.run(module); + shmem_liveness.run(module); + shmem_allocation.run(); + shmem_barriers.run(module); } vectorize.run(module); } codegen::tune tune; - codegen::buffer_info_pass buffer_info; - codegen::place_shared_copy shared; - codegen::liveness liveness; - codegen::allocation allocation; - codegen::barriers barriers; + codegen::shmem_info shmem_info; + codegen::shmem_liveness shmem_liveness; + codegen::shmem_allocation shmem_allocation; + codegen::shmem_barriers shmem_barriers; codegen::vectorize vectorize; codegen::selection selection; + codegen::optimize_dot optimize_dot; + codegen::optimize_cse optimize_cse; + codegen::optimize_trans optimize_trans; codegen::target* target_; }; private: std::string compute_data_layout(bool is_64bit = true, bool use_short_pointers = true); std::unique_ptr make_llvm_module(triton::ir::module &module, passes_wrapper &passes); - std::unique_ptr make_triton_module(const std::string &src); + std::unique_ptr make_triton_module(const std::string &name, const std::string &src); public: jit(driver::context* context); - void autotune(const std::string &src, benchmark_t benchmark); + void autotune(const std::string &name, const std::string &src, benchmark_t benchmark); void add_module(ir::module &module, const std::vector& params = {}); - void add_module(const std::string &src, const std::vector& params = {}); + void add_module(const std::string &name, const std::string &src, const std::vector& params = {}); driver::kernel* get_function(const std::string &name); launch_information get_launch_info(const std::string &name); unsigned get_int(const std::string &name); + driver::buffer *get_buffer(const std::string &name); private: std::vector modules_; diff --git a/lib/ast/lowering.cpp b/lib/ast/lowering.cpp index 77ba26464..45906cef8 100644 --- a/lib/ast/lowering.cpp +++ b/lib/ast/lowering.cpp @@ -95,55 +95,75 @@ void node::implicit_cast(ir::builder &builder, ir::value *&lhs, ir::value *&rhs, throw std::runtime_error("unreachable"); } -void node::implicit_broadcast(ir::module *mod, ir::value *&arg, ir::type *ty) { - ir::value *tmp = ir::undef_value::get(ty); - implicit_broadcast(mod, arg, tmp); -} - -void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs){ - ir::builder &builder = mod->get_builder(); +void node::implicit_broadcast(ir::module *mod, ir::value *&lhs, ir::value *&rhs) { ir::type *lhs_ty = lhs->get_type(); ir::type *rhs_ty = rhs->get_type(); - ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); - // Both are scalar + ir::type *res_ty = nullptr; if(!lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty()) return; - // One argument is scalar - if(lhs_ty->is_tile_ty() ^ rhs_ty->is_tile_ty()){ - auto &shapes = lhs_ty->is_tile_ty()?lhs_ty->get_tile_shapes():rhs_ty->get_tile_shapes(); - auto &scalar = lhs_ty->is_tile_ty()?rhs:lhs; - scalar = builder.create_splat(scalar, shapes); + else if(lhs_ty->is_tile_ty() && !rhs_ty->is_tile_ty()) + res_ty = lhs_ty; + else if(!lhs_ty->is_tile_ty() && rhs_ty->is_tile_ty()) + res_ty = rhs_ty; + else{ + auto lhs_shapes = lhs_ty->get_tile_shapes(); + auto rhs_shapes = rhs_ty->get_tile_shapes(); + size_t lhs_size = lhs_shapes.size(); + size_t rhs_size = rhs_shapes.size(); + size_t res_size = std::max(lhs_size, rhs_size); + ir::type::tile_shapes_t res_shapes(res_size); + ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); + for(int i = 0; i < res_size; i++){ + if(i >= res_size - lhs_size && i >= res_size - rhs_size) + res_shapes[i] = lhs_shapes[i]==one?rhs_shapes[i]:lhs_shapes[i]; + else if(i >= res_size - lhs_size) + res_shapes[i] = lhs_shapes[i]; + else if(i >= res_size - rhs_size) + res_shapes[i] = rhs_shapes[i]; + } + res_ty = ir::tile_type::get(lhs_ty->get_scalar_ty(), res_shapes); + } + implicit_broadcast(mod, res_ty, rhs); + implicit_broadcast(mod, res_ty, lhs); +} + +void node::implicit_broadcast(ir::module *mod, ir::type *ty, ir::value *&src){ + ir::builder &builder = mod->get_builder(); + ir::type *src_ty = src->get_type(); + ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); + // Both are scalar + if(!ty->is_tile_ty() && !src_ty->is_tile_ty()) + return; + // Broadcast scalar + if(ty->is_tile_ty() && !src_ty->is_tile_ty()){ + src = builder.create_splat(src, ty->get_tile_shapes()); + return; + } + // Downcast tile + if(!ty->is_tile_ty() && src_ty->is_tile_ty()){ + for(ir::constant *shape: src_ty->get_tile_shapes()) + if(shape != one) + throw std::runtime_error("cannot downcast"); + src = builder.create_downcast(src); return; } // Both are arrays - auto lhs_shapes = lhs->get_type()->get_tile_shapes(); - auto rhs_shapes = rhs->get_type()->get_tile_shapes(); - if(lhs_shapes == rhs_shapes) - return; - int lhs_dim = lhs_shapes.size(); - int rhs_dim = rhs_shapes.size(); - auto &shortest = (lhs_dim < rhs_dim)?lhs_shapes:rhs_shapes; - auto &longest = (lhs_dim < rhs_dim)?rhs_shapes:lhs_shapes; - size_t ndim = longest.size(); - int off = longest.size() - shortest.size(); - for(int i = longest.size() - 1; i>= 0; i--){ - if(shortest[off + i] != longest[i] && shortest[off + i] != one && longest[i] != one) - throw std::runtime_error("cannot broadcast"); - } + auto dst_shapes = ty->get_tile_shapes(); + auto src_shapes = src_ty->get_tile_shapes(); + int dst_dim = dst_shapes.size(); + int src_dim = src_shapes.size(); // Pad + int off = dst_dim - src_dim; for(size_t i = 0; i < off; i++) - shortest.insert(shortest.begin(), one); - ir::value *&target = (lhs_dim < rhs_dim)?lhs:rhs; + src_shapes.insert(src_shapes.begin(), one); if(off > 0) - target = builder.create_reshape(target, shortest); + src = builder.create_reshape(src, src_shapes); // Broadcast - ir::type::tile_shapes_t shapes(ndim); - for(size_t i = 0; i < ndim; i++) - shapes[i] = shortest[i]==one?longest[i]:shortest[i]; - if(shapes != lhs_shapes) - lhs = builder.create_broadcast(lhs, shapes); - if(shapes != rhs_shapes) - rhs = builder.create_broadcast(rhs, shapes); + for(int i = dst_dim - 1; i>= 0; i--) + if(dst_shapes[i] != src_shapes[i] && dst_shapes[i] != one && src_shapes[i] != one) + throw std::runtime_error("cannot broadcast"); + if(dst_shapes != src_shapes) + src = builder.create_broadcast(src, dst_shapes); } /* Helper */ @@ -336,7 +356,9 @@ ir::value* iteration_statement::codegen(ir::module *mod) const{ return builder.create_cond_br(cond, loop_bb, next_bb); }); init_->codegen(mod); - builder.create_br(loop_bb); + ir::value *cond = stop_->codegen(mod); + builder.create_cond_br(cond, loop_bb, next_bb); +// builder.create_br(loop_bb); builder.set_insert_point(loop_bb); if(!is_terminator(statements_->codegen(mod))) mod->get_continue_fn()(); @@ -378,6 +400,7 @@ ir::value* selection_statement::codegen(ir::module* mod) const{ builder.create_br(endif_bb); } // Endif + mod->seal_block(endif_bb); builder.set_insert_point(endif_bb); return nullptr; } @@ -422,7 +445,7 @@ ir::value* initializer::codegen(ir::module * mod) const{ else if(expr_){ value = expr_->codegen(mod); value = explicit_cast(mod->get_builder(), value, ty); - implicit_broadcast(mod, value, ty); + implicit_broadcast(mod, ty, value); } value->set_name(name); mod->set_value(name, value); @@ -543,6 +566,19 @@ ir::value* get_global_range::codegen(ir::module *mod) const { return builder.create_get_global_range(axis_->value(), (ir::constant_int*)size_->codegen(mod)); } +// get_range_id +ir::value* get_range_id::codegen(ir::module *mod) const { + return mod->get_builder().create_get_range_id(axis_->value()); +} + +// atomic cas +ir::value* atomic_cas::codegen(ir::module *mod) const { + ir::value *ptr = ptr_->codegen(mod); + ir::value *cmp = cmp_->codegen(mod); + ir::value *val = val_->codegen(mod); + return mod->get_builder().create_atomic_cas(ptr, cmp, val); +} + // matmul ir::value* matmul_expression::codegen(ir::module *mod) const { ir::value *A = A_->codegen(mod); @@ -554,10 +590,37 @@ ir::value* matmul_expression::codegen(ir::module *mod) const { // ir::type *tile_ty = ir::tile_type::get(scalar_ty, {M, N}); // ir::value *tmp = ir::undef_value::get(tile_ty); // implicit_broadcast(mod, tmp, C); - return mod->get_builder().create_matmul(A, B, C); + return mod->get_builder().create_dot(A, B, C); } +// min +ir::value* min_expression::codegen(ir::module *mod) const { + ir::value* cmp = binary_operator(LT, (node*)x_, (node*)y_).codegen(mod); + ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0); + ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1); + return mod->get_builder().create_select(cmp, x, y); +} +// max +ir::value* max_expression::codegen(ir::module *mod) const { + ir::value* cmp = binary_operator(GT, (node*)x_, (node*)y_).codegen(mod); + ir::value* x = ((ir::cmp_inst*)cmp)->get_operand(0); + ir::value* y = ((ir::cmp_inst*)cmp)->get_operand(1); + return mod->get_builder().create_select(cmp, x, y); +} + +// select +ir::value* select_expression::codegen(ir::module *mod) const { + ir::value* pred = pred_->codegen(mod); + ir::value* if_value = if_value_->codegen(mod); + ir::value* else_value = else_value_->codegen(mod); + return mod->get_builder().create_select(pred, if_value, else_value); +} + +// Trans +ir::value* trans_expression::codegen(ir::module *mod) const { + return mod->get_builder().create_trans(arg_->codegen(mod)); +} /* Postfix expression */ ir::value* indexing_expression::codegen(ir::module *mod) const{ @@ -573,6 +636,7 @@ ir::value* indexing_expression::codegen(ir::module *mod) const{ return mod->get_builder().create_reshape(in, out_shapes); } + /* Unary operator */ ir::value *unary_operator::llvm_op(ir::builder &builder, ir::value *arg, const std::string &name) const{ ir::type *atype = arg->get_type(); @@ -666,7 +730,7 @@ ir::value *assignment_expression::codegen(ir::module *mod) const{ if(auto *x = dynamic_cast(lvalue_)){ ir::type *ty = mod->get_scope().types.at(x->id()->name()); rvalue = explicit_cast(mod->get_builder(), rvalue, ty); - implicit_broadcast(mod, rvalue, ty); + implicit_broadcast(mod, ty, rvalue); mod->set_value(x->id()->name(), rvalue); } else if(auto* x = dynamic_cast(lvalue_)){ diff --git a/lib/codegen/buffer_info.cpp b/lib/codegen/buffer_info.cpp deleted file mode 100644 index dff371a64..000000000 --- a/lib/codegen/buffer_info.cpp +++ /dev/null @@ -1,90 +0,0 @@ -#include "triton/codegen/buffer_info.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/instructions.h" -#include "triton/ir/type.h" - -namespace triton { - -namespace codegen{ - - -// run pass on module -bool buffer_info_pass::is_loop_latch(ir::phi_node *phi, ir::value *terminator){ - if(auto *br = dynamic_cast(terminator)) - return br->get_true_dest() == phi->get_parent() - || br->get_false_dest() == phi->get_parent(); - else if(auto *br = dynamic_cast(terminator)) - return false; - else - throw std::runtime_error("unreachable"); -} - -void buffer_info_pass::replace(ir::value* before, ir::value *after) { - shared_.erase(before); - shared_.insert(after); - if(refs_.find(before) != refs_.end()){ - ir::value* v = refs_.at(before); - refs_.erase(before); - refs_.insert({after, v}); - } -} - -void buffer_info_pass::run(ir::module &mod) { - // Find which buffers are shared - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i: block->get_inst_list()) - if(dynamic_cast(i)){ - shared_.insert(i->get_operand(0)); - shared_.insert(i->get_operand(1)); - } - - // Handles phi nodes - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i: block->get_inst_list()) { - if(!i->get_type()->is_tile_ty()) - continue; - // handle phi - if(auto *phi = dynamic_cast(i)) - if(is_shared(phi)){ - // determine if the value is in shared memory - bool is_double = false; - for(unsigned n = 0; n < phi->get_num_incoming(); n++){ - ir::basic_block *inc_block = phi->get_incoming_block(n); - ir::value *terminator = inc_block->get_inst_list().back(); - is_double = is_double || is_loop_latch(phi, terminator); - } - // add to double-buffered - if(is_double) - double_.insert(phi); - // set references of input - for(unsigned n = 0; n < phi->get_num_incoming(); n++){ - ir::value *inc_val = phi->get_incoming_value(n); - refs_[inc_val] = phi; - } - } - } - - for(auto &ref: refs_) - shared_.insert(ref.first); -} - -// query double-buffered status -bool buffer_info_pass::is_double(ir::value *x) -{ return double_.find(x) != double_.end(); } - -// query shared status -bool buffer_info_pass::is_shared(ir::value *x) -{ return shared_.find(x) != shared_.end(); } - -// get reference if any -ir::value *buffer_info_pass::get_reference(ir::value *x) -{ return refs_[x]; } - - - -} -} diff --git a/lib/codegen/layout.cpp b/lib/codegen/layout.cpp deleted file mode 100644 index 0722321b8..000000000 --- a/lib/codegen/layout.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#include "triton/codegen/layout.h" -#include "triton/ir/function.h" -#include "triton/ir/module.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/instructions.h" - -namespace triton{ -namespace codegen{ - - -shared_view_info layout::get_shared_view(ir::value *v, unsigned idx){ - return shared_views_.at(v)[idx]; -} - -unsigned layout::get_num_shared_views(ir::value *v){ - return shared_views_.at(v).size(); -} - -// Phi node -void layout::add_phi_nodes(ir::value *v){ - if(ir::phi_node *phi = dynamic_cast(v)) - if(shared_views_.find(phi) != shared_views_.end()) - for(ir::value *v: phi->ops()){ - shared_views_[v] = shared_views_[phi]; - for(shared_view_info &info: shared_views_[v]) - info.has_dedicated_storage = false; - } -} - -// Memory Layout -void layout::add_shared_views(ir::value *v){ - // GEMM has shared inputs - if(dynamic_cast(v)) - shared_views_[v].push_back({v, true}); - if(dynamic_cast(v)) - shared_views_[v].push_back({v, true}); -} - -// Entry point -void layout::run(ir::module &mod) { -for(ir::function *fn: mod.get_function_list()){ - // Non-phis - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *instr: block->get_inst_list()) { - add_shared_views(instr); - } - // Phi nodes - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *instr: block->get_inst_list()) { - add_phi_nodes(instr); - } -} -} - -} -} diff --git a/lib/codegen/loop_info.cpp b/lib/codegen/loop_info.cpp deleted file mode 100644 index e69de29bb..000000000 diff --git a/lib/codegen/optimize_cse.cpp b/lib/codegen/optimize_cse.cpp new file mode 100644 index 000000000..b0c07a99e --- /dev/null +++ b/lib/codegen/optimize_cse.cpp @@ -0,0 +1,14 @@ +#include "triton/ir/function.h" +#include "triton/ir/basic_block.h" +#include "triton/ir/module.h" +#include "triton/codegen/optimize_cse.h" + +namespace triton { +namespace codegen{ + + +void optimize_cse::run(ir::module &mod) { +} + +} +} diff --git a/lib/codegen/optimize_dot.cpp b/lib/codegen/optimize_dot.cpp new file mode 100644 index 000000000..67e3f8569 --- /dev/null +++ b/lib/codegen/optimize_dot.cpp @@ -0,0 +1,50 @@ +#include "triton/ir/function.h" +#include "triton/ir/basic_block.h" +#include "triton/ir/module.h" +#include "triton/codegen/optimize_dot.h" +#include "triton/codegen/tune.h" + +namespace triton { +namespace codegen{ + +inline bool is_trans(ir::value *v){ + return dynamic_cast(v) != nullptr; +} + +void optimize_dot::run(ir::module &mod) { + ir::builder &builder = mod.get_builder(); + std::vector to_delete; + // iterate + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction *i: block->get_inst_list()) + if(auto dot = dynamic_cast(i)) + if(dot->get_operand(1)->get_type()->get_tile_shapes()[1]->get_value() != 1) + if(!dot->is_a_trans() && !dot->is_b_trans()){ + builder.set_insert_point(i); + ir::value *A = dot->get_operand(0); + ir::value *B = dot->get_operand(1); + ir::value *D = dot->get_operand(2); + // dot(op(a), trans(b)) + if(is_trans(B)){ + ir::value* BN = ((ir::trans_inst*)B)->get_operand(0); + ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BN, D)); + dot->replace_all_uses_with(NT); + to_delete.push_back((ir::instruction*)B); + to_delete.push_back(dot); + } + // dot(op(a), b) + if(!is_trans(B)){ + ir::value* BT = builder.create_trans(B); + ir::instruction *NT = builder.insert(ir::dot_inst::create_nt(A, BT, D)); + dot->replace_all_uses_with(NT); + to_delete.push_back(dot); + } + } + + for(ir::instruction* i: to_delete) + i->erase_from_parent(); +} + +} +} diff --git a/lib/codegen/optimize_trans.cpp b/lib/codegen/optimize_trans.cpp new file mode 100644 index 000000000..b6ad7cfd2 --- /dev/null +++ b/lib/codegen/optimize_trans.cpp @@ -0,0 +1,71 @@ +#include "triton/ir/module.h" +#include "triton/ir/function.h" +#include "triton/codegen/optimize_trans.h" + +namespace triton { +namespace codegen{ + + +ir::value* optimize_trans::replace_phi(ir::value* value, + std::vector& to_delete, + ir::builder& builder){ + if(auto phi = dynamic_cast(value)) { + // transpose operands + std::vector incs; + for(unsigned n = 0; n < phi->get_num_incoming(); n++) + incs.push_back(replace_phi(phi->get_incoming_value(n), to_delete, builder)); + // create phi for transposed values + builder.set_insert_point(phi); + ir::phi_node* result = builder.create_phi(incs[0]->get_type(), incs.size(), phi->get_name()); + for(unsigned n = 0; n < phi->get_num_incoming(); n++) + result->add_incoming(incs[n], phi->get_incoming_block(n)); + phi->replace_all_uses_with(result); + to_delete.push_back(phi); + return result; + } + else if(auto i = dynamic_cast(value)){ + ir::basic_block* block = i->get_parent(); + auto it = std::find(block->begin(), block->end(), i); + it++; + builder.set_insert_point(it); + ir::instruction *trans = (ir::instruction*)builder.create_trans(i); + i->replace_all_uses_with(trans); + trans->set_operand(0, i); + return trans; + } + throw std::runtime_error("cannot transpose phi"); +} + + +void optimize_trans::run(ir::module &mod) { + ir::builder &builder = mod.get_builder(); + std::vector to_delete; + // iterate + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction* i: block->get_inst_list()){ + // filter transposition + if(auto trans = dynamic_cast(i)) { + auto users = trans->get_users(); + auto ops = trans->ops(); + if(users.size() > 1 || ops.size() > 1) + continue; + ir::value* op = *ops.begin(); + // chains of transpositions + // TODO + + // trans(phi) -> phi(trans(), trans()...) + if(dynamic_cast(op)){ + ir::value* new_phi = replace_phi(op, to_delete, builder); + to_delete.push_back(trans); + trans->replace_all_uses_with(new_phi); + } + } + } + // erase dead code + for(ir::instruction* i: to_delete) + i->erase_from_parent(); +} + +} +} diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index 546b0e76f..c04b4cdfb 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -1,6 +1,6 @@ #include "triton/codegen/selection.h" #include "triton/codegen/tune.h" -#include "triton/codegen/allocation.h" +#include "triton/codegen/shmem_allocation.h" #include "triton/codegen/target.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Module.h" @@ -309,7 +309,47 @@ Instruction *selection::llvm_inst(ir::instruction *inst, std::function(inst)){ Value *ptr = value(ii->get_pointer_operand()); - return builder.Insert(new LoadInst(ptr)); + LoadInst *result = new LoadInst(ptr); + return builder.Insert(result); + } + if(ir::store_inst* ii = dynamic_cast(inst)){ + Value *val = value(ii->get_value_operand()); + Value *ptr = value(ii->get_pointer_operand()); + builder.CreateStore(val, ptr); + return nullptr; + } + if(ir::select_inst* ii = dynamic_cast(inst)){ + Value *pred = value(ii->get_operand(0)); + Value *if_value = value(ii->get_operand(1)); + Value *else_value = value(ii->get_operand(2)); + return builder.Insert(SelectInst::Create(pred, if_value, else_value)); + } + if(ir::get_range_id_inst* ii = dynamic_cast(inst)){ + Value *offset = tgt_->get_block_id(builder.GetInsertBlock()->getModule(), builder, ii->get_axis()); + return (Instruction*)builder.CreateAdd(offset, builder.getInt32(0)); + } + if(ir::atomic_cas_inst* ii = dynamic_cast(inst)){ + BasicBlock *current = builder.GetInsertBlock(); + Module *module = current->getModule(); + Value *tid = tgt_->get_local_id(module, builder, 0); + Value *pred = builder.CreateICmpEQ(tid, builder.getInt32(0)); + BasicBlock *tid_0_bb = BasicBlock::Create(ctx, "tid_0", current->getParent()); + BasicBlock *tid_0_done_bb = BasicBlock::Create(ctx, "tid_0_done", current->getParent()); + Value *ptr = builder.CreateGEP(sh_mem_ptr_, builder.getInt32(alloc_->get_offset(ii))); + ptr = builder.CreateBitCast(ptr, PointerType::get(builder.getInt32Ty(), ptr->getType()->getPointerAddressSpace())); + builder.CreateCondBr(pred, tid_0_bb, tid_0_done_bb); + builder.SetInsertPoint(tid_0_bb); + Value *cas_ptr = value(ii->get_operand(0)); + Value *cas_cmp = value(ii->get_operand(1)); + Value *cas_val = value(ii->get_operand(2)); + Value *old = builder.CreateAtomicCmpXchg(cas_ptr, cas_cmp, cas_val, AtomicOrdering::Monotonic, AtomicOrdering::Monotonic); + old = builder.CreateExtractValue(old, {0}); + builder.CreateStore(old, ptr); + builder.CreateBr(tid_0_done_bb); + builder.SetInsertPoint(tid_0_done_bb); + tgt_->add_barrier(module, builder); + Value *res = builder.CreateLoad(ptr); + return (Instruction*)res; } // unknown instruction throw std::runtime_error("unknown conversion from ir::instruction to Instruction"); @@ -446,7 +486,7 @@ void selection::create_grids(std::vector &grids, bind_references(op); // bind const auto& shapes = v->get_type()->get_tile_shapes(); - if(dynamic_cast(v) || buffer_info_->is_double(v)) + if(buffer_info_->is_shared(v)) return; for(size_t d = 0; d < shapes.size(); d++){ if(shapes[d]->get_value() == 1) @@ -490,20 +530,11 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, shapes2.push_back(shape->get_value()); Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx); // create shared tile - if(dynamic_cast(v) || (buffer_info_->is_double(v))){ + if(buffer_info_->is_shared(v)){ // shared copy PointerType *ptr_ty = ty->getPointerTo(sh_mem_ptr->getType()->getPointerAddressSpace()); - // TODO - buffer info not up-to-date with references - if(dynamic_cast(v)) { - if(!has_phi_user(v)){ - size_t offset = alloc_->get_offset(v); - Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset)); - ptr = builder.CreateBitCast(ptr, ptr_ty); - tmap_.insert({v, new shared_tile(ty, shapes2, ptr, builder)}); - } - } // phi-node (double-buffering) - else if(auto *phi = dynamic_cast(v)) { + if(auto *phi = dynamic_cast(v)) { BasicBlock *parent = (BasicBlock*)vmap_[phi->get_parent()]; unsigned id_pre = 0, id_loop = 1; if(phi->get_incoming_block(0) == phi->get_parent()) @@ -522,13 +553,19 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, for(unsigned i = 0; i < phi->get_num_incoming(); i++) { ir::basic_block* inc_block = phi->get_incoming_block(i); ir::value* inc_value = phi->get_incoming_value(i); - ir::value* terminator = inc_block->get_inst_list().back(); + ir::instruction* terminator = inc_block->get_inst_list().back(); bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator); tmap_.insert({inc_value, new shared_tile(ty, shapes2, is_loop_latch?next_ptr:pre_ptr, builder)}); } } - else - throw std::runtime_error("unknown shared memory tile"); + else { + if(!has_phi_user(v)){ + size_t offset = alloc_->get_offset(v); + Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset)); + ptr = builder.CreateBitCast(ptr, ptr_ty); + tmap_.insert({v, new shared_tile(ty, shapes2, ptr, builder)}); + } + } } // create distributed tile else { @@ -607,10 +644,16 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & tile *value = tmap_.at(x->get_value_operand()); ptr->for_each([&](indices_t idx){ set_mask_insert_pt(idx); - builder.CreateStore(value->get_value(idx), ptr->get_value(idx)); + StoreInst *store = new StoreInst(value->get_value(idx), ptr->get_value(idx)); +// store->setAlignment(16); + builder.Insert(store); }); } else { + if(auto *x = dynamic_cast(ins)){ + vmap_[x] = tmap_[x->get_operand(0)]->get_value({builder.getInt32(0)}); + return; + } tile *ti = tmap_[ins]; distributed_tile* result = (distributed_tile*)ti; if(!ins->get_type()->is_tile_ty()) @@ -727,31 +770,67 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & ti->set_value(idx, in->get_value(idx)); }); } - else if(dynamic_cast(ins) || (buffer_info_->is_double(ins))) + // trans + else if(dynamic_cast(ins)) { + distributed_tile* in = (distributed_tile*)tmap_.at(ins->get_operand(0)); + in->for_each([&](indices_t idx){ + indices_t out_idx = idx; + std::rotate(out_idx.begin(), out_idx.begin() + 1, out_idx.end()); + ti->set_value(out_idx, in->get_value(idx)); + }); + } + else if(buffer_info_->is_shared(ins)) return; - // matrix multiplication - else if(dynamic_cast(ins)) { + // dot + else if(auto dot = dynamic_cast(ins)) { ir::value *A = ins->get_operand(0); ir::value *B = ins->get_operand(1); ir::value *C = ins->get_operand(2); - shared_tile *TA = (shared_tile*)tmap_.at(A); - shared_tile *TB = (shared_tile*)tmap_.at(B); + bool AT = dot->is_a_trans(); + bool BT = dot->is_b_trans(); distributed_tile *TC = (distributed_tile*)tmap_.at(C); - TA->set_vector_size(TC->axis(0).contiguous); - TB->set_vector_size(TC->axis(1).contiguous); Function *f_mul_add = Intrinsic::getDeclaration(module, Intrinsic::fmuladd, {llvm_type(C->get_type()->get_scalar_ty(), ctx)}); - result->for_each([&](indices_t idx){ - Value *res = TC->get_value(idx); - unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value(); - for(unsigned K = 0; K < NK; ++K){ - indices_t a_idx = {idx[0], builder.getInt32(K)}; - indices_t b_idx = {idx[1], builder.getInt32(K)}; + if(dot->get_operand(0)->get_type()->get_tile_shapes()[1]->get_value() != 1) + { + shared_tile *TA = (shared_tile*)tmap_.at(A); + shared_tile *TB = (shared_tile*)tmap_.at(B); + TA->set_vector_size(TC->axis(0).contiguous); + TB->set_vector_size(TC->axis(1).contiguous); + result->for_each([&](indices_t idx){ + Value *res = TC->get_value(idx); + unsigned NK = A->get_type()->get_tile_shapes()[1]->get_value(); + for(unsigned K = 0; K < NK; ++K){ + indices_t a_idx = {idx[0], builder.getInt32(K)}; + indices_t b_idx = {builder.getInt32(K), idx[1]}; + if(AT) + std::swap(a_idx[0], a_idx[1]); + if(BT) + std::swap(b_idx[0], b_idx[1]); + Value *a = TA->get_value(a_idx); + Value *b = TB->get_value(b_idx); + res = builder.CreateCall(f_mul_add, {a, b, res}); + } + result->set_value(idx, res); + }); + } + else + { + distributed_tile *TA = (distributed_tile*)tmap_.at(A); + distributed_tile *TB = (distributed_tile*)tmap_.at(B); + result->for_each([&](indices_t idx){ + Value *res = TC->get_value(idx); + indices_t a_idx = {idx[0], builder.getInt32(0)}; + indices_t b_idx = {builder.getInt32(0), idx[1]}; + if(AT) + std::swap(a_idx[0], a_idx[1]); + if(BT) + std::swap(b_idx[0], b_idx[1]); Value *a = TA->get_value(a_idx); Value *b = TB->get_value(b_idx); res = builder.CreateCall(f_mul_add, {a, b, res}); - } - result->set_value(idx, res); - }); + result->set_value(idx, res); + }); + } } // element-wise else { @@ -858,6 +937,7 @@ void selection::run(ir::module &src, Module &dst) { nullptr, "__shared_ptr", nullptr, GlobalVariable::NotThreadLocal, 3); sh_mem_ptr = dst_builder.CreateBitCast(sh_mem_array, ptr_ty); } + sh_mem_ptr_ = sh_mem_ptr; // create grids init_grids(fn, dst_builder, sh_mem_ptr); @@ -890,7 +970,7 @@ void selection::run(ir::module &src, Module &dst) { for(unsigned n = 0; n < phi->get_num_incoming(); n++){ ir::basic_block* inc_block = phi->get_incoming_block(n); ir::value* inc_val = phi->get_incoming_value(n); - ir::value* terminator = inc_block->get_inst_list().back(); + ir::instruction* terminator = inc_block->get_inst_list().back(); BasicBlock *llvm_inc_block = last_block.at(inc_block); shared_tile *inc_shared = (shared_tile*)tmap_.at(inc_val); bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator); @@ -920,8 +1000,8 @@ void selection::run(ir::module &src, Module &dst) { }); } else { - PHINode *llvm_phi = (PHINode*)vmap_.at(phi); - Value *llvm_inc_val = vmap_.at(inc_val); + PHINode *llvm_phi = (PHINode*)llvm_value(phi, dst_builder); + Value *llvm_inc_val = llvm_value(inc_val, dst_builder); llvm_phi->addIncoming(llvm_inc_val, llvm_inc_block); } } diff --git a/lib/codegen/shared_copy.cpp b/lib/codegen/shared_copy.cpp deleted file mode 100644 index 6c05b7807..000000000 --- a/lib/codegen/shared_copy.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include -#include "triton/codegen/shared_copy.h" -#include "triton/codegen/buffer_info.h" -#include "triton/ir/module.h" -#include "triton/ir/function.h" -#include "triton/ir/basic_block.h" -#include "triton/ir/instructions.h" - -namespace triton { - -namespace codegen{ - -void place_shared_copy::add_copy(ir::value *x, ir::builder &builder) { - if(auto *i = dynamic_cast(x)){ - ir::basic_block* block = i->get_parent(); - auto it = std::find(block->begin(), block->end(), i); - builder.set_insert_point(++it); - } - ir::instruction *rx = (ir::instruction*)builder.create_copy_to_shared(x); - x->replace_all_uses_with(rx); - rx->set_operand(0, x); -} - -void place_shared_copy::run(ir::module &mod) { - ir::builder &builder = mod.get_builder(); - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i: block->get_inst_list()) - if(info_->is_shared(i) && !info_->is_double(i)) - add_copy(i, builder); - - for(ir::function *fn: mod.get_function_list()) - for(ir::basic_block *block: fn->blocks()) - for(ir::instruction *i: block->get_inst_list()) - if(auto* cts = dynamic_cast(i)) - info_->replace(cts->get_operand(0), cts); -} - -} -} diff --git a/lib/codegen/allocation.cpp b/lib/codegen/shmem_allocation.cpp similarity index 91% rename from lib/codegen/allocation.cpp rename to lib/codegen/shmem_allocation.cpp index fd272a243..43ab8bc39 100644 --- a/lib/codegen/allocation.cpp +++ b/lib/codegen/shmem_allocation.cpp @@ -1,7 +1,6 @@ -#include "triton/codegen/allocation.h" -#include "triton/codegen/liveness.h" -#include "triton/codegen/layout.h" -#include "triton/codegen/buffer_info.h" +#include "triton/codegen/shmem_allocation.h" +#include "triton/codegen/shmem_liveness.h" +#include "triton/codegen/shmem_info.h" #include "triton/ir/basic_block.h" #include "triton/ir/type.h" #include "triton/ir/value.h" @@ -11,14 +10,14 @@ namespace triton{ namespace codegen{ -unsigned allocation::get_num_bytes(ir::value *x) { - unsigned result = x->get_type()->get_tile_bitwidth() / 8; +unsigned shmem_allocation::get_num_bytes(ir::value *x) { + unsigned result = x->get_type()->get_primitive_size_in_bits() / 8; if(buffer_info_->is_double(x)) result *= 2; return result; } -void allocation::run(){ +void shmem_allocation::run(){ using std::max; using std::min; typedef std::multimap triples_map_type; diff --git a/lib/codegen/barriers.cpp b/lib/codegen/shmem_barriers.cpp similarity index 75% rename from lib/codegen/barriers.cpp rename to lib/codegen/shmem_barriers.cpp index bb3611f85..717b927fd 100644 --- a/lib/codegen/barriers.cpp +++ b/lib/codegen/shmem_barriers.cpp @@ -1,7 +1,7 @@ #include -#include "triton/codegen/barriers.h" -#include "triton/codegen/allocation.h" -#include "triton/codegen/buffer_info.h" +#include "triton/codegen/shmem_barriers.h" +#include "triton/codegen/shmem_allocation.h" +#include "triton/codegen/shmem_info.h" #include "triton/ir/module.h" #include "triton/ir/function.h" #include "triton/ir/basic_block.h" @@ -12,7 +12,7 @@ namespace triton { namespace codegen{ -bool barriers::intersect(const interval_vec_t &X, interval_t x) { +bool shmem_barriers::intersect(const interval_vec_t &X, interval_t x) { return std::any_of(X.begin(), X.end(), [&](const interval_t &y){ bool left_intersect = y.first <= x.first && x.first < y.second; bool right_intersect = y.first <= x.second && x.second < y.second; @@ -20,31 +20,31 @@ bool barriers::intersect(const interval_vec_t &X, interval_t x) { }); } -bool barriers::intersect(const interval_vec_t &X, const interval_vec_t &Y) { +bool shmem_barriers::intersect(const interval_vec_t &X, const interval_vec_t &Y) { return std::any_of(Y.begin(), Y.end(), [&](const interval_t &y){ return intersect(X, y); }); } -void barriers::add_reference(ir::value *v, interval_vec_t &res){ - if(dynamic_cast(v)){ +void shmem_barriers::add_reference(ir::value *v, interval_vec_t &res){ + if(buffer_info_->is_shared(v) && !dynamic_cast(v)){ unsigned offset = alloc_->get_offset(v); unsigned num_bytes = alloc_->get_num_bytes(v); res.push_back(interval_t(offset, offset + num_bytes)); } } -void barriers::get_read_intervals(ir::instruction *i, interval_vec_t &res){ +void shmem_barriers::get_read_intervals(ir::instruction *i, interval_vec_t &res){ for(ir::value *op: i->ops()) add_reference(op, res); } -void barriers::get_written_intervals(ir::instruction *i, interval_vec_t &res){ +void shmem_barriers::get_written_intervals(ir::instruction *i, interval_vec_t &res){ if(!dynamic_cast(i)) add_reference(i, res); } -void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) { +void shmem_barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) { if(auto *phi = dynamic_cast(instr)) { std::set incoming; for(unsigned n = 0; n < phi->get_num_incoming(); n++){ @@ -63,16 +63,16 @@ void barriers::insert_barrier(ir::instruction *instr, ir::builder &builder) { } } -barriers::interval_vec_t barriers::join(const std::vector& intervals) { - barriers::interval_vec_t result; +shmem_barriers::interval_vec_t shmem_barriers::join(const std::vector& intervals) { + shmem_barriers::interval_vec_t result; for(auto x: intervals) for(interval_t i: x) result.push_back(i); return result; } -std::pair barriers::transfer(ir::basic_block *block, +std::pair shmem_barriers::transfer(ir::basic_block *block, const interval_vec_t &written_to, const interval_vec_t &read_from, std::set& insert_loc) { @@ -83,13 +83,13 @@ std::pair(i) && + if(buffer_info_->is_shared(i) && buffer_info_->is_double(buffer_info_->get_reference(i))) - written_while_read = false; - if(read_while_written || written_while_read) { + write_after_read = false; + if(read_after_write || write_after_read) { insert_loc.insert(i); new_written_to.clear(); new_read_from.clear(); @@ -100,7 +100,7 @@ std::pair rpo = ir::cfg::reverse_post_order(fn); diff --git a/lib/codegen/shmem_info.cpp b/lib/codegen/shmem_info.cpp new file mode 100644 index 000000000..6d3caafab --- /dev/null +++ b/lib/codegen/shmem_info.cpp @@ -0,0 +1,135 @@ +#include "triton/codegen/shmem_info.h" +#include "triton/ir/module.h" +#include "triton/ir/function.h" +#include "triton/ir/basic_block.h" +#include "triton/ir/instructions.h" +#include "triton/ir/type.h" + +namespace triton { + +namespace codegen{ + + +// run pass on module +bool shmem_info::is_loop_latch(ir::phi_node *phi, ir::instruction *terminator){ + if(phi->get_parent() != terminator->get_parent()) + return false; + if(auto *br = dynamic_cast(terminator)) + return br->get_true_dest() == phi->get_parent() + || br->get_false_dest() == phi->get_parent(); + else if(auto *br = dynamic_cast(terminator)) + return false; + else + throw std::runtime_error("unreachable"); +} + +void shmem_info::replace(ir::value* before, ir::value *after) { + shared_.erase(before); + shared_.insert(after); + if(refs_.find(before) != refs_.end()){ + ir::value* v = refs_.at(before); + refs_.erase(before); + refs_.insert({after, v}); + } +} + +inline bool get_is_shared(ir::value* v) { + if(auto x = dynamic_cast(v)) + return true; + if(auto x = dynamic_cast(v)) + return true; + if(auto x = dynamic_cast(v)) + return true; + if(auto x = dynamic_cast(v)){ + bool res = true; + for(unsigned inc = 0; inc < x->get_num_incoming(); inc++) + res = res && get_is_shared(x->get_incoming_value(inc)); + return res; + } + return false; +} + +void add_copy(ir::value *x, ir::builder &builder) { + if(auto phi = dynamic_cast(x)){ + for(unsigned i = 0; i < phi->get_num_incoming(); ++i) + add_copy(phi->get_incoming_value(i), builder); + } + else { + if(get_is_shared(x)) + return; + if(auto *i = dynamic_cast(x)){ + ir::basic_block* block = i->get_parent(); + auto it = std::find(block->begin(), block->end(), i); + builder.set_insert_point(++it); + } + ir::instruction *rx = (ir::instruction*)builder.create_copy_to_shared(x); + x->replace_all_uses_with(rx); + rx->set_operand(0, x); + } +} + +void shmem_info::run(ir::module &mod) { + // Add shared copies + for(ir::function *fn: mod.get_function_list()){ + ir::builder builder(mod.get_context()); + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction *i: block->get_inst_list()){ + if(dynamic_cast(i)) + if(i->get_operand(1)->get_type()->get_tile_shapes()[1]->get_value() != 1){ + add_copy(i->get_operand(0), builder); + add_copy(i->get_operand(1), builder); + } + } + } + + // Find which buffers are shared + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction *i: block->get_inst_list()) + if(get_is_shared(i)) + shared_.insert(i); + + // double-buffering + for(ir::function *fn: mod.get_function_list()) + for(ir::basic_block *block: fn->blocks()) + for(ir::instruction *i: block->get_inst_list()) { + if(!i->get_type()->is_tile_ty()) + continue; + // handle phi + if(auto *phi = dynamic_cast(i)) + if(is_shared(phi)){ + // determine if the value is in shared memory + bool is_double = false; + for(unsigned n = 0; n < phi->get_num_incoming(); n++){ + ir::basic_block *inc_block = phi->get_incoming_block(n); + ir::instruction *terminator = inc_block->get_inst_list().back(); + is_double = is_double || is_loop_latch(phi, terminator); + } + // add to double-buffered + if(is_double) + double_.insert(phi); + // set references of input + for(unsigned n = 0; n < phi->get_num_incoming(); n++){ + ir::value *inc_val = phi->get_incoming_value(n); + refs_[inc_val] = phi; + } + } + } +} + +// query double-buffered status +bool shmem_info::is_double(ir::value *x) +{ return double_.find(x) != double_.end(); } + +// query shared status +bool shmem_info::is_shared(ir::value *x) +{ return shared_.find(x) != shared_.end(); } + +// get reference if any +ir::value *shmem_info::get_reference(ir::value *x) +{ return refs_[x]; } + + + +} +} diff --git a/lib/codegen/liveness.cpp b/lib/codegen/shmem_liveness.cpp similarity index 67% rename from lib/codegen/liveness.cpp rename to lib/codegen/shmem_liveness.cpp index ca33bd487..4d8e9c66b 100644 --- a/lib/codegen/liveness.cpp +++ b/lib/codegen/shmem_liveness.cpp @@ -1,5 +1,5 @@ -#include "triton/codegen/liveness.h" -#include "triton/codegen/buffer_info.h" +#include "triton/codegen/shmem_liveness.h" +#include "triton/codegen/shmem_info.h" #include "triton/ir/basic_block.h" #include "triton/ir/function.h" #include "triton/ir/module.h" @@ -11,19 +11,7 @@ namespace codegen{ // Entry point -inline bool is_shared(ir::value* v) { - if(auto x = dynamic_cast(v)) - return true; - if(auto x = dynamic_cast(v)){ - bool res = true; - for(unsigned inc = 0; inc < x->get_num_incoming(); inc++) - res = res && is_shared(x->get_incoming_value(inc)); - return res; - } - return false; -} - -void liveness::run(ir::module &mod) { +void shmem_liveness::run(ir::module &mod) { for(ir::function *fn: mod.get_function_list()){ // Assigns index to each instruction slot_index index = 0; diff --git a/lib/codegen/target.cpp b/lib/codegen/target.cpp index 27a982a6c..2554bf5c3 100644 --- a/lib/codegen/target.cpp +++ b/lib/codegen/target.cpp @@ -4,6 +4,7 @@ #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Value.h" #include "llvm/IR/IRBuilder.h" +#include using namespace llvm; @@ -26,6 +27,12 @@ Instruction* amd_cl_target::add_barrier(Module *module, IRBuilder<>& builder) { } Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) { + Value* group_id = get_block_id(module, builder, ax); + Value* result = builder.CreateMul(builder.getInt32(stride), group_id); + return result; +} + +Value* amd_cl_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) { static std::array ids = { Intrinsic::amdgcn_workgroup_id_x, Intrinsic::amdgcn_workgroup_id_y, @@ -33,8 +40,7 @@ Value* amd_cl_target::get_global_offset(Module *module, IRBuilder<>& builder, un }; Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]); Value* group_id = builder.CreateCall(get_group_id, {}); - Value* result = builder.CreateMul(builder.getInt32(stride), group_id); - return result; + return group_id; } Value* amd_cl_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) { @@ -65,6 +71,12 @@ Instruction* nvidia_cu_target::add_barrier(Module *module, IRBuilder<>& builder) } Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) { + Value* group_id = get_block_id(module, builder, ax); + Value* result = builder.CreateMul(builder.getInt32(stride), group_id); + return result; +} + +Value* nvidia_cu_target::get_block_id(Module *module, IRBuilder<>& builder, unsigned ax) { static std::array ids = { Intrinsic::nvvm_read_ptx_sreg_ctaid_x, Intrinsic::nvvm_read_ptx_sreg_ctaid_y, @@ -72,8 +84,7 @@ Value* nvidia_cu_target::get_global_offset(Module *module, IRBuilder<>& builder, }; Value* get_group_id = Intrinsic::getDeclaration(module, ids[ax]); Value* group_id = builder.CreateCall(get_group_id, {}); - Value* result = builder.CreateMul(builder.getInt32(stride), group_id); - return result; + return group_id; } Value* nvidia_cu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned ax) { @@ -97,7 +108,7 @@ Instruction* cpu_target::add_barrier(Module *module, IRBuilder<>& builder) { return (Instruction*)builder.CreateAdd(builder.getInt32(0), builder.getInt32(0)); } -Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) { +Value* cpu_target::get_block_id(Module *module, llvm::IRBuilder<> &builder, unsigned ax) { const Function *fn = builder.GetInsertBlock()->getParent(); size_t num_params = fn->getFunctionType()->getNumParams(); static std::array ids = { @@ -105,7 +116,11 @@ Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsig fn->arg_begin() + num_params - 2, fn->arg_begin() + num_params - 1 }; - Value* result = builder.CreateMul(builder.getInt32(stride), (Argument*)ids[ax]); + return (Argument*)ids[ax]; +} + +Value* cpu_target::get_global_offset(Module *module, IRBuilder<>& builder, unsigned stride, unsigned ax) { + Value* result = builder.CreateMul(builder.getInt32(stride), get_block_id(module, builder, ax)); return result; } @@ -113,6 +128,5 @@ Value* cpu_target::get_local_id(Module *module, IRBuilder<>& builder, unsigned a return builder.getInt32(0); } - } } diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 4353b1332..1a1562c8f 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -1,5 +1,4 @@ #include "triton/codegen/tune.h" -#include "triton/codegen/shared_copy.h" #include "triton/ir/instructions.h" #include "triton/ir/type.h" #include "triton/ir/module.h" @@ -40,6 +39,8 @@ void tune::init_c_graph(ir::instruction *v) { ir::type::tile_shapes_t shapes; if(auto *store = dynamic_cast(v)) shapes = store->get_pointer_operand()->get_type()->get_tile_shapes(); + else if(auto *downcast = dynamic_cast(v)) + return; else shapes = v->get_type()->get_tile_shapes(); // Reshape @@ -56,6 +57,14 @@ void tune::init_c_graph(ir::instruction *v) { // Splat else if(dynamic_cast(v)){ + } + // Trans + else if(dynamic_cast(v)){ + ir::value *op = v->get_operand(0); + size_t n_shapes = shapes.size(); + for(unsigned i = 0; i < n_shapes; i++){ + add_constraint({v, (i + 1) % n_shapes}, {op, i}); + } } // Broadcast else if(dynamic_cast(v)){ @@ -68,7 +77,7 @@ void tune::init_c_graph(ir::instruction *v) { } } // Matrix multiplication - else if(dynamic_cast(v)){ + else if(dynamic_cast(v)){ ir::value *D = v->get_operand(2); add_constraint({v, 0}, {D, 0}); add_constraint({v, 1}, {D, 1}); @@ -119,6 +128,13 @@ std::vector tune::get_params(ir::module &mod) { if(seen.insert(x.second).second && !x.second->has_value()){ result.push_back(x.second); } + + for(auto x: mod.globals()){ + if(auto mp = dynamic_cast(x.second)) + if(seen.insert(mp).second && !mp->has_value()) + result.push_back(mp); + } + return result; } @@ -145,23 +161,22 @@ void tune::run(ir::module &mod) { // Layout parameters while(!nodes_.empty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 2, 4); + ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1); + nts->set_value(1); ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32); connected_components(*nodes_.begin(), {nts, mts}, nodes_, dependencies_); } } // Simplify metaparameters - std::set fixed_io_nts; for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i : block->get_inst_list()) - if(dynamic_cast(i) || dynamic_cast(i)) - if(i->get_type()->is_tile_ty()) - for(unsigned d = 1; d < i->get_type()->get_tile_shapes().size(); d++) - fixed_io_nts.insert(params_.at(i).at("nts.d" + std::to_string(d))); - for(ir::metaparameter* mp: fixed_io_nts) - mp->set_value(1); + if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ + ir::type *ty = mod.get_builder().get_int32_ty(); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 2, 2)); + *params_.at(i).at("nts.d0") = *tmp; + } } void tune::init(ir::module &mod) { diff --git a/lib/driver/buffer.cpp b/lib/driver/buffer.cpp index b5030d710..a64e0aeca 100755 --- a/lib/driver/buffer.cpp +++ b/lib/driver/buffer.cpp @@ -64,9 +64,6 @@ buffer* buffer::create(driver::context* ctx, size_t size) { host_buffer::host_buffer(driver::context *context, size_t size) : buffer(context, host_buffer_t(), true){ hst_->data = new char[size]; - std::cout << size << std::endl; - std::cout << "allocating " << (float*)hst_->data << std::endl; - std::cout << *((float*)(hst_->data) + 512*500) << std::endl; } // diff --git a/lib/driver/module.cpp b/lib/driver/module.cpp index 8346961fe..641b900b4 100755 --- a/lib/driver/module.cpp +++ b/lib/driver/module.cpp @@ -106,7 +106,11 @@ void module::compile_llvm_module(llvm::Module* module, const std::string& triple const std::string& features, file_type_t ft) { init_llvm(); - + // debug +// llvm::legacy::PassManager pm; +// pm.add(llvm::createPrintModulePass(llvm::outs())); +// pm.add(llvm::createVerifierPass()); +// pm.run(*module); // create machine module->setTargetTriple(triple); std::string error; @@ -249,6 +253,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) { cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ +// std::cout << source << std::endl; cu_context::context_switcher ctx_switch(*context); // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; @@ -264,11 +269,11 @@ cu_module::cu_module(driver::context * context, std::string const & source) : mo } } -cu_buffer cu_module::symbol(const char *name) const{ +cu_buffer* cu_module::symbol(const char *name) const{ CUdeviceptr handle; size_t size; dispatch::cuModuleGetGlobal_v2(&handle, &size, *cu_, name); - return cu_buffer(ctx_, handle, false); + return new cu_buffer(ctx_, handle, false); } diff --git a/lib/ir/builder.cpp b/lib/ir/builder.cpp index db0ae9e94..c913c37e8 100644 --- a/lib/ir/builder.cpp +++ b/lib/ir/builder.cpp @@ -285,6 +285,10 @@ value *builder::create_broadcast(value *arg, const type::tile_shapes_t &shapes, return insert(broadcast_inst::create(arg, shapes, name)); } +value *builder::create_downcast(value *arg, const std::string &name) { + return insert(downcast_inst::create(arg, name)); +} + //===----------------------------------------------------------------------===// // built-in instructions //===----------------------------------------------------------------------===// @@ -293,8 +297,24 @@ value *builder::create_get_global_range(unsigned axis, type::tile_shapes_t::valu return insert(get_global_range_inst::create(ctx_, axis, size, name)); } -value *builder::create_matmul(value *A, value *B, value *C, const std::string &name) { - return insert(matmul_inst::create(A, B, C, name)); +value *builder::create_get_range_id(unsigned axis, const std::string &name) { + return insert(get_range_id_inst::create(ctx_, axis, name)); +} + +value *builder::create_atomic_cas(value *ptr, value *cmp, value *val, const std::string &name){ + return insert(atomic_cas_inst::create(ptr, cmp, val, name)); +} + +value *builder::create_dot(value *A, value *B, value *C, const std::string &name) { + return insert(dot_inst::create_nn(A, B, C, name)); +} + +value *builder::create_trans(value *A, const std::string &name) { + return insert(trans_inst::create(A, name)); +} + +value *builder::create_select(value *pred, value *if_value, value *else_value, const std::string &name){ + return insert(select_inst::create(pred, if_value, else_value, name)); } //===----------------------------------------------------------------------===// diff --git a/lib/ir/instructions.cpp b/lib/ir/instructions.cpp index 2a44ec4fb..8a9205c4e 100644 --- a/lib/ir/instructions.cpp +++ b/lib/ir/instructions.cpp @@ -28,6 +28,8 @@ instruction::instruction(type *ty, unsigned num_ops, unsigned num_results, const void instruction::erase_from_parent() { parent_->erase(this); + for(ir::value* op: ops()) + op->erase_use(this); } bool instruction::has_tile_result_or_op() { @@ -482,27 +484,82 @@ instruction* broadcast_inst::create(value *arg, const type::tile_shapes_t &shape return new broadcast_inst(arg, shapes, name, next); } +// downcast + +instruction* downcast_inst::create(value *arg, const std::string &name, instruction *next) { + return new downcast_inst(arg->get_type()->get_scalar_ty(), arg, name, next); +} //===----------------------------------------------------------------------===// // matmul_inst classes //===----------------------------------------------------------------------===// -matmul_inst::matmul_inst(value *A, value *B, value *C, +dot_inst::dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, const std::string &name, instruction *next) - : builtin_inst(C->get_type(), 3, 0, name, next) { + : builtin_inst(C->get_type(), 3, 1, name, next), AT_(AT), BT_(BT) { set_operand(0, A); set_operand(1, B); set_operand(2, C); } -instruction *matmul_inst::create(value *A, value *B, value *C, +instruction *dot_inst::create_nn(value *A, value *B, value *C, const std::string &name, instruction *next) { - return new matmul_inst(A, B, C, name, next); + return new dot_inst(A, B, C, NoTrans, NoTrans, name, next); +} + +instruction *dot_inst::create_nt(value *A, value *B, value *C, + const std::string &name, instruction *next) { + return new dot_inst(A, B, C, NoTrans, Trans, name, next); +} + +instruction *dot_inst::create_tn(value *A, value *B, value *C, + const std::string &name, instruction *next) { + return new dot_inst(A, B, C, Trans, NoTrans, name, next); +} + +instruction *dot_inst::create_tt(value *A, value *B, value *C, + const std::string &name, instruction *next) { + return new dot_inst(A, B, C, Trans, Trans, name, next); } +//===----------------------------------------------------------------------===// +// trans instructions +//===----------------------------------------------------------------------===// + +ir::type* trans_inst::get_res_ty(ir::type* ty) { + auto shapes = ty->get_tile_shapes(); + std::rotate(shapes.begin(), shapes.begin() + 1, shapes.end()); + return tile_type::get(ty->get_scalar_ty(), shapes); +} + +trans_inst::trans_inst(value *arg, const std::string &name, instruction *next) + : builtin_inst(get_res_ty(arg->get_type()), 1, 1, name, next) { + set_operand(0, arg); +} + +instruction* trans_inst::create(value *arg, const std::string &name, instruction *next) { + return new trans_inst(arg, name, next); +} + +//===----------------------------------------------------------------------===// +// select instructions +//===----------------------------------------------------------------------===// + +select_inst::select_inst(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next) + : builtin_inst(if_value->get_type(), 3, 1, name, next){ + set_operand(0, pred); + set_operand(1, if_value); + set_operand(2, else_value); +} + +instruction* select_inst::create(value *pred, value *if_value, value *else_value, const std::string &name, instruction *next) { + return new select_inst(pred, if_value, else_value, name, next); +} //===----------------------------------------------------------------------===// // builtin instructions //===----------------------------------------------------------------------===// + +// get_global_range get_global_range_inst::get_global_range_inst(type *ty, unsigned axis, const std::string &name, instruction *next) : builtin_inst(ty, 0, 1, name, next), axis_(axis) { @@ -516,6 +573,28 @@ instruction* get_global_range_inst::create(context &ctx, unsigned axis, type::ti return new get_global_range_inst(tile_ty, axis, name, next); } +// get_range_id +get_range_id_inst::get_range_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next) + : builtin_inst(ty, 0, 1, name, next), axis_(axis){ + +} + +instruction* get_range_id_inst::create(context &ctx, unsigned axis, const std::string &name, instruction *next) { + return new get_range_id_inst(type::get_int32_ty(ctx), axis, name, next); +} + +// atomic cas + +atomic_cas_inst::atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next) + : builtin_inst(ptr->get_type()->get_pointer_element_ty(), 3, 1, name, next) { + set_operand(0, ptr); + set_operand(1, cmp); + set_operand(2, val); +} + +instruction* atomic_cas_inst::create(value *ptr, value *cmp, value *val, const std::string &name, instruction *next) { + return new atomic_cas_inst(ptr, cmp, val, name, next); +} //===----------------------------------------------------------------------===// // intrinsic instructions //===----------------------------------------------------------------------===// @@ -530,7 +609,7 @@ vectorize_inst* vectorize_inst::create(value *arg, const std::string &name, inst barrier_inst::barrier_inst(context &ctx, const std::string &name, instruction *next) - : instruction(type::get_void_ty(ctx), 0, 1, name, next){ } + : instruction(type::get_void_ty(ctx), 0, 0, name, next){ } barrier_inst* barrier_inst::create(context &ctx, const std::string &name, instruction *next) { return new barrier_inst(ctx, name, next); diff --git a/lib/ir/module.cpp b/lib/ir/module.cpp index 14f1337e1..d8f07ecc4 100644 --- a/lib/ir/module.cpp +++ b/lib/ir/module.cpp @@ -128,6 +128,9 @@ ir::value *module::get_value(const std::string& name) { return get_value(name, builder_.get_insert_block()); } +const std::string& module::get_name() { + return name_; +} void module::seal_block(ir::basic_block *block){ for(auto &x: incomplete_phis_[block]){ diff --git a/lib/ir/type.cpp b/lib/ir/type.cpp index 862039220..215e8f746 100644 --- a/lib/ir/type.cpp +++ b/lib/ir/type.cpp @@ -172,7 +172,7 @@ unsigned tile_type::get_bitwidth() const { tile_type* tile_type::get(type *elt_ty, const tile_shapes_t &shapes) { assert(elt_ty && "Can't get a tile of type!"); assert(shapes.size() && "Can't create a tile with empty shapes!"); - assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!"); + assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!"); // look-up context_impl *impl = elt_ty->get_context().p_impl.get(); tile_type *&entry = impl->tile_tys[std::make_pair(elt_ty, shapes)]; diff --git a/lib/jit.cpp b/lib/jit.cpp index 068a824f0..9a4181e2a 100644 --- a/lib/jit.cpp +++ b/lib/jit.cpp @@ -68,7 +68,7 @@ void loop_nest(std::vector> const & iterates, std::function jit::make_llvm_module(ir::module &module, passes_wrapper &passes) { - llvm::Module* result = new llvm::Module("matmul", llvm_context_); + llvm::Module* result = new llvm::Module(module.get_name(), llvm_context_); passes.selection.run(module, *result); // launch information auto &launch_info_map = launch_info_map_[result->getName()]; @@ -79,14 +79,14 @@ std::unique_ptr jit::make_llvm_module(ir::module &module, passes_w return std::unique_ptr(result); } -std::unique_ptr jit::make_triton_module(const std::string &src) { +std::unique_ptr jit::make_triton_module(const std::string &name, const std::string &src) { // create AST from Triton-C source YY_BUFFER_STATE buffer = yy_scan_string(src.c_str()); yyparse(); yy_delete_buffer(buffer); translation_unit *program = ast_root; // create Triton-IR from AST - ir::module* module = new ir::module("matrix", triton_context_); + ir::module* module = new ir::module(name, triton_context_); program->codegen(module); return std::unique_ptr(module); } @@ -97,18 +97,20 @@ jit::jit(driver::context *context): driver_context_(context), } -void jit::autotune(const std::string &src, benchmark_t benchmark) { +void jit::autotune(const std::string &name, const std::string &src, benchmark_t benchmark) { // find metaparameters - auto ptt_module = make_triton_module(src); + auto ptt_module = make_triton_module(name, src); ir::module &tt_module = *ptt_module; // set parameters passes_wrapper passes(target_.get()); + passes.target_independent(tt_module); passes.tune.run(tt_module); auto mps = passes.tune.get_params(tt_module); // create parameter ranges std::vector> ranges; for(ir::metaparameter *mp: mps) ranges.push_back(mp->get_space()); +// std::cout << ranges.size() << std::endl; // iterate over parameters unsigned i; double best = 0; @@ -117,51 +119,56 @@ void jit::autotune(const std::string &src, benchmark_t benchmark) { i = 0; for(ir::metaparameter *mp: mps) mp->set_value(params[i++]); + passes.target_independent(tt_module); passes.tune.init(tt_module); if(!passes.tune.check_constraints(errors)) return; // Deep copy of the module and tuner - auto ptt_module = make_triton_module(src); + auto ptt_module = make_triton_module(name, src); ir::module &tt_module = *ptt_module; passes_wrapper passes(target_.get()); + passes.target_independent(tt_module); passes.tune.run(tt_module); i = 0; for(ir::metaparameter* mp: passes.tune.get_params(tt_module)){ mp->set_value(params[i++]); } passes.tune.init(tt_module); - passes.init(tt_module); + passes.target_dependent(tt_module); driver::device* device = driver_context_->device(); - if(passes.allocation.get_allocated_size() > device->max_shared_memory()) + if(passes.shmem_allocation.get_allocated_size() > device->max_shared_memory()) return; if(passes.tune.get_num_threads() > device->max_threads_per_block()) return; // Compile auto ll_module = make_llvm_module(tt_module, passes); std::unique_ptr module(driver::module::create(driver_context_, &*ll_module)); - std::unique_ptr kernel(driver::kernel::create(module.get(), "matmul")); - launch_information info = launch_info_map_.at("matmul"); + std::unique_ptr kernel(driver::kernel::create(module.get(), name.c_str())); + launch_information info = launch_info_map_.at(name.c_str()); for(unsigned p: params) std::cout << p << " " << std::flush; // add globals for(auto x: tt_module.globals()) global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value(); + modules_.push_back(module.get()); double perf; perf = benchmark(kernel.get(), info); best = std::max(perf, best); std::cout << perf << " [ " << best << " ] " << std::endl; + modules_.pop_back(); }); } void jit::add_module(ir::module &tt_module, const std::vector ¶ms) { // set parameters passes_wrapper passes(target_.get()); + passes.target_independent(tt_module); passes.tune.run(tt_module); unsigned i = 0; for(ir::metaparameter* mp: passes.tune.get_params(tt_module)) mp->set_value(params[i++]); passes.tune.init(tt_module); - passes.init(tt_module); + passes.target_dependent(tt_module); // check constraints std::map> errors; passes.tune.check_constraints(errors); @@ -184,8 +191,8 @@ void jit::add_module(ir::module &tt_module, const std::vector ¶ms) global_ints_[x.first] = ((ir::metaparameter*)x.second)->get_value(); } -void jit::add_module(const std::string &src, const std::vector ¶ms) { - auto ptt_module = make_triton_module(src); +void jit::add_module(const std::string &name, const std::string &src, const std::vector ¶ms) { + auto ptt_module = make_triton_module(name, src); add_module(*ptt_module, params); } @@ -201,4 +208,9 @@ unsigned jit::get_int(const std::string &name){ return global_ints_.at(name); } +driver::buffer *jit::get_buffer(const std::string &name){ + driver::cu_module *mod = (driver::cu_module*)modules_.front(); + return mod->symbol(name.c_str()); +} + }