diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index 20b5bc72f..e9ad43f71 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -49,7 +49,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int stream->synchronize(); triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, ty, 8, 8, 8); // benchmark triton - double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream); + double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::FULL_TUNING);}, stream); // benchmark cublas // NumericT alpha = 1; // NumericT beta = 0; @@ -111,7 +111,7 @@ int main() { // shapes to benchmark std::vector configs = { // {false, false, 8192, 512, 512}, - {false, true, 128, 128, 128} + {false, true, 64, 64, 128} // {false, true, 128, 128, 128}, // {false, false, 128, 128, 128}, // {true, false, 128, 128, 128}, diff --git a/examples/python/tensorflow/blocksparse.cpp b/examples/python/tensorflow/blocksparse.cpp index e2a0b5144..1ff5e9f6f 100644 --- a/examples/python/tensorflow/blocksparse.cpp +++ b/examples/python/tensorflow/blocksparse.cpp @@ -46,6 +46,19 @@ Status XpropShape(InferenceContext* ctx) return Status::OK(); } +Status UpdatShape(InferenceContext* ctx) +{ + //printf("UpdatShape: %d\n", ctx->Rank(ctx->input(0))); + + int blocks, bsize; + TF_RETURN_IF_ERROR(ctx->GetAttr("blocks", &blocks)); + TF_RETURN_IF_ERROR(ctx->GetAttr("bsize", &bsize)); + + // (blocks, block_size, block_size) + DimensionHandle bsize_dim = ctx->MakeDim(bsize); + ctx->set_output(0, ctx->MakeShape({ ctx->MakeDim(blocks), bsize_dim, bsize_dim })); + return Status::OK(); +} typedef struct bsmm_params { @@ -72,34 +85,46 @@ typedef struct bsmm_params template class BlocksparseMatmulOp : public OpKernel { -public: - explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", ¶ms_.segments)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", ¶ms_.locks )); - OP_REQUIRES_OK(ctx, ctx->GetAttr("blocks", ¶ms_.blocks )); - OP_REQUIRES_OK(ctx, ctx->GetAttr("bsize", ¶ms_.bsize )); - OP_REQUIRES_OK(ctx, ctx->GetAttr("C", ¶ms_.C )); - OP_REQUIRES_OK(ctx, ctx->GetAttr("K", ¶ms_.K )); - OP_REQUIRES_OK(ctx, ctx->GetAttr("shared", ¶ms_.shared )); - OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", ¶ms_.alpha )); - OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", ¶ms_.beta )); - OP_REQUIRES_OK(ctx, ctx->GetAttr("gated_dw", &gated_dw_ )); - OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_ )); - OP_REQUIRES_OK(ctx, ctx->GetAttr("bench", &bench_)); - OP_REQUIRES(ctx, params_.K < params_.bsize*65536, errors::InvalidArgument("K < bsize*65536")); - OP_REQUIRES(ctx, params_.C < params_.bsize*65536, errors::InvalidArgument("C < bsize*65536")); - params_.pcount = 1; - params_.blk_A = 0; - is_gpu_ = ctx->device_type() == DEVICE_GPU; - if (bench_) { - repeat_ = bench_; - flops_ = (float)(params_.blocks * params_.bsize*params_.bsize); - const char* op = "FPROP"; - sprintf(bench_string_, "%s %02d-%d C:%05d K:%05d blks:%d", op, params_.bsize, axis_, params_.C, params_.K, params_.blocks); - } +private: + void ComputeDw(OpKernelContext* context){ + // get device/stream + GPUDevice device = context->eigen_device(); + triton::driver::cu_stream sstream(device.stream(), false); + triton::driver::context* ctx = sstream.context(); + triton::driver::stream* stream = &sstream; + // extract input + OpInputList x, dy, gate; + context->input_list( "x", &x); + context->input_list( "dy", &dy); + context->input_list("gate", &gate); + // sanity checks + params_.pcount = x.size(); + if (params_.pcount > 1) + errors::Internal("No more than 1 input allowed."); + if (params_.beta != 0.0f || params_.alpha != 1.0f) + errors::Internal("Not supported yet"); + // N + int N = 1; + int rank = x[0].dims(); + for (int i = 0; i < rank; i++) + if (i != axis_) + N *= x[0].dim_size(i); + // allocate output + Tensor* C; + TensorShape shapeC({ params_.blocks, params_.bsize, params_.bsize }); + OP_REQUIRES_OK(context, context->allocate_output(0, shapeC, &C)); + // wrap tensorflow handles + triton::driver::cu_buffer da(ctx, x[0].tensor_data().size(), (CUdeviceptr)x[0].tensor_data().data(), false); + triton::driver::cu_buffer db(ctx, dy[0].tensor_data().size(), (CUdeviceptr)dy[0].tensor_data().data(), false); + triton::driver::cu_buffer dc(ctx, C->tensor_data().size(), (CUdeviceptr)C->tensor_data().data(), false); + triton::driver::cu_buffer dlut(ctx, context->input(params_.pcount*2).tensor_data().size(), (CUdeviceptr)context->input(params_.pcount*2).tensor_data().data(), false); + // create profile + triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "half", params_.bsize, params_.locks, params_.blocks, OP); + // enqueue + dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::FULL_TUNING); } - void Compute(OpKernelContext* context){ + void ComputeYDx(OpKernelContext* context){ // get device/stream GPUDevice device = context->eigen_device(); triton::driver::cu_stream sstream(device.stream(), false); @@ -129,8 +154,8 @@ public: triton::driver::cu_buffer dlut(ctx, lut.tensor_data().size(), (CUdeviceptr)lut.tensor_data().data(), false); // create profile triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "half", params_.bsize, params_.locks, params_.blocks, OP); - // blocksparse matmul - triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::PARTIAL_TUNING); + // enqueue + triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING); triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks(); Tensor *tmp = nullptr; TensorShape tmp_shapes; @@ -138,6 +163,41 @@ public: OP_REQUIRES_OK(context, context->allocate_output(1, tmp_shapes, &tmp)); } +public: + + explicit BlocksparseMatmulOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("segments", ¶ms_.segments)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("locks", ¶ms_.locks )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("blocks", ¶ms_.blocks )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("bsize", ¶ms_.bsize )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("C", ¶ms_.C )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("K", ¶ms_.K )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shared", ¶ms_.shared )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("alpha", ¶ms_.alpha )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("beta", ¶ms_.beta )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("gated_dw", &gated_dw_ )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_ )); + OP_REQUIRES_OK(ctx, ctx->GetAttr("bench", &bench_)); + OP_REQUIRES(ctx, params_.K < params_.bsize*65536, errors::InvalidArgument("K < bsize*65536")); + OP_REQUIRES(ctx, params_.C < params_.bsize*65536, errors::InvalidArgument("C < bsize*65536")); + params_.pcount = 1; + params_.blk_A = 0; + is_gpu_ = ctx->device_type() == DEVICE_GPU; + if (bench_) { + repeat_ = bench_; + flops_ = (float)(params_.blocks * params_.bsize*params_.bsize); + const char* op = "FPROP"; + sprintf(bench_string_, "%s %02d-%d C:%05d K:%05d blks:%d", op, params_.bsize, axis_, params_.C, params_.K, params_.blocks); + } + } + + void Compute(OpKernelContext* context) override{ + if(OP == triton::dnn::blocksparse::WGRAD) + ComputeDw(context); + else + ComputeYDx(context); + } + private: bsmm_params params_; int axis_, bench_, repeat_, SMs_, major_, grid_n_; @@ -212,3 +272,33 @@ Multiply the matrix "a" by the blocksparse matrix "b". REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDX").Device(DEVICE_GPU).TypeConstraint("T"),BlocksparseMatmulOp); REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDX").Device(DEVICE_GPU).TypeConstraint("T"),BlocksparseMatmulOp); + +REGISTER_OP("TritonBlocksparseMatmulDW") + .Input("x: params * T") + .Input("dy: params * T") + .Input("lut: int64") + .Input("gate: ngate * float") + .Output("dw: T") + .Attr("T: {half, float, bfloat16}") + .Attr("params: int") + .Attr("blocks: int >=0") + .Attr("bsize: int") + .Attr("segments: int = 0") + .Attr("locks: int = 0") + .Attr("axis: int = 1") + .Attr("C: int >=0") + .Attr("K: int >=0") + .Attr("shared: int = 0") + .Attr("alpha: float = 1.0") + .Attr("beta: float = 0.0") + .Attr("gated_dw: bool = false") + .Attr("gate_grad: bool = false") + .Attr("bench: int = 0") + .Attr("ngate: int >= 0") + .SetShapeFn(UpdatShape) + .Doc(R"doc( +Multiply the matrix "a" by the blocksparse matrix "b". +)doc"); + +REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDW").Device(DEVICE_GPU).TypeConstraint("T"),BlocksparseMatmulOp); +REGISTER_KERNEL_BUILDER(Name("TritonBlocksparseMatmulDW").Device(DEVICE_GPU).TypeConstraint("T"),BlocksparseMatmulOp); diff --git a/include/triton/dnn/blocksparse/dot.h b/include/triton/dnn/blocksparse/dot.h index 488c26c31..f42d5b9d8 100644 --- a/include/triton/dnn/blocksparse/dot.h +++ b/include/triton/dnn/blocksparse/dot.h @@ -29,6 +29,9 @@ private: void init_impl(driver::stream *stream, driver::cu_module *module, triton::runtime::launch_information info); // deinit void deinit_impl(); + // source + std::string triton_c_src_ydx() const; + std::string triton_c_src_dw() const; public: // constructor dot(int32_t N, int32_t K, int32_t S, int32_t C, const std::string &ty, int32_t BS, int32_t nlocks, int32_t nblocks, op_t op = FPROP); diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index f43f94e8f..0fbd21938 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -75,7 +75,6 @@ public: optimize_dot.run(module); optimize_trans.run(module); dce.run(module); -// ir::print(module, std::cout); } void target_dependent(ir::module &module) { @@ -89,6 +88,7 @@ public: } vectorize.run(module); dce.run(module); +// ir::print(module, std::cout); } codegen::selection selection; diff --git a/lib/codegen/analysis/tune.cpp b/lib/codegen/analysis/tune.cpp index 7fba702cc..bd2b0cbce 100644 --- a/lib/codegen/analysis/tune.cpp +++ b/lib/codegen/analysis/tune.cpp @@ -250,15 +250,15 @@ void tune::run(ir::module &mod) { node_t node = *nodes_.begin(); if(fragments_[node] == STRIDED_SCAN) { ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1); - ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32); + ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 1, 8); connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++); nts->set_value(1); } else { - ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2); + ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 1); if(node.second == 2) fpw->set_value(1); - ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4); + ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 1); connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++); } } @@ -277,7 +277,7 @@ void tune::run(ir::module &mod) { size_t addr_space = ptr_ty->get_pointer_address_space(); if(addr_space < 4){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 4, 8)); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 1, 1)); *params_.at(i).at("nts.d0") = *tmp; } } @@ -287,8 +287,8 @@ void tune::run(ir::module &mod) { // *params_.at(i->get_operand(0)).at("mts.d2") = *mts_2; // *params_.at(i->get_operand(1)).at("mts.d2") = *mts_2; if(fragments_.at({i, 0}) == STRIDED_SCAN){ - std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 4, 8)); - std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 4, 8)); + std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 1, 1)); + std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 1, 1)); *params_.at(i).at("nts.d0") = *tmp1; *params_.at(i).at("nts.d1") = *tmp2; // for(size_t k = 2; k < shapes.size(); k++) @@ -423,7 +423,7 @@ bool tune::check_constraints(std::map> &er for(size_t k = 0; k < shapes.size(); k++){ prod *= params_[i]["fpw.d" + std::to_string(k)]->get_value(); } - if(prod != 4) + if(prod > 4) errors[i].push_back("HMMA must have only 4 fragments per warp"); } int num_threads = get_req_num_threads(i); diff --git a/lib/codegen/selection/selection.cpp b/lib/codegen/selection/selection.cpp index 491ff870f..43e08ee1e 100644 --- a/lib/codegen/selection/selection.cpp +++ b/lib/codegen/selection/selection.cpp @@ -554,7 +554,6 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id num_packs_0_ = num_rep_0 / pack_size_0_; num_packs_1_ = num_rep_1 / pack_size_1_; - /* intra warp offset */ // offset of quad in pair Value *in_pair_off_a = builder.CreateMul(builder.CreateUDiv(builder.CreateAnd(u_thread_id, _16), builder.getInt32(4)), @@ -566,7 +565,7 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id Value *pair_a_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4); Value *pair_b_id = builder.CreateUDiv(builder.CreateURem(u_thread_id, _16), _4); pair_a_id = builder.CreateURem(pair_a_id, builder.getInt32(fpw_0)); - pair_b_id = builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0)); + pair_b_id = builder.CreateURem(builder.CreateUDiv(pair_b_id, builder.getInt32(fpw_0)), builder.getInt32(fpw_1)); // Quad pair offset Value *pair_a_off = builder.CreateMul(pair_a_id, builder.getInt32(4 * pack_size_0_)); Value *pair_b_off = builder.CreateMul(pair_b_id, builder.getInt32(4 * pack_size_1_)); diff --git a/lib/dnn/blocksparse/dot.cpp b/lib/dnn/blocksparse/dot.cpp index 054904f27..97823e309 100644 --- a/lib/dnn/blocksparse/dot.cpp +++ b/lib/dnn/blocksparse/dot.cpp @@ -52,40 +52,56 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel, driver::buffer *b = args[1]; driver::buffer *c = args[2]; driver::buffer *lut = args[3]; - int32_t lda = N_; - int32_t ldc = N_; kernel->setArg(0, a); kernel->setArg(1, b); kernel->setArg(2, c); - kernel->setArg(3, lda); - kernel->setArg(4, ldc); - kernel->setArg(5, N_); - kernel->setArg(6, lut); - kernel->setArg(7, locks_.get()); - kernel->setArg(8, nlocks_); - int32_t TM = info.globals["TM"]; - size_t grid_0 = (N_ + TM - 1) / TM; - size_t grid_1 = S_; - if(nlocks_) - ((driver::cu_buffer*)locks_.get())->set_zero(stream, grid_0 * nlocks_ * 2 * 4); - stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1}); + if(op_ == FPROP || op_ == BPROP){ + kernel->setArg(3, N_); + kernel->setArg(4, BS_); + kernel->setArg(5, N_); + } + else{ + kernel->setArg(3, N_); + kernel->setArg(4, N_); + kernel->setArg(5, BS_); + } + kernel->setArg(6, N_); + kernel->setArg(7, lut); + kernel->setArg(8, locks_.get()); + kernel->setArg(9, nlocks_); + if(op_ == FPROP || op_ == BPROP){ + int32_t TM = info.globals["TM"]; + size_t grid_0 = (N_ + TM - 1) / TM; + size_t grid_1 = S_; + if(nlocks_) + ((driver::cu_buffer*)locks_.get())->set_zero(stream, grid_0 * nlocks_ * 2 * 4); + stream->enqueue(kernel, {grid_0, grid_1, 1}, {info.num_threads, 1, 1}); + } + else{ + size_t grid_0 = nblocks_; + stream->enqueue(kernel, {grid_0, 1, 1}, {info.num_threads, 1, 1}); + } } driver::buffer* dot::get_locks() const { return locks_.get(); } -void dot::triton_c_src(std::ostream &os) const { - std::string usea = (op_ == WGRAD) ? "trans(a)" : "a"; - std::string useb = (op_ == FPROP) ? "trans(b)" : "b"; +std::string dot::triton_c_src_ydx() const { + bool AT = (op_ == WGRAD); + bool BT = (op_ == FPROP); + std::string usea = AT ? "trans(a)" : "a"; + std::string useb = BT ? "trans(b)" : "b"; std::string sizea = "TM, TK"; - std::string sizeb = (op_ == FPROP) ? "TN, TK" : "TK, TN"; + std::string sizeb = BT ? "TN, TK" : "TK, TN"; std::string bca0 = ":, newaxis"; std::string bca1 = "newaxis, :"; - std::string bcb0 = (op_ == FPROP) ? ":, newaxis" : "newaxis, :"; - std::string bcb1 = (op_ == FPROP) ? "newaxis, :" : ":, newaxis"; - std::string ldb0 = (op_ == FPROP) ? "" : "*TK"; - std::string ldb1 = (op_ == FPROP) ? "*TK" : "" ; + std::string bcb0 = BT ? ":, newaxis" : "newaxis, :"; + std::string bcb1 = BT ? "newaxis, :" : ":, newaxis"; + std::string lda0 = AT ? "*lda" : ""; + std::string lda1 = AT ? "" : "*lda"; + std::string ldb0 = BT ? "" : "*ldb"; + std::string ldb1 = BT ? "*ldb" : "" ; std::string result = R"( const tunable int TM = {16, 32, 64, 128}; @@ -95,26 +111,25 @@ void dot::triton_c_src(std::ostream &os) const { void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A, restrict read_only align(16) )" + ab_ty_ + R"( *B, )" + c_ty_ + R"(* C, - int lda, int ldc, int N, - int* lut, int* locks, int nlocks) { + int lda, int ldb, int ldc, + int N, int* lut, + int* locks, int nlocks) { int ridx = get_range_id(0); - int ridy = get_range_id(1); float acc[TM, TN] = 0; - int rxa[TM] = ridx * TM + (0 ... TM); - int ryb[TN] = 0 ... TN; int rka[TK] = 0 ... TK; int rkb[TK] = 0 ... TK; - bool checka[TM, TK] = (rxa < N)[:, newaxis]; - int offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda; - int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(; - int *header = lut + ridy * 4; + int *header = lut + get_range_id(1) * 4; int offset = *(header + 0); int K = *(header + 1); int column = *(header + 2); int lockid = *(header + 3); + int rxa[TM] = ridx * TM + (0 ... TM); + int ryb[TN] = 0 ... TN; int *plut = lut + offset * 2; - for(int k = K; k > 0; k = k - 1) - { + int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(; + int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(; + bool checka[TM, TK] = (rxa < N)[:, newaxis]; + for(int k = K; k > 0; k = k - 1) { int ak = *(plut + 0); int bk = *(plut + 1); )" + ab_ty_ + "* pa[" + sizea + R"(] = A + offa + ak * TK * lda; @@ -137,17 +152,83 @@ void dot::triton_c_src(std::ostream &os) const { int *pcount = plock + get_num_program(0)*nlocks; while(__atomic_cas(plock, 0, 1)); int count = *pcount; - if(count == 0){ + if(count == 0) @checkc *pc = c; - } - else{ + else @checkc *pc = c + *pc; - } __atomic_exch(pcount, 1); __atomic_exch(plock, 0); } })"; - os << result; + + return result; +} + +std::string dot::triton_c_src_dw() const { + bool AT = (op_ == WGRAD); + bool BT = (op_ == FPROP); + std::string usea = AT ? "trans(a)" : "a"; + std::string useb = BT ? "trans(b)" : "b"; + std::string sizea = AT ? "TK, TM" : "TM, TK"; + std::string sizeb = BT ? "TN, TK" : "TK, TN"; + std::string bca0 = AT ? "newaxis, :" : ":, newaxis"; + std::string bca1 = AT ? ":, newaxis" : "newaxis, :"; + std::string bcb0 = BT ? ":, newaxis" : "newaxis, :"; + std::string bcb1 = BT ? "newaxis, :" : ":, newaxis"; + std::string lda0 = AT ? "*lda" : ""; + std::string lda1 = AT ? "" : "*lda"; + std::string ldb0 = BT ? "" : "*ldb"; + std::string ldb1 = BT ? "*ldb" : "" ; + std::string result = + R"( + const tunable int TM = {)" + std::to_string(BS_) + R"(}; + const tunable int TN = {)" + std::to_string(BS_) + R"(}; + const tunable int TK = {32}; + + void bsdot(restrict read_only align(16) )" + ab_ty_ + R"( *A, + restrict read_only align(16) )" + ab_ty_ + R"( *B, + )" + c_ty_ + R"(* C, + int lda, int ldb, int ldc, + int N, int* lut, + int* locks, int nlocks) { + int ridx = get_range_id(0); + float acc[TM, TN] = 0; + int rka[TK] = 0 ... TK; + int rkb[TK] = 0 ... TK; + int *header = lut + ridx * 2; + int offx = *(header + 0); + int offy = *(header + 1); + int rxa[TM] = offx*TM + (0 ... TM); + int ryb[TN] = offy*TN + (0 ... TN); + bool checka[TK, TM] = (rka < N)[:, newaxis]; + bool checkb[TK, TN] = (rkb < N)[:, newaxis]; + int offa[)" + sizea + "] = rxa[" + bca0 + "]" + lda0 + " + rka[" + bca1 + "]" + lda1 + R"(; + int offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(; + )" + ab_ty_ + " * pa[" + sizea + R"(] = A + offa; + )" + ab_ty_ + " * pb[" + sizeb + R"(] = B + offb; + )" + ab_ty_ + " a[" + sizea + R"(] = checka ? *pa : 0; + )" + ab_ty_ + " b[" + sizeb + R"(] = checkb ? *pb : 0; + for(int k = N; k > 0; k = k - TK) { + acc = dot()" + usea + ", " + useb + R"(, acc); + pa = pa + TK)" + lda1 + R"(; + pb = pb + TK)" + ldb1 + R"(; + a = checka ? *pa : 0; + b = checkb ? *pb : 0; + } + int rxc[TM] = (0 ... TM); + int ryc[TN] = (0 ... TN); + )" + c_ty_ + R"( c[TM, TN] = acc; + )" + c_ty_ + R"(* pc[TM, TN] = C + rxc[:, newaxis]*TM + ryc[newaxis, :] + ridx*TM*TN; + *pc = c; + })"; + + return result; +} +void dot::triton_c_src(std::ostream &os) const { + if(op_ == FPROP || op_ == BPROP) + os << triton_c_src_ydx(); + else + os << triton_c_src_dw(); } diff --git a/lib/dnn/dot.cpp b/lib/dnn/dot.cpp index 22198b7af..30cec06a4 100644 --- a/lib/dnn/dot.cpp +++ b/lib/dnn/dot.cpp @@ -106,8 +106,8 @@ void dot::triton_c_src(std::ostream &os) const { std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")"; std::string res = R"( -const tunable int TM = {32}; -const tunable int TN = {32}; +const tunable int TM = {8}; +const tunable int TN = {8}; const tunable int TK = {32}; const tunable int GZ = {1}; diff --git a/lib/runtime/jit.cpp b/lib/runtime/jit.cpp index 141bb8054..4c7a030f8 100644 --- a/lib/runtime/jit.cpp +++ b/lib/runtime/jit.cpp @@ -37,13 +37,13 @@ void parallel_loop_nest(std::vector const & ranges, size_t D = ranges.size(); std::vector values(D, 0); // thread pools - ThreadPool pool(nthreads); +// ThreadPool pool(nthreads); // Start with innermost loop size_t i = D - 1; while(true){ // Execute function - pool.enqueue(f,values); -// f(values); +// pool.enqueue(f,values); + f(values); while(values[i]++ == ranges[i] - 1){ if(i == 0) return;