[test] re-added bounds checking in dot test
This commit is contained in:
@@ -106,10 +106,10 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT,
|
|||||||
opt.num_warps = {nwarp};
|
opt.num_warps = {nwarp};
|
||||||
}
|
}
|
||||||
if(mode == BENCH) {
|
if(mode == BENCH) {
|
||||||
opt.defines.push_back({"TM", {"128"}});
|
opt.defines.push_back({"TM", {"64", "128"}});
|
||||||
opt.defines.push_back({"TN", {"128"}});
|
opt.defines.push_back({"TN", {"64", "128"}});
|
||||||
opt.defines.push_back({"TK", {"8"}});
|
opt.defines.push_back({"TK", {"8"}});
|
||||||
opt.num_warps = {4};
|
opt.num_warps = {2, 4, 8};
|
||||||
}
|
}
|
||||||
|
|
||||||
// kernels
|
// kernels
|
||||||
|
@@ -26,8 +26,10 @@ void dot(TYPE * A, TYPE * B, TYPE * C,
|
|||||||
c += USEA @ USEB;
|
c += USEA @ USEB;
|
||||||
pa = pa + TK * STRIDE_AK;
|
pa = pa + TK * STRIDE_AK;
|
||||||
pb = pb + TK * STRIDE_BK;
|
pb = pb + TK * STRIDE_BK;
|
||||||
a = *pa;
|
bool checka[SHAPE_A] = k > TK;
|
||||||
b = *pb;
|
bool checkb[SHAPE_B] = k > TK;
|
||||||
|
a = checka ? *pa : 0;
|
||||||
|
b = checkb ? *pb : 0;
|
||||||
}
|
}
|
||||||
// epilogue
|
// epilogue
|
||||||
int rxc[TM] = ridx * TM + 0 ... TM;
|
int rxc[TM] = ridx * TM + 0 ... TM;
|
||||||
|
Reference in New Issue
Block a user