From 1bf0c8adeb80587624a80cf1c1c226d2ceaa686b Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Wed, 2 Oct 2019 15:08:32 -0400 Subject: [PATCH] [test] re-added bounds checking in dot test --- tests/common/dot.h | 6 +++--- tests/common/src/dot.h | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/common/dot.h b/tests/common/dot.h index f96ce17f2..599784570 100644 --- a/tests/common/dot.h +++ b/tests/common/dot.h @@ -106,10 +106,10 @@ bool triton_dot(drv::stream* stream, bool AT, bool BT, opt.num_warps = {nwarp}; } if(mode == BENCH) { - opt.defines.push_back({"TM", {"128"}}); - opt.defines.push_back({"TN", {"128"}}); + opt.defines.push_back({"TM", {"64", "128"}}); + opt.defines.push_back({"TN", {"64", "128"}}); opt.defines.push_back({"TK", {"8"}}); - opt.num_warps = {4}; + opt.num_warps = {2, 4, 8}; } // kernels diff --git a/tests/common/src/dot.h b/tests/common/src/dot.h index c360edbfe..dc71d86bb 100644 --- a/tests/common/src/dot.h +++ b/tests/common/src/dot.h @@ -26,8 +26,10 @@ void dot(TYPE * A, TYPE * B, TYPE * C, c += USEA @ USEB; pa = pa + TK * STRIDE_AK; pb = pb + TK * STRIDE_BK; - a = *pa; - b = *pb; + bool checka[SHAPE_A] = k > TK; + bool checkb[SHAPE_B] = k > TK; + a = checka ? *pa : 0; + b = checkb ? *pb : 0; } // epilogue int rxc[TM] = ridx * TM + 0 ... TM;