[codegen/reassociation] now recursively takes pointer arguments into account as well
This commit is contained in:
@@ -88,8 +88,8 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
std::string bca1 = "newaxis, :";
|
||||
std::string bcb0 = (op_ == FPROP) ? ":, newaxis" : "newaxis, :";
|
||||
std::string bcb1 = (op_ == FPROP) ? "newaxis, :" : ":, newaxis";
|
||||
std::string ldb0 = (op_ == FPROP) ? "1" : "TK";
|
||||
std::string ldb1 = (op_ == FPROP) ? "TK" : "1" ;
|
||||
std::string ldb0 = (op_ == FPROP) ? "" : "*TK";
|
||||
std::string ldb1 = (op_ == FPROP) ? "*TK" : "" ;
|
||||
std::string result =
|
||||
R"(
|
||||
const tunable int32 TM = {16, 32, 64, 128};
|
||||
@@ -110,7 +110,7 @@ void dot::triton_c_src(std::ostream &os) const {
|
||||
int32 rkb[TK] = 0 ... TK;
|
||||
int1 checka[TM, TK] = (rxa < N)[:, newaxis];
|
||||
int32 offa[)" + sizea + "] = rxa[" + bca0 + "] + rka[" + bca1 + R"(]*lda;
|
||||
int32 offb[)" + sizeb + "] = ryb[" + bcb0 + "]*" + ldb0 + " + rkb[" + bcb1 + "]*" + ldb1 + R"(;
|
||||
int32 offb[)" + sizeb + "] = ryb[" + bcb0 + "]" + ldb0 + " + rkb[" + bcb1 + "]" + ldb1 + R"(;
|
||||
int32 *header = lut + ridy * 4;
|
||||
int32 offset = *(header + 0);
|
||||
int32 K = *(header + 1);
|
||||
|
Reference in New Issue
Block a user