[test] re-added bounds checking in dot test
This commit is contained in:
@@ -106,10 +106,10 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
||||
opt.num_warps = {nwarp};
|
||||
}
|
||||
if(mode == BENCH) {
|
||||
opt.defines.push_back({"TM", {"128"}});
|
||||
opt.defines.push_back({"TN", {"128"}});
|
||||
opt.defines.push_back({"TM", {"64", "128"}});
|
||||
opt.defines.push_back({"TN", {"64", "128"}});
|
||||
opt.defines.push_back({"TK", {"8"}});
|
||||
opt.num_warps = {4};
|
||||
opt.num_warps = {2, 4, 8};
|
||||
}
|
||||
|
||||
// kernels
|
||||
|
@@ -26,8 +26,10 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
||||
c += USEA @ USEB;
|
||||
pa = pa + TK * STRIDE_AK;
|
||||
pb = pb + TK * STRIDE_BK;
|
||||
a = *pa;
|
||||
b = *pb;
|
||||
bool checka[SHAPE_A] = k > TK;
|
||||
bool checkb[SHAPE_B] = k > TK;
|
||||
a = checka ? *pa : 0;
|
||||
b = checkb ? *pb : 0;
|
||||
}
|
||||
// epilogue
|
||||
int rxc[TM] = ridx * TM + 0 ... TM;
|
||||
|
Reference in New Issue
Block a user