diff --git a/examples/cpp/dot.cpp b/examples/cpp/dot.cpp index 591237fbe..d4f5adb6e 100644 --- a/examples/cpp/dot.cpp +++ b/examples/cpp/dot.cpp @@ -26,7 +26,7 @@ struct perf_t { perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){ typedef float NumericT; - std::string ty = "half"; + std::string ty = "float"; size_t dt_nbytes = sizeof(NumericT); triton::driver::context* context = stream->context(); std::vector hc(M*N); @@ -48,28 +48,40 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int stream->synchronize(); triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8, 8); // benchmark triton - double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream); + double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream); // benchmark cublas - NumericT alpha = 1; - NumericT beta = 0; - int32_t lda = AT ? K : M; - int32_t ldb = BT ? N : K; - int32_t ldc = M; +// NumericT alpha = 1; +// NumericT beta = 0; +// int32_t lda = AT ? K : M; +// int32_t ldb = BT ? N : K; +// int32_t ldc = M; // cublasGemmAlgo_t fastest; // cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K, // &alpha, da, lda, // db, ldb, &beta, // dc, ldc, &fastest); - double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K, - &alpha, da, lda, - db, ldb, &beta, - dc, ldc, nullptr, CUBLAS_GEMM_DEFAULT_TENSOR_OP); }, stream); +// double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K, +// &alpha, da, lda, +// db, ldb, &beta, +// dc, ldc, nullptr, CUBLAS_GEMM_DEFAULT_TENSOR_OP); }, stream); // result auto tflops = [&](double nanosec) { return dot.num_flops() / nanosec * 1e-3; }; perf_t result; - result.cublas = tflops(cublas_ns); +// result.cublas = tflops(cublas_ns); result.triton = tflops(triton_ns); + + // test + stream->read(dc, true, 0, hc); + std::vector rc(hc.size()); + dot.cpu_ref(rc, ha, hb); + for(size_t i = 0; i < M*N; i++) + if(!std::isnan(hc[i]) && std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){ + std::cout << i << " " << hc[i] << " " << rc[i] << std::endl; + exit(EXIT_FAILURE); + } + std::cout << "Pass!" << std::endl; + // clean-up delete dc; delete da; @@ -99,8 +111,8 @@ int main() { std::vector configs = { // {false, false, 8192, 512, 512}, // {false, true, 8192, 8192, 8192} - {false, true, 32768, 256, 256}, - {false, true, 32768, 256, 512} + {false, true, 128, 128, 128}, +// {false, true, 32768, 256, 512} // {true, false, 8192, 512, 512}, // {true, true, 8192, 512, 512} }; diff --git a/examples/python/tensorflow/dot.cpp b/examples/python/tensorflow/dot.cpp index 553ad11fa..bdcb5c62c 100644 --- a/examples/python/tensorflow/dot.cpp +++ b/examples/python/tensorflow/dot.cpp @@ -49,7 +49,7 @@ class DotOp : public OpKernel { triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false); triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false); // template - triton::dnn::dot dot(M, N, K, false, false, "half", "half", 8, 8, 8); + triton::dnn::dot dot(M, N, K, false, true, "half", "half", 8, 8, 8); dot.enqueue(stream, {&da, &db, &dc}); } diff --git a/examples/python/tensorflow/run.py b/examples/python/tensorflow/run.py index 8dbc6ac55..4b1f7ac53 100644 --- a/examples/python/tensorflow/run.py +++ b/examples/python/tensorflow/run.py @@ -23,7 +23,7 @@ def run_dot(): result = sess.run([c], feed_dict = {a: ha, b: hb})[0] # Test - hresult = np.dot(ha.T, hb.T).T + hresult = np.dot(ha.T, hb).T dif = np.abs(result - hresult) np.savetxt('dif.dat', dif, '%2.4f') print(hresult) @@ -131,6 +131,6 @@ def run_batchnorm(): print(np.max(np.abs(dg_t - dg_n))) print(np.max(np.abs(db_t - db_n))) -#run_dot() +run_dot() #run_shift() -run_batchnorm() +#run_batchnorm() diff --git a/include/triton/runtime/jit.h b/include/triton/runtime/jit.h index 16f56c0e5..fffec7794 100644 --- a/include/triton/runtime/jit.h +++ b/include/triton/runtime/jit.h @@ -73,11 +73,11 @@ public: optimize_dot.run(module); optimize_trans.run(module); optimize_dce.run(module); +// ir::print(module, std::cout); } void target_dependent(ir::module &module) { alignment_info.run(module); -// ir::print(module, std::cout); // reassociate.run(module); if(target_->is_gpu()){ shmem_info.run(module); diff --git a/lib/codegen/optimize_dot.cpp b/lib/codegen/optimize_dot.cpp index ee59145c7..8688e918e 100644 --- a/lib/codegen/optimize_dot.cpp +++ b/lib/codegen/optimize_dot.cpp @@ -33,8 +33,7 @@ void optimize_dot::run(ir::module &mod) { for(ir::function *fn: mod.get_function_list()) for(ir::basic_block *block: fn->blocks()) for(ir::instruction *i: block->get_inst_list()) - if(auto dot = dynamic_cast(i)) - if(dot->get_operand(1)->get_type()->get_tile_shapes()[1]->get_value() != 1){ + if(auto dot = dynamic_cast(i)){ builder.set_insert_point(i); ir::value *A = dot->get_operand(0); ir::value *B = dot->get_operand(1); diff --git a/lib/codegen/selection.cpp b/lib/codegen/selection.cpp index dc8980a28..e419f5a8d 100644 --- a/lib/codegen/selection.cpp +++ b/lib/codegen/selection.cpp @@ -135,8 +135,12 @@ void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_ Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, indices_t idx) { Value *result = builder.getInt32(0); result = builder.CreateAdd(result, idx[0]); - for(size_t i = 1; i < idx.size(); i++) - result = builder.CreateAdd(result, builder.CreateMul(idx[i], builder.getInt32(shapes[i-1]))); + Value *ld = builder.getInt32(shapes[0]); + for(size_t i = 1; i < idx.size(); i++) { + result = builder.CreateAdd(result, builder.CreateMul(idx[i], ld)); + if(i < idx.size() - 1) + ld = builder.CreateMul(ld, builder.getInt32(shapes[i])); + } return result; } @@ -854,10 +858,13 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & Value *&result = x.second; indices_t write_idx = x.first; write_idx.insert(write_idx.begin() + axis, lane); + // shared memory write pointer Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), write_idx); Value *write_ptr = builder.CreateGEP(base_ptr, write_offset); + // initialize shared memory + tgt_->add_barrier(module, builder); builder.CreateStore(result, write_ptr); // build result for(unsigned i = depth/2; i > 0; i >>= 1){ @@ -993,15 +1000,14 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & { shared_tile *TA = (shared_tile*)tmap_.at(A); shared_tile *TB = (shared_tile*)tmap_.at(B); - if(params_->get_fragment(ins, 0) == tune::STRIDED_SCAN) - { + if(params_->get_fragment(ins, 0) == tune::STRIDED_SCAN) { TA->set_vector_size(TC->axis(0).contiguous); TB->set_vector_size(TC->axis(1).contiguous); result->for_each([&](indices_t idx){ Value *res = TC->get_value(idx); for(unsigned K = 0; K < NK; ++K){ - indices_t a_idx = {idx[0], builder.getInt32(K)}; - indices_t b_idx = {builder.getInt32(K), idx[1]}; + indices_t a_idx = {idx[0], builder.getInt32(K), idx[2]}; + indices_t b_idx = {builder.getInt32(K), idx[1], idx[2]}; if(AT) std::swap(a_idx[0], a_idx[1]); if(BT) @@ -1013,13 +1019,11 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> & if(b->getType() != c_ty) b = builder.CreateFPCast(b, c_ty); res = builder.CreateCall(f_mul_add, {a, b, res}); - } result->set_value(idx, res); }); } - else - { + else { TA->set_vector_size(4*pack_size_0_); TB->set_vector_size(4*pack_size_1_); TA->set_return_mode(true); diff --git a/lib/codegen/shmem_allocation.cpp b/lib/codegen/shmem_allocation.cpp index 641170215..6e9bf86ff 100644 --- a/lib/codegen/shmem_allocation.cpp +++ b/lib/codegen/shmem_allocation.cpp @@ -42,8 +42,8 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) { } unsigned shmem_allocation::get_num_bytes(ir::value *x) { - unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8; if(auto *red = dynamic_cast(x)){ + unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8; size_t axis = red->get_axis(); ir::value *op = red->get_operand(0); auto shapes = op->get_type()->get_tile_shapes(); @@ -54,6 +54,7 @@ unsigned shmem_allocation::get_num_bytes(ir::value *x) { size_t depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value(); return num_elements * num_bytes * depth; } + unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8; unsigned pad = is_ld_padded(x); if(pad > 0){ unsigned ld = x->get_type()->get_tile_shapes()[0]->get_value(); diff --git a/lib/codegen/tune.cpp b/lib/codegen/tune.cpp index db3ed1c81..35445a72d 100644 --- a/lib/codegen/tune.cpp +++ b/lib/codegen/tune.cpp @@ -24,8 +24,7 @@ bool is_hmma(ir::value *v){ ir::type *b_ty = b->get_type(); // inputs have to be FP16 result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty(); - // reduction has to be multiple of 4 - result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0); + // reduction has to be multiple of 4: TODO } return result; } @@ -66,9 +65,10 @@ void tune::init_c_graph(ir::instruction *v) { for(unsigned i = 0; i < in_shapes.size(); i++){ if(i == axis) continue; -// std::cout << arg->get_name() << " " << v->get_name() << std::endl; add_constraint({reduce, current++}, {arg, i}); } +// add_constraint({reduce, 0}, {arg, 0}); +// add_constraint({reduce, 1}, {arg, 1}); return; } else @@ -81,8 +81,10 @@ void tune::init_c_graph(ir::instruction *v) { for(unsigned i = 0; i < shapes.size(); i ++){ bool is_one = shapes[i] == one; bool is_same = shapes[i] == op->get_type()->get_tile_shapes()[current]; - if(is_one) + if(is_one){ static_params_.insert({{v, i}, 1}); + add_constraint({v, i}, {v, i}); + } else if(!is_skewed && is_same) add_constraint({v, i}, {op, current++}); else{ @@ -114,9 +116,17 @@ void tune::init_c_graph(ir::instruction *v) { } // Matrix multiplication else if(dynamic_cast(v)){ + ir::value *A = v->get_operand(0); + ir::value *B = v->get_operand(1); ir::value *D = v->get_operand(2); - add_constraint({v, 0}, {D, 0}); - add_constraint({v, 1}, {D, 1}); + for(unsigned i = 0; i < shapes.size(); i++) + add_constraint({v, i}, {D, i}); + for(unsigned i = 2; i < shapes.size(); i++){ + if(shapes[i] == one) + static_params_.insert({{v, i}, 1}); + add_constraint({v, i}, {A, i}); + add_constraint({v, i}, {B, i}); + } } // Element-wise else if(dynamic_cast(v)) { @@ -242,7 +252,7 @@ void tune::run(ir::module &mod) { node_t node = *nodes_.begin(); if(fragments_[node] == STRIDED_SCAN) { ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1); - ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 2, 64); + ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32); connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++); nts->set_value(1); } @@ -266,14 +276,14 @@ void tune::run(ir::module &mod) { size_t addr_space = ptr_ty->get_pointer_address_space(); if(addr_space < 4){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 1, 8)); + std::unique_ptr tmp(ir::metaparameter::create(ctx, ty, 1, 1)); *params_.at(i).at("nts.d0") = *tmp; } } if(dynamic_cast(i) && i->get_type()->is_tile_ty()){ ir::type *ty = mod.get_builder().get_int32_ty(); - std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 1, 8)); - std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 1, 8)); + std::unique_ptr tmp1(ir::metaparameter::create(ctx, ty, 1, 1)); + std::unique_ptr tmp2(ir::metaparameter::create(ctx, ty, 1, 1)); *params_.at(i).at("nts.d0") = *tmp1; *params_.at(i).at("nts.d1") = *tmp2; } @@ -365,6 +375,7 @@ bool tune::check_constraints(std::map> &er // check constraints for(ir::instruction *i: grids_){ +// std::cout << i->get_name() << std::endl; ir::type *ty = i->get_type(); const auto &shapes = ty->get_tile_shapes(); // for each dimension, the product of layout components @@ -396,11 +407,15 @@ bool tune::check_constraints(std::map> &er errors[i].push_back("HMMA must have only 4 fragments per warp"); } int num_threads = get_req_num_threads(i); - if(num_threads % 64 != 0) + if(num_threads % 32 != 0) errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size"); if(num_threads != num_threads_) errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")"); } +// for(auto x: errors) +// for(auto e: x.second) +// std::cout << x.first->get_name() << ": " << e << std::endl; +// exit(EXIT_SUCCESS); return errors.empty(); } diff --git a/lib/dnn/base.cpp b/lib/dnn/base.cpp index 1c1ee8ceb..ebbe699c1 100644 --- a/lib/dnn/base.cpp +++ b/lib/dnn/base.cpp @@ -54,16 +54,17 @@ std::pair base::get_profile_impl(driver::stream *stream, std::v return num_flops() / ts * 1e-3; }; // auto-tune and save result - if(autotune != NO_TUNING) { + if(autotune == FULL_TUNING || autotune == PARTIAL_TUNING) { std::vector space = {}; if(autotune == PARTIAL_TUNING) space = search_space(); rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark, space); jit->add_module(name_.c_str(), src.c_str(), best.params); } - else { - params_t params = heuristics(); + else{ +// params_t params = heuristics(); // params_t params = jit->get_valid(name_.c_str(), src.c_str()); + params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 1}; jit->add_module(name_.c_str(), src.c_str(), params); } triton::driver::kernel* kernel = jit->get_function(name_.c_str()); diff --git a/lib/dnn/dot.cpp b/lib/dnn/dot.cpp index 3b9a2e300..7cc7563dc 100644 --- a/lib/dnn/dot.cpp +++ b/lib/dnn/dot.cpp @@ -74,12 +74,14 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel, void dot::triton_c_src(std::ostream &os) const { std::string AS0 = "TM", AS1 = "TK"; std::string BS0 = "TK", BS1 = "TN"; + std::string XAS0 = "TM", XAS1 = "TK/4", XAS2 = "4"; + std::string XBS0 = "TN", XBS1 = "TK/4", XBS2 = "4"; std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]"; std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]"; std::string lda0 = "*lda", lda1 = ""; std::string ldb0 = "", ldb1 = "*ldb"; - std::string usea = AT_ ? "trans(a)" : "a"; - std::string useb = BT_ ? "trans(b)" : "b"; + std::string usea = AT_ ? "trans(xa)" : "xa"; + std::string useb = BT_ ? "trans(xb)" : "xb"; if(AT_){ std::swap(AS0, AS1); std::swap(bca0, bca1); @@ -92,12 +94,15 @@ void dot::triton_c_src(std::ostream &os) const { } std::string AS = AS0 + ", " + AS1; std::string BS = BS0 + ", " + BS1; + std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2; + std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2; + std::string XCS = "TM, TN, 4"; std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")"; std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")"; std::string res = R"( -const tunable int TM = {16, 32, 64, 128}; -const tunable int TN = {16, 32, 64, 128}; +const tunable int TM = {32}; +const tunable int TN = {32}; const tunable int TK = {32}; const tunable int GZ = {1}; @@ -113,7 +118,7 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, int ryb[TN] = ridy * TN + (0 ... TN); int rka[TK] = 0 ... TK; int rkb[TK] = 0 ... TK; - float c[TM, TN] = 0; + float xc[)" + XCS + R"(] = 0; )" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(; )" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(; bool checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(; @@ -121,7 +126,9 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, )" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0; )" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0; for(int k = K; k > 0; k = k - TK){ - c = dot()" + usea + ", " + useb + R"(, c); + )" + a_ty_ + R"( xa[)" + XAS + "] = __reshape(a, " + XAS + R"(); + )" + b_ty_ + R"( xb[)" + XBS + "] = __reshape(b, " + XBS + R"(); + xc = dot()" + usea + ", " + useb + R"(, xc); pa = pa + TK)" + lda0 + R"(; pb = pb + TK)" + ldb0 + R"(; bool checka[)" + AS + R"(] = k > TK; @@ -131,11 +138,9 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A, } int rxc[TM] = ridx * TM + (0 ... TM); int ryc[TN] = ridy * TN + (0 ... TN); - bool checkc0[TM] = rxc < M; - bool checkc1[TN] = ryc < N; - bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :]; float* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis]; - @checkc *pc = c; + float c[TM, TN] = __sum(xc, 2); + *pc = c; } )"; diff --git a/lib/ir/builder.cpp b/lib/ir/builder.cpp index b10488161..d47fbbaa5 100644 --- a/lib/ir/builder.cpp +++ b/lib/ir/builder.cpp @@ -148,13 +148,13 @@ DEFINE_UNARY_FLOAT(fneg) value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs, value *rhs, const std::string &name, bool has_nuw, bool has_nsw) { - if(auto *clhs = dynamic_cast(lhs)){ - if(auto *crhs = dynamic_cast(rhs)){ - constant_expression* result = constant_expression::create(op, clhs, crhs); - if (has_nuw) result->set_has_no_unsigned_wrap(); - if (has_nsw) result->set_has_no_signed_wrap(); - return result; - } + auto *clhs = dynamic_cast(lhs); + auto *crhs = dynamic_cast(rhs); + if(clhs && crhs){ + constant_expression* result = constant_expression::create(op, clhs, crhs); + if (has_nuw) result->set_has_no_unsigned_wrap(); + if (has_nsw) result->set_has_no_signed_wrap(); + return result; } else { binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name); diff --git a/lib/lang/expression.cpp b/lib/lang/expression.cpp index 355a2a369..c54179943 100644 --- a/lib/lang/expression.cpp +++ b/lib/lang/expression.cpp @@ -101,7 +101,6 @@ ir::value *binary_expression::llvm_op(ir::module *mod, ir::builder &builder, ir: ir::value* binary_expression::codegen(ir::module *mod) const{ ir::value *lhs = lhs_->codegen(mod); ir::value *rhs = rhs_->codegen(mod); - std::cout << " " << typeid(*lhs_).name() << " " << typeid(*rhs_).name() << std::endl; ir::value *result = llvm_op(mod, mod->get_builder(), lhs, rhs, ""); return result; }