[test] re-added bounds checking in dot test

This commit is contained in:
Philippe Tillet
2019-10-02 15:08:32 -04:00
parent adbc56d10a
commit 1bf0c8adeb
2 changed files with 7 additions and 5 deletions

View File

@@ -106,10 +106,10 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
opt.num_warps = {nwarp}; opt.num_warps = {nwarp};
} }
if(mode == BENCH) { if(mode == BENCH) {
opt.defines.push_back({"TM", {"128"}}); opt.defines.push_back({"TM", {"64", "128"}});
opt.defines.push_back({"TN", {"128"}}); opt.defines.push_back({"TN", {"64", "128"}});
opt.defines.push_back({"TK", {"8"}}); opt.defines.push_back({"TK", {"8"}});
opt.num_warps = {4}; opt.num_warps = {2, 4, 8};
} }
// kernels // kernels

View File

@@ -26,8 +26,10 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
c += USEA @ USEB; c += USEA @ USEB;
pa = pa + TK * STRIDE_AK; pa = pa + TK * STRIDE_AK;
pb = pb + TK * STRIDE_BK; pb = pb + TK * STRIDE_BK;
a = *pa; bool checka[SHAPE_A] = k > TK;
b = *pb; bool checkb[SHAPE_B] = k > TK;
a = checka ? *pa : 0;
b = checkb ? *pb : 0;
} }
// epilogue // epilogue
int rxc[TM] = ridx * TM + 0 ... TM; int rxc[TM] = ridx * TM + 0 ... TM;