[codegen] some cleaning for batched matmul

This commit is contained in:
Philippe Tillet
2019-08-07 11:08:04 -07:00
parent 7b75b68edc
commit 392b55280d
11 changed files with 82 additions and 80 deletions

View File

@@ -40,7 +40,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
hb[i] = static_cast<NumericT>((double)rand()/RAND_MAX);
for(size_t i = 0; i < hc.size(); i++)
hc[i] = static_cast<NumericT>((double)0);
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*dt_nbytes);
triton::driver::buffer* dc = triton::driver::buffer::create(context, hc.size()*4);
triton::driver::buffer* da = triton::driver::buffer::create(context, ha.size()*dt_nbytes);
triton::driver::buffer* db = triton::driver::buffer::create(context, hb.size()*dt_nbytes);
stream->write(da, true, 0, ha);
@@ -49,7 +49,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
stream->synchronize();
triton::dnn::dot dot(M, N, K, AT, BT, ty, ty, ty, 8, 8, 8);
// benchmark triton
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::NO_TUNING);}, stream);
double triton_ns = triton::tools::bench([&]() { dot.enqueue(stream, {da, db, dc}, triton::dnn::PARTIAL_TUNING);}, stream);
// benchmark cublas
// NumericT alpha = 1;
// NumericT beta = 0;
@@ -77,7 +77,7 @@ perf_t do_bench(triton::driver::stream* stream, bool AT, bool BT, int32_t M, int
std::vector<NumericT> rc(hc.size());
dot.cpu_ref(rc, ha, hb);
for(size_t i = 0; i < M*N; i++)
if(std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
if(std::isinf(hc[i]) || std::isnan(hc[i]) || std::abs(hc[i] - rc[i])/std::max(hc[i], rc[i]) > 1e-2){
std::cout << i << " " << hc[i] << " " << rc[i] << std::endl;
exit(EXIT_FAILURE);
}
@@ -111,8 +111,8 @@ int main() {
// shapes to benchmark
std::vector<config_t> configs = {
// {false, false, 8192, 512, 512},
// {false, true, 8192, 8192, 8192}
{false, true, 128, 128, 128},
{false, true, 128, 128, 128}
// {false, true, 128, 128, 128},
// {false, false, 128, 128, 128},
// {true, false, 128, 128, 128},
// {true, true, 128, 128, 128}

View File

@@ -153,8 +153,8 @@ private:
alignment_info *axis_info_;
std::map<unsigned, distributed_axis> axes_;
llvm::Value *sh_mem_ptr_;
llvm::Value *offset_a_i_, *offset_a_k_, *offset_a_z_;
llvm::Value *offset_b_j_, *offset_b_k_, *offset_b_z_;
llvm::Value *offset_a_i_, *offset_a_k_;
llvm::Value *offset_b_j_, *offset_b_k_;
unsigned num_packs_0_, num_packs_1_;
unsigned pack_size_0_, pack_size_1_;
};

View File

@@ -71,6 +71,7 @@ public:
void target_independent(ir::module &module) {
optimize_dot.run(module);
optimize_dce.run(module);
optimize_trans.run(module);
optimize_dce.run(module);
}
@@ -86,7 +87,7 @@ public:
}
vectorize.run(module);
optimize_dce.run(module);
ir::print(module, std::cout);
// ir::print(module, std::cout);
}
codegen::tune tune;

View File

@@ -38,8 +38,8 @@ inline double bench(std::function<void()> const & op, driver::stream * stream)
while(total_time*1e-9 < 1e-3){
float norm = 1;
// normalize clock if possible to reduce noise in auto-tuning
// if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
// norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
if(auto cu_device = dynamic_cast<const triton::driver::cu_device*>(device))
norm = (float)cu_device->current_sm_clock()/cu_device->max_sm_clock();
tmr.start();
op();
stream->synchronize();

View File

@@ -30,8 +30,8 @@ inline bool is_hmma(ir::value *v){
ir::type *b_ty = b->get_type();
// inputs have to be FP16
result = a_ty->get_scalar_ty()->is_half_ty() && b_ty->get_scalar_ty()->is_half_ty();
// reduction has to be multiple of 4
result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
// reduction has to be multiple of 4
// result = result && ((a_ty->get_tile_shapes()[1]->get_value() % 4) == 0);
}
return result;
}
@@ -70,13 +70,13 @@ void optimize_dot::run(ir::module &mod) {
BB = ((ir::trans_inst*)B)->get_operand(0);
}
else{
if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
std::vector<ir::constant_int*> perm(T->get_perm());
std::swap(perm[0], perm[1]);
AA = builder.create_trans(T->get_operand(0), perm);
T->replace_all_uses_with(AA);
trans_a = true;
}
// if(auto *T = dynamic_cast<ir::trans_inst*>(A)){
// std::vector<ir::constant_int*> perm(T->get_perm());
// std::swap(perm[0], perm[1]);
// AA = builder.create_trans(T->get_operand(0), perm);
// T->replace_all_uses_with(AA);
// trans_a = true;
// }
}
ir::instruction *dot_atbt = builder.insert(ir::dot_inst::create(AA, BB, D, trans_a, trans_b));
dot->replace_all_uses_with(dot_atbt);

View File

@@ -516,6 +516,10 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
}
}
else {
if(shapes.size() > 3)
throw std::runtime_error("unsupported");
bool is_batched = shapes.size() >= 3;
Value *_1 = builder.getInt32(1);
Value *_2 = builder.getInt32(2);
Value *_3 = builder.getInt32(3);
@@ -526,23 +530,23 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
// fragments per warp
unsigned fpw_0 = params_->get_param(v, "fpw.d0")->get_value();
unsigned fpw_1 = params_->get_param(v, "fpw.d1")->get_value();
unsigned fpw_2 = params_->get_param(v, "fpw.d2")->get_value();
unsigned fpw_2 = is_batched ? params_->get_param(v, "fpw.d2")->get_value() : 1;
// warps per tile
unsigned wpt_0 = params_->get_param(v, "wpt.d0")->get_value();
unsigned wpt_1 = params_->get_param(v, "wpt.d1")->get_value();
unsigned wpt_2 = params_->get_param(v, "wpt.d2")->get_value();
unsigned wpt_2 = is_batched ? params_->get_param(v, "wpt.d2")->get_value() : 1;
// hmma warp tile size
unsigned hmma_wts_0 = fpw_0 * 8;
unsigned hmma_wts_1 = fpw_1 * 8;
unsigned hmma_wts_2 = 1;
unsigned hmma_wts_2 = is_batched ? fpw_2 : 1;
// hmma block tile size
unsigned hmma_bts_0 = hmma_wts_0 * wpt_0;
unsigned hmma_bts_1 = hmma_wts_1 * wpt_1;
unsigned hmma_bts_2 = hmma_wts_2 * wpt_2;
unsigned hmma_bts_2 = is_batched ? hmma_wts_2 * wpt_2 : 1;
// number of repetition
unsigned num_rep_0 = shapes[0]->get_value() / hmma_bts_0;
unsigned num_rep_1 = shapes[1]->get_value() / hmma_bts_1;
unsigned num_rep_2 = shapes[2]->get_value() / hmma_bts_2;
unsigned num_rep_2 = is_batched ? shapes[2]->get_value() / hmma_bts_2 : 1;
// size of each pack (interleaving)
pack_size_0_ = std::min<unsigned>(num_rep_0, 1);
pack_size_1_ = std::min<unsigned>(num_rep_1, 1);
@@ -579,19 +583,15 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
// a offset
offset_a_i_ = builder.CreateAdd(warp_offset_i, builder.CreateAdd(pair_a_off, in_pair_off_a));
offset_a_k_ = builder.CreateAnd(u_thread_id, _3);
offset_a_z_ = warp_id_2;
// b offsets
offset_b_j_ = builder.CreateAdd(warp_offset_j, builder.CreateAdd(pair_b_off, in_pair_off_b));
offset_b_k_ = builder.CreateAnd(u_thread_id, _3);
offset_b_z_ = warp_id_2;
// c offsets
Value *offset_c_i = builder.CreateAdd(builder.CreateAnd(u_thread_id, _1), offset_a_i_);
Value *offset_c_j = builder.CreateAdd(builder.CreateAnd(u_thread_id, _2),
builder.CreateAdd(warp_offset_j, pair_b_off));
/* indices */
// i indices
std::vector<Value*> idx_i;
@@ -617,7 +617,8 @@ void selection::init_axes(ir::value *v, IRBuilder<> &builder, Value *u_thread_id
/* axes */
axes_[params_->get_param_group(v, 0)] = distributed_axis{1, idx_i, warp_id_0};
axes_[params_->get_param_group(v, 1)] = distributed_axis{1, idx_j, warp_id_1};
axes_[params_->get_param_group(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
if(is_batched)
axes_[params_->get_param_group(v, 2)] = distributed_axis{1, idx_z, warp_id_2};
}
}
@@ -1062,10 +1063,10 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
std::map<Value*, std::vector<Value*>> fcs;
result->for_each([&](indices_t idx){
fcs[idx[2]].push_back(TC->get_value(idx));
// fc.push_back(UndefValue::get(TC->get_value(idx)->getType()));
fcs[{builder.getInt32(0)}].push_back(TC->get_value(idx));
});
Type *fp32_ty = builder.getFloatTy();
Type *fp16x2_ty = VectorType::get(builder.getHalfTy(), 2);
Type *fp32_pack8_ty = StructType::get(ctx, {fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty, fp32_ty});
@@ -1121,8 +1122,8 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
std::swap(idx_a[0], idx_a[1]);
if(!dot->is_b_trans())
std::swap(idx_b[0], idx_b[1]);
idx_a.push_back(x.first);
idx_b.push_back(x.first);
// idx_a.push_back(builder.getInt32(0));
// idx_b.push_back(builder.getInt32(0));
Value *ha = TA->get_value(idx_a);
Value *hb = TB->get_value(idx_b);
for(unsigned ii = 0; ii < pack_size_0_; ii++)
@@ -1158,9 +1159,9 @@ void selection::lower_tile_instruction(ir::instruction *ins, llvm::IRBuilder<> &
// write back
unsigned i = 0;
result->for_each([&](indices_t idx){
if(i >= fcs.at(idx[2]).size())
if(i >= fcs.at({builder.getInt32(0)}).size())
i = 0;
result->set_value(idx, fcs.at(idx[2])[i++]);
result->set_value(idx, fcs.at({builder.getInt32(0)})[i++]);
});
TA->set_return_mode(false);

View File

@@ -257,56 +257,54 @@ void tune::run(ir::module &mod) {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
if(node.second == 2)
fpw->set_value(1);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 2, 4);
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
}
}
}
// Simplify metaparameters
for(ir::function *fn: mod.get_function_list())
for(ir::basic_block *block: fn->blocks())
for(ir::instruction *i : block->get_inst_list()){
if(fragments_.find({i, 0}) != fragments_.end() && fragments_.at({i, 0}) != STRIDED_SCAN)
if(!i->get_type()->is_tile_ty())
continue;
auto shapes = i->get_type()->get_tile_shapes();
if(auto *x = dynamic_cast<ir::load_inst*>(i))
if(i->get_type()->is_tile_ty()){
if(auto *x = dynamic_cast<ir::load_inst*>(i)){
ir::type *ptr_ty = x->get_pointer_operand()->get_type()->get_scalar_ty();
size_t addr_space = ptr_ty->get_pointer_address_space();
if(addr_space < 4){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 4));
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
*params_.at(i).at("nts.d0") = *tmp;
}
}
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 4));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 4));
*params_.at(i).at("nts.d0") = *tmp1;
*params_.at(i).at("nts.d1") = *tmp2;
// std::unique_ptr<ir::metaparameter> mts_2(ir::metaparameter::create(ctx, ty, 1, 4));
// *params_.at(i->get_operand(0)).at("mts.d2") = *mts_2;
// *params_.at(i->get_operand(1)).at("mts.d2") = *mts_2;
if(fragments_.at({i, 0}) == STRIDED_SCAN){
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 4, 8));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 4, 8));
*params_.at(i).at("nts.d0") = *tmp1;
*params_.at(i).at("nts.d1") = *tmp2;
// for(size_t k = 2; k < shapes.size(); k++)
// if(auto *x = dynamic_cast<ir::metaparameter*>(shapes[k]))
// *params_.at(i).at("mts.d" + std::to_string(k)) = *x;
// else
// params_.at(i).at("mts.d" + std::to_string(k))->set_value(shapes[k]->get_value());
}
else{
// for(size_t k = 2; k < shapes.size(); k++)
// if(auto *x = dynamic_cast<ir::metaparameter*>(shapes[k]))
// *params_.at(i).at("wpt.d" + std::to_string(k)) = *x;
// else
// params_.at(i).at("wpt.d" + std::to_string(k))->set_value(shapes[k]->get_value());
}
}
}
// initialize grids
// for(ir::instruction *i: grids_){
// auto shapes = i->get_type()->get_tile_shapes();
// for(size_t k = 0; k < shapes.size(); k++)
// if(shapes[k]->get_value() == 1) {
// if(fragments_.at({i, k}) == STRIDED_SCAN){
// params_.at(i).at("nts.d" + std::to_string(k))->set_value(1);
// params_.at(i).at("mts.d" + std::to_string(k))->set_value(1);
// }
// if(fragments_.at({i, k}) == HMMA_FRAGMENT_C){
// params_.at(i).at("fpw.d" + std::to_string(k))->set_value(1);
// params_.at(i).at("wpt.d" + std::to_string(k))->set_value(1);
// }
// }
// }
}
void tune::init(ir::module &mod) {

View File

@@ -62,11 +62,7 @@ std::pair<base*, rt::jit*> base::get_profile_impl(driver::stream *stream, std::v
jit->add_module(name_.c_str(), src.c_str(), best.params);
}
else{
// params_t params = heuristics();
// params_t params = jit->get_valid(name_.c_str(), src.c_str());
// params_t params = {4, 1, 32, 4, 1, 32, 4, 4, 4, 1, 1, 16, 32, 16, 4, 4, 4, 4, 1}; //NT
// params_t params = {4, 1, 32, 4, 32, 4, 4, 4, 1, 1, 16, 32, 16, 1, 4, 4, 4, 4, 4, 1}; //NN
params_t params = {4, 2, 16, 4, 2, 16, 2, 2, 1, 1, 2, 16, 32, 16, 4, 4, 4, 4, 1}; // TT
params_t params = heuristics();
jit->add_module(name_.c_str(), src.c_str(), params);
}
triton::driver::kernel* kernel = jit->get_function(name_.c_str());

View File

@@ -72,11 +72,11 @@ void dot::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
}
void dot::triton_c_src(std::ostream &os) const {
std::string ZS = "4";
std::string ZS = "1";
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
std::string XAS0 = "TM", XAS1 = "TK / " + ZS, XAS2 = ZS;
std::string XBS0 = "TK / " + ZS, XBS1 = ZS, XBS2 = "TN";
std::string XAS0 = "TM", XAS1 = "TK", XAS2 = ZS;
std::string XBS0 = "TK", XBS1 = ZS, XBS2 = "TN";
std::string bca0 = "[newaxis, :]", bca1 = "[:, newaxis]";
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
std::string lda0 = "*lda", lda1 = "";
@@ -131,9 +131,9 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
bool checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(;
)" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0;
)" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0;
)" + a_ty_ + R"( xa[)" + XAS + "] = __reshape(a, " + XAS + R"();
)" + b_ty_ + R"( xb[)" + XBS + "] = __reshape(b, " + XBS + R"();
for(int k = K; k > 0; k = k - TK){
)" + a_ty_ + R"( xa[)" + XAS + "] = __reshape(a, " + XAS + R"();
)" + b_ty_ + R"( xb[)" + XBS + "] = __reshape(b, " + XBS + R"();
xc = dot()" + usea + ", " + useb + R"(, xc);
pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(;
@@ -141,12 +141,17 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
bool checkb[)" + BS + R"(] = k > TK;
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
xa = __reshape(a, )" + XAS + R"();
xb = __reshape(b, )" + XBS + R"();
}
int rxc[TM] = ridx * TM + (0 ... TM);
int ryc[TN] = ridy * TN + (0 ... TN);
)" + c_ty_ + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
)" + c_ty_ + R"( c[TM, TN] = __sum(xc, 2);
*pc = c;
bool checkc0[TM] = rxc < M;
bool checkc1[TN] = ryc < N;
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
@checkc *pc = c;
}
)";

View File

@@ -255,6 +255,7 @@ std::string cu_module::compile_llvm_module(llvm::Module* module) {
cu_module::cu_module(driver::context * context, llvm::Module* ll_module): cu_module(context, compile_llvm_module(ll_module)) { }
cu_module::cu_module(driver::context * context, std::string const & source) : module(context, CUmodule(), true), source_(source){
// std::cout << source << std::endl;
cu_context::context_switcher ctx_switch(*context);
// JIT compile source-code
CUjit_option opt[] = {CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, CU_JIT_ERROR_LOG_BUFFER};

View File

@@ -37,13 +37,13 @@ void parallel_loop_nest(std::vector<size_t> const & ranges,
size_t D = ranges.size();
std::vector<size_t> values(D, 0);
// thread pools
// ThreadPool pool(nthreads);
ThreadPool pool(nthreads);
// Start with innermost loop
size_t i = D - 1;
while(true){
// Execute function
// pool.enqueue(f,values);
f(values);
pool.enqueue(f,values);
// f(values);
while(values[i]++ == ranges[i] - 1){
if(i == 0)
return;
@@ -174,9 +174,9 @@ jit::tune_res_t jit::autotune(const char *name, const char *src, benchmark_t ben
std::lock_guard<std::mutex> lock(mutex);
for(ir::metaparameter *mp: mps)
mp->set_value(params[i++]);
for(size_t i = 0; i < params.size(); i++)
std::cout << ((i==0)?"":", ") << params[i] << std::flush;
std::cout << std::endl;
// for(size_t i = 0; i < params.size(); i++)
// std::cout << ((i==0)?"":", ") << params[i] << std::flush;
// std::cout << std::endl;
passes_0.tune.init(tt_module_0);
passes_0.tune.check_constraints(errors);
// for(auto x: errors)