[triton-c] added implicit conversion to bool in while/for loops

This commit is contained in:
Philippe Tillet
2019-04-28 00:31:08 -04:00
parent af58b8bd81
commit 93f53501c6
3 changed files with 10 additions and 10 deletions

View File

@@ -10,7 +10,7 @@ R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TK = {8};
const tunable int32 GZ = {2};
const tunable int32 GZ = {1};
void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
int32 M, int32 N, int32 K,
@@ -57,7 +57,7 @@ void matmul(restrict read_only fp32 *A, restrict read_only fp32 *B, fp32 *C,
int32 ridy = get_range_id(1);
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
int32 *plock = locks + ridx + ridy*grid0;
while(__atomic_cas(plock, 0, 1) == 1);
while(__atomic_cas(plock, 0, 1));
int32 *pcount = plock + grid0*grid1;
int32 count = *pcount;
int32 countp1 = select(count == GZ - 1, 0, count + 1);
@@ -82,7 +82,7 @@ int main() {
triton::jit jit(context);
// matrix multiplication parameters
int32_t M = 256, N = 256, K = 2048;
int32_t M = 512, N = 512, K = 512;
std::vector<float> hc(M*N);
std::vector<float> rc(M*N);
std::vector<float> ha(M*K);
@@ -144,7 +144,7 @@ int main() {
// just-in-time compile source-code
std::vector<unsigned> params = {
16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 4
16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1
};
// jit.autotune("matmul",src, benchmark);
jit.add_module("matmul", src, params);