basic split-k across warps working for GEMM

This commit is contained in:
Philippe Tillet
2019-08-05 19:33:28 -07:00
parent 899b2b72e1
commit d62e581ab3
12 changed files with 99 additions and 63 deletions

View File

@@ -26,7 +26,7 @@ struct perf_t {
perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int32_t N, int32_t K){
typedef float NumericT;
std::string ty = "half";
std::string ty = "float";
size_t dt_nbytes = sizeof(NumericT);
triton::driver::context* context = stream->context();
std::vector<NumericT> hc(M*N);
@@ -48,28 +48,40 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
stream->synchronize();
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, 8, 8, 8);
// benchmark triton
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
// benchmark cublas
NumericT alpha = 1;
NumericT beta = 0;
int32_t lda = AT ? K : M;
int32_t ldb = BT ? N : K;
int32_t ldc = M;
// NumericT alpha = 1;
// NumericT beta = 0;
// int32_t lda = AT ? K : M;
// int32_t ldb = BT ? N : K;
// int32_t ldc = M;
// cublasGemmAlgo_t fastest;
// cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
// &alpha, da, lda,
// db, ldb, &beta,
// dc, ldc, &fastest);
double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
&alpha, da, lda,
db, ldb, &beta,
dc, ldc, nullptr, CUBLAS_GEMM_DEFAULT_TENSOR_OP); }, stream);
// double cublas_ns = triton::tools::bench([&]() { cublasGemm(HALF_TYPE, stream, AT, BT, M, N, K,
// &alpha, da, lda,
// db, ldb, &beta,
// dc, ldc, nullptr, CUBLAS_GEMM_DEFAULT_TENSOR_OP); }, stream);
// result
auto tflops = [&](double nanosec) { return dot.num_flops() / nanosec * 1e-3; };
perf_t result;
result.cublas = tflops(cublas_ns);
// result.cublas = tflops(cublas_ns);
result.triton = tflops(triton_ns);
// test
stream->read(dc, true, 0, hc);
std::vector<float> rc(hc.size());
dot.cpu_ref(rc, ha, hb);
for(size_t i = 0; i < M*N; i++)
if(!std::isnan(hc[i]) && std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-4){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE);
}
std::cout << "Pass!" << std::endl;
// clean-up
delete dc;
delete da;
@@ -99,8 +111,8 @@ int main() {
std::vector<config_t> configs = {
// {false, false, 8192, 512, 512},
// {false, true, 8192, 8192, 8192}
{false, true, 32768, 256, 256},
{false, true, 32768, 256, 512}
{false, true, 128, 128, 128},
// {false, true, 32768, 256, 512}
// {true, false, 8192, 512, 512},
// {true, true, 8192, 512, 512}
};

View File

@@ -49,7 +49,7 @@ class DotOp : public OpKernel {
triton::driver::cu_buffer db(ctx, b.tensor_data().size(), (CUdeviceptr)b.tensor_data().data(), false);
triton::driver::cu_buffer dc(ctx, c->tensor_data().size(), (CUdeviceptr)c->tensor_data().data(), false);
// template
triton::dnn::dot dot(M, N, K, false, false, "half", "half", 8, 8, 8);
triton::dnn::dot dot(M, N, K, false, true, "half", "half", 8, 8, 8);
dot.enqueue(stream, {&da, &db, &dc});
}

View File

@@ -23,7 +23,7 @@ def run_dot():
result = sess.run([c], feed_dict = {a: ha,
b: hb})[0]
# Test
hresult = np.dot(ha.T, hb.T).T
hresult = np.dot(ha.T, hb).T
dif = np.abs(result - hresult)
np.savetxt('dif.dat', dif, '%2.4f')
print(hresult)
@@ -131,6 +131,6 @@ def run_batchnorm():
print(np.max(np.abs(dg_t - dg_n)))
print(np.max(np.abs(db_t - db_n)))
#run_dot()
run_dot()
#run_shift()
run_batchnorm()
#run_batchnorm()

View File

@@ -73,11 +73,11 @@ public:
optimize_dot.run(module);
optimize_trans.run(module);
optimize_dce.run(module);
// ir::print(module, std::cout);
}
void target_dependent(ir::module &module) {
alignment_info.run(module);
// ir::print(module, std::cout);
// reassociate.run(module);
if(target_->is_gpu()){
shmem_info.run(module);

View File

@@ -33,8 +33,7 @@ void optimize_dot::run(ir::module &mod) {
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i: block->get_inst_list())
if(auto dot = dynamic_cast<ir::dot_inst*>(i))
if(dot->get_operand(1)->get_type()->get_tile_shapes()[1]->get_value() != 1){
if(auto dot = dynamic_cast<ir::dot_inst*>(i)){
builder.set_insert_point(i);
ir::value *A = dot->get_operand(0);
ir::value *B = dot->get_operand(1);

View File

@@ -135,8 +135,12 @@ void shared_tile::extract_constant(const indices_t &arg_idx, indices_t &non_cst_
Value* shared_tile::shared_offset(llvm::IRBuilder<> &builder, const shapes_t& shapes, indices_t idx) {
Value *result = builder.getInt32(0);
result = builder.CreateAdd(result, idx[0]);
for(size_t i = 1; i < idx.size(); i++)
result = builder.CreateAdd(result, builder.CreateMul(idx[i], builder.getInt32(shapes[i-1])));
Value *ld = builder.getInt32(shapes[0]);
for(size_t i = 1; i < idx.size(); i++) {
result = builder.CreateAdd(result, builder.CreateMul(idx[i], ld));
if(i < idx.size() - 1)
ld = builder.CreateMul(ld, builder.getInt32(shapes[i]));
}
return result;
}
@@ -854,10 +858,13 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
Value *&result = x.second;
indices_t write_idx = x.first;
write_idx.insert(write_idx.begin() + axis, lane);
// shared memory write pointer
Value *write_offset = shared_tile::shared_offset(builder, op_tile->get_shapes(), write_idx);
Value *write_ptr = builder.CreateGEP(base_ptr, write_offset);
// initialize shared memory
tgt_->add_barrier(module, builder);
builder.CreateStore(result, write_ptr);
// build result
for(unsigned i = depth/2; i > 0; i >>= 1){
@@ -993,15 +1000,14 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
{
shared_tile *TA = (shared_tile*)tmap_.at(A);
shared_tile *TB = (shared_tile*)tmap_.at(B);
if(params_->get_fragment(ins, 0) == tune::STRIDED_SCAN)
{
if(params_->get_fragment(ins, 0) == tune::STRIDED_SCAN) {
TA->set_vector_size(TC->axis(0).contiguous);
TB->set_vector_size(TC->axis(1).contiguous);
result->for_each([&](indices_t idx){
Value *res = TC->get_value(idx);
for(unsigned K = 0; K < NK; ++K){
indices_t a_idx = {idx[0], builder.getInt32(K)};
indices_t b_idx = {builder.getInt32(K), idx[1]};
indices_t a_idx = {idx[0], builder.getInt32(K), idx[2]};
indices_t b_idx = {builder.getInt32(K), idx[1], idx[2]};
if(AT)
std::swap(a_idx[0], a_idx[1]);
if(BT)
@@ -1013,13 +1019,11 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
if(b->getType() != c_ty)
b = builder.CreateFPCast(b, c_ty);
res = builder.CreateCall(f_mul_add, {a, b, res});
}
result->set_value(idx, res);
});
}
else
{
else {
TA->set_vector_size(4*pack_size_0_);
TB->set_vector_size(4*pack_size_1_);
TA->set_return_mode(true);

View File

@@ -42,8 +42,8 @@ unsigned shmem_allocation::is_ld_padded(ir::value *x) {
}
unsigned shmem_allocation::get_num_bytes(ir::value *x) {
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
if(auto *red = dynamic_cast<ir::reduce_inst*>(x)){
unsigned num_bytes = x->get_type()->get_scalar_ty()->get_primitive_size_in_bits() / 8;
size_t axis = red->get_axis();
ir::value *op = red->get_operand(0);
auto shapes = op->get_type()->get_tile_shapes();
@@ -54,6 +54,7 @@ unsigned shmem_allocation::get_num_bytes(ir::value *x) {
size_t depth = params_->get_param(op, "mts.d" + std::to_string(axis))->get_value();
return num_elements * num_bytes * depth;
}
unsigned num_bytes = x->get_type()->get_primitive_size_in_bits() / 8;
unsigned pad = is_ld_padded(x);
if(pad > 0){
unsigned ld = x->get_type()->get_tile_shapes()[0]->get_value();

View File

@@ -24,8 +24,7 @@ bool is_hmma(ir::value *v){
ir::type *b_ty = b->get_type();
// inputs have to be FP16
result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
// reduction has to be multiple of 4
result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
// reduction has to be multiple of 4: TODO
}
return result;
}
@@ -66,9 +65,10 @@ void tune::init_c_graph(ir::instruction *v) {
for(unsigned i = 0; i < in_shapes.size(); i++){
if(i == axis)
continue;
// std::cout << arg->get_name() << " " << v->get_name() << std::endl;
add_constraint({reduce, current++}, {arg, i});
}
// add_constraint({reduce, 0}, {arg, 0});
// add_constraint({reduce, 1}, {arg, 1});
return;
}
else
@@ -81,8 +81,10 @@ void tune::init_c_graph(ir::instruction *v) {
for(unsigned i = 0; i < shapes.size(); i ++){
bool is_one = shapes[i] == one;
bool is_same = shapes[i] == op->get_type()->get_tile_shapes()[current];
if(is_one)
if(is_one){
static_params_.insert({{v, i}, 1});
add_constraint({v, i}, {v, i});
}
else if(!is_skewed && is_same)
add_constraint({v, i}, {op, current++});
else{
@@ -114,9 +116,17 @@ void tune::init_c_graph(ir::instruction *v) {
}
// Matrix multiplication
else if(dynamic_cast<ir::dot_inst*>(v)){
ir::value *A = v->get_operand(0);
ir::value *B = v->get_operand(1);
ir::value *D = v->get_operand(2);
add_constraint({v, 0}, {D, 0});
add_constraint({v, 1}, {D, 1});
for(unsigned i = 0; i < shapes.size(); i++)
add_constraint({v, i}, {D, i});
for(unsigned i = 2; i < shapes.size(); i++){
if(shapes[i] == one)
static_params_.insert({{v, i}, 1});
add_constraint({v, i}, {A, i});
add_constraint({v, i}, {B, i});
}
}
// Element-wise
else if(dynamic_cast<ir::user*>(v)) {
@@ -242,7 +252,7 @@ void tune::run(ir::module &mod) {
node_t node = *nodes_.begin();
if(fragments_[node] == STRIDED_SCAN) {
ir::metaparameter *nts = ir::metaparameter::create(ctx, ty, 1, 1);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 2, 64);
ir::metaparameter *mts = ir::metaparameter::create(ctx, ty, 4, 32);
connected_components(node, {nts, mts}, {"nts", "mts"}, nodes_, dependencies_, group_id++);
nts->set_value(1);
}
@@ -266,14 +276,14 @@ void tune::run(ir::module &mod) {
size_t addr_space = ptr_ty->get_pointer_address_space();
if(addr_space < 4){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 8));
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 1, 1));
*params_.at(i).at("nts.d0") = *tmp;
}
}
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 8));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 8));
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 1, 1));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 1, 1));
*params_.at(i).at("nts.d0") = *tmp1;
*params_.at(i).at("nts.d1") = *tmp2;
}
@@ -365,6 +375,7 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
// check constraints
for(ir::instruction *i: grids_){
// std::cout << i->get_name() << std::endl;
ir::type *ty = i->get_type();
const auto &shapes = ty->get_tile_shapes();
// for each dimension, the product of layout components
@@ -396,11 +407,15 @@ bool tune::check_constraints(std::map<ir::value *, std::vector<std::string>> &er
errors[i].push_back("HMMA must have only 4 fragments per warp");
}
int num_threads = get_req_num_threads(i);
if(num_threads % 64 != 0)
if(num_threads % 32 != 0)
errors[i].push_back("number of threads per block (" + to_string(num_threads) + ") must be multiple of warp size");
if(num_threads != num_threads_)
errors[i].push_back("Number of threads must be the same for all tiles (" + to_string(num_threads_) + ")");
}
// for(auto x: errors)
// for(auto e: x.second)
// std::cout << x.first->get_name() << ": " << e << std::endl;
// exit(EXIT_SUCCESS);
return errors.empty();
}

View File

@@ -54,16 +54,17 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
return num_flops() / ts * 1e-3;
};
// auto-tune and save result
if(autotune != NO_TUNING) {
if(autotune == FULL_TUNING || autotune == PARTIAL_TUNING) {
std::vector<params_t> space = {};
if(autotune == PARTIAL_TUNING)
space = search_space();
rt::jit::tune_res_t best = jit->autotune(name_.c_str(), src.c_str(), benchmark, space);
jit->add_module(name_.c_str(), src.c_str(), best.params);
}
else {
params_t params = heuristics();
else{
// params_t params = heuristics();
// params_t params = jit->get_valid(name_.c_str(), src.c_str());
params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 1};
jit->add_module(name_.c_str(), src.c_str(), params);
}
triton::driver::kernel* kernel = jit->get_function(name_.c_str());

View File

@@ -74,12 +74,14 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
void dot::triton_c_src(std::ostream &os) const {
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string XAS0 = "TM", XAS1 = "TK/4", XAS2 = "4";
std::string XBS0 = "TN", XBS1 = "TK/4", XBS2 = "4";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = "";
std::string ldb0 = "", ldb1 = "*ldb";
std::string usea = AT_ ? "trans(a)" : "a";
std::string useb = BT_ ? "trans(b)" : "b";
std::string usea = AT_ ? "trans(xa)" : "xa";
std::string useb = BT_ ? "trans(xb)" : "xb";
if(AT_){
std::swap(AS0, AS1);
std::swap(bca0, bca1);
@@ -92,12 +94,15 @@ void dot::triton_c_src(std::ostream &os) const {
}
std::string AS = AS0 + ", " + AS1;
std::string BS = BS0 + ", " + BS1;
std::string XAS = XAS0 + ", " + XAS1 + ", " + XAS2;
std::string XBS = XBS0 + ", " + XBS1 + ", " + XBS2;
std::string XCS = "TM, TN, 4";
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
std::string res =
R"(
const tunable int TM = {16, 32, 64, 128};
const tunable int TN = {16, 32, 64, 128};
const tunable int TM = {32};
const tunable int TN = {32};
const tunable int TK = {32};
const tunable int GZ = {1};
@@ -113,7 +118,7 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
int ryb[TN] = ridy * TN + (0 ... TN);
int rka[TK] = 0 ... TK;
int rkb[TK] = 0 ... TK;
float c[TM, TN] = 0;
float xc[)" + XCS + R"(] = 0;
)" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
bool checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(;
@@ -121,7 +126,9 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
)" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0;
)" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0;
for(int k = K; k > 0; k = k - TK){
c = dot()" + usea + ", " + useb + R"(, c);
)" + a_ty_ + R"( xa[)" + XAS + "] = __reshape(a, " + XAS + R"();
)" + b_ty_ + R"( xb[)" + XBS + "] = __reshape(b, " + XBS + R"();
xc = dot()" + usea + ", " + useb + R"(, xc);
pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(;
bool checka[)" + AS + R"(] = k > TK;
@@ -131,11 +138,9 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
}
int rxc[TM] = ridx * TM + (0 ... TM);
int ryc[TN] = ridy * TN + (0 ... TN);
bool checkc0[TM] = rxc < M;
bool checkc1[TN] = ryc < N;
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
float* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
@checkc *pc = c;
float c[TM, TN] = __sum(xc, 2);
*pc = c;
}
)";

View File

@@ -148,13 +148,13 @@ DEFINE_UNARY_FLOAT(fneg)
value* builder::create_insert_nuwnswb_binop(binary_operator::op_t op, value *lhs,
value *rhs, const std::string &name,
bool has_nuw, bool has_nsw) {
if(auto *clhs = dynamic_cast<constant_int*>(lhs)){
if(auto *crhs = dynamic_cast<constant_int*>(rhs)){
constant_expression* result = constant_expression::create(op, clhs, crhs);
if (has_nuw) result->set_has_no_unsigned_wrap();
if (has_nsw) result->set_has_no_signed_wrap();
return result;
}
auto *clhs = dynamic_cast<constant_int*>(lhs);
auto *crhs = dynamic_cast<constant_int*>(rhs);
if(clhs && crhs){
constant_expression* result = constant_expression::create(op, clhs, crhs);
if (has_nuw) result->set_has_no_unsigned_wrap();
if (has_nsw) result->set_has_no_signed_wrap();
return result;
}
else {
binary_operator* result = insert(binary_operator::create(op, lhs, rhs), name);

View File

@@ -101,7 +101,6 @@ ir::value *binary_expression::llvm_op(ir::module *mod, ir::builder &builder, ir:
ir::value* binary_expression::codegen(ir::module *mod) const{
ir::value *lhs = lhs_->codegen(mod);
ir::value *rhs = rhs_->codegen(mod);
std::cout << " " << typeid(*lhs_).name() << " " << typeid(*rhs_).name() << std::endl;
ir::value *result = llvm_op(mod, mod->get_builder(), lhs, rhs, "");
return result;
}