basic parsing doesn't throw error
This commit is contained in:
@@ -55,8 +55,8 @@ std::string src(bool AT, bool BT, std::string a_ty, std::string b_ty, std::strin
|
||||
std::string bcb0 = "[:, newaxis]", bcb1 = "[newaxis, :]";
|
||||
std::string lda0 = "*lda", lda1 = "";
|
||||
std::string ldb0 = "", ldb1 = "*ldb";
|
||||
std::string usea = AT ? "trans(a)" : "a";
|
||||
std::string useb = BT ? "trans(b)" : "b";
|
||||
std::string usea = AT ? "^a" : "a";
|
||||
std::string useb = BT ? "^b" : "b";
|
||||
if(AT){
|
||||
std::swap(AS0, AS1);
|
||||
std::swap(XAS0, XAS1);
|
||||
@@ -82,6 +82,11 @@ R"(
|
||||
#define TN 128
|
||||
#define TK 32
|
||||
|
||||
#define bool _Bool
|
||||
#define true 1
|
||||
#define false 0
|
||||
#define __bool_true_false_are_defined 1
|
||||
|
||||
extern int get_program_id(int);
|
||||
|
||||
void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
|
||||
@@ -94,28 +99,28 @@ void matmul(restrict )" + a_ty + R"( * A __attribute__((readonly, aligned(16))),
|
||||
int ridx = get_program_id(0);
|
||||
int ridy = get_program_id(1);
|
||||
int rxa[{TM, TN}] = ridx * TM + 0 ... TM;
|
||||
int ryb[TN] = ridy * TN + 0 ... TN;
|
||||
int rka[TK] = 0 ... TK;
|
||||
int rkb[TK] = 0 ... TK;
|
||||
float xc[)" + XCS + R"(] = 0;
|
||||
)" + a_ty + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
)" + a_ty + R"( a[)" + AS + R"(] = *pa;
|
||||
)" + b_ty + R"( b[)" + BS + R"(] = *pb;
|
||||
int ryb[{TN}] = ridy * TN + 0 ... TN;
|
||||
int rka[{TK}] = 0 ... TK;
|
||||
int rkb[{TK}] = 0 ... TK;
|
||||
float xc[{)" + XCS + R"(}] = 0;
|
||||
)" + a_ty + R"(* pa[{)" + AS + "}] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
|
||||
)" + b_ty + R"(* pb[{)" + BS + "}] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
|
||||
)" + a_ty + R"( a[{)" + AS + R"(}] = *pa;
|
||||
)" + b_ty + R"( b[{)" + BS + R"(}] = *pb;
|
||||
for(int k = K; k > 0; k = k - TK){
|
||||
xc = dot()" + usea + ", " + useb + R"(, xc);
|
||||
xc = )" + usea + " @ " + useb + R"( + xc;
|
||||
pa = pa + TK)" + lda0 + R"(;
|
||||
pb = pb + TK)" + ldb0 + R"(;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
}
|
||||
int rxc[TM] = ridx * TM + (0 ... TM);
|
||||
int ryc[TN] = ridy * TN + (0 ... TN);
|
||||
)" + c_ty + R"(* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
)" + c_ty + R"( c[TM, TN] = xc;
|
||||
bool checkc0[TM] = rxc < M;
|
||||
bool checkc1[TN] = ryc < N;
|
||||
bool checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
int rxc[{TM}] = ridx * TM + (0 ... TM);
|
||||
int ryc[{TN}] = ridy * TN + (0 ... TN);
|
||||
)" + c_ty + R"(* pc[{TM, TN}] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
|
||||
)" + c_ty + R"( c[{TM, TN}] = xc;
|
||||
bool checkc0[{TM}] = rxc < M;
|
||||
bool checkc1[{TN}] = ryc < N;
|
||||
bool checkc[{TM, TN}] = checkc0[:, newaxis] && checkc1[newaxis, :];
|
||||
*pc = c;
|
||||
}
|
||||
)";
|
||||
|
Reference in New Issue
Block a user