diff --git a/examples/cpp/conv.cpp b/examples/cpp/conv.cpp index 2d6d7a845..dbe0591f0 100644 --- a/examples/cpp/conv.cpp +++ b/examples/cpp/conv.cpp @@ -18,12 +18,12 @@ int main() { int32_t pad_d = 0, pad_h = 0, pad_w = 0; int32_t stride_d = 1, stride_h = 1, stride_w = 1; int32_t upsample_d = 1, upsample_h = 1, upsample_w = 1; -// triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "fp32", "fp32", triton::dnn::conv::FPROP, 0); +// triton::dnn::conv configuration(128, 256, 1, 14, 14, 1, 5, 5, 512, 1, 1, 1, 0, 0, 0, 1, 1, 1, "float", "float", triton::dnn::conv::FPROP, 0); triton::dnn::conv configuration(B, NC, D, H, W, T, R, S, NF, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, upsample_d, upsample_h, upsample_w, - "fp32", "fp32", ty, 0); + "float", "float", ty, 0); // convolution configuration std::vector hc(configuration.c_size()); std::vector rc(configuration.c_size()); diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index 771e44c1f..591237fbe 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -26,7 +26,7 @@ struct perf_t { perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ typedef float NumericT; - std::string ty = "fp16"; + std::string ty = "half"; size_t dt_nbytes = sizeof(NumericT); triton::driver::context* context = stream->context(); std::vector hc(M*N); @@ -46,7 +46,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int stream->write(db, true, 0, hb); stream->write(dc, true, 0, hc); stream->synchronize(); - triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8); + triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8, 8); // benchmark triton double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream); // benchmark cublas diff --git a/examples/cpp/shift.cpp b/examples/cpp/shift.cpp index 38e0e37bf..1495de3c4 100644 --- a/examples/cpp/shift.cpp +++ b/examples/cpp/shift.cpp @@ -134,7 +134,7 @@ int main() { }; for(config_t c: resnet18){ for(op_t op: {op_t::FPROP, op_t::BPROP, op_t::WGRAD}){ - configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "fp16"}); + configs.push_back({c.B, c.C, c.H, c.W, c.R, c.S, c.F, c.stride_h, c.stride_w, op, layout_t::CHWN, "half"}); } } diff --git a/examples/python/pytorch/batchnorm.cpp b/examples/python/pytorch/batchnorm.cpp index 521137a9e..64559e197 100644 --- a/examples/python/pytorch/batchnorm.cpp +++ b/examples/python/pytorch/batchnorm.cpp @@ -37,7 +37,7 @@ std::vector triton::driver::cu_buffer m(ctx, (CUdeviceptr)fw_m.storage().data(), false); triton::driver::cu_buffer v(ctx, (CUdeviceptr)fw_v.storage().data(), false); // create template - triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32"); + triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "float"); batchnorm.enqueue(&stream, {&y, &m, &v, &x, &g, &b}); stream.synchronize(); return {fw_y, fw_m, fw_v}; @@ -79,7 +79,7 @@ std::vector triton::driver::cu_buffer dg(ctx, (CUdeviceptr)fw_dg.storage().data(), false); triton::driver::cu_buffer db(ctx, (CUdeviceptr)fw_db.storage().data(), false); // create config - triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32", eps); + triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "float", eps); batchnorm.enqueue(&stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v}); stream.synchronize(); return {fw_dx, fw_dg, fw_db}; diff --git a/examples/python/pytorch/conv.cpp b/examples/python/pytorch/conv.cpp index eab6ba9e7..91cef5441 100644 --- a/examples/python/pytorch/conv.cpp +++ b/examples/python/pytorch/conv.cpp @@ -30,7 +30,7 @@ torch::Tensor conv_common( stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, 1, 1, 1, - "fp32", "fp32", ty, has_bias); + "float", "float", ty, has_bias); // Bind memory triton::driver::cu_buffer a(ctx, (CUdeviceptr)torcha.storage().data(), false); triton::driver::cu_buffer b(ctx, (CUdeviceptr)torchb.storage().data(), false); diff --git a/examples/python/pytorch/shift.cpp b/examples/python/pytorch/shift.cpp index 7c86b227e..bd80d73d9 100644 --- a/examples/python/pytorch/shift.cpp +++ b/examples/python/pytorch/shift.cpp @@ -49,9 +49,9 @@ torch::Tensor shift_common( std::string dtype; at::ScalarType type = torcha.scalar_type(); switch(type){ - case at::ScalarType::Double: dtype = "fp64"; break; - case at::ScalarType::Float: dtype = "fp32"; break; - case at::ScalarType::Half: dtype = "fp16"; break; + case at::ScalarType::Double: dtype = "double"; break; + case at::ScalarType::Float: dtype = "float"; break; + case at::ScalarType::Half: dtype = "half"; break; default: AT_ERROR("unknown data-type for shift-conv"); } // Get configuration diff --git a/examples/python/tensorflow/batchnorm.cpp b/examples/python/tensorflow/batchnorm.cpp index 137a84809..956ecef24 100644 --- a/examples/python/tensorflow/batchnorm.cpp +++ b/examples/python/tensorflow/batchnorm.cpp @@ -58,7 +58,7 @@ public: triton::driver::cu_buffer m(ctx, fw_m->tensor_data().size(), (CUdeviceptr)fw_m->tensor_data().data(), false); triton::driver::cu_buffer v(ctx, fw_v->tensor_data().size(), (CUdeviceptr)fw_v->tensor_data().data(), false); // create config - triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "fp32"); + triton::dnn::batchnorm_forward batchnorm(C, 1, H, W, B, "float", triton::dnn::FULL_TUNING); batchnorm.enqueue(stream, {&y, &m, &v, &x, &g, &b}); } @@ -126,7 +126,7 @@ public: triton::driver::cu_buffer dg(ctx, fw_dg->tensor_data().size(), (CUdeviceptr)fw_dg->tensor_data().data(), false); triton::driver::cu_buffer db(ctx, fw_db->tensor_data().size(), (CUdeviceptr)fw_db->tensor_data().data(), false); // create config - triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "fp32"); + triton::dnn::batchnorm_backward batchnorm(C, 1, H, W, B, "float", triton::dnn::FULL_TUNING); batchnorm.enqueue(stream, {&dx, &dg, &db, &dy, &x, &g, &m, &v}); } diff --git a/examples/python/tensorflow/blocksparse.cpp b/examples/python/tensorflow/blocksparse.cpp index 0d37d382d..3a6a2505c 100644 --- a/examples/python/tensorflow/blocksparse.cpp +++ b/examples/python/tensorflow/blocksparse.cpp @@ -128,9 +128,9 @@ public: triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false); triton::driver::cu_buffer dlut(ctx, lut.tensor_data().size(), (CUdeviceptr)lut.tensor_data().data(), false); // create profile - triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp16", params_.bsize, params_.locks, params_.blocks, OP); + triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "half", params_.bsize, params_.locks, params_.blocks, OP); // blocksparse matmul - triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::FULL_TUNING); + triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING); triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks(); Tensor *tmp = nullptr; TensorShape tmp_shapes; diff --git a/examples/python/tensorflow/conv.cpp b/examples/python/tensorflow/conv.cpp index f06bf679c..00bf05473 100644 --- a/examples/python/tensorflow/conv.cpp +++ b/examples/python/tensorflow/conv.cpp @@ -61,7 +61,7 @@ public: stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, 1, 1, 1, - "fp16", "fp16", + "half", "half", triton::dnn::conv::FPROP, has_bias); // allocate output auto c_shapes = conv.c_shapes(); diff --git a/examples/python/tensorflow/dot.cpp b/examples/python/tensorflow/dot.cpp index 7acedb7e9..553ad11fa 100644 --- a/examples/python/tensorflow/dot.cpp +++ b/examples/python/tensorflow/dot.cpp @@ -49,7 +49,7 @@ class DotOp : public OpKernel { triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false); triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false); // template - triton::dnn::dot dot(M, N, K, false, false, "fp16", "fp16", 8, 8); + triton::dnn::dot dot(M, N, K, false, false, "half", "half", 8, 8, 8); dot.enqueue(stream, {&da, &db, &dc}); } diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index 88fe7ef3d..8dbc6ac55 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -105,7 +105,7 @@ def batch_norm_grad(op, dy, mean, var): def run_batchnorm(): - C, H, W, B = 32, 14, 14, 64 + C, H, W, B = 8, 4, 4, 32 np.random.seed(0) # Placeholders x = tf.placeholder(tf.float32, shape=[C, H, W, B]) @@ -131,6 +131,6 @@ def run_batchnorm(): print(np.max(np.abs(dg_t - dg_n))) print(np.max(np.abs(db_t - db_n))) -run_dot() +#run_dot() #run_shift() -#run_batchnorm() +run_batchnorm() diff --git a/examples/python/tensorflow/shift.cpp b/examples/python/tensorflow/shift.cpp index 28e10b679..cb28ce281 100644 --- a/examples/python/tensorflow/shift.cpp +++ b/examples/python/tensorflow/shift.cpp @@ -106,7 +106,7 @@ public: triton::dnn::shift shift(B, C, D, H, W, T, R_, S_, F, stride_h_, stride_w_, shift_h_data, shift_w_data, - "fp16", "fp16", OP, has_bias, layout_); + "half", "half", OP, has_bias, layout_); // shapes for c std::vector c_shapes; diff --git a/include/triton/codegen/selection.h b/include/triton/codegen/selection.h index e1d2dbf0b..317fc7f2b 100644 --- a/include/triton/codegen/selection.h +++ b/include/triton/codegen/selection.h @@ -91,6 +91,7 @@ public: void set_value(indices_t idx, llvm::Value *v); llvm::Value* get_value(indices_t idx); unsigned get_linear_index(indices_t idx); + indices_t get_ordered_indices(unsigned id); void for_each(std::function fn); const distributed_axis &axis(unsigned dim) { return axes_.at(dim); } diff --git a/include/triton/dnn/base.h b/include/triton/dnn/base.h index b9e2b886b..b991c3726 100644 --- a/include/triton/dnn/base.h +++ b/include/triton/dnn/base.h @@ -52,12 +52,15 @@ struct launch_context_t{ typedef std::vector params_t; class base { - friend class cmp_recompile; + friend class recompile_hash; + friend class recompile_equal; protected: // leading dimensions static void set_ld(const std::vector& shapes, std::vector& ld); + // list of retuning parameters + virtual std::vector retune_params() const = 0; private: // initialize @@ -70,8 +73,6 @@ private: triton::runtime::launch_information info) = 0; // number of flops virtual size_t num_flops() const = 0; - // comparison for maps - virtual bool operator<(const base& other) const = 0; // default parameters virtual std::vector search_space() const; virtual params_t heuristics() const; @@ -94,12 +95,21 @@ private: std::string name_; }; -struct cmp_recompile{ + +struct recompile_equal{ bool operator()(base* x, base* y) const{ - return *x < *y; + return typeid(*x) == typeid(*y) && + x->retune_params() == y->retune_params(); } }; +struct recompile_hash{ + unsigned operator()(base* x) const{ + return x->retune_params()[0]; + } +}; + + } } diff --git a/include/triton/dnn/batchnorm.h b/include/triton/dnn/batchnorm.h index 32c006b99..204ab631b 100644 --- a/include/triton/dnn/batchnorm.h +++ b/include/triton/dnn/batchnorm.h @@ -47,15 +47,15 @@ private: triton::runtime::launch_information info); // number of flops size_t num_flops() const; - // comparison for maps - bool operator<(const base& other) const; + // retuning parameters + std::vector retune_params() const; // clone base* clone() const; public: // constructor batchnorm_forward(int C, int D, int H, int W, int B, - std::string ty = "fp32", float eps = 1e-5); + std::string ty = "float", float eps = 1e-5); // triton-c source void triton_c_src(std::ostream &os) const; @@ -82,15 +82,15 @@ private: runtime::launch_information info); // number of flops size_t num_flops() const; - // comparison for maps - bool operator<(const base& other) const; + // retuning parameters + std::vector retune_params() const; // clone base* clone() const; public: // constructor batchnorm_backward(int C, int D, int H, int W, int B, - std::string ty = "fp32", float eps = 1e-5); + std::string ty = "float", float eps = 1e-5); // triton-c source void triton_c_src(std::ostream &os) const; diff --git a/include/triton/dnn/blocksparse/dot.h b/include/triton/dnn/blocksparse/dot.h index 98a1ce6fa..488c26c31 100644 --- a/include/triton/dnn/blocksparse/dot.h +++ b/include/triton/dnn/blocksparse/dot.h @@ -20,8 +20,8 @@ private: triton::runtime::launch_information info); // number of flops size_t num_flops() const; - // comparison for maps - bool operator<(const base& other) const; + // retuning parameters + std::vector retune_params() const; // default parameters std::vector search_space() const; params_t heuristics() const; diff --git a/include/triton/dnn/conv.h b/include/triton/dnn/conv.h index 2745d72bc..5a167531d 100644 --- a/include/triton/dnn/conv.h +++ b/include/triton/dnn/conv.h @@ -37,8 +37,8 @@ private: triton::runtime::launch_information info); // number of flops size_t num_flops() const; - // comparison for maps - bool operator<(const base& other) const; + // retuning parameters + std::vector retune_params() const; // clone base* clone() const; @@ -50,7 +50,7 @@ public: int stride_d, int stride_h, int stride_w, int pad_d, int pad_h, int pad_w, int upsample_d, int upsample_h, int upsample_w, - std::string a_ty = "fp32", std::string b_ty = "fp32", + std::string a_ty = "float", std::string b_ty = "float", type ty = FPROP, bool bias = false); // accessors diff --git a/include/triton/dnn/dot.h b/include/triton/dnn/dot.h index 30836357f..c655d12b5 100644 --- a/include/triton/dnn/dot.h +++ b/include/triton/dnn/dot.h @@ -16,8 +16,8 @@ private: void enqueue_impl(driver::stream *stream, driver::kernel *kernel, std::vector args, triton::runtime::launch_information info); - // comparison for maps - bool operator<(const base& other) const; + // retuning parameters + std::vector retune_params() const; // default parameters virtual std::vector search_space() const; virtual params_t heuristics() const; @@ -25,7 +25,7 @@ private: public: dot(int M, int N, int K, bool AT, bool BT, std::string a_ty, std::string b_ty, - unsigned alignment_lda, unsigned alignment_ldb); + unsigned align_lda, unsigned align_ldb, unsigned align_ldc); // number of flops size_t num_flops() const; @@ -70,6 +70,7 @@ private: std::string b_ty_; unsigned align_lda_; unsigned align_ldb_; + unsigned align_ldc_; driver::buffer *locks_; }; diff --git a/include/triton/dnn/shift.h b/include/triton/dnn/shift.h index 35ad312e0..4590c476e 100644 --- a/include/triton/dnn/shift.h +++ b/include/triton/dnn/shift.h @@ -64,7 +64,7 @@ public: int T, int R, int S, int NF, int stride_h, int stride_w, const int32_t* shift_h, const int32_t* shift_w, - std::string a_ty = "fp32", std::string b_ty = "fp32", + std::string a_ty = "float", std::string b_ty = "float", op_t ty = FPROP, bool bias = false, layout_t layout = CHWN); // look-up table @@ -86,8 +86,8 @@ public: size_t num_flops() const; // source void triton_c_src(std::ostream &os) const; - // comparison - bool operator<(const base& other) const; + // retuning parameters + std::vector retune_params() const; // clone base* clone() const; // cpu reference diff --git a/include/triton/lang/scanner.l b/include/triton/lang/scanner.l index af691349d..fc791ae94 100644 --- a/include/triton/lang/scanner.l +++ b/include/triton/lang/scanner.l @@ -30,19 +30,18 @@ using triton::lang::return_void; "for" { return return_impl(FOR, yytext); } "while" { return return_impl(WHILE, yytext); } "void" { return return_impl(VOID, yytext); } -"uint1" { return return_impl(UINT1, yytext); } -"uint8" { return return_impl(UINT8, yytext); } -"uint16" { return return_impl(UINT16, yytext); } -"uint32" { return return_impl(UINT32, yytext); } -"uint64" { return return_impl(UINT64, yytext); } -"int1" { return return_impl(INT1, yytext); } -"int8" { return return_impl(INT8, yytext); } -"int16" { return return_impl(INT16, yytext); } -"int32" { return return_impl(INT32, yytext); } -"int64" { return return_impl(INT64, yytext); } -"fp16" { return return_impl(FP16, yytext); } -"fp32" { return return_impl(FP32, yytext); } -"fp64" { return return_impl(FP64, yytext); } +"uchar" { return return_impl(UCHAR, yytext); } +"ushort" { return return_impl(USHORT, yytext); } +"uint" { return return_impl(UINT, yytext); } +"ulong" { return return_impl(ULONG, yytext); } +"bool" { return return_impl(BOOL, yytext); } +"char" { return return_impl(CHAR, yytext); } +"short" { return return_impl(SHORT, yytext); } +"int" { return return_impl(INT, yytext); } +"long" { return return_impl(LONG, yytext); } +"half" { return return_impl(HALF, yytext); } +"float" { return return_impl(FLOAT, yytext); } +"double" { return return_impl(DOUBLE, yytext); } "..." { return return_impl(ELLIPSIS, yytext); } "get_range_id" { return return_impl(GET_RANGE_ID, yytext); } "get_num_program" { return return_impl(GET_NUM_PROGRAM, yytext); } diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index 19fde0e84..939aebbfe 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -78,7 +78,7 @@ public: void target_dependent(ir::module &module) { alignment_info.run(module); // ir::print(module, std::cout); - reassociate.run(module); +// reassociate.run(module); if(target_->is_gpu()){ shmem_info.run(module); shmem_liveness.run(module); @@ -86,7 +86,7 @@ public: shmem_barriers.run(module); } vectorize.run(module); - optimize_dce.run(module); +// optimize_dce.run(module); // ir::print(module, std::cout); } diff --git a/include/triton/tools/bench.hpp b/include/triton/tools/bench.hpp index 6d71d27ae..74053b717 100644 --- a/include/triton/tools/bench.hpp +++ b/include/triton/tools/bench.hpp @@ -37,7 +37,7 @@ inline double bench(std::function const & op, driver::stream * stream) stream->synchronize(); while(total_time*1e-9 < 1e-3){ float norm = 1; - // normalize clock if possible to get roughly constant result + // normalize clock if possible to reduce noise in auto-tuning // if(auto cu_device = dynamic_cast(device)) // norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock(); tmr.start(); diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index a57713f38..84326529d 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -74,6 +74,11 @@ unsigned distributed_tile::get_linear_index(indices_t idx) { return indices_[idx]; } +indices_t distributed_tile::get_ordered_indices(unsigned id) { + return ordered_indices_.at(id); +} + + void distributed_tile::for_each(std::function fn) { for(unsigned i = 0; i < ordered_indices_.size(); i++) if(i % vector_size_ == 0) @@ -779,13 +784,21 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & // store if(auto *x = dynamic_cast(ins)){ distributed_tile* ptrs = (distributed_tile*)tmap_.at(x->get_pointer_operand()); - tile *scalars = tmap_.at(x->get_value_operand()); + distributed_tile* scalars = (distributed_tile*)tmap_.at(x->get_value_operand()); ir::value *mask = x->get_mask_operand(); distributed_tile* preds = (distributed_tile*)tmap_.at(mask); ptrs->for_each([&](indices_t idx){ Value *scalar = scalars->get_value(idx); Value *ptr = ptrs->get_value(idx); Value *pred = preds->get_value(idx); + BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn); + BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn); + builder.CreateCondBr(pred, mask_then_bb, mask_done_bb); + builder.SetInsertPoint(mask_then_bb); + builder.CreateStore(scalar, ptr); + builder.CreateBr(mask_done_bb); + builder.SetInsertPoint(mask_done_bb); + // std::string offset = ""; // if(GetElementPtrInst *gep = dyn_cast(ptr)) // if(gep->getNumIndices() == 1) @@ -796,14 +809,6 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & // std::string asm_str = "@$0 st.global.b32 [$1" + offset + "], $2;"; // InlineAsm *iasm = InlineAsm::get(ty, asm_str, "b,l,f", true); // builder.CreateCall(iasm, {pred, ptr, scalar}); - - BasicBlock *mask_then_bb = BasicBlock::Create(ctx, "mask_then", fn); - BasicBlock *mask_done_bb = BasicBlock::Create(ctx, "mask_done", fn); - builder.CreateCondBr(pred, mask_then_bb, mask_done_bb); - builder.SetInsertPoint(mask_then_bb); - builder.CreateStore(scalar, ptr); - builder.CreateBr(mask_done_bb); - builder.SetInsertPoint(mask_done_bb); }); } else if(auto *x = dynamic_cast(ins)) { @@ -893,11 +898,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & ir::value* in = ins->get_operand(0); distributed_tile *in_tile = (distributed_tile*)tmap_.at(in); result->for_each([&](indices_t out_idx){ - indices_t in_idx; - for(size_t k = 0; k < shapes.size(); k++){ - if(shapes[k]->get_value() > 1) - in_idx.push_back(out_idx[k]); - } + unsigned pos = result->get_linear_index(out_idx); + indices_t in_idx = in_tile->get_ordered_indices(pos); result->set_value(out_idx, in_tile->get_value(in_idx)); }); } diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index 820db29b3..9cdf2767d 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -63,14 +63,19 @@ void tune::init_c_graph(ir::instruction *v) { else shapes = v->get_type()->get_tile_shapes(); // Reshape - if(dynamic_cast(v)){ + if(dynamic_cast(v)) { ir::value *op = v->get_operand(0); unsigned current = 0; + bool is_skewed = false; for(unsigned i = 0; i < shapes.size(); i ++){ - if(shapes[i] == one) + bool is_one = shapes[i] == one; + bool is_same = shapes[i] == op->get_type()->get_tile_shapes()[current]; + if(is_one) static_params_.insert({{v, i}, 1}); - else + else if(!is_skewed && is_same) add_constraint({v, i}, {op, current++}); + else + is_skewed = true; } } // Splat @@ -81,9 +86,8 @@ void tune::init_c_graph(ir::instruction *v) { else if(dynamic_cast(v)){ ir::value *op = v->get_operand(0); size_t n_shapes = shapes.size(); - for(unsigned i = 0; i < n_shapes; i++){ + for(unsigned i = 0; i < n_shapes; i++) add_constraint({v, (i + 1) % n_shapes}, {op, i}); - } } // Broadcast else if(dynamic_cast(v)){ @@ -247,14 +251,14 @@ void tune::run(ir::module &mod) { size_t addr_space = ptr_ty->get_pointer_address_space(); if(addr_space < 4){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 4, 8)); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 1, 8)); *params_.at(i).at("nts.d0") = *tmp; } } if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 4, 8)); - std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 4, 8)); + std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 1, 8)); + std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 1, 8)); *params_.at(i).at("nts.d0") = *tmp1; *params_.at(i).at("nts.d1") = *tmp2; } diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp index 033e2497c..1c1ee8ceb 100644 --- a/lib/dnn/base.cpp +++ b/lib/dnn/base.cpp @@ -1,4 +1,5 @@ #include +#include #include "triton/dnn/base.h" #include "triton/runtime/jit.h" #include "triton/tools/bench.hpp" @@ -31,7 +32,7 @@ params_t base::heuristics() const { } std::pair base::get_profile_impl(driver::stream *stream, std::vector args, autotuning_t autotune) { - static std::map, cmp_recompile> m_jit; + static std::unordered_map, recompile_hash, recompile_equal> m_jit; driver::context* ctx = stream->context(); rt::jit* jit; /* the current template has not already been compiled */ diff --git a/lib/dnn/batchnorm.cpp b/lib/dnn/batchnorm.cpp index 34275a931..dcc9d6a4e 100644 --- a/lib/dnn/batchnorm.cpp +++ b/lib/dnn/batchnorm.cpp @@ -30,7 +30,7 @@ namespace dnn{ * --------------- */ batchnorm_forward::batchnorm_forward(int C, int D, int H, int W, int B, std::string ty, float eps) - : base("batchnorm"), + : base("batchnorm_forward"), C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps) { DHWB_ = D_*H_*W_*B_; rcpDHWB_ = (float)1 / DHWB_; @@ -40,12 +40,9 @@ size_t batchnorm_forward::num_flops() const { return C_*DHWB_; } -bool batchnorm_forward::operator <(const base& other) const { - auto *y = dynamic_cast(&other); - if(!y) - return true; - return std::tie(C_, D_, H_, W_, B_, ty_) - < std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_); + +std::vector batchnorm_forward::retune_params() const { + return {C_, D_, H_, W_, B_}; } base* batchnorm_forward::clone() const { @@ -74,50 +71,50 @@ void batchnorm_forward::enqueue_impl(driver::stream *stream, driver::kernel *ker void batchnorm_forward::triton_c_src(std::ostream &os) const { os << R"( -const tunable int32 TM = {32, 64, 128}; +const tunable int TM = {32, 64, 128}; -void batchnorm(fp32 *Y, fp32 *M, fp32 *V, - restrict read_only fp32 *X, - restrict read_only fp32 *G, - restrict read_only fp32 *B, - int32 DHWN, - fp32 rcpDHWN, fp32 eps) { - int32 rx[TM] = 0 ... TM; - fp32 *px[TM]; - fp32 x[TM]; - int32 c = get_range_id(1); - fp32 g = *(G + c); - fp32 b = *(B + c); +void batchnorm_forward(float *Y, float *M, float *V, + restrict read_only float *X, + restrict read_only float *G, + restrict read_only float *B, + int DHWN, + float rcpDHWN, float eps) { + int rx[TM] = 0 ... TM; + float *px[TM]; + float x[TM] = 0; + int c = get_range_id(1); + float g = *(G + c); + float b = *(B + c); - fp32 mean[TM] = 0; + float mean[TM] = 0; px = X + rx + c*DHWN; - for(int32 i = 0; i < DHWN; i = i + TM){ + for(int i = 0; i < DHWN; i = i + TM){ x = *px; mean = mean + x; px = px + TM; } - fp32 *pm = M + c; - fp32 m = __sum(mean) * rcpDHWN; + float *pm = M + c; + float m = __sum(mean) * rcpDHWN; *pm = m; - fp32 var[TM] = 0; + float var[TM] = 0; px = X + rx + c*DHWN; - for(int32 i = 0; i < DHWN; i = i + TM){ + for(int i = 0; i < DHWN; i = i + TM){ x = *px; x = x - m; var = var + x*x; px = px + TM; } - fp32 v = __sum(var) * rcpDHWN; - fp32 *pv = V + c; + float v = __sum(var) * rcpDHWN; + float *pv = V + c; *pv = v; - fp32 rstdg = 1 / sqrt(v + eps) * g; + float rstdg = 1 / sqrt(v + eps) * g; px = X + rx + c*DHWN; - fp32* py[TM] = Y + rx + c*DHWN; - for(int32 i = 0; i < DHWN; i = i + TM){ + float* py[TM] = Y + rx + c*DHWN; + for(int i = 0; i < DHWN; i = i + TM){ x = *px; - fp32 y[TM] = (x - m)*rstdg + b; + float y[TM] = (x - m)*rstdg + b; *py = y; px = px + TM; py = py + TM; @@ -130,7 +127,7 @@ void batchnorm(fp32 *Y, fp32 *M, fp32 *V, * --------------- */ batchnorm_backward::batchnorm_backward(int C, int D, int H, int W, int B, std::string ty, float eps) - : base("batchnorm"), + : base("batchnorm_backward"), C_(C), D_(D), H_(H), W_(W), B_(B), ty_(ty), eps_(eps) { } @@ -139,12 +136,8 @@ size_t batchnorm_backward::num_flops() const { return C_*D_*H_*W_*B_; } -bool batchnorm_backward::operator <(const base& other) const { - auto *y = dynamic_cast(&other); - if(!y) - return true; - return std::tie(C_, D_, H_, W_, B_, ty_) - < std::tie(y->C_, y->D_, y->H_, y->W_, y->B_, y->ty_); +std::vector batchnorm_backward::retune_params() const { + return {C_, D_, H_, W_, B_}; } base* batchnorm_backward::clone() const { @@ -174,54 +167,54 @@ void batchnorm_backward::enqueue_impl(driver::stream *stream, driver::kernel *ke void batchnorm_backward::triton_c_src(std::ostream &os) const { os << R"( -const tunable int32 TM = {32, 64, 128}; +const tunable int TM = {32, 64, 128}; -void batchnorm(fp32 *DX, fp32 *DG, fp32 *DB, - restrict read_only fp32 *DY, - restrict read_only fp32 *X, - restrict read_only fp32 *G, - restrict read_only fp32 *M, - restrict read_only fp32 *V, - int32 DHWN, fp32 rcpDHWN, fp32 epsilon) { - int32 rx[TM] = 0 ... TM; - int32 c = get_range_id(1); - int32 offset = c*DHWN; - fp32 g = *(G + c); - fp32 mean = *(M + c); - fp32 var = *(V + c); - fp32 rstd = 1 / sqrt(var + epsilon); - fp32* px[TM]; - fp32* pdx[TM]; - fp32* pdy[TM]; +void batchnorm_backward(float *DX, float *DG, float *DB, + restrict read_only float *DY, + restrict read_only float *X, + restrict read_only float *G, + restrict read_only float *M, + restrict read_only float *V, + int DHWN, float rcpDHWN, float epsilon) { + int rx[TM] = 0 ... TM; + int c = get_range_id(1); + int offset = c*DHWN; + float g = *(G + c); + float mean = *(M + c); + float var = *(V + c); + float rstd = 1 / sqrt(var + epsilon); + float* px[TM]; + float* pdx[TM]; + float* pdy[TM]; px = X + rx + offset; pdy = DY + rx + offset; - fp32 dg[TM] = 0; - fp32 db[TM] = 0; - for(int32 i = 0; i < DHWN; i = i + TM){ - fp32 x[TM] = *px; - fp32 dy[TM] = *pdy; + float dg[TM] = 0; + float db[TM] = 0; + for(int i = 0; i < DHWN; i = i + TM){ + float x[TM] = *px; + float dy[TM] = *pdy; dg = dg + dy*(x - mean)*rstd; db = db + dy; px = px + TM; pdy = pdy + TM; } - fp32 sdg = __sum(dg); - fp32 sdb = __sum(db); - fp32 *pdg = DG + c; - fp32 *pdb = DB + c; + float sdg = __sum(dg); + float sdb = __sum(db); + float *pdg = DG + c; + float *pdb = DB + c; *pdg = sdg; *pdb = sdb; px = X + rx + offset; pdy = DY + rx + offset; pdx = DX + rx + offset; - for(int32 i = 0; i < DHWN; i = i + TM){ - fp32 x[TM] = *px; - fp32 dy[TM] = *pdy; - fp32 xhat[TM] = (x - mean) * rstd; - fp32 xtmp[TM] = (xhat * dg + db) * rcpDHWN; - fp32 dx[TM] = (dy - xtmp) * rstd * g; + for(int i = 0; i < DHWN; i = i + TM){ + float x[TM] = *px; + float dy[TM] = *pdy; + float xhat[TM] = (x - mean) * rstd; + float xtmp[TM] = (xhat * dg + db) * rcpDHWN; + float dx[TM] = (dy - xtmp) * rstd * g; *pdx = dx; px = px + TM; pdy = pdy + TM; diff --git a/lib/dnn/blocksparse/dot.cpp b/lib/dnn/blocksparse/dot.cpp index c7e3a9a85..9c7fd95d9 100644 --- a/lib/dnn/blocksparse/dot.cpp +++ b/lib/dnn/blocksparse/dot.cpp @@ -10,12 +10,8 @@ size_t dot::num_flops() const { return 2.*nblocks_*BS_*BS_*N_; } -bool dot::operator <(const base& other) const { - auto *y = dynamic_cast(&other); - if(!y) - return true; - return std::tie(N_, S_, C_, BS_, nlocks_, ab_ty_, c_ty_, op_) - < std::tie(y->N_, y->S_, y->C_, y->BS_, y->nlocks_, y->ab_ty_, y->c_ty_, y->op_); +std::vector dot::retune_params() const{ + return {N_, S_, C_, BS_, nlocks_, op_}; } std::vector dot::search_space() const { @@ -92,35 +88,35 @@ void dot::triton_c_src(std::ostream &os) const { std::string ldb1 = (op_ == FPROP) ? "*TK" : "" ; std::string result = R"( - const tunable int32 TM = {16, 32, 64, 128}; - const tunable int32 TN = {)" + std::to_string(BS_) + R"(}; - const tunable int32 TK = {)" + std::to_string(BS_) + R"(}; + const tunable int TM = {16, 32, 64, 128}; + const tunable int TN = {)" + std::to_string(BS_) + R"(}; + const tunable int TK = {)" + std::to_string(BS_) + R"(}; void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A, restrict read_only align(16) )" + ab_ty_ + R"( *B, )" + c_ty_ + R"(* C, - int32 lda, int32 ldc, int32 N, - int32* lut, int32* locks, int32 nlocks){ - int32 ridx = get_range_id(0); - int32 ridy = get_range_id(1); - fp32 acc[TM, TN] = 0; - int32 rxa[TM] = ridx * TM + (0 ... TM); - int32 ryb[TN] = 0 ... TN; - int32 rka[TK] = 0 ... TK; - int32 rkb[TK] = 0 ... TK; - int1 checka[TM, TK] = (rxa < N)[:, newaxis]; - int32 offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda; - int32 offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(; - int32 *header = lut + ridy * 4; - int32 offset = *(header + 0); - int32 K = *(header + 1); - int32 column = *(header + 2); - int32 lockid = *(header + 3); - int32 *plut = lut + offset * 2; - for(int32 k = K; k > 0; k = k - 1) + int lda, int ldc, int N, + int* lut, int* locks, int nlocks){ + int ridx = get_range_id(0); + int ridy = get_range_id(1); + float acc[TM, TN] = 0; + int rxa[TM] = ridx * TM + (0 ... TM); + int ryb[TN] = 0 ... TN; + int rka[TK] = 0 ... TK; + int rkb[TK] = 0 ... TK; + bool checka[TM, TK] = (rxa < N)[:, newaxis]; + int offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda; + int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(; + int *header = lut + ridy * 4; + int offset = *(header + 0); + int K = *(header + 1); + int column = *(header + 2); + int lockid = *(header + 3); + int *plut = lut + offset * 2; + for(int k = K; k > 0; k = k - 1) { - int32 ak = *(plut + 0); - int32 bk = *(plut + 1); + int ak = *(plut + 0); + int bk = *(plut + 1); )" + ab_ty_ + "* pa[" + sizea + R"(] = A + offa + ak * TK * lda; )" + ab_ty_ + "* pb[" + sizeb + R"(] = B + offb + bk * TK * TN; )" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0; @@ -128,19 +124,19 @@ void dot::triton_c_src(std::ostream &os) const { acc = dot()" + usea + ", " + useb + R"(, acc); plut = plut + 2; } - int32 rxc[TM] = ridx * TM + (0 ... TM); - int32 ryc[TN] = column * TN + (0 ... TN); + int rxc[TM] = ridx * TM + (0 ... TM); + int ryc[TN] = column * TN + (0 ... TN); )" + c_ty_ + R"(" c[TM, TN] = acc; )" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis] + ryc[newaxis, :]*ldc; - int1 checkc[TM, TN] = (rxc < N)[:, newaxis]; + bool checkc[TM, TN] = (rxc < N)[:, newaxis]; if(lockid == 0) @checkc *pc = c; else { - int32 *plock = locks + ridx*nlocks + lockid - 1; - int32 *pcount = plock + get_num_program(0)*nlocks; + int *plock = locks + ridx*nlocks + lockid - 1; + int *pcount = plock + get_num_program(0)*nlocks; while(__atomic_cas(plock, 0, 1)); - int32 count = *pcount; + int count = *pcount; if(count == 0){ @checkc *pc = c; } diff --git a/lib/dnn/conv.cpp b/lib/dnn/conv.cpp index f54c63560..0f32455ea 100644 --- a/lib/dnn/conv.cpp +++ b/lib/dnn/conv.cpp @@ -98,20 +98,12 @@ conv::conv(int B, int NC, } // comparison for maps -bool conv::operator<(const base& other) const { - auto *y = dynamic_cast(&other); - if(!y) - return true; - return std::tie(NB_, NC_, AD_, AH_, AW_, - NF_, BD_, BH_, BW_, - pad_d_, pad_h_, pad_w_, - stride_d_, stride_h_, stride_w_, - a_ty_, b_ty_, ty_, bias_) - < std::tie(y->NB_, y->NC_, y->AD_, y->AH_, y->AW_, - y->NF_, y->BD_, y->BH_, y->BW_, - y->pad_d_, y->pad_h_, y->pad_w_, - y->stride_d_, y->stride_h_, y->stride_w_, - y->a_ty_, y->b_ty_, y->ty_, y->bias_); +std::vector conv::retune_params() const { + return {NB_, NC_, AD_, AH_, AW_, + NF_, BD_, BH_, BW_, + pad_d_, pad_h_, pad_w_, + stride_d_, stride_h_, stride_w_, + ty_, bias_}; } // clone @@ -549,114 +541,114 @@ void conv::triton_c_src(std::ostream &os) const { os << R"( -const tunable int32 TM = {16, 32, 64}; -const tunable int32 TN = {16, 32, 64}; -const tunable int32 TK = {)" << TK_ << R"(}; -const tunable int32 GZ = {1}; +const tunable int TM = {16, 32, 64}; +const tunable int TN = {16, 32, 64}; +const tunable int TK = {)" << TK_ << R"(}; +const tunable int GZ = {1}; )"; if(is_a_deltas_cst) - os << "__constant__ int32* delta = alloc_const int32[" + std::to_string(h_a_deltas_.size()) + "];\n"; + os << "__constant__ int* delta = alloc_const int[" + std::to_string(h_a_deltas_.size()) + "];\n"; if(b_lut_ && is_b_deltas_cst_) - os << "__constant__ int32* b_delta = alloc_const int32[" + std::to_string(h_b_deltas_.size()) + "];\n"; + os << "__constant__ int* b_delta = alloc_const int[" + std::to_string(h_b_deltas_.size()) + "];\n"; if(is_mask_cst_) - os << "__constant__ int32* masks = alloc_const int32[" + std::to_string(h_masks_.size()) + "];\n"; + os << "__constant__ int* masks = alloc_const int[" + std::to_string(h_masks_.size()) + "];\n"; os << R"( void conv(read_only restrict )" << a_ty_ << R"( *a, read_only restrict )" << b_ty_ << R"( *b, - fp32 *c, - fp32 *bias, - int32 M, int32 N, int32 K, - int32 AH, int32 AW, - int32 BH, int32 BW, - int32 CH, int32 CW, - int32 NC, - int32 lda_n, int32 lda_c, int32 lda_d, int32 lda_h, int32 lda_w, - int32 ldb_c, int32 ldb_t, int32 ldb_r, int32 ldb_s, int32 ldb_k, - int32 ldc_n, int32 ldc_k, int32 ldc_m, int32 ldc_p, int32 ldc_q, - int32 pad_h, int32 pad_w, - int32 stride_h, int32 stride_w, - int32 upsample_h, int32 upsample_w, - int32 off_uh, int32 off_uw, - int32 off_uah, int32 off_uaw, - int32 off_uch, int32 off_ucw, - int32 *locks, int32 grid0, int32 grid1)"; + float *c, + float *bias, + int M, int N, int K, + int AH, int AW, + int BH, int BW, + int CH, int CW, + int NC, + int lda_n, int lda_c, int lda_d, int lda_h, int lda_w, + int ldb_c, int ldb_t, int ldb_r, int ldb_s, int ldb_k, + int ldc_n, int ldc_k, int ldc_m, int ldc_p, int ldc_q, + int pad_h, int pad_w, + int stride_h, int stride_w, + int upsample_h, int upsample_w, + int off_uh, int off_uw, + int off_uah, int off_uaw, + int off_uch, int off_ucw, + int *locks, int grid0, int grid1)"; if(!is_a_deltas_cst) - os << ", int32* delta"; + os << ", int* delta"; if(b_lut_ && !is_b_deltas_cst_) - os << ", int32* b_delta"; + os << ", int* b_delta"; if(!is_mask_cst_) - os << ", int32* masks"; + os << ", int* masks"; os << R"(){ - int32 rxa[TM] = get_global_range[TM](0); - int32 rb0[TN] = get_global_range[TN](1); - int32 rz = get_global_range[1](2); - int32 rka[TK] = 0 ... TK; - int32 rkb[TK] = 0 ... TK; - fp32 C[TM, TN] = 0; - int32 ldlut = )" + std::to_string(Luts_) + R"(; - int32 div = K / GZ; - int32 rem = K % GZ; + int rxa[TM] = get_global_range[TM](0); + int rb0[TN] = get_global_range[TN](1); + int rz = get_global_range[1](2); + int rka[TK] = 0 ... TK; + int rkb[TK] = 0 ... TK; + float C[TM, TN] = 0; + int ldlut = )" + std::to_string(Luts_) + R"(; + int div = K / GZ; + int rem = K % GZ; K = select(rz < rem, div, div + rem); - int32 offk = rz*div; + int offk = rz*div; rka = rka + offk; rkb = rkb + offk; - int32 rabh[TM] = rxa / CW; - int32 raw[TM] = rxa % CW; - int32 rab[TM] = rabh / CH; - int32 rah[TM] = rabh % CH; + int rabh[TM] = rxa / CW; + int raw[TM] = rxa % CW; + int rab[TM] = rabh / CH; + int rah[TM] = rabh % CH; rah = rah)" + upaw + R"( - off_uah; raw = raw)" + upah + R"( - off_uaw; - int32 ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w; - int32 ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(; - int32 ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(; - int32 ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(; - int32 ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(; + int ra0[TM] = rab*lda_n + rah*lda_h + raw*lda_w; + int ra)" + ax[0] + ax[1] + "[TK] = rka / " + redax[2] + R"(; + int ra)" + ax[2] + "[TK] = rka % " + redax[2] + R"(; + int ra)" + ax[0] + "[TK] = ra" + ax[0] + ax[1] + " / " + redax[1] + R"(; + int ra)" + ax[1] + "[TK] = ra" + ax[0] + ax[1] + " % " + redax[1] + R"(; rar = )" + flipr + R"( rar; ras = )" + flips + R"( ras; rar = )" + upar + R"( rar; ras = )" + upas + R"( ras; - int32 ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; + int ra1[TK] = rac*lda_c + rar*lda_h + ras*lda_w; )" << a_ty_ << R"(* pa[TM, TK] = a + ra1[newaxis, :] + ra0[:, newaxis];)"; if(b_lut_){ os << R"( - int32 rb)" + ax[0] + ax[1] + "[TK] = rkb / " + redax[2] + R"(; - int32 rb)" + ax[2] + "[TK] = rkb % " + redax[2] + R"(; - int32 rb)" + ax[0] + "[TK] = rb" + ax[0] + ax[1] + " / " + redax[1] + R"(; - int32 rb)" + ax[1] + "[TK] = rb" + ax[0] + ax[1] + " % " + redax[1] + R"(; + int rb)" + ax[0] + ax[1] + "[TK] = rkb / " + redax[2] + R"(; + int rb)" + ax[2] + "[TK] = rkb % " + redax[2] + R"(; + int rb)" + ax[0] + "[TK] = rb" + ax[0] + ax[1] + " / " + redax[1] + R"(; + int rb)" + ax[1] + "[TK] = rb" + ax[0] + ax[1] + " % " + redax[1] + R"(; rbr = rbr*upsample_h + off_uh; rbs = rbs*upsample_w + off_uw; - int32 offdb[TK] = rkb % ldlut; - int32 rb1[TK] = rbc*ldb_c + rbr*ldb_r + rbs*ldb_s; - )" + b_delta_mem + R"( int32* pdb[TK] = b_delta + offdb + off_uw*ldlut + off_uh*ldlut*upsample_w; - int32 db[TK] = *pdb;)"; + int offdb[TK] = rkb % ldlut; + int rb1[TK] = rbc*ldb_c + rbr*ldb_r + rbs*ldb_s; + )" + b_delta_mem + R"( int* pdb[TK] = b_delta + offdb + off_uw*ldlut + off_uh*ldlut*upsample_w; + int db[TK] = *pdb;)"; } else{ os << R"( - int32 rb1[TK] = rkb)" + ldb0 + ";"; + int rb1[TK] = rkb)" + ldb0 + ";"; } os << R"( )" << b_ty_ << R"(* pb)" + BS + " = b + rb1" + bcb1 + " + rb0" + bcb0 + R"(*ldb_k; - int32 offda[TK] = rka % ldlut; - )" + a_delta_mem + R"( int32* pincd[TK] = delta + offda; - )" + a_delta_mem + R"( int32* pda[TK] = delta + ldlut + offda + off_uw*ldlut + off_uh*ldlut*upsample_w; - int32 da[TK] = *pda; - int32 incd[TK] = *pincd; - int32 maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0); - int32 maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0); - int32 offma = offk % ldlut; - )" + masks_mem + R"( int32* pm[TM] = masks + ldlut + offma + maskw*ldlut + maskh*ldlut*(2*pad_w + 1) + off_uw*ldlut*(2*pad_w+1)*(2*pad_h+1) + off_uh*ldlut*(2*pad_w+1)*(2*pad_h+1)*upsample_w; - )" + a_delta_mem + R"( int32* pincm[TM] = delta + offma; - int32 incm[TM] = *pincm; - int32 maska0[TM] = *pm; - int32 maska1[TK] = 1 << (0 ... TK); - int1 checka[TM, TK] = (maska0[:, newaxis] & maska1[newaxis, :]) > 0; - int1 checkb0[TN] = rb0 < N; - int1 checkb)" + BS + " = checkb0" + bcb0 + R"(; + int offda[TK] = rka % ldlut; + )" + a_delta_mem + R"( int* pincd[TK] = delta + offda; + )" + a_delta_mem + R"( int* pda[TK] = delta + ldlut + offda + off_uw*ldlut + off_uh*ldlut*upsample_w; + int da[TK] = *pda; + int incd[TK] = *pincd; + int maskh[TM] = pad_h + min(rah, 0) + max(rah + BH - AH, 0); + int maskw[TM] = pad_w + min(raw, 0) + max(raw + BW - AW, 0); + int offma = offk % ldlut; + )" + masks_mem + R"( int* pm[TM] = masks + ldlut + offma + maskw*ldlut + maskh*ldlut*(2*pad_w + 1) + off_uw*ldlut*(2*pad_w+1)*(2*pad_h+1) + off_uh*ldlut*(2*pad_w+1)*(2*pad_h+1)*upsample_w; + )" + a_delta_mem + R"( int* pincm[TM] = delta + offma; + int incm[TM] = *pincm; + int maska0[TM] = *pm; + int maska1[TK] = 1 << (0 ... TK); + bool checka[TM, TK] = (maska0[:, newaxis] & maska1[newaxis, :]) > 0; + bool checkb0[TN] = rb0 < N; + bool checkb)" + BS + " = checkb0" + bcb0 + R"(; )" << a_ty_ << R"( a[TM, TK] = checka ? *pa : 0; )" << b_ty_ << R"( b)" + BS + R"( = checkb ? *pb : 0; - int32 rkamin[TK] = rka - offk + TK; - for(int32 k = K; k > 0; k = k - TK){ + int rkamin[TK] = rka - offk + TK; + for(int k = K; k > 0; k = k - TK){ C = dot(a, )" + useb + R"(, C); pa = pa + da[newaxis, :]; pb = pb + )" + inc_pb + R"(; @@ -673,7 +665,7 @@ if(b_lut_){ pm = pm + incm; pincm = pincm + incm; incm = *pincm; - int1 checka1[TK] = (rkamin < k); + bool checka1[TK] = (rkamin < k); maska0 = *pm; checka = (maska0[:, newaxis] & maska1[newaxis, :]) > 0; checka = checka && checka1[newaxis,:]; @@ -681,31 +673,31 @@ if(b_lut_){ checkb = checkb && (k > TK); @checkb b = *pb; } - int32 rxc[TM] = get_global_range[TM](0); - int32 rc1[TN] = get_global_range[TN](1); - int32 rcn[TM] = rxc / (CH*CW); - int32 rcpq[TM] = rxc % (CH*CW); - int32 rcp[TM] = rcpq / CW; - int32 rcq[TM] = rcpq % CW; + int rxc[TM] = get_global_range[TM](0); + int rc1[TN] = get_global_range[TN](1); + int rcn[TM] = rxc / (CH*CW); + int rcpq[TM] = rxc % (CH*CW); + int rcp[TM] = rcpq / CW; + int rcq[TM] = rcpq % CW; rcp = rcp * upsample_h + off_uch; rcq = rcq * upsample_w + off_ucw; - int1 checkc1[TN] = rc1 < N; - int32 rc0[TM] = rcn * ldc_n + rcp * ldc_p + rcq * ldc_q; - fp32* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis]; - int1 checkc0[TM] = rxc < M; - int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; - int32 ridx = get_range_id(0); - int32 ridy = get_range_id(1); - int32 *plock = locks + ridx + ridy*grid0; + bool checkc1[TN] = rc1 < N; + int rc0[TM] = rcn * ldc_n + rcp * ldc_p + rcq * ldc_q; + float* pc[TM, TN] = c + rc1[newaxis, :]*ldc_k + rc0[:, newaxis]; + bool checkc0[TM] = rxc < M; + bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; + int ridx = get_range_id(0); + int ridy = get_range_id(1); + int *plock = locks + ridx + ridy*grid0; while(__atomic_cas(plock, 0, 1) == 1); - int32 *pcount = plock + grid0*grid1; - int32 count = *pcount; - int32 countp1 = select(count == GZ - 1, 0, count + 1); + int *pcount = plock + grid0*grid1; + int count = *pcount; + int countp1 = select(count == GZ - 1, 0, count + 1); if(count == 0) {)"; if(bias_ && ty_==FPROP){ os << R"( - fp32* pbias[TN] = bias + rc1; - fp32 bias[TN] = checkc1 ? *pbias : 0; + float* pbias[TN] = bias + rc1; + float bias[TN] = checkc1 ? *pbias : 0; C = C + bias[newaxis, :];)"; } os << R"( diff --git a/lib/dnn/dot.cpp b/lib/dnn/dot.cpp index 1b5e061d3..3b9a2e300 100644 --- a/lib/dnn/dot.cpp +++ b/lib/dnn/dot.cpp @@ -10,11 +10,11 @@ namespace dnn{ dot::dot(int M, int N, int K, bool AT, bool BT, std::string a_ty, std::string b_ty, - unsigned alignment_lda, unsigned alignment_ldb) + unsigned align_lda, unsigned align_ldb, unsigned align_ldc) : base("matmul"), M_(M), N_(N), K_(K), AT_(AT), BT_(BT), a_ty_(a_ty), b_ty_(b_ty), - align_lda_(alignment_lda), align_ldb_(alignment_ldb), + align_lda_(align_lda), align_ldb_(align_ldb), align_ldc_(align_ldc), locks_(nullptr) { } @@ -23,15 +23,10 @@ size_t dot::num_flops() const { return 2.*M_*N_*K_; } -// comparison for maps -bool dot::operator<(const base& other) const { - auto *y = dynamic_cast(&other); - if(!y) - return true; - return std::tie(M_, N_, K_, AT_, BT_, - a_ty_, b_ty_, align_lda_, align_ldb_) - < std::tie(y->M_, y->N_, y->K_, y->AT_, y->BT_, - y->a_ty_, y->b_ty_, y->align_lda_, y->align_ldb_); +// retune parameters +std::vector dot::retune_params() const { + return {M_, N_, K_, AT_, BT_, + (int)align_lda_, (int)align_ldb_}; } // clone @@ -101,45 +96,45 @@ void dot::triton_c_src(std::ostream &os) const { std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")"; std::string res = R"( -const tunable int32 TM = {16, 32, 64, 128, 256}; -const tunable int32 TN = {16, 32, 64, 128, 256}; -const tunable int32 TK = {32}; -const tunable int32 GZ = {1}; +const tunable int TM = {16, 32, 64, 128}; +const tunable int TN = {16, 32, 64, 128}; +const tunable int TK = {32}; +const tunable int GZ = {1}; void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, restrict read_only align(16) )" + b_ty_ + R"( *B, - fp32 *C, - int32 M, int32 N, int32 K, - )" + align_lda_str + R"( int32 lda, )" + align_ldb_str + R"(" int32 ldb, int32 ldc, - int32 bound, int32 *locks, int32 grid0, int32 grid1) { - int32 ridx = get_range_id(0); - int32 ridy = get_range_id(1); - int32 rxa[TM] = ridx * TM + (0 ... TM); - int32 ryb[TN] = ridy * TN + (0 ... TN); - int32 rka[TK] = 0 ... TK; - int32 rkb[TK] = 0 ... TK; - fp32 c[TM, TN] = 0; + restrict read_only align(16) float *C, + int M, int N, int K, + )" + align_lda_str + R"( int lda, )" + align_ldb_str + R"(" int ldb, int ldc, + int bound, int *locks, int grid0, int grid1) { + int ridx = get_range_id(0); + int ridy = get_range_id(1); + int rxa[TM] = ridx * TM + (0 ... TM); + int ryb[TN] = ridy * TN + (0 ... TN); + int rka[TK] = 0 ... TK; + int rkb[TK] = 0 ... TK; + float c[TM, TN] = 0; )" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(; )" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(; - int1 checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(; - int1 checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(; + bool checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(; + bool checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(; )" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0; )" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0; - for(int32 k = K; k > 0; k = k - TK){ + for(int k = K; k > 0; k = k - TK){ c = dot()" + usea + ", " + useb + R"(, c); pa = pa + TK)" + lda0 + R"(; pb = pb + TK)" + ldb0 + R"(; - int1 checka[)" + AS + R"(] = k > TK; - int1 checkb[)" + BS + R"(] = k > TK; + bool checka[)" + AS + R"(] = k > TK; + bool checkb[)" + BS + R"(] = k > TK; a = checka ? *pa : 0; b = checkb ? *pb : 0; } - int32 rxc[TM] = ridx * TM + (0 ... TM); - int32 ryc[TN] = ridy * TN + (0 ... TN); - int1 checkc0[TM] = rxc < M; - int1 checkc1[TN] = ryc < N; - int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; - fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; + int rxc[TM] = ridx * TM + (0 ... TM); + int ryc[TN] = ridy * TN + (0 ... TN); + bool checkc0[TM] = rxc < M; + bool checkc1[TN] = ryc < N; + bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; + float* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; @checkc *pc = c; } )"; diff --git a/lib/dnn/shift.cpp b/lib/dnn/shift.cpp index 3bf5e1035..5b50a73b4 100644 --- a/lib/dnn/shift.cpp +++ b/lib/dnn/shift.cpp @@ -28,7 +28,7 @@ shift::shift(int B, int C, layout_(layout){ // std::cout << B_ << " " << C_ << " " << F_ << " " << stride_h_ << " " << stride_w_ << " " << a_ty_ << " " << b_ty_ << " " << ty_ << " " << layout_ << std::endl; // max number of channels - TK_ = (ty == FPROP && a_ty_ == "fp32") ? 8 : 32; + TK_ = (ty == FPROP && a_ty_ == "float") ? 8 : 32; MAX_C_ = 8192 + TK_; // activation sizes CD_ = AD_ / stride_d_; @@ -204,26 +204,15 @@ size_t shift::ldb() const size_t shift::ldc() const { return M_; } -bool shift::operator <(const base& other) const{ - auto *y = dynamic_cast(&other); - if(!y) - return true; - return std::tie(B_, C_, F_, - AD_, AH_, AW_, - BD_, BH_, BW_, - CD_, CH_, CW_, - shift_h_, shift_w_, - stride_h_, stride_w_, - layout_, op_, - bias_) - < std::tie(y->B_, y->C_, y->F_, - y->AD_, y->AH_, y->AW_, - y->BD_, y->BH_, y->BW_, - y->CD_, y->CH_, y->CW_, - y->shift_h_, y->shift_w_, - y->stride_h_, y->stride_w_, - y->layout_, y->op_, - y->bias_); +std::vector shift::retune_params() const { + return {B_, C_, F_, + AD_, AH_, AW_, + BD_, BH_, BW_, + CD_, CH_, CW_, + (int64_t)shift_h_, (int64_t)shift_w_, + stride_h_, stride_w_, + layout_, op_, + bias_}; } void shift::init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info) { @@ -325,56 +314,56 @@ void shift::triton_c_src(std::ostream &os) const { if(is_chwn) { return R"( - int32 )" + rx + "wh[" + sz + "] = " + rkx + " / " + B + R"(; - int32 )" + rx + "b[" + sz + "] = " + rkx + " % " + B + R"(; - int32 )" + rx + "w[" + sz + "] = (" + rx + "wh % " + CW + R"() + pad_w; - int32 )" + rx + "h[" + sz + "] = (" + rx + "wh / " + CW + R"() + pad_h;)"; + int )" + rx + "wh[" + sz + "] = " + rkx + " / " + B + R"(; + int )" + rx + "b[" + sz + "] = " + rkx + " % " + B + R"(; + int )" + rx + "w[" + sz + "] = (" + rx + "wh % " + CW + R"() + pad_w; + int )" + rx + "h[" + sz + "] = (" + rx + "wh / " + CW + R"() + pad_h;)"; } else { return R"( - int32 )" + rx + "bh[" + sz + "] = " + rkx + " / " + CW + R"(; - int32 )" + rx + "w[" + sz + "] = (" + rkx + " % " + CW + R"() + pad_w; - int32 )" + rx + "h[" + sz + "] = (" + rx + "bh % " + CH + R"() + pad_h; - int32 )" + rx + "b[" + sz + "] = " + rx + "bh / " + CH + ";"; + int )" + rx + "bh[" + sz + "] = " + rkx + " / " + CW + R"(; + int )" + rx + "w[" + sz + "] = (" + rkx + " % " + CW + R"() + pad_w; + int )" + rx + "h[" + sz + "] = (" + rx + "bh % " + CH + R"() + pad_h; + int )" + rx + "b[" + sz + "] = " + rx + "bh / " + CH + ";"; } }; std::string result = R"( -const tunable int32 TM = {16, 32, 64, 128}; -const tunable int32 TN = {16, 32, 64, 128}; -const tunable int32 TK = {)" + std::to_string(TK_) + "};"; +const tunable int TM = {16, 32, 64, 128}; +const tunable int TN = {16, 32, 64, 128}; +const tunable int TK = {)" + std::to_string(TK_) + "};"; if(op_ == WGRAD) - result += "const tunable int32 GZ = {1};"; + result += "const tunable int GZ = {1};"; else - result += "const tunable int32 GZ = {1};"; + result += "const tunable int GZ = {1};"; result += R"( -__constant__ int32* delta_a = alloc_const int32[)" + std::to_string(MAX_C_) + R"(]; +__constant__ int* delta_a = alloc_const int[)" + std::to_string(MAX_C_) + R"(]; void shift(restrict read_only align(16) )" + a_ty_ + R"( *A, restrict read_only align(16) )" + b_ty_ + R"( *B, )" + c_ty_ + R"( *C, - int32 M, int32 N, int32 K, - int32 stride_h, int32 stride_w, - multiple_of(8) int32 lda_b, multiple_of(8) int32 lda_w, multiple_of(8) int32 lda_h, multiple_of(8) int32 lda_c, - multiple_of(8) int32 ldb_b, multiple_of(8) int32 ldb_w, multiple_of(8) int32 ldb_h, multiple_of(8) int32 ldb_c, - multiple_of(8) int32 ldc_b, multiple_of(8) int32 ldc_w, multiple_of(8) int32 ldc_h, multiple_of(8) int32 ldc_c, - int32 NB, - int32 AH, int32 AW, - int32 BH, int32 BW, - int32 CH, int32 CW, - int32* locks, int32 grid0, int32 grid1, int32 grid2) { - int32 ridx = get_range_id(0); - int32 ridy = get_range_id(1); - int32 rz = get_range_id(2); - int32 rxa[TM] = ridx*TM + (0 ... TM); - int32 ryb[TN] = ridy*TN + (0 ... TN); - int32 rka[TK] = 0 ... TK; - int32 rkb[TK] = 0 ... TK; - fp32 acc[TM, TN] = 0; - int32 pad_h = BH / 2; - int32 pad_w = BW / 2;)"; + int M, int N, int K, + int stride_h, int stride_w, + multiple_of(8) int lda_b, multiple_of(8) int lda_w, multiple_of(8) int lda_h, multiple_of(8) int lda_c, + multiple_of(8) int ldb_b, multiple_of(8) int ldb_w, multiple_of(8) int ldb_h, multiple_of(8) int ldb_c, + multiple_of(8) int ldc_b, multiple_of(8) int ldc_w, multiple_of(8) int ldc_h, multiple_of(8) int ldc_c, + int NB, + int AH, int AW, + int BH, int BW, + int CH, int CW, + int* locks, int grid0, int grid1, int grid2) { + int ridx = get_range_id(0); + int ridy = get_range_id(1); + int rz = get_range_id(2); + int rxa[TM] = ridx*TM + (0 ... TM); + int ryb[TN] = ridy*TN + (0 ... TN); + int rka[TK] = 0 ... TK; + int rkb[TK] = 0 ... TK; + float acc[TM, TN] = 0; + int pad_h = BH / 2; + int pad_w = BW / 2;)"; /* A offsets */ if(op_ == FPROP){ @@ -382,49 +371,49 @@ if(op_ == FPROP){ compute_bhw("ra", "TM", "rxa") + R"( raw = raw * )" + stride_w + R"(; rah = rah * )" + stride_h + R"(; - int32 offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; - int32 offa0[TM, TK] = offxa[:, newaxis]; - __constant__ int32* pd[TK] = delta_a + rka; - multiple_of(8) int32 d[TK] = *pd; - int32 offa1[TM, TK] = d[newaxis, :];)"; + int offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; + int offa0[TM, TK] = offxa[:, newaxis]; + __constant__ int* pd[TK] = delta_a + rka; + multiple_of(8) int d[TK] = *pd; + int offa1[TM, TK] = d[newaxis, :];)"; } if(op_ == BPROP){ result += compute_bhw("ra", "TM", "rxa") + R"( - int32 offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; - int32 offa0[TM, TK] = offxa[:, newaxis]; - int32 offa1[TM, TK] = rka[newaxis, :] * lda_c;)"; + int offxa[TM] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; + int offa0[TM, TK] = offxa[:, newaxis]; + int offa1[TM, TK] = rka[newaxis, :] * lda_c;)"; } if(op_ == WGRAD){ result += compute_bhw("ra", "TK", "rka") + R"( - int32 offa0[TK, TM] = rxa[newaxis, :] * lda_c; - int32 offxa[TK] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; - int32 offa1[TK, TM] = offxa[:, newaxis];)"; + int offa0[TK, TM] = rxa[newaxis, :] * lda_c; + int offxa[TK] = rab*)" + lda_b + R"( + raw*lda_w + rah*lda_h; + int offa1[TK, TM] = offxa[:, newaxis];)"; } /* B offsets */ if(op_ == FPROP){ result += R"( - int32 offb0[TN, TK] = ryb[:, newaxis]; - int32 offb1[TN, TK] = rkb[newaxis, :] * ldb_c;)"; + int offb0[TN, TK] = ryb[:, newaxis]; + int offb1[TN, TK] = rkb[newaxis, :] * ldb_c;)"; } if(op_ == BPROP){ result += R"( - int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c; - int32 offb1[TK, TN] = rkb[:, newaxis];)"; + int offb0[TK, TN] = ryb[newaxis, :] * ldb_c; + int offb1[TK, TN] = rkb[:, newaxis];)"; } if(op_ == WGRAD){ result += compute_bhw("rb", "TK", "rkb") + R"( - __constant__ int32* pd[TN] = delta_a + ryb; - multiple_of(8) int32 d[TN] = *pd; - multiple_of(8) int32 shift[TK, TN] = d[newaxis, :]; + __constant__ int* pd[TN] = delta_a + ryb; + multiple_of(8) int d[TN] = *pd; + multiple_of(8) int shift[TK, TN] = d[newaxis, :]; rbw = rbw * )" + stride_w + R"(; rbh = rbh * )" + stride_h + R"(; - int32 offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h; - int32 offb0[TK, TN] = ryb[newaxis, :] * ldb_c; - int32 offb1[TK, TN] = offkb[:, newaxis]; + int offkb[TK] = rbb*)" + ldb_b + R"( + rbw*ldb_w + rbh*ldb_h; + int offb0[TK, TN] = ryb[newaxis, :] * ldb_c; + int offb1[TK, TN] = offkb[:, newaxis]; )" + a_ty_ + "* pa_base[" + AS + R"(] = A + offa0; )" + b_ty_ + "* pb_base[" + BS + R"(] = B + offb0 + shift; )" + a_ty_ + "* pa[" + AS + R"(] = pa_base + offa1; @@ -439,14 +428,14 @@ else{ /* Main loop */ /* Increment A pointers */ result += R"( - int1 checka[)" + AS + "] = (rka < K)" + bca0 + R"(; - int1 checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(; + bool checka[)" + AS + "] = (rka < K)" + bca0 + R"(; + bool checkb[)" + BS + "] = (rkb < K)" + bcb0 + R"(; )" + a_ty_ + " a[" + AS + R"(] = checka ? *pa : 0; )" + b_ty_ + " b[" + BS + R"(] = checkb ? *pb : 0; - for(int32 k = K; k > 0; k = k - TK){ + for(int k = K; k > 0; k = k - TK){ acc = dot()" + usea + "," + useb + R"(, acc); - int1 checka[)" + AS + R"(] = k > TK; - int1 checkb[)" + BS + R"(] = k > TK;)"; + bool checka[)" + AS + R"(] = k > TK; + bool checkb[)" + BS + R"(] = k > TK;)"; /* Increment A pointers */ if(op_ == FPROP){ @@ -490,8 +479,8 @@ if(op_ == BPROP){ result += R"( b = checkb ? *pb : 0; } - int32 rxc[TM] = ridx*TM + (0 ... TM); - int32 ryc[TN] = ridy*TN + (0 ... TN);)"; + int rxc[TM] = ridx*TM + (0 ... TM); + int ryc[TN] = ridy*TN + (0 ... TN);)"; /* C offsets */ if(op_ == BPROP){ @@ -499,26 +488,26 @@ if(op_ == BPROP){ compute_bhw("rc", "TM", "rxc") + R"( rcw = rcw * )" + stride_w + R"(; rch = rch * )" + stride_h + R"(; - int32 offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)"; + int offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)"; } if(op_ == FPROP){ result += compute_bhw("rc", "TM", "rxc") + R"( - int32 offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)"; + int offxc[TM] = rcb*)" + ldc_b + R"( + rcw*ldc_w + rch*ldc_h;)"; } if(op_ == WGRAD){ result += R"( - int32 offxc[TM] = rxc;)"; + int offxc[TM] = rxc;)"; } result += R"(" )" + c_ty_ + R"( c[TM, TN] = acc; )" + c_ty_ + R"(* pc[TM, TN] = C + offxc[:, newaxis] + ryc[newaxis, :]*ldc_c; - int1 checkc0[TM] = rxc < M; - int1 checkc1[TN] = ryc < N; - int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)"; + bool checkc0[TM] = rxc < M; + bool checkc1[TN] = ryc < N; + bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];)"; if(op_ == BPROP){ result += R"( - __constant__ int32* pd[TN] = delta_a + ryc; + __constant__ int* pd[TN] = delta_a + ryc; )" + c_ty_ + R"(* shift_pc[TM, TN] = pc + (*pd)[newaxis, :]; @checkc *shift_pc = c; )"; diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index 928ec0812..1ce0a77e6 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -174,8 +174,15 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben std::lock_guard lock(mutex); for(ir::metaparameter *mp: mps) mp->set_value(params[i++]); +// for(size_t i = 0; i < params.size(); i++) +// std::cout << ((i==0)?"":", ") << params[i] << std::flush; +// std::cout << std::endl; passes_0.tune.init(tt_module_0); passes_0.tune.check_constraints(errors); +// for(auto x: errors) +// for(auto e: x.second){ +// std::cout << x.first->get_name() << ": " << e << std::endl; +// } } if(!errors.empty()) return; @@ -212,9 +219,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben best.perf = perf; best.params = params; } - for(size_t i = 0; i < params.size(); i++) - std::cout << ((i==0)?"":", ") << params[i] << std::flush; - std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl; +// for(size_t i = 0; i < params.size(); i++) +// std::cout << ((i==0)?"":", ") << params[i] << std::flush; +// std::cout << ", " << perf << " [ " << best.perf << " ] " << std::endl; } };