From c172bd518b571250e880401bd23c27d3c738ac56 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 30 Jun 2019 16:55:02 -0700 Subject: [PATCH] more stuff --- examples/cpp/shift.cpp | 7 +- include/triton/dnn/shift.h | 25 ++++--- include/triton/lang/expression.h | 6 +- include/triton/lang/parser.y | 2 +- include/triton/runtime/jit.h | 6 +- lib/codegen/tune.cpp | 2 +- lib/dnn/shift.cpp | 110 +++++++++++++++++++++---------- lib/driver/module.cpp | 2 +- lib/lang/expression.cpp | 54 ++++++--------- 9 files changed, 124 insertions(+), 90 deletions(-) diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index 83082ec4d..ba4f7fa43 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -16,6 +16,7 @@ int main() { auto context = triton::driver::backend::contexts::get_default(); // initialize just-in-time compiler triton::jit jit(context); + // initialization int32_t R = 3, S = 3; int32_t BS = 32, F = 1024; @@ -30,7 +31,7 @@ int main() { shift_w[c] = rand() % S - S/2; } // configuration - triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str); + triton::dnn::shift shift(BS, C, 1, H, W, 1, R, S, F, shift_h, shift_w, numeric_t_str, numeric_t_str, triton::dnn::shift::FPROP); // host buffers std::vector hc(shift.c_size()); std::vector rc(shift.c_size()); @@ -58,7 +59,7 @@ int main() { auto benchmark = [&](triton::driver::kernel* kernel, triton::jit::launch_information info) { shift.init(stream, (triton::driver::cu_module*)kernel->module()); - // launch info + // launch infoRR unsigned TM = info.global_range_size[0]; unsigned TN = info.global_range_size[1]; unsigned nthreads = info.num_threads; @@ -78,7 +79,7 @@ int main() { std::ostringstream oss; shift.src(oss); std::string src = oss.str(); -// jit.autotune("shift", src.c_str(), benchmark); + jit.autotune("shift", src.c_str(), benchmark); jit.add_module("shift", src.c_str(), params); triton::driver::kernel* kernel = jit.get_function("shift"); triton::jit::launch_information info = jit.get_launch_info("shift"); diff --git a/include/triton/dnn/shift.h b/include/triton/dnn/shift.h index 1b407aa43..3c4b53037 100644 --- a/include/triton/dnn/shift.h +++ b/include/triton/dnn/shift.h @@ -38,7 +38,9 @@ class shift { public: enum type { - FPROP + FPROP, + BPROP, + WGRAD }; private: @@ -85,11 +87,11 @@ public: OUT_DTYPE acc; for(int32_t p = 0; p < AH_; ++p) for(int32_t q = 0; q < AW_; ++q) - for(int32_t bs = 0; bs < NB_; ++bs) - for(int32_t k = 0; k < NF_; ++k) + for(int32_t bs = 0; bs < B_; ++bs) + for(int32_t k = 0; k < F_; ++k) { acc = 0; - for(int32_t c = 0; c < NC_; ++c){ + for(int32_t c = 0; c < C_; ++c){ int32_t h = p; int32_t w = q; if(h >= BH_/2 && h < AH_ - BH_/2 @@ -97,11 +99,11 @@ public: h += shift_h_[c]; w += shift_w_[c]; } - IN_DTYPE a = I[bs + w*NB_ + h*NB_*AW_ + c*NB_*AH_*AW_]; - IN_DTYPE b = F[k + c*NF_]; + IN_DTYPE a = I[bs + w*B_ + h*B_*AW_ + c*B_*AH_*AW_]; + IN_DTYPE b = F[k + c*F_]; acc = std::fma(a, b, acc); } - O[bs + q*NB_ + p*NB_*AW_ + k*NB_*AH_*AW_] = acc; + O[bs + q*B_ + p*B_*AW_ + k*B_*AH_*AW_] = acc; } } @@ -109,8 +111,8 @@ private: int32_t MAX_C_; int32_t TK_; // image size - int32_t NB_; - int32_t NC_; + int32_t B_; + int32_t C_; int32_t AD_; int32_t AH_; int32_t AW_; @@ -118,7 +120,7 @@ private: int32_t BD_; int32_t BH_; int32_t BW_; - int32_t NF_; + int32_t F_; // activation size int32_t CD_; int32_t CH_; @@ -149,6 +151,9 @@ private: // convolution type type ty_; bool bias_; + // transpose + bool AT_; + bool BT_; }; } diff --git a/include/triton/lang/expression.h b/include/triton/lang/expression.h index 6ce0819cb..dc9a6a449 100644 --- a/include/triton/lang/expression.h +++ b/include/triton/lang/expression.h @@ -160,13 +160,13 @@ private: class indexing_expression: public postfix_expression{ public: - indexing_expression(node *id, node *slices) - : id_((const identifier*)id), slices_((const list*)slices) {} + indexing_expression(node *lhs, node *slices) + : lhs_((const expression*)lhs), slices_((const list*)slices) {} ir::value* codegen(ir::module *) const; private: - const identifier* id_; + const expression* lhs_; const list* slices_; }; diff --git a/include/triton/lang/parser.y b/include/triton/lang/parser.y index 2c942b86c..579099e80 100644 --- a/include/triton/lang/parser.y +++ b/include/triton/lang/parser.y @@ -157,7 +157,7 @@ slice_list postfix_expression : primary_expression { $$ = $1;} - | identifier '[' slice_list ']' { $$ = new indexing_expression($1, $3);} + | primary_expression '[' slice_list ']' { $$ = new indexing_expression($1, $3);} ; /* Unary */ diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index 684bc6875..b74ae7c83 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -65,9 +65,9 @@ public: target_(target) { } void target_independent(ir::module &module) { - optimize_dot.run(module); - optimize_trans.run(module); -// ir::print(module, std::cout); + ir::print(module, std::cout); + optimize_dot.run(module); + optimize_trans.run(module); } void target_dependent(ir::module &module) { diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index ac56bd5ed..3821ecdb2 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -59,7 +59,7 @@ void tune::init_c_graph(ir::instruction *v) { else if(auto *downcast = dynamic_cast(v)) return; else{ -// std::cout << v->get_name() << std::endl; + std::cout << v->get_name() << std::endl; shapes = v->get_type()->get_tile_shapes(); } // Reshape diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index 102a970df..099192080 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -8,42 +8,63 @@ void shift::set_ld(const std::vector& shapes, std::vector& ld) { size_t size = shapes.size(); ld.resize(size); - ld[3] = 1; - ld[2] = shapes[3]*ld[3]; - ld[1] = shapes[2]*ld[2]; - ld[0] = shapes[1]*ld[1]; + ld[size - 1] = 1; + for(int i = size - 1; i >= 1; i--) + ld[i - 1] = shapes[i] * ld[i]; } -shift::shift(int B, int NC, +shift::shift(int B, int C, int D, int H, int W, int T, int R, int S, - int NF, + int F, const std::vector& shift_h, const std::vector& shift_w, std::string a_ty, std::string b_ty, type ty, bool bias) - : NB_(B), NC_(NC), + : B_(B), C_(C), AD_(D), AH_(H), AW_(W), BD_(T), BH_(R), BW_(S), - NF_(NF), + F_(F), shift_h_(shift_h), shift_w_(shift_w), a_ty_(a_ty), b_ty_(b_ty), ty_(ty), bias_(bias) { // max number of channels TK_ = 16; MAX_C_ = 8192 + TK_; + // transpose + AT_ = false; + BT_ = true; // equivalent matmul - M_ = NB_*AH_*AW_; - N_ = NF_; - K_ = NC_; + M_ = B_*AH_*AW_; + N_ = F_; + K_ = C_; // shapes - // input layout: C, H, W, BS - // filter layout: C, K - // output layout: K, H, W, BS - shapes_a_ = {NC, H, W, B}; - shapes_b_ = {NC, NF}; - shapes_c_ = {NF, H, W, B}; + // input layout: C, H, W, B + // filter layout: C, F + // output layout: F, H, W, B + shapes_a_ = {C, H, W, B}; + shapes_b_ = {C, F}; + shapes_c_ = {F, H, W, B}; + if(ty_ == WGRAD){ + shapes_b_.swap(shapes_c_); + shapes_a_.swap(shapes_b_); + AT_ = true; + BT_ = false; + M_ = K_; + N_ = C_; + K_ = B_*AH_*AW_; + } + if(ty_ == BPROP){ + shapes_a_.swap(shapes_c_); + AT_ = false; + BT_ = false; + K_ = F_; + M_ = B_*AH_*AW_; + N_ = C_; + } // memory strides set_ld(shapes_a_, ld_a_); + set_ld(shapes_b_, ld_b_); + set_ld(shapes_c_, ld_c_); // build LUTs build_deltas(); } @@ -57,7 +78,7 @@ void shift::build_deltas() { // populate look-up table for(unsigned c = 0; c < TK_; c++) h_deltas_[c] = offset(c); - for(unsigned c = 0; c < NC_; c++) + for(unsigned c = 0; c < C_; c++) h_deltas_[TK_ + c] = offset(c + TK_) - offset(c); } @@ -99,18 +120,36 @@ void shift::enqueue(driver::stream *stream, driver::kernel *kernel, kernel->setArg(3, M_); kernel->setArg(4, N_); kernel->setArg(5, K_); - kernel->setArg(6, NB_*AH_*AW_); - kernel->setArg(7, NB_); + kernel->setArg(6, B_*AH_*AW_); + kernel->setArg(7, B_); kernel->setArg(8, AH_); kernel->setArg(9, AW_); kernel->setArg(10, BH_); kernel->setArg(11, BW_); - // dry run std::array grid = {(M_ + TM - 1)/TM, (N_ + TN - 1)/TN, 1}; stream->enqueue(kernel, grid, {nthreads, 1, 1}); } void shift::src(std::ostream &os) { + std::string AS0 = "TM", AS1 = "TK"; + std::string BS0 = "TK", BS1 = "TN"; + std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]"; + std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; + std::string lda0 = "*lda", lda1 = ""; + std::string ldb0 = "", ldb1 = "*ldb"; + std::string usea = AT_ ? "trans(a)" : "a"; + std::string useb = BT_ ? "trans(b)" : "b"; + if(AT_){ + std::swap(AS0, AS1); + std::swap(bca0, bca1); + std::swap(lda0, lda1); + } + if(BT_){ + std::swap(BS0, BS1); + std::swap(bcb0, bcb1); + std::swap(ldb0, ldb1); + } + os << R"( const tunable int32 TM = {16, 32, 64, 128}; @@ -136,26 +175,27 @@ void shift(restrict read_only align(16) )" << a_ty_ << R"( *a, int32 raw[TM] = rawhc % AW; int32 rahc[TM] = rawhc / AW; int32 rah[TM] = rahc % AH; + __constant__ int32* pd[TK] = delta + rka; + multiple_of(4) int32 d[TK] = *pd; int1 maskh[TM] = (rah >= pad_h) && (rah < (AH - pad_h)); int1 maskw[TM] = (raw >= pad_w) && (raw < (AW - pad_w)); - int1 mask[TM, TK] = maskh[:, newaxis] && maskw[:, newaxis]; - __constant__ int32* pd[TK] = delta + rka; - multiple_of(4) int32 d[TK]; - d = *pd; - int32 offa1[TK] = rka*lda; - int32 inc[TM, TK] = mask ? d[newaxis, :] : offa1[newaxis, :]; - )" << a_ty_ << R"(* pa[TM, TK] = a + rxa[:, newaxis] + inc; - )" << b_ty_ << R"(* pb[TN, TK] = b + rkb[newaxis, :]*N + ryb[:, newaxis]; - )" << a_ty_ << R"( a[TM, TK] = *pa; - )" << b_ty_ << R"( b[TN, TK] = *pb; + int1 mask[)" << AS0 << ", " << AS1 << "] = maskh" << bca1 << " && maskw" << bca1 << R"(; + int32 inc_true[)" << AS0 << ", " << AS1 << "] = d" << bca0 << R"(; + int32 inc_false[)" << AS0 << ", " << AS1 << "] = rka" << bca0 << R"( * lda; + )" << a_ty_ << "* pa[" << AS0 << ", " << AS1 << R"(] = a + rxa)" << bca1 << R"( + (mask ? inc_true : inc_false); + )" << b_ty_ << "* pb[" << BS0 << ", " << BS1 << "] = b + ryb" << bcb1 << " + rkb" << bcb0 << R"(*N; + )" << a_ty_ << " a[" << AS0 << ", " << AS1 << R"(] = *pa; + )" << b_ty_ << " b[" << BS0 << ", " << BS1 << R"(] = *pb; for(int32 k = K; k > 0; k = k - TK){ - C = dot(a, trans(b), C); + C = dot()" << usea << "," << useb << R"(, C); pb = pb + TK*N; pd = pd + TK; d = *pd; - pa = pa + (mask ? d[newaxis, :] : TK*lda); - int1 checka[TM, TK] = k > TK; - int1 checkb[TN, TK] = k > TK; + inc_true = d)" << bca0 << R"(; + inc_false = TK * lda; + pa = pa + (mask ? inc_true : inc_false); + int1 checka[)" << AS0 << ", " << AS1 << R"(] = k > TK; + int1 checkb[)" << BS0 << ", " << BS1 << R"(] = k > TK; @checka a = *pa; @checkb b = *pb; } diff --git a/lib/driver/module.cpp b/lib/driver/module.cpp index 4ff863666..f11118401 100755 --- a/lib/driver/module.cpp +++ b/lib/driver/module.cpp @@ -255,7 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) { cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { } cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){ -// std::cout << source << std::endl; +// std::cout << source << sd::endl; cu_context::context_switcher ctx_switch(*context); // JIT compile source-code CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER}; diff --git a/lib/lang/expression.cpp b/lib/lang/expression.cpp index 6054a2694..388815164 100644 --- a/lib/lang/expression.cpp +++ b/lib/lang/expression.cpp @@ -175,7 +175,7 @@ ir::value* trans_expression::codegen(ir::module *mod) const { /* Postfix expression */ ir::value* indexing_expression::codegen(ir::module *mod) const{ - ir::value *in = mod->get_value(id_->name()); + ir::value *in = lhs_->codegen(mod); const std::vector &slices = slices_->values(); auto in_shapes = in->get_type()->get_tile_shapes(); ir::type::tile_shapes_t::value_type one = ir::tile_type::make_one(mod->get_context()); @@ -234,44 +234,32 @@ ir::value* cast_expression::codegen(ir::module *mod) const{ /* Conditional expression */ ir::value *conditional_expression::codegen(ir::module *mod) const{ ir::builder &builder = mod->get_builder(); + ir::basic_block::inst_list_t &instructions = builder.get_insert_block()->get_inst_list(); ir::value *pred = cond_->codegen(mod); ir::instruction *mask = (ir::instruction*)builder.create_mask(pred); + /* true value */ ir::value *true_mask = mask->get_result(0); - ir::value *false_mask = mask->get_result(1); + auto it_true_begin = instructions.end(); + it_true_begin--; ir::value *true_value = true_value_->codegen(mod); - ir::value *false_value = false_value_->codegen(mod); - if(auto *itn = dynamic_cast(true_value)) - itn->set_mask_pred(true_mask); - if(auto *itn = dynamic_cast(false_value)) - itn->set_mask_pred(false_mask); - bool is_float, is_ptr, is_int, is_signed; - ir::value *uncasted_true_value = true_value; - ir::value *uncasted_false_value = false_value; - implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed); implicit_broadcast(mod, pred, true_value); + it_true_begin++; + auto it_true_end = instructions.end(); + for(auto it = it_true_begin; it != it_true_end; it++) + (*it)->set_mask_pred(true_mask); + /* false value */ + ir::value *false_mask = mask->get_result(1); + auto it_false_begin = instructions.end(); + it_false_begin--; + ir::value *false_value = false_value_->codegen(mod); + it_false_begin++; implicit_broadcast(mod, pred, false_value); - { - ir::value *current = true_value; - while(current != uncasted_true_value) { - if(auto *itn = dynamic_cast(current)){ - itn->set_mask_pred(true_mask); - current = itn->get_operand(0); - } - else - break; - } - } - { - ir::value *current = false_value; - while(current != uncasted_false_value) { - if(auto *itn = dynamic_cast(current)){ - itn->set_mask_pred(false_mask); - current = itn->get_operand(0); - } - else - break; - } - } + auto it_false_end = instructions.end(); + for(auto it = it_false_begin; it != it_false_end; it++) + (*it)->set_mask_pred(false_mask); + /* cast */ + bool is_float, is_ptr, is_int, is_signed; + implicit_cast(builder, true_value, false_value, is_float, is_ptr, is_int, is_signed); ir::value *result = builder.create_merge(true_mask, true_value, false_mask, false_value); return result; }