[dnn/gemm] added some bounds checking

This commit is contained in:
Philippe Tillet
2019-07-19 21:32:55 -07:00
parent 5215fb0424
commit 28c250216c
5 changed files with 36 additions and 37 deletions

View File

@@ -8,15 +8,15 @@
int main() {
bool AT = true;
bool BT = false;
bool AT = false;
bool BT = true;
typedef float T;
std::string ty = "fp16";
size_t dt_nbytes = sizeof(T);
// initialize default compute device
auto context = triton::driver::backend::contexts::get_default();
// matrix multiplication parameters
int32_t M = 65536, N = 2048, K = 2048;
int32_t M = 4096, N = 4096, K = 4096;
std::vector<T> hc(M*N);
std::vector<T> rc(M*N);
std::vector<T> ha(M*K);

View File

@@ -31,9 +31,6 @@ public:
// clone
base* clone() const;
// default params
std::vector<unsigned> default_params();
// CPU reference implementation
template<class T, bool AT, bool BT>
static void cpu_ref(std::vector<T> &c, const std::vector<T> &a, const std::vector<T> &b,

View File

@@ -235,7 +235,7 @@ void tune::run(ir::module &mod) {
continue;
if(dynamic_cast<ir::load_inst*>(i) && i->get_type()->is_tile_ty()){
ir::type *ty = mod.get_builder().get_int32_ty();
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 8, 8));
std::unique_ptr<ir::metaparameter> tmp(ir::metaparameter::create(ctx, ty, 4, 8));
*params_.at(i).at("nts.d0") = *tmp;
}
if(dynamic_cast<ir::dot_inst*>(i) && i->get_type()->is_tile_ty()){

View File

@@ -51,6 +51,7 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
driver::buffer *a = args[0], *b = args[1], *c = args[2];
unsigned TM = info.globals.at("TM");
unsigned TN = info.globals.at("TN");
unsigned TK = info.globals.at("TK");
unsigned grid_0 = (M_ + TM - 1)/TM;
unsigned grid_1 = (N_ + TN - 1)/TN;
unsigned grid_2 = 1;
@@ -67,23 +68,13 @@ void gemm::enqueue_impl(driver::stream *stream, driver::kernel *kernel,
kernel->setArg(6, lda);
kernel->setArg(7, ldb);
kernel->setArg(8, ldc);
kernel->setArg(9, locks_);
kernel->setArg(10, grid_0);
kernel->setArg(11, grid_1);
kernel->setArg(9, TK);
kernel->setArg(10, locks_);
kernel->setArg(11, grid_0);
kernel->setArg(12, grid_1);
stream->enqueue(kernel, grid, {info.num_threads, 1, 1});
}
std::vector<unsigned> gemm::default_params() {
if(AT_ && BT_)
return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1};
else if(AT_ && !BT_)
return {32, 64, 32, 64, 16, 8, 2, 2, 4, 2, 8, 4, 2, 1};
else if(!AT_ && BT_)
return {16, 2, 64, 16, 2, 64, 16, 8, 2, 2, 8, 8, 8, 1};
else
return {16, 2, 128, 32, 32, 32, 4, 2, 2, 8, 8, 4, 2, 1};
}
void gemm::triton_c_src(std::ostream &os) const {
std::string AS0 = "TM", AS1 = "TK";
std::string BS0 = "TK", BS1 = "TN";
@@ -103,12 +94,14 @@ void gemm::triton_c_src(std::ostream &os) const {
std::swap(bcb0, bcb1);
std::swap(ldb0, ldb1);
}
std::string AS = AS0 + ", " + AS1;
std::string BS = BS0 + ", " + BS1;
std::string align_lda_str = "multiple_of(" + std::to_string(align_lda_) + ")";
std::string align_ldb_str = "multiple_of(" + std::to_string(align_ldb_) + ")";
std::string res =
R"(
const tunable int32 TM = {16, 32, 64, 128};
const tunable int32 TN = {16, 32, 64, 128};
const tunable int32 TM = {32, 64, 128, 256};
const tunable int32 TN = {32, 64, 128, 256};
const tunable int32 TK = {32};
const tunable int32 GZ = {1};
@@ -117,27 +110,36 @@ void matmul(restrict read_only align(16) )" + a_ty_ + R"( *A,
fp32 *C,
int32 M, int32 N, int32 K,
)" + align_lda_str + R"( int32 lda, )" + align_ldb_str + R"(" int32 ldb, int32 ldc,
int32 *locks, int32 grid0, int32 grid1) {
int32 rxa[TM] = get_global_range[TM](0);
int32 ryb[TN] = get_global_range[TN](1);
int32 bound, int32 *locks, int32 grid0, int32 grid1) {
int32 ridx = get_range_id(0);
int32 ridy = get_range_id(1);
int32 rxa[TM] = ridx*TM + (0 ... TM);
int32 ryb[TN] = ridy*TN + (0 ... TN);
int32 rka[TK] = 0 ... TK;
int32 rkb[TK] = 0 ... TK;
fp32 c[TM, TN] = 0;
)" + a_ty_ + R"(* pa[)" + AS0 + ", " + AS1 + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty_ + R"(* pb[)" + BS0 + ", " + BS1 + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
)" + a_ty_ + R"( a[)" + AS0 + ", " + AS1 + R"(] = *pa;
)" + b_ty_ + R"( b[)" + BS0 + ", " + BS1 + R"(] = *pb;
for(int32 k = K; k > TK; k = k - TK){
)" + a_ty_ + R"(* pa[)" + AS + "] = A + rka" + bca0 + lda0 + " + rxa" + bca1 + lda1 + R"(;
)" + b_ty_ + R"(* pb[)" + BS + "] = B + rkb" + bcb0 + ldb0 + " + ryb" + bcb1 + ldb1 + R"(;
int1 checka[)" + AS + R"(] = (rka < K))" + bca0 + " && (rxa < M)" + bca1 + R"(;
int1 checkb[)" + BS + R"(] = (rkb < K))" + bcb0 + " && (ryb < N)" + bcb1 + R"(;
)" + a_ty_ + R"( a[)" + AS + R"(] = checka ? *pa : 0;
)" + b_ty_ + R"( b[)" + BS + R"(] = checkb ? *pb : 0;
for(int32 k = K; k > 0; k = k - TK){
c = dot()" + usea + ", " + useb + R"(, c);
pa = pa + TK)" + lda0 + R"(;
pb = pb + TK)" + ldb0 + R"(;
a = *pa;
b = *pb;
int1 checka[)" + AS + R"(] = k > bound;
int1 checkb[)" + BS + R"(] = k > bound;
@checka a = *pa;
@checkb b = *pb;
}
int32 rxc[TM] = get_global_range[TM](0);
int32 ryc[TN] = get_global_range[TN](1);
int32 rxc[TM] = ridx*TM + (0 ... TM);
int32 ryc[TN] = ridy*TN + (0 ... TN);
int1 checkc0[TM] = rxc < M;
int1 checkc1[TN] = ryc < N;
int1 checkc[TM, TN] = checkc0[:, newaxis] && checkc1[newaxis, :];
fp32* pc[TM, TN] = C + ryc[newaxis, :]*ldc + rxc[:, newaxis];
*pc = c;
@checkc *pc = c;
}
)";
os << res;

View File

@@ -49,7 +49,7 @@ void loop_nest(std::vector<size_t> const & ranges,
values[i--] = 0;
}
i = D - 1;
// Small sleep so that the thread pool doesn't grow too big
// Short sleep so that the thread pool doesn't grow too big
std::this_thread::sleep_for(std::chrono::microseconds(1));
}
}