From f80441017c5d6c72fd8327ce2137c3f41a8891d3 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 4 May 2019 20:15:34 -0400 Subject: [PATCH] [codegen] added leading dimension padding for transposition in shared memory --- examples/cpp/common.hpp | 11 ++ examples/cpp/dot.cpp | 155 +++++++++++------- examples/python/pytorch/conv.cpp | 184 ++++++++++++++++++++-- include/triton/codegen/shmem_allocation.h | 1 + include/triton/jit.h | 1 - lib/codegen/selection.cpp | 26 +-- lib/codegen/shmem_allocation.cpp | 19 ++- 7 files changed, 314 insertions(+), 83 deletions(-) diff --git a/examples/cpp/common.hpp b/examples/cpp/common.hpp index 8a16b9457..87525eb68 100644 --- a/examples/cpp/common.hpp +++ b/examples/cpp/common.hpp @@ -14,6 +14,17 @@ void simple_gemm(std::vector &c, const std::vector &a, const std::vector +void simple_gemm(bool AT, bool BT, std::vector &c, const std::vector &a, const std::vector &b, size_t M, size_t N, size_t K) { + if(AT && BT) + simple_gemm(c, a, b, M, N, K); + else if(AT && !BT) + simple_gemm(c, a, b, M, N, K); + else if(!AT && BT) + simple_gemm(c, a, b, M, N, K); + else + simple_gemm(c, a, b, M, N, K); +} class timer{ typedef std::chrono::high_resolution_clock high_resolution_clock; diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index bf44b7cb5..980f83b31 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -5,63 +5,104 @@ #include "triton/driver/backend.h" #include "triton/driver/stream.h" -const char* src = -R"( -const tunable int32 TM = {16, 32, 64, 128}; -const tunable int32 TN = {16, 32, 64, 128}; -const tunable int32 TK = {8}; -const tunable int32 GZ = {1}; -void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C, - int32 M, int32 N, int32 K, - int32 lda, int32 ldb, int32 ldc, - int32 *locks, int32 grid0, int32 grid1) { - int32 rxa[TM] = get_global_range[TM](0); - int32 ryb[TN] = get_global_range[TN](1); - int32 rz = get_global_range[1](2); - int32 rka[TK] = 0 ... TK; - int32 rkb[TK] = 0 ... TK; - fp32 c[TM, TN] = 0; - int32 div = K / GZ; - int32 rem = K % GZ; - K = select(rz < rem, div - 1, div); - int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem); - fp32* pa[TM, TK] = A + (offk + rka[newaxis, :])*lda + rxa[:, newaxis]; - fp32* pb[TN, TK] = B + (offk + rkb[newaxis, :])*ldb + ryb[:, newaxis]; - fp32 a[TM, TK] = *pa; - fp32 b[TN, TK] = *pb; - int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda; - int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb; - last_a = last_a / TK * TK; - last_b = last_b / TK * TK; - int32 bound = K - max(last_a, last_b); - for(int32 k = K; k > bound; k = k - TK){ - c = dot(a, trans(b), c); - pa = pa + TK*lda; - pb = pb + TK*ldb; - a = *pa; - b = *pb; +std::string triton_source(bool AT, bool BT) { + std::string AS0 = "TM", AS1 = "TK"; + std::string BS0 = "TK", BS1 = "TN"; + std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]"; + std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; + std::string lda0 = "*lda", lda1 = ""; + std::string ldb0 = "", ldb1 = "*ldb"; + std::string usea = AT ? "trans(a)" : "a"; + std::string useb = BT ? "trans(b)" : "b"; + if(AT){ + std::swap(AS0, AS1); + std::swap(bca0, bca1); + std::swap(lda0, lda1); } - int32 rxc[TM] = get_global_range[TM](0); - int32 ryc[TN] = get_global_range[TN](1); - for(int32 k = bound; k > 0; k = k - 1){ - int1 checka[TM, 1] = rxc[:, newaxis] < M; - int1 checkb[TN, 1] = ryc[:, newaxis] < N; - fp32* pa[TM, 1] = A + (offk + K - k)*lda + rxc[:, newaxis]; - fp32* pb[TN, 1] = B + (offk + K - k)*ldb + ryc[:, newaxis]; - fp32 a[TM, 1] = checka ? *pa : 0; - fp32 b[TN, 1] = checkb ? *pb : 0; - c = dot(a, trans(b), c); + if(BT){ + std::swap(BS0, BS1); + std::swap(bcb0, bcb1); + std::swap(ldb0, ldb1); } - fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; - int1 checkc0[TM] = rxc < M; - int1 checkc1[TN] = ryc < N; - int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; - @checkc *pc = c; + std::string res = + R"( + const tunable int32 TM = {16, 32, 64, 128}; + const tunable int32 TN = {16, 32, 64, 128}; + const tunable int32 TK = {8}; + const tunable int32 GZ = {1}; + + void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C, + int32 M, int32 N, int32 K, + int32 lda, int32 ldb, int32 ldc, + int32 *locks, int32 grid0, int32 grid1) { + int32 rxa[TM] = get_global_range[TM](0); + int32 ryb[TN] = get_global_range[TN](1); + int32 rz = get_global_range[1](2); + int32 rka[TK] = 0 ... TK; + int32 rkb[TK] = 0 ... TK; + fp32 c[TM, TN] = 0; + int32 div = K / GZ; + int32 rem = K % GZ; + K = select(rz < rem, div - 1, div); + int32 offk = select(rz < rem, rz*(div + 1), rz*div + rem); + fp32* pa[)" + AS0 + ", " + AS1 + "] = A + (offk + rka" + bca0 + ")" + lda0 + " + rxa" + bca1 + lda1 + R"(; + fp32* pb[)" + BS0 + ", " + BS1 + "] = B + (offk + rkb" + bcb0 + ")" + ldb0 + " + ryb" + bcb1 + ldb1 + R"(; + fp32 a[)" + AS0 + ", " + AS1 + R"(] = *pa; + fp32 b[)" + BS0 + ", " + BS1 + R"(] = *pb; + int32 last_a = ((M*K - 1) - (TM*TK + 1)) / lda; + int32 last_b = ((K*N - 1) - (TN*TK + 1)) / ldb; + last_a = last_a / TK * TK; + last_b = last_b / TK * TK; + int32 bound = K - max(last_a, last_b); + for(int32 k = K; k > bound; k = k - TK){ + c = dot()" + usea + ", " + useb + R"(, c); + pa = pa + TK)" + lda0 + R"(; + pb = pb + TK)" + ldb0 + R"(; + a = *pa; + b = *pb; + } + int32 rxc[TM] = get_global_range[TM](0); + int32 ryc[TN] = get_global_range[TN](1); + for(int32 k = bound; k > 0; k = k - 1){ + int1 checka[TM, 1] = rxc[:, newaxis] < M; + int1 checkb[TN, 1] = ryc[:, newaxis] < N; + fp32* pa[TM, 1] = A + (offk + K - k))" + lda0 + " + rxc[:, newaxis]" + lda1 + R"(; + fp32* pb[TN, 1] = B + (offk + K - k))" + ldb0 + " + ryc[:, newaxis]" + ldb1 + R"(; + fp32 a[TM, 1] = checka ? *pa : 0; + fp32 b[TN, 1] = checkb ? *pb : 0; + c = dot(a, trans(b), c); + } + int32 ridx = get_range_id(0); + int32 ridy = get_range_id(1); + fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; + int32 *plock = locks + ridx + ridy*grid0; + while(__atomic_cas(plock, 0, 1)); + int32 *pcount = plock + grid0*grid1; + int32 count = *pcount; + int32 countp1 = select(count == GZ - 1, 0, count + 1); + int1 checkc0[TM] = rxc < M; + int1 checkc1[TN] = ryc < N; + int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; + if(count == 0) { + @checkc *pc = c; + *pcount = countp1; + } + else { + @checkc *pc = c + *pc; + *pcount = countp1; + } + __atomic_cas(plock, 1, 0); + } + )"; + return res; } -)"; + int main() { + bool AT = false; + bool BT = true; + // initialize default compute device auto context = triton::driver::backend::contexts::get_default(); triton::jit jit(context); @@ -128,16 +169,16 @@ int main() { // just-in-time compile source-code - std::vector params = { - 16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1 - }; -// jit.autotune("matmul",src, benchmark); - jit.add_module("matmul", src, params); + std::string src = triton_source(AT, BT); +// jit.autotune("matmul",src.c_str(), benchmark); + jit.add_module("matmul", src.c_str(), {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1}); +// jit.add_module("matmul", src.c_str(), {16, 2, 128, 32, 32, 32, 4, 2, 2, 8, 8, 4, 2, 1}); +// jit.add_module("matmul", src.c_str(), {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1}); triton::driver::kernel* kernel = jit.get_function("matmul"); triton::jit::launch_information info = jit.get_launch_info("matmul"); std::cout << "Performance: " << benchmark(kernel, info) << " TFLOPS " << std::endl; stream->read(dc, true, 0, hc); - simple_gemm(rc, ha, hb, M, N, K); + simple_gemm(AT, BT, rc, ha, hb, M, N, K); for(size_t i = 0; i < M*N; i++) if(std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; diff --git a/examples/python/pytorch/conv.cpp b/examples/python/pytorch/conv.cpp index d3d2bb212..71ea8e2be 100644 --- a/examples/python/pytorch/conv.cpp +++ b/examples/python/pytorch/conv.cpp @@ -88,6 +88,85 @@ void conv(read_only restrict fp32 *a, @checkc *pc = C; })"; +void build_conv_lut(int TK, + int stride_d, int stride_h, int stride_w, int stride_c, + int pad_d, int pad_h, int pad_w, + int T, int R, int S, + std::vector& res, std::vector& masks) { + /* convolution parameters */ + int F = T * R * S; + int Nlut = (TK + F - 1) / F * F; + int upsample_w = 1; + int upsample_h = 1; + int upsample_d = 1; + /* unpack index wrt filters */ + auto unpack = [&](int32_t trs){ + int32_t tr = trs / S; + int32_t s = trs - tr*S; + int32_t t = tr / R; + int32_t r = tr - t*R; + return std::make_tuple(t, r, s); + }; + /* increments */ + for(size_t i = 0; i < Nlut; ++i) + res[i] = (((i + TK) % Nlut) - i); + /* deltas */ + size_t Ds0 = Nlut; + size_t Ds1 = upsample_w; + size_t Ds2 = upsample_h; + size_t Ds3 = upsample_d; + for(size_t pd = 0; pd < Ds3; ++pd) + for(size_t ph = 0; ph < Ds2; ++ph) + for(size_t pw = 0; pw < Ds1; ++pw){ + int32_t* deltas_ptr = &res[Nlut + pw*Ds0 + ph*Ds0*Ds1 + pd*Ds0*Ds1*Ds2]; + // cumulative increments + for(size_t i = 0; i < Ds0; ++i){ + int32_t ctrs = i; + int32_t c = ctrs / F; + int32_t t, r, s; + std::tie(t, r, s) = unpack(ctrs % F); + // next indices + int32_t nextctrs = ctrs + TK; + int32_t nextc = nextctrs / F; + int32_t nextt, nextr, nexts; + std::tie(nextt, nextr, nexts) = unpack(nextctrs % F); + // diffs + int32_t cdiff = nextc - c; + int32_t tdiff = (nextt + pd)/upsample_d - (t + pd)/upsample_d; + int32_t rdiff = (nextr + ph)/upsample_h - (r + ph)/upsample_h; + int32_t sdiff = (nexts + pw)/upsample_w - (s + pw)/upsample_w; + // delta pointers + deltas_ptr[i] = cdiff*stride_c + sdiff*stride_w + rdiff*stride_h + tdiff*stride_d; + } + } + + /* Masks */ + size_t Ms0 = Nlut; + size_t Ms1 = 2*pad_w + 1; + size_t Ms2 = 2*pad_h + 1; + size_t Ms3 = 2*pad_d + 1; + + for(size_t pd = 0; pd < Ms3; ++pd) + for(size_t ph = 0; ph < Ms2; ++ph) + for(size_t pw = 0; pw < Ms1; ++pw){ + int32_t* masks_ptr = &masks[Nlut + pw*Ms0 + ph*Ms0*Ms1 + pd*Ms0*Ms1*Ms2]; + for(size_t i = 0; i < Ms0; ++i){ + int32_t t, r, s; + int32_t mask = 0x0; + for(size_t j = 0; j < TK; ++j){ + std::tie(t, r, s) = unpack((i + j) % F); + bool in_bounds_d = (t + pd) >= pad_d && (t + pd) < (T + pad_d); + bool in_bounds_h = (r + ph) >= pad_h && (r + ph) < (R + pad_h); + bool in_bounds_w = (s + pw) >= pad_w && (s + pw) < (S + pad_w); + mask |= (in_bounds_d && in_bounds_h && in_bounds_w) << j; + } + masks_ptr[i] = mask; + } + } + for(size_t i = 0; i < Nlut; ++i) + masks[i] = 0x0; +} + torch::Tensor conv_forward( const torch::Tensor data, const torch::Tensor weight) { @@ -95,37 +174,118 @@ torch::Tensor conv_forward( CHECK_INPUT(data); CHECK_INPUT(weight); // Unpack data shapes - const auto B = data.size(0); - const auto Ci = data.size(1); - const auto H = data.size(2); - const auto W = data.size(3); + const int32_t B = data.size(0); + const int32_t Ci = data.size(1); + const int32_t H = data.size(2); + const int32_t W = data.size(3); // Unpack weight shapes - const auto Cf = weight.size(0); - const auto R = weight.size(1); - const auto S = weight.size(2); - const auto K = weight.size(3); + const int32_t Cf = weight.size(0); + const int32_t T = 1; + const int32_t R = weight.size(1); + const int32_t S = weight.size(2); + const int32_t NF = weight.size(3); + // Conv parameters + int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; + int32_t pad_d = 0, pad_h = 0, pad_w = 0; + int32_t stride_h = 1, stride_w = 1; + // Output shapes + int32_t P = (H*upsample_h - R + 1 + 2*pad_h + stride_h - 1)/stride_h; + int32_t Q = (W*upsample_w - S + 1 + 2*pad_w + stride_w - 1)/stride_w; // Allocate output AT_CHECK(Ci == Cf, "Number of channels in data and weights must match"); - torch::Tensor output = torch::empty({B, K, H, W}, torch::kFloat); + torch::Tensor output = torch::empty({B, NF, P, Q}, torch::kFloat).cuda(); // Wrap CUDA handles - triton::driver::cu_stream sstream(at::cuda::getCurrentCUDAStream(), false); + c10::DeviceIndex device = output.storage().device().index(); + triton::driver::cu_stream sstream((CUstream)at::cuda::getCurrentCUDAStream(device).stream(), false); triton::driver::stream* stream = &sstream; triton::driver::context* ctx = stream->context(); triton::driver::cu_buffer d(ctx, (CUdeviceptr)data.storage().data(), false); triton::driver::cu_buffer w(ctx, (CUdeviceptr)weight.storage().data(), false); + triton::driver::cu_buffer a(ctx, (CUdeviceptr)output.storage().data(), false); // Create JIT triton::jit jit(ctx); std::vector params = { 16, 2, 64, 32, 2, 64, 16, 8, 2, 2, - 8, 8, + 8, 1, 8, 4 }; jit.add_module("conv", src, params); triton::driver::kernel* kernel = jit.get_function("conv"); triton::jit::launch_information info = jit.get_launch_info("conv"); - + // launch info + unsigned TM = info.global_range_size[0]; + unsigned TN = info.global_range_size[1]; + unsigned TK = jit.get_int("TK"); + // initialize constant memory + int FS = T*R*S; + int nlut = (TK + FS - 1) / FS * FS; + std::vector h_delta(nlut + upsample_d*upsample_h*upsample_w*nlut); + std::vector h_masks(nlut + (2*pad_h+1)*(2*pad_w+1)*(2*pad_d+1)*nlut); + // memory stride for images + int32_t stride_i_w = 1; + int32_t stride_i_h = W*stride_i_w; + int32_t stride_i_d = H*stride_i_h; + int32_t stride_i_c = 1*stride_i_d; + int32_t stride_i_n = Ci*stride_i_c; + // memory stride for activations + int32_t stride_o_q = 1; + int32_t stride_o_p = Q*stride_o_q; + int32_t stride_o_m = P*stride_o_p; + int32_t stride_o_k = 1*stride_o_m; + int32_t stride_o_n = NF*stride_o_k; + build_conv_lut(TK, stride_i_d, stride_i_h, stride_i_w, stride_i_c, pad_d, pad_h, pad_w, T, R, S, h_delta, h_masks); + // equivalent matmul dimensions + int32_t M = B*P*Q; + int32_t N = NF; + int32_t K = Ci*R*S; + triton::driver::buffer* delta = jit.get_buffer("delta"); + triton::driver::buffer* masks = jit.get_buffer("masks"); + stream->write(delta, false, 0, h_delta.size()*4, h_delta.data()); + stream->write(masks, false, 0, h_masks.size()*4, h_masks.data()); + // launch info + unsigned nthreads = info.num_threads; + std::array grid = {(M + TM - 1)/TM, (N + TN - 1)/TN, 1}; + // fast bounds-checking + unsigned lasti = (grid[0]*TM - 1)*TM + TM - 1; + unsigned lastj = (grid[1]*TN - 1)*TN + TN - 1; + unsigned lastk = TK - 1; + bool AT = false; + bool BT = true; + unsigned last_safe_a = (AT==false)?(M*K - 1 - lasti)/M - lastk : M*K - 1 - lasti*K - lastk; + unsigned last_safe_b = (BT==true)?(N*K - 1 - lastj)/N - lastk : N*K - 1 - lastj*K - lastk; + int32_t bound = std::max(1, std::max(K - last_safe_a, K - last_safe_b)); + // set arguments + kernel->setArg(0, *d.cu()); + kernel->setArg(1, *w.cu()); + kernel->setArg(2, *a.cu()); + kernel->setArg(3, M); + kernel->setArg(4, N); + kernel->setArg(5, K); + kernel->setArg(6, B); + kernel->setArg(7, H); + kernel->setArg(8, W); + kernel->setArg(9, B); + kernel->setArg(10, NF); + kernel->setArg(11, P); + kernel->setArg(12, Q); + kernel->setArg(13, Ci); + kernel->setArg(14, R); + kernel->setArg(15, S); + kernel->setArg(16, stride_i_n); + kernel->setArg(17, stride_i_c); + kernel->setArg(18, stride_i_h); + kernel->setArg(19, stride_i_w); + kernel->setArg(20, stride_o_n); + kernel->setArg(21, stride_o_k); + kernel->setArg(22, stride_o_p); + kernel->setArg(23, stride_o_q); + kernel->setArg(24, pad_h); + kernel->setArg(25, pad_w); + kernel->setArg(26, bound); +// // dry run + stream->enqueue(kernel, grid, {nthreads, 1, 1}); return output; } diff --git a/include/triton/codegen/shmem_allocation.h b/include/triton/codegen/shmem_allocation.h index 27a96f285..8a6f175a8 100644 --- a/include/triton/codegen/shmem_allocation.h +++ b/include/triton/codegen/shmem_allocation.h @@ -26,6 +26,7 @@ public: // utilities unsigned get_num_bytes(ir::value *x); + bool is_ld_padded(ir::value* x); // accessors unsigned get_offset(ir::value *x) const { return offsets_.at(x); } diff --git a/include/triton/jit.h b/include/triton/jit.h index b001148e5..a3e554c67 100644 --- a/include/triton/jit.h +++ b/include/triton/jit.h @@ -58,7 +58,6 @@ public: target_(target) { } void target_independent(ir::module &module) { -// ir::print(module, std::cout); optimize_dot.run(module); optimize_trans.run(module); } diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index c59ca2f12..7927e5400 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -525,10 +525,12 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, for(ir::value *op: user->ops()) create_tile(op, builder, references, seen, sh_mem_ptr); LLVMContext &ctx = builder.getContext(); - const auto& shapes = v->get_type()->get_tile_shapes(); - std::vector shapes2; - for(ir::constant_int* shape: shapes) - shapes2.push_back(shape->get_value()); + const auto& cshapes = v->get_type()->get_tile_shapes(); + std::vector shapes; + for(ir::constant_int* shape: cshapes) + shapes.push_back(shape->get_value()); + if(alloc_->is_ld_padded(v)) + shapes[0] += 4; Type* ty = llvm_type(v->get_type()->get_scalar_ty(), ctx); // create shared tile if(buffer_info_->is_shared(v)){ @@ -550,13 +552,13 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, Value *pre_ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(alloc_->get_offset(phi))); pre_ptr = builder.CreateBitCast(pre_ptr, ptr->getType()); Value *next_ptr = builder.CreateGEP(ptr, offset, "next_ptr"); - tmap_.insert({phi, new shared_tile(ty, shapes2, ptr, builder, offset)}); + tmap_.insert({phi, new shared_tile(ty, shapes, ptr, builder, offset)}); for(unsigned i = 0; i < phi->get_num_incoming(); i++) { ir::basic_block* inc_block = phi->get_incoming_block(i); ir::value* inc_value = phi->get_incoming_value(i); ir::instruction* terminator = inc_block->get_inst_list().back(); bool is_loop_latch = buffer_info_->is_loop_latch(phi, terminator); - tmap_.insert({inc_value, new shared_tile(ty, shapes2, is_loop_latch?next_ptr:pre_ptr, builder)}); + tmap_.insert({inc_value, new shared_tile(ty, shapes, is_loop_latch?next_ptr:pre_ptr, builder)}); } } else { @@ -564,16 +566,16 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, size_t offset = alloc_->get_offset(v); Value *ptr = builder.CreateGEP(sh_mem_ptr, builder.getInt32(offset)); ptr = builder.CreateBitCast(ptr, ptr_ty); - tmap_.insert({v, new shared_tile(ty, shapes2, ptr, builder)}); + tmap_.insert({v, new shared_tile(ty, shapes, ptr, builder)}); } } } // create distributed tile else { - const auto &shapes = v->get_type()->get_tile_shapes(); - std::vector axes(shapes.size()); - for(size_t d = 0; d < shapes.size(); d++){ - if(shapes[d]->get_value() > 1){ + const auto &cshapes = v->get_type()->get_tile_shapes(); + std::vector axes(cshapes.size()); + for(size_t d = 0; d < cshapes.size(); d++){ + if(cshapes[d]->get_value() > 1){ ir::metaparameter *x = params_->get_param(v, "nts.d" + std::to_string(d)); axes[d] = axes_.at(x); } @@ -583,7 +585,7 @@ void selection::create_tile(ir::value *v, IRBuilder<> &builder, } } bool vectorize = dynamic_cast(v); - distributed_tile *T = new distributed_tile(ty, shapes2, axes, builder, vectorize); + distributed_tile *T = new distributed_tile(ty, shapes, axes, builder, vectorize); tmap_.insert({v, T}); // constant range if(dynamic_cast(v) && !dynamic_cast(v)){ diff --git a/lib/codegen/shmem_allocation.cpp b/lib/codegen/shmem_allocation.cpp index 43ab8bc39..90cf7ef2b 100644 --- a/lib/codegen/shmem_allocation.cpp +++ b/lib/codegen/shmem_allocation.cpp @@ -10,8 +10,24 @@ namespace triton{ namespace codegen{ +bool shmem_allocation::is_ld_padded(ir::value *x) { + if(auto* phi = dynamic_cast(x)) { + bool result = false; + for(unsigned i = 0; i < phi->get_num_incoming(); i++) + result = result | is_ld_padded(phi->get_incoming_value(i)); + return result; + } + if(dynamic_cast(x)) + return true; + return false; +} + unsigned shmem_allocation::get_num_bytes(ir::value *x) { unsigned result = x->get_type()->get_primitive_size_in_bits() / 8; + if(is_ld_padded(x)){ + unsigned ld = x->get_type()->get_tile_shapes()[0]->get_value(); + result += 4 * result / ld; + } if(buffer_info_->is_double(x)) result *= 2; return result; @@ -23,8 +39,9 @@ void shmem_allocation::run(){ typedef std::multimap triples_map_type; std::vector I; - for(auto x: liveness_->intervals()) + for(auto x: liveness_->intervals()){ I.push_back(x.first); + } std::vector J = I; triples_map_type H;