From ba9955ae39994b5c0e91612659351277a0c60d01 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Fri, 7 Aug 2020 00:49:04 -0400 Subject: [PATCH] [CODEGEN][ANALYSIS] Fixed issue in layout inference --- include/triton/tools/bench.hpp | 17 +++++++------ lib/codegen/analysis/layout.cc | 26 +------------------ python/examples/tutorials/conv2d.py | 32 ++++++++++++------------ python/examples/tutorials/mat_mul.py | 2 +- python/triton/kernel.py | 4 +-- tests/bench/CMakeLists.txt | 2 +- tests/bench/dot.cc | 37 +++++++++++++++++++++++++--- tests/common/dot.h | 26 ++++++++++++------- tests/common/src/dot.h | 35 ++++++++++++++++++++------ 9 files changed, 109 insertions(+), 72 deletions(-) diff --git a/include/triton/tools/bench.hpp b/include/triton/tools/bench.hpp index e132e3b83..fdf17eaa2 100644 --- a/include/triton/tools/bench.hpp +++ b/include/triton/tools/bench.hpp @@ -38,19 +38,22 @@ inline double bench(std::function const & op, driver::stream * stream, b double total_time = 0; op(); stream->synchronize(); - while(total_time*1e-9 < 1e-2){ - float norm = 1; + tmr.start(); + for(size_t i = 0; i < 10; i++){ +// while(total_time*1e-9 < 1e-2){ +// float norm = 1; // normalize clock if possible to reduce noise in auto-tuning // if(normalize) // if(auto cu_device = dynamic_cast(stream->context()->device())) // norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock(); - tmr.start(); op(); - stream->synchronize(); - times.push_back(norm*tmr.get().count()); - total_time+=times.back(); +// times.push_back(norm*tmr.get().count()); +// total_time+=times.back(); } - return *std::min_element(times.begin(), times.end()); + stream->synchronize(); + return (float)tmr.get().count() / 10; + +// return *std::min_element(times.begin(), times.end()); } } diff --git a/lib/codegen/analysis/layout.cc b/lib/codegen/analysis/layout.cc index 209e8daeb..14b207eec 100644 --- a/lib/codegen/analysis/layout.cc +++ b/lib/codegen/analysis/layout.cc @@ -16,29 +16,6 @@ namespace analysis{ * Helper Functions * * -------------------------------- */ -inline int gcd_impl(int a, int b, int *x, int *y) -{ - // Base Case - if (a == 0) - { - *x = 0; - *y = 1; - return b; - } - int x1, y1; // To store results of recursive call - int gcd = gcd_impl(b%a, a, &x1, &y1); - // Update x and y using results of - // recursive call - *x = y1 - (b/a) * x1; - *y = x1; - return gcd; -} - -inline int gcd(int a, int b) { - int x, y; - return gcd_impl(a, b, &x, &y); -} - inline unsigned clamp(unsigned x, unsigned a, unsigned b) { unsigned lo = std::min(a, b); unsigned hi = std::max(a, b); @@ -210,8 +187,7 @@ scanline_layout::scanline_layout(size_t num_warps, if(ptr) contiguous = std::min(align->contiguous(ptr)[i], 4); - int max_contiguous = shape_[i] / (num_warps*32); - nts_[i] = clamp(size / num_threads, 1, gcd(contiguous, max_contiguous)); + nts_[i] = clamp(size / num_threads, 1, std::min(contiguous, shape_[i])); mts_[i] = clamp(num_threads, 1, shape_[i] / nts_[i]); size /= shape_[i]; num_threads /= mts_[i]; diff --git a/python/examples/tutorials/conv2d.py b/python/examples/tutorials/conv2d.py index a997c8d5f..34c5d41ed 100644 --- a/python/examples/tutorials/conv2d.py +++ b/python/examples/tutorials/conv2d.py @@ -21,13 +21,11 @@ class _conv(torch.autograd.Function): int ridx = get_program_id(0); int ridy = get_program_id(1); int ridz = get_program_id(2); - /* int gridx = M / TM; int gridy = N / TN; int rid = ridx + ridy * gridx; ridx = rid / gridy; ridy = rid % gridy; - */ int rm[TM] = ridx * TM + 0 ... TM; int rn[TN] = ridy * TN + 0 ... TN; // reduction splitting @@ -36,10 +34,10 @@ class _conv(torch.autograd.Function): // unpack aggregate rows // m = (z, p, q) - int rq[TM] = rm % QQ; - int rzp[TM] = rm / QQ; - int rp[TM] = rzp % PP; - int rz[TM] = rzp / PP; + int rq[TM] = rm % QQ; + int rzp[TM] = rm / QQ; + int rp[TM] = rzp % PP; + int rz[TM] = rzp / PP; // unpack aggregate reduction // k = (ci, r, s) int rs [TK] = rk % SS; @@ -68,10 +66,12 @@ class _conv(torch.autograd.Function): TYPE* pb[TK, TN] = B + offb; // prefetches operands - bool checka[TM, TK] = rh >= 0 && rh < HH && rw >= 0 && rw < WW; + bool checkam[TM, TK] = rm[:, newaxis] < M; + bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW; bool checkb[TK, TN] = rk[:, newaxis] < K; TYPE a[TM, TK] = checka ? *pa : 0; TYPE b[TK, TN] = checkb ? *pb : 0; + int total = 0; // reduction loop float acc[TM, TN] = 0; @@ -81,8 +81,6 @@ class _conv(torch.autograd.Function): int adelta[TK] = *padelta; padelta += TK; pa += adelta[newaxis, :]; - // increment B - pb += TK * ldb_s; // bounds-checking A rk += TK; rs = rk % SS; @@ -90,7 +88,9 @@ class _conv(torch.autograd.Function): rr = rcir % RR; rh = rh_0[:, newaxis] + rr[newaxis, :]; rw = rw_0[:, newaxis] + rs[newaxis, :]; - bool checka[TM, TK] = rh >= 0 && rh < HH && rw >= 0 && rw < WW; + bool checka[TM, TK] = checkam && rh >= 0 && rh < HH && rw >= 0 && rw < WW; + // increment B + pb += TK * ldb_s; // bounds-checking B bool checkb[TK, TN] = k > TK; a = checka ? *pa : 0; @@ -152,18 +152,18 @@ class _conv(torch.autograd.Function): Q = (W + 2*pad[1] - S)//stride[1] + 1 # compile kernel if dtype not in _conv.kernel: + TK = 8 defines = { 'TYPE' : dtype, - 'TM' : [64, 128], - 'TN' : [64, 128], - 'TK' : [8], + 'TM' : [16, 32, 64, 128], + 'TN' : [16, 32, 64, 128], + 'TK' : [TK], 'TZ' : [1], - 'LUTSIZE' : 4*CI*R*S, 'HH': H, 'WW': W, 'PP': P, 'QQ': Q, 'SS': S, 'RR': R, } idx = torch.arange(CI*R*S) ci, r, s = _conv.unpack(idx, CI, R, S) - nci, nr, ns = _conv.unpack(idx + 8, CI, R, S) + nci, nr, ns = _conv.unpack(idx + TK, CI, R, S) delta = (nci - ci)*a.stride(1) + (nr - r)*a.stride(2) + (ns - s)*a.stride(3) delta = delta.type(torch.int32).cuda() _conv.kernel[dtype] = (delta, triton.kernel(_conv.src, num_warps=[2, 4], defines=defines)) @@ -186,7 +186,7 @@ class _conv(torch.autograd.Function): conv = _conv.apply torch.manual_seed(0) -Z, H, W, CI, CO, R, S = 1, 32, 64, 256, 2048, 3, 3 +Z, H, W, CI, CO, R, S = 1, 56, 56, 1024, 1024, 3, 3 pad = (1, 1) stride = (1, 1) a = torch.rand((Z, CI, H, W)).cuda() diff --git a/python/examples/tutorials/mat_mul.py b/python/examples/tutorials/mat_mul.py index bc0017f23..419a61f51 100644 --- a/python/examples/tutorials/mat_mul.py +++ b/python/examples/tutorials/mat_mul.py @@ -112,7 +112,7 @@ class _dot(torch.autograd.Function): time = kernel(a, b, c, 1., M, N, K, a.stride(0), b.stride(0), c.stride(0), grid=grid, bench=100) - print(time) + print(2*M*N*K/(time*1e-6)*1e-9) return c diff --git a/python/triton/kernel.py b/python/triton/kernel.py index 19729f677..019819b7e 100644 --- a/python/triton/kernel.py +++ b/python/triton/kernel.py @@ -170,8 +170,8 @@ class kernel: arg_types = libtriton.get_fn_signature(self.src, opt) self.fw_op = _make_framework_op(arg_types) - def set_constant(self, name, value): - libtriton.register_cst(self.op_id, name, value) + def set_constant(self, device, name, value): + libtriton.register_cst((self.op_id, device), name, value) def __call__(self, *args, **kwargs): for x in args: diff --git a/tests/bench/CMakeLists.txt b/tests/bench/CMakeLists.txt index d9978ca3f..f531fadf4 100644 --- a/tests/bench/CMakeLists.txt +++ b/tests/bench/CMakeLists.txt @@ -1,4 +1,4 @@ -foreach(PROG dot copy) +foreach(PROG dot copy conv) set(TARGET bench_${PROG}) add_executable(${TARGET} ${PROG}.cc) set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME ${TARGET}) diff --git a/tests/bench/dot.cc b/tests/bench/dot.cc index fd4a96622..9f1260469 100644 --- a/tests/bench/dot.cc +++ b/tests/bench/dot.cc @@ -9,11 +9,40 @@ int main() { // shapes to benchmark typedef std::tuple, bool, bool, int, int, int> config_t; std::vector configs; - for(auto ord: std::vector>{{1, 0}}) - for(auto x: std::vector>{{false, false}}){ + for(auto ord: std::vector>{{0, 1}}) + for(auto x: std::vector>{{false, true}, {false, false}, {true, false}, {true, true}}){ std::vector tmp = { +// config_t{ord, x[0], x[1], 128, 128, 128}, +// config_t{ord, x[0], x[1], 256, 256, 256}, +// config_t{ord, x[0], x[1], 384, 384, 384}, // config_t{ord, x[0], x[1], 512, 512, 512}, - config_t{ord, x[0], x[1], 1024, 1024, 1024}, +// config_t{ord, x[0], x[1], 768, 768, 768}, +// config_t{ord, x[0], x[1], 1024, 1024, 1024}, +// config_t{ord, x[0], x[1], 1280, 1280, 1280}, +// config_t{ord, x[0], x[1], 1536, 1536, 1536}, +// config_t{ord, x[0], x[1], 2048, 2048, 2048}, + config_t{ord, x[0], x[1], 8192, 8192, 8192}, + +// config_t{ord, x[0], x[1], 256, 16, 256}, +// config_t{ord, x[0], x[1], 512, 16, 512}, +// config_t{ord, x[0], x[1], 768, 16, 768}, +// config_t{ord, x[0], x[1], 1024, 16, 1024}, +// config_t{ord, x[0], x[1], 1280, 16, 1280}, +// config_t{ord, x[0], x[1], 1536, 16, 1536}, +// config_t{ord, x[0], x[1], 2048, 16, 2048}, +// config_t{ord, x[0], x[1], 3072, 16, 3072}, +// config_t{ord, x[0], x[1], 4096, 16, 4096}, +// config_t{ord, x[0], x[1], 5120, 16, 5120}, +// config_t{ord, x[0], x[1], 6144, 16, 6144}, +// config_t{ord, x[0], x[1], 7168, 16, 7168}, + +// config_t{ord, x[0], x[1], 64, 64, 4096}, +// config_t{ord, x[0], x[1], 64, 64, 8192}, +// config_t{ord, x[0], x[1], 64, 64, 16384}, +// config_t{ord, x[0], x[1], 64, 64, 32768}, +// config_t{ord, x[0], x[1], 64, 64, 65536}, +// config_t{ord, x[0], x[1], 64, 64, 131072} + // config_t{ord, x[0], x[1], 127008, 768, 576}, // config_t{ord, x[0], x[1], 8192, 8192, 8192} // config_t{ord, x[0], x[1], 16, 2048, 2048}, @@ -36,7 +65,7 @@ int main() { for(const auto& c: configs){ std::tie(ord, AT, BT, M, N, K) = c; std::cout << "// " << c ; - for(auto perf: bench_dot(stream, FLOAT, AT, BT, M, N, K, ord, ord)) + for(auto perf: bench_dot(stream, HALF, AT, BT, M, N, K, ord, ord)) std::cout << ", " << perf << std::flush; std::cout << std::endl; } diff --git a/tests/common/dot.h b/tests/common/dot.h index 5dc7ef74e..6433ec562 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -79,9 +79,11 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, std::vector sb = { "1", "ldb" }; // inputs - auto dc = std::shared_ptr(drv::buffer::create(context, M*N*dt_nbytes)); - auto da = std::shared_ptr(drv::buffer::create(context, M*K*dt_nbytes)); - auto db = std::shared_ptr(drv::buffer::create(context, K*N*dt_nbytes)); + auto dc = std::shared_ptr(drv::buffer::create(context, M*N*dt_nbytes)); + auto da = std::shared_ptr(drv::buffer::create(context, M*K*dt_nbytes)); + auto db = std::shared_ptr(drv::buffer::create(context, K*N*dt_nbytes)); + auto dlocks = std::shared_ptr(drv::buffer::create(context, 1024*1024*2*4)); + ((drv::cu_buffer*)dlocks.get())->set_zero(stream, dlocks->size()); // macros rt::function::options_space_t opt; @@ -110,15 +112,21 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, } if(mode == BENCH) { opt.defines.push_back({"TM", {"128"}}); - opt.defines.push_back({"TN", {"32"}}); - opt.defines.push_back({"TK", {to_string::value == "half" ? "16" : "8"}}); + opt.defines.push_back({"TN", {"128"}}); + opt.defines.push_back({"TK", {"16"}}); + opt.defines.push_back({"TZ", {"1"}}); opt.num_warps = {4}; } // kernels + rt::function function(src::dot, opt); - std::vector args = {&*da, &*db, &*dc, (float)1, M, N, K, lda, ldb, ldc}; - auto grid = grid2d(M, N); + std::vector args = {&*da, &*db, &*dc, (float)1, M, N, K, lda, ldb, ldc, &*dlocks}; + auto grid = [M, N](const rt::function::options_t& x) { + return rt::grid_t{ceil(M, x.D("TM")), + ceil(N, x.D("TN")), + (size_t)x.D("TZ")}; + }; // metrics if(mode == BENCH){ @@ -126,13 +134,13 @@ void triton_dot(drv::stream* stream, bool AT, bool BT, double triton_ns = triton::tools::bench([&]() { function(args, grid, stream);}, stream); bench.push_back(tflops(triton_ns)); -// // cublas + // cublas // if(cublas::cublasinit()){ // T alpha(static_cast(1)); // T beta(static_cast(0)); // cublasGemmAlgo_t fastest; // cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, &alpha, &*da, lda, &*db, ldb, &beta, &*dc, ldc, &fastest); -// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_16F, stream, AT, BT, M, N, K, +// double cublas_ms = triton::tools::bench([&]() { cublasGemm(CUDA_R_32F, stream, AT, BT, M, N, K, // &alpha, &*da, lda, &*db, ldb, &beta, &*dc, // ldc, nullptr, fastest); }, stream); // bench.push_back(tflops(cublas_ms)); diff --git a/tests/common/src/dot.h b/tests/common/src/dot.h index 4dcab1efc..e94e03b5c 100644 --- a/tests/common/src/dot.h +++ b/tests/common/src/dot.h @@ -6,13 +6,15 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), TYPE * B __noalias __readonly __aligned(16), TYPE * C __noalias __aligned(16), float alpha, - int M, int N, int K, + int M, int N, int K __multipleof(16), int lda __multipleof(8), int ldb __multipleof(8), - int ldc __multipleof(8)) { + int ldc __multipleof(8), + int* locks) { // prologue int ridx = get_program_id(0); int ridy = get_program_id(1); + int ridz = get_program_id(2); int gridx = M / TM; int gridy = N / TN; int rid = ridx + ridy * gridx; @@ -20,7 +22,10 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), ridy = rid % gridy; int rm[TM] = ridx * TM + 0 ... TM; int rn[TN] = ridy * TN + 0 ... TN; - int rk[TK] = 0 ... TK; + + // reduction splitting + K = K / TZ; + int rk[TK] = ridz * K + 0 ... TK; // pointers to operands int offa[SHAPE_A] = rk[BROADCAST_AK] * STRIDE_AK + rm[BROADCAST_AM] * STRIDE_AM; @@ -35,9 +40,9 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), TYPE b[SHAPE_B] = checkb ? *pb : 0; // reduction loop - float c[TM, TN] = 0; + float acc[TM, TN] = 0; for(int k = K; k > 0; k -= TK){ - c += USEA @ USEB; + acc += USEA @ USEB; bool checka[SHAPE_A] = k > TK; bool checkb[SHAPE_B] = k > TK; pa += TK * STRIDE_AK; @@ -45,7 +50,8 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), a = *?(checka)pa; b = *?(checkb)pb; } - //c = c * alpha; + acc = acc * alpha; + TYPE c[TM, TN] = acc; // epilogue int rxm[TM] = get_program_id(0) * TM + 0 ... TM; @@ -53,7 +59,22 @@ __global__ void dot(TYPE * A __noalias __readonly __aligned(16), int offc[TM, TN] = rxm[:, newaxis] * ldc + rxn[newaxis, :]; TYPE* pc[TM, TN] = C + offc; bool checkc[TM, TN] = (rxm[:, newaxis] < M) && (rxn[newaxis, :] < N); - *?(checkc)pc = (TYPE[TM, TN])c; + +#if (TZ==1) + *?(checkc) pc = c; +#else + // accumulate partial result using spin-locks + int *plock = locks + rid; + int *pcount = plock + get_num_programs(0) * get_num_programs(1); + for(int repeat = 1; repeat == 1; repeat = atomic_cas(plock, 0, 1)); + int count = *pcount; + if(count == 0) + *?(checkc) pc = c; + else + *?(checkc) pc = c + *?(checkc)pc; + atomic_xchg(pcount, (count + 1) % TZ); + atomic_xchg(plock, 0); +#endif } )";