[codegen/reassociation] now recursively takes pointer arguments into account as well
This commit is contained in:
@@ -130,7 +130,7 @@ public:
|
||||
// create profile
|
||||
triton::dnn::blocksparse::dot dot(N, params_.K, params_.segments, params_.C, "fp16", params_.bsize, params_.locks, params_.blocks, OP);
|
||||
// blocksparse matmul
|
||||
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::NO_TUNING);
|
||||
triton::dnn::base* op = dot.enqueue(stream, {&da, &db, &dc, &dlut}, triton::dnn::FULL_TUNING);
|
||||
triton::driver::buffer* locks_buffer = ((triton::dnn::blocksparse::dot*)op)->get_locks();
|
||||
Tensor *tmp = nullptr;
|
||||
TensorShape tmp_shapes;
|
||||
|
@@ -77,6 +77,7 @@ public:
|
||||
|
||||
void target_dependent(ir::module &module) {
|
||||
alignment_info.run(module);
|
||||
// ir::print(module, std::cout);
|
||||
reassociate.run(module);
|
||||
if(target_->is_gpu()){
|
||||
shmem_info.run(module);
|
||||
|
@@ -215,6 +215,21 @@ void reassociate::run(ir::module &mod) {
|
||||
infos[sta_ptr].dyn_ptr = (ir::getelementptr_inst*)dyn_ptr;
|
||||
infos[sta_ptr].sta_ptr = (ir::getelementptr_inst*)sta_ptr;
|
||||
}
|
||||
// reassociate pointer argument
|
||||
if(ir::getelementptr_inst* gepy = dynamic_cast<ir::getelementptr_inst*>(py))
|
||||
if(infos.find(gepy) != infos.end()){
|
||||
builder.set_insert_point(pz);
|
||||
ir::getelementptr_inst *sta = infos[gepy].sta_ptr;
|
||||
ir::getelementptr_inst *dyn = infos[gepy].dyn_ptr;
|
||||
ir::value *cst = *sta->idx_begin();
|
||||
ir::value *off = *pz->idx_begin();
|
||||
ir::value *new_dyn = builder.create_gep(dyn, {off});
|
||||
ir::value *new_pz = builder.create_gep(new_dyn, {cst}, pz->get_name());
|
||||
params_->copy(new_dyn, pz);
|
||||
params_->copy(new_pz, pz);
|
||||
align_->copy(new_pz, pz);
|
||||
pz->replace_all_uses_with(new_pz);
|
||||
}
|
||||
// reassociate phi-node pointer
|
||||
if(ir::phi_node* phi = dynamic_cast<ir::phi_node*>(py)){
|
||||
// only optimize the case where py = phi pa, pz for now
|
||||
|
@@ -228,7 +228,7 @@ void tune::run(ir::module &mod) {
|
||||
nts->set_value(1);
|
||||
}
|
||||
else {
|
||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
ir::metaparameter *fpw = ir::metaparameter::create(ctx, ty, 2, 2);
|
||||
ir::metaparameter *wpt = ir::metaparameter::create(ctx, ty, 1, 4);
|
||||
connected_components(node, {fpw, wpt}, {"fpw", "wpt"}, nodes_, dependencies_, group_id++);
|
||||
}
|
||||
@@ -247,14 +247,14 @@ void tune::run(ir::module &mod) {
|
||||
size_t addr_space = ptr_ty->get_pointer_address_space();
|
||||
if(addr_space < 4){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 2, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
*params_.at(i).at("nts.d0") = *tmp;
|
||||
}
|
||||
}
|
||||
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){
|
||||
ir::type *ty = mod.get_builder().get_int32_ty();
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 2, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 2, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp1(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
std::unique_ptr<ir::metaparameter> tmp2(ir::metaparameter::create(ctx, ty, 4, 8));
|
||||
*params_.at(i).at("nts.d0") = *tmp1;
|
||||
*params_.at(i).at("nts.d1") = *tmp2;
|
||||
}
|
||||
|
@@ -88,8 +88,8 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string bca1 = "newaxis, :";
|
||||
std::string bcb0 = (op_ == FPROP) ? ":, newaxis" : "newaxis, :";
|
||||
std::string bcb1 = (op_ == FPROP) ? "newaxis, :" : ":, newaxis";
|
||||
std::string ldb0 = (op_ == FPROP) ? "1" : "TK";
|
||||
std::string ldb1 = (op_ == FPROP) ? "TK" : "1" ;
|
||||
std::string ldb0 = (op_ == FPROP) ? "" : "*TK";
|
||||
std::string ldb1 = (op_ == FPROP) ? "*TK" : "" ;
|
||||
std::string result =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
@@ -110,7 +110,7 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
int1 checka[TM, TK] = (rxa < N)[:, newaxis];
|
||||
int32 offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda;
|
||||
int32 offb[)" + sizeb + "] = ryb[" + bcb0 + "]*" + ldb0 + " + rkb[" + bcb1 + "]*" + ldb1 + R"(;
|
||||
int32 offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
|
||||
int32 *header = lut + ridy * 4;
|
||||
int32 offset = *(header + 0);
|
||||
int32 K = *(header + 1);
|
||||
|
Reference in New Issue
Block a user