[codegen/reassociation] now recursively takes pointer arguments into account as well

This commit is contained in:
Philippe Tillet
2019-07-31 18:41:56 -07:00
parent f7bd976fc7
commit 3b92ddf7e6
5 changed files with 24 additions and 8 deletions

View File

@@ -130,7 +130,7 @@ public:
// create profile
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp16", params_.bsize, params_.locks, params_.blocks, OP);
// blocksparse matmul
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::FULL_TUNING);
triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks();
Tensor *tmp = nullptr;
TensorShape tmp_shapes;

View File

@@ -77,6 +77,7 @@ public:
void target_dependent(ir::module &module) {
alignment_info.run(module);
// ir::print(module, std::cout);
reassociate.run(module);
if(target_->is_gpu()){
shmem_info.run(module);

View File

@@ -215,6 +215,21 @@ void reassociate::run(ir::module &mod) {
infos[sta_ptr].dyn_ptr = (ir::getelementptr_inst*)dyn_ptr;
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
}
// reassociate pointer argument
if(ir::getelementptr_inst* gepy = dynamic_cast<ir::getelementptr_inst*>(py))
if(infos.find(gepy) != infos.end()){
builder.set_insert_point(pz);
ir::getelementptr_inst *sta = infos[gepy].sta_ptr;
ir::getelementptr_inst *dyn = infos[gepy].dyn_ptr;
ir::value *cst = *sta->idx_begin();
ir::value *off = *pz->idx_begin();
ir::value *new_dyn = builder.create_gep(dyn, {off});
ir::value *new_pz = builder.create_gep(new_dyn, {cst}, pz->get_name());
params_->copy(new_dyn, pz);
params_->copy(new_pz, pz);
align_->copy(new_pz, pz);
pz->replace_all_uses_with(new_pz);
}
// reassociate phi-node pointer
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(py)){
// only optimize the case where py = phi pa, pz for now

View File

@@ -228,7 +228,7 @@ void tune::run(ir::module &mod) {
nts->set_value(1);
}
else {
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4);
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
}
@@ -247,14 +247,14 @@ void tune::run(ir::module &mod) {
size_t addr_space = ptr_ty->get_pointer_address_space();
if(addr_space < 4){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 8));
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
*params_.at(i).at("nts.d0") = *tmp;
}
}
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 8));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 8));
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 4, 8));
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 4, 8));
*params_.at(i).at("nts.d0") = *tmp1;
*params_.at(i).at("nts.d1") = *tmp2;
}

View File

@@ -88,8 +88,8 @@ void dot::triton_c_src(std::ostream &os) const {
std::string bca1 = "newaxis, :";
std::string bcb0 = (op_ == FPROP) ? ":, newaxis" : "newaxis, :";
std::string bcb1 = (op_ == FPROP) ? "newaxis, :" : ":, newaxis";
std::string ldb0 = (op_ == FPROP) ? "1" : "TK";
std::string ldb1 = (op_ == FPROP) ? "TK" : "1" ;
std::string ldb0 = (op_ == FPROP) ? "" : "*TK";
std::string ldb1 = (op_ == FPROP) ? "*TK" : "" ;
std::string result =
R"(
const tunable int32 TM = {16, 32, 64, 128};
@@ -110,7 +110,7 @@ void dot::triton_c_src(std::ostream &os) const {
int32 rkb[TK] = 0 ... TK;
int1 checka[TM, TK] = (rxa < N)[:, newaxis];
int32 offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda;
int32 offb[)" + sizeb + "] = ryb[" + bcb0 + "]*" + ldb0 + " + rkb[" + bcb1 + "]*" + ldb1 + R"(;
int32 offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
int32 *header = lut + ridy * 4;
int32 offset = *(header + 0);
int32 K = *(header + 1);