[codegen/reassociation] now recursively takes pointer arguments into account as well

This commit is contained in:
Philippe Tillet
2019-07-31 18:41:56 -07:00
parent f7bd976fc7
commit 3b92ddf7e6
5 changed files with 24 additions and 8 deletions

View File

@@ -88,8 +88,8 @@ void dot::triton_c_src(std::ostream &os) const {
std::string bca1 = "newaxis, :";
std::string bcb0 = (op_ == FPROP) ? ":, newaxis" : "newaxis, :";
std::string bcb1 = (op_ == FPROP) ? "newaxis, :" : ":, newaxis";
std::string ldb0 = (op_ == FPROP) ? "1" : "TK";
std::string ldb1 = (op_ == FPROP) ? "TK" : "1" ;
std::string ldb0 = (op_ == FPROP) ? "" : "*TK";
std::string ldb1 = (op_ == FPROP) ? "*TK" : "" ;
std::string result =
R"(
const tunable int32 TM = {16, 32, 64, 128};
@@ -110,7 +110,7 @@ void dot::triton_c_src(std::ostream &os) const {
int32 rkb[TK] = 0 ... TK;
int1 checka[TM, TK] = (rxa < N)[:, newaxis];
int32 offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda;
int32 offb[)" + sizeb + "] = ryb[" + bcb0 + "]*" + ldb0 + " + rkb[" + bcb1 + "]*" + ldb1 + R"(;
int32 offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
int32 *header = lut + ridy * 4;
int32 offset = *(header + 0);
int32 K = *(header + 1);